@@ -38,20 +38,28 @@ def get_optimizer(name, net_params, optim_params):
3838
3939def get_lr_scheduler (optimizer , sched_params ):
4040 name = sched_params ["lr_scheduler" ]
41+ val_it = sched_params ["iter_valid" ]
4142 if (name is None ):
4243 return None
43- lr_gamma = sched_params ["lr_gamma" ]
4444 if (keyword_match (name , "ReduceLROnPlateau" )):
4545 patience_it = sched_params ["ReduceLROnPlateau_patience" .lower ()]
46- val_it = sched_params [ "iter_valid" ]
47- patience = patience_it / val_it
46+ patience = patience_it / val_it
47+ lr_gamma = sched_params [ "lr_gamma" ]
4848 scheduler = lr_scheduler .ReduceLROnPlateau (optimizer ,
4949 mode = "max" , factor = lr_gamma , patience = patience )
5050 elif (keyword_match (name , "MultiStepLR" )):
5151 lr_milestones = sched_params ["lr_milestones" ]
52- last_iter = sched_params ["last_iter" ]
52+ lr_milestones = [int (item / val_it ) for item in lr_milestones ]
53+ epoch_last = sched_params ["last_iter" ] / val_it
54+ lr_gamma = sched_params ["lr_gamma" ]
5355 scheduler = lr_scheduler .MultiStepLR (optimizer ,
54- lr_milestones , lr_gamma , last_iter )
56+ lr_milestones , lr_gamma , epoch_last )
57+ elif (keyword_match (name , "CosineAnnealingLR" )):
58+ epoch_max = sched_params ["iter_max" ] / val_it
59+ epoch_last = sched_params ["last_iter" ] / val_it
60+ lr_min = sched_params .get ("lr_min" , 0 )
61+ scheduler = lr_scheduler .CosineAnnealingLR (optimizer ,
62+ epoch_max , lr_min , epoch_last )
5563 else :
5664 raise ValueError ("unsupported lr scheduler {0:}" .format (name ))
5765 return scheduler
0 commit comments