2626import subprocess
2727from warnings import warn
2828
29- from ..base import CommandLine , isdefined , CommandLineInputSpec , traits
29+ from ..base import CommandLine , CommandLineInputSpec , traits , Undefined
3030from ...utils .filemanip import split_filename
3131
3232
@@ -47,8 +47,9 @@ def no_nifty_package(cmd='reg_f3d'):
4747class NiftyRegCommandInputSpec (CommandLineInputSpec ):
4848 """Input Spec for niftyreg interfaces."""
4949 # Set the number of omp thread to use
50- omp_core_val = traits .Int (desc = 'Number of openmp thread to use' ,
51- argstr = '-omp %i' )
50+ omp_core_val = traits .Int (int (os .environ .get ('OMP_NUM_THREADS' , '1' )),
51+ desc = 'Number of openmp thread to use' ,
52+ argstr = '-omp %i' , usedefault = True )
5253
5354
5455class NiftyRegCommand (CommandLine ):
@@ -58,7 +59,10 @@ class NiftyRegCommand(CommandLine):
5859 _suffix = '_nr'
5960 _min_version = '1.5.30'
6061
62+ input_spec = NiftyRegCommandInputSpec
63+
6164 def __init__ (self , required_version = None , ** inputs ):
65+ self .num_threads = 1
6266 super (NiftyRegCommand , self ).__init__ (** inputs )
6367 self .required_version = required_version
6468 _version = self .get_version ()
@@ -73,6 +77,29 @@ def __init__(self, required_version=None, **inputs):
7377 msg = 'The version of NiftyReg differs from the required'
7478 msg += '(%s != %s)'
7579 warn (msg % (_version , self .required_version ))
80+ self .inputs .on_trait_change (self ._omp_update , 'omp_core_val' )
81+ self .inputs .on_trait_change (self ._environ_update , 'environ' )
82+ self ._omp_update ()
83+
84+ def _omp_update (self ):
85+ if self .inputs .omp_core_val :
86+ self .inputs .environ ['OMP_NUM_THREADS' ] = \
87+ str (self .inputs .omp_core_val )
88+ self .num_threads = self .inputs .omp_core_val
89+ else :
90+ if 'OMP_NUM_THREADS' in self .inputs .environ :
91+ del self .inputs .environ ['OMP_NUM_THREADS' ]
92+ self .num_threads = 1
93+
94+ def _environ_update (self ):
95+ if self .inputs .environ :
96+ if 'OMP_NUM_THREADS' in self .inputs .environ :
97+ self .inputs .omp_core_val = \
98+ int (self .inputs .environ ['OMP_NUM_THREADS' ])
99+ else :
100+ self .inputs .omp_core_val = Undefined
101+ else :
102+ self .inputs .omp_core_val = Undefined
76103
77104 def check_version (self ):
78105 _version = self .get_version ()
@@ -102,13 +129,6 @@ def version(self):
102129 def exists (self ):
103130 return self .get_version () is not None
104131
105- def _run_interface (self , runtime ):
106- # Update num threads estimate from OMP_NUM_THREADS env var
107- # Default to 1 if not set
108- if not isdefined (self .inputs .environ ['OMP_NUM_THREADS' ]):
109- self .inputs .environ ['OMP_NUM_THREADS' ] = self .num_threads
110- return super (NiftyRegCommand , self )._run_interface (runtime )
111-
112132 def _format_arg (self , name , spec , value ):
113133 if name == 'omp_core_val' :
114134 self .numthreads = value
0 commit comments