@@ -91,27 +91,32 @@ def flags_as_args():
9191
9292def machine_config (num_gpus = 1 , use_tpu = False , master_type = None ):
9393 """Return dict specifying machine config for trainingInput."""
94- scale_tier = 'BASIC_GPU'
9594 if use_tpu :
96- scale_tier = 'BASIC_TPU '
95+ master_type = 'standard_tpu '
9796 elif num_gpus <= 0 :
98- scale_tier = 'BASIC'
99- elif num_gpus > 1 :
100- scale_tier = 'CUSTOM'
101-
102- config = {'scaleTier' : scale_tier }
103-
104- if scale_tier == 'CUSTOM' :
105- assert num_gpus > 1
106- if num_gpus not in [4 , 8 ]:
97+ master_type = master_type or 'standard'
98+ cpu_types = ['standard' , 'large_model' , 'complex_model_s' ,
99+ 'complex_model_m' , 'complex_model_l' ]
100+ if master_type not in cpu_types :
101+ raise ValueError ('Expected `cloudml_engine_master_type` to be one of %s '
102+ 'when `worker_gpu` <= 0, found %s.' , str (cpu_types ),
103+ master_type )
104+ elif num_gpus >= 1 :
105+ if num_gpus == 1 :
106+ if master_type != 'standard_gpu' :
107+ master_type = 'standard_p100'
108+ elif num_gpus == 4 :
109+ if master_type != 'complex_model_m_gpu' :
110+ master_type = 'complex_model_m_p100'
111+ elif num_gpus == 8 :
112+ master_type = 'complex_model_l_gpu'
113+ else :
107114 raise ValueError ('Must use exactly 1, 4, or 8 GPUs.' )
108- config ['masterType' ] = ('complex_model_m_gpu'
109- if num_gpus == 4 else 'complex_model_l_gpu' )
110-
111- if master_type :
112- config ['masterType' ] = master_type
113-
114- return config
115+ assert master_type
116+ return {
117+ 'scaleTier' : 'CUSTOM' ,
118+ 'masterType' : master_type
119+ }
115120
116121
117122def configure_job ():
@@ -145,9 +150,6 @@ def configure_job():
145150 FLAGS .autotune_parallel_trials ,
146151 )
147152
148- if training_input ['scaleTier' ] == 'CUSTOM' :
149- assert 'masterType' in training_input
150-
151153 timestamp = datetime .datetime .now ().strftime ('%Y%m%d_%H%M%S' )
152154 job_name = '%s_%s_t2t_%s' % (FLAGS .model , FLAGS .problems , timestamp )
153155 job_spec = {'jobId' : job_name , 'trainingInput' : training_input }
0 commit comments