@@ -10,28 +10,30 @@ def get_optimizer(name, net_params, optim_params):
1010 lr = optim_params ['learning_rate' ]
1111 momentum = optim_params ['momentum' ]
1212 weight_decay = optim_params ['weight_decay' ]
13+ # see https://www.codeleading.com/article/44815584159/
14+ param_group = [{'params' : net_params , 'initial_lr' : lr }]
1315 if (keyword_match (name , "SGD" )):
14- return optim .SGD (net_params , lr ,
16+ return optim .SGD (param_group , lr ,
1517 momentum = momentum , weight_decay = weight_decay )
1618 elif (keyword_match (name , "Adam" )):
17- return optim .Adam (net_params , lr , weight_decay = weight_decay )
19+ return optim .Adam (param_group , lr , weight_decay = weight_decay )
1820 elif (keyword_match (name , "SparseAdam" )):
19- return optim .SparseAdam (net_params , lr )
21+ return optim .SparseAdam (param_group , lr )
2022 elif (keyword_match (name , "Adadelta" )):
21- return optim .Adadelta (net_params , lr , weight_decay = weight_decay )
23+ return optim .Adadelta (param_group , lr , weight_decay = weight_decay )
2224 elif (keyword_match (name , "Adagrad" )):
23- return optim .Adagrad (net_params , lr , weight_decay = weight_decay )
25+ return optim .Adagrad (param_group , lr , weight_decay = weight_decay )
2426 elif (keyword_match (name , "Adamax" )):
25- return optim .Adamax (net_params , lr , weight_decay = weight_decay )
27+ return optim .Adamax (param_group , lr , weight_decay = weight_decay )
2628 elif (keyword_match (name , "ASGD" )):
27- return optim .ASGD (net_params , lr , weight_decay = weight_decay )
29+ return optim .ASGD (param_group , lr , weight_decay = weight_decay )
2830 elif (keyword_match (name , "LBFGS" )):
29- return optim .LBFGS (net_params , lr )
31+ return optim .LBFGS (param_group , lr )
3032 elif (keyword_match (name , "RMSprop" )):
31- return optim .RMSprop (net_params , lr , momentum = momentum ,
33+ return optim .RMSprop (param_group , lr , momentum = momentum ,
3234 weight_decay = weight_decay )
3335 elif (keyword_match (name , "Rprop" )):
34- return optim .Rprop (net_params , lr )
36+ return optim .Rprop (param_group , lr )
3537 else :
3638 raise ValueError ("unsupported optimizer {0:}" .format (name ))
3739
@@ -57,7 +59,7 @@ def get_lr_scheduler(optimizer, sched_params):
5759 elif (keyword_match (name , "CosineAnnealingLR" )):
5860 epoch_max = sched_params ["iter_max" ] / val_it
5961 epoch_last = sched_params ["last_iter" ] / val_it
60- lr_min = sched_params .get ("lr_min" , 0 )
62+ lr_min = sched_params .get ("lr_min" , 0 )
6163 scheduler = lr_scheduler .CosineAnnealingLR (optimizer ,
6264 epoch_max , lr_min , epoch_last )
6365 else :
0 commit comments