@@ -234,6 +234,8 @@ def create_optimizer(
234234 foreach : Optional [bool ] = None ,
235235 weight_decay_exclude_1d : bool = True ,
236236 layer_decay : Optional [float ] = None ,
237+ layer_decay_min_scale : Optional [float ] = None ,
238+ layer_decay_no_opt_scale : Optional [float ] = None ,
237239 param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
238240 ** kwargs : Any ,
239241 ) -> torch .optim .Optimizer :
@@ -248,6 +250,8 @@ def create_optimizer(
248250 foreach: Enable/disable foreach operation
249251 weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
250252 layer_decay: Layer-wise learning rate decay
253+ layer_scale_min_scale: Minimum layer scale factor clamp value
254+ layer_scale_no_opt_scale: Layer scale below which optimization is disabled
251255 param_group_fn: Optional custom parameter grouping function
252256 **kwargs: Additional optimizer-specific arguments
253257
@@ -273,6 +277,8 @@ def create_optimizer(
273277 layer_decay = layer_decay ,
274278 no_weight_decay_list = no_weight_decay ,
275279 weight_decay_exclude_1d = weight_decay_exclude_1d ,
280+ min_scale = layer_decay_min_scale ,
281+ no_opt_scale = layer_decay_no_opt_scale ,
276282 )
277283 weight_decay = 0.
278284 elif weight_decay and weight_decay_exclude_1d :
@@ -1140,6 +1146,8 @@ def create_optimizer_v2(
11401146 foreach : Optional [bool ] = None ,
11411147 filter_bias_and_bn : bool = True ,
11421148 layer_decay : Optional [float ] = None ,
1149+ layer_decay_min_scale : float = 0.0 ,
1150+ layer_decay_no_opt_scale : Optional [float ] = None ,
11431151 param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
11441152 ** kwargs : Any ,
11451153) -> torch .optim .Optimizer :
@@ -1215,31 +1223,36 @@ def create_optimizer_v2(
12151223 foreach = foreach ,
12161224 weight_decay_exclude_1d = filter_bias_and_bn ,
12171225 layer_decay = layer_decay ,
1226+ layer_decay_min_scale = layer_decay_min_scale ,
1227+ layer_decay_no_opt_scale = layer_decay_no_opt_scale ,
12181228 param_group_fn = param_group_fn ,
12191229 ** kwargs
12201230 )
12211231
12221232
12231233def optimizer_kwargs (cfg ):
1224- """ cfg/argparse to kwargs helper
1225- Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
1226- """
1227- kwargs = dict (
1228- opt = cfg .opt ,
1229- lr = cfg .lr ,
1230- weight_decay = cfg .weight_decay ,
1231- momentum = cfg .momentum ,
1232- )
1233- if getattr (cfg , 'opt_eps' , None ) is not None :
1234- kwargs ['eps' ] = cfg .opt_eps
1235- if getattr (cfg , 'opt_betas' , None ) is not None :
1236- kwargs ['betas' ] = cfg .opt_betas
1237- if getattr (cfg , 'layer_decay' , None ) is not None :
1238- kwargs ['layer_decay' ] = cfg .layer_decay
1239- if getattr (cfg , 'opt_args' , None ) is not None :
1240- kwargs .update (cfg .opt_args )
1241- if getattr (cfg , 'opt_foreach' , None ) is not None :
1242- kwargs ['foreach' ] = cfg .opt_foreach
1234+ """Convert argparse-style `cfg` object to kwargs for an optimizer factory."""
1235+ kwargs = {
1236+ 'opt' : cfg .opt ,
1237+ 'lr' : cfg .lr ,
1238+ 'weight_decay' : cfg .weight_decay ,
1239+ 'momentum' : cfg .momentum ,
1240+ }
1241+ if (eps := getattr (cfg , 'opt_eps' , None )) is not None :
1242+ kwargs ['eps' ] = eps
1243+ if (betas := getattr (cfg , 'opt_betas' , None )) is not None :
1244+ kwargs ['betas' ] = betas
1245+ if (layer_decay := getattr (cfg , 'layer_decay' , None )) is not None :
1246+ kwargs ['layer_decay' ] = layer_decay
1247+ if (ld_min := getattr (cfg , 'layer_decay_min_scale' , None )) is not None :
1248+ kwargs ['layer_decay_min_scale' ] = ld_min
1249+ if (ld_no_opt := getattr (cfg , 'layer_decay_no_opt_scale' , None )) is not None :
1250+ kwargs ['layer_decay_no_opt_scale' ] = ld_no_opt
1251+ if (opt_args := getattr (cfg , 'opt_args' , None )) is not None :
1252+ kwargs .update (opt_args )
1253+ if (foreach := getattr (cfg , 'opt_foreach' , None )) is not None :
1254+ kwargs ['foreach' ] = foreach
1255+
12431256 return kwargs
12441257
12451258
0 commit comments