@@ -195,7 +195,9 @@ def train_valid(self):
195195 ckpt_dir = self .config ['training' ]['ckpt_save_dir' ]
196196 if (ckpt_dir [- 1 ] == "/" ):
197197 ckpt_dir = ckpt_dir [:- 1 ]
198- ckpt_prefx = ckpt_dir .split ('/' )[- 1 ]
198+ ckpt_prefix = self .config ['training' ].get ('ckpt_prefix' , None )
199+ if (ckpt_prefix is None ):
200+ ckpt_prefix = ckpt_dir .split ('/' )[- 1 ]
199201 iter_start = self .config ['training' ]['iter_start' ]
200202 iter_max = self .config ['training' ]['iter_max' ]
201203 iter_valid = self .config ['training' ]['iter_valid' ]
@@ -206,7 +208,7 @@ def train_valid(self):
206208 self .best_model_wts = None
207209 self .checkpoint = None
208210 if (iter_start > 0 ):
209- checkpoint_file = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefx , iter_start )
211+ checkpoint_file = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , iter_start )
210212 self .checkpoint = torch .load (checkpoint_file , map_location = self .device )
211213 assert (self .checkpoint ['iteration' ] == iter_start )
212214 self .net .load_state_dict (self .checkpoint ['model_state_dict' ])
@@ -237,9 +239,9 @@ def train_valid(self):
237239 'valid_pred' : valid_scalars [metrics ],
238240 'model_state_dict' : self .net .state_dict (),
239241 'optimizer_state_dict' : self .optimizer .state_dict ()}
240- save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefx , glob_it )
242+ save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , glob_it )
241243 torch .save (save_dict , save_name )
242- txt_file = open ("{0:}/{1:}_latest.txt" .format (ckpt_dir , ckpt_prefx ), 'wt' )
244+ txt_file = open ("{0:}/{1:}_latest.txt" .format (ckpt_dir , ckpt_prefix ), 'wt' )
243245 txt_file .write (str (glob_it ))
244246 txt_file .close ()
245247
@@ -248,9 +250,9 @@ def train_valid(self):
248250 'valid_pred' : self .max_val_score ,
249251 'model_state_dict' : self .best_model_wts ,
250252 'optimizer_state_dict' : self .optimizer .state_dict ()}
251- save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefx , self .max_val_it )
253+ save_name = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , self .max_val_it )
252254 torch .save (save_dict , save_name )
253- txt_file = open ("{0:}/{1:}_best.txt" .format (ckpt_dir , ckpt_prefx ), 'wt' )
255+ txt_file = open ("{0:}/{1:}_best.txt" .format (ckpt_dir , ckpt_prefix ), 'wt' )
254256 txt_file .write (str (self .max_val_it ))
255257 txt_file .close ()
256258 logging .info ('The best perfroming iter is {0:}, valid {1:} {2:}' .format (\
0 commit comments