From 4815d0b6885183ba7dc0bf370c64e3ad5796eb69 Mon Sep 17 00:00:00 2001 From: init2winit Team Date: Tue, 11 Jun 2024 16:35:36 -0700 Subject: [PATCH] Adding schedule free with adam and nadam PiperOrigin-RevId: 642432327 --- init2winit/optimizer_lib/optimizers.py | 38 ++++++++++++++++++++++++++ init2winit/trainer_lib/base_trainer.py | 10 ++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 5cad993f..afbb6e1c 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -394,6 +394,44 @@ def get_optimizer(hps, model=None, batch_axis_name=None): eps=hps.opt_hparams['eps'], weight_decay=hps.opt_hparams['weight_decay'], ) + elif hps.optimizer == 'schedule_free_adam': + base_opt = utils.static_inject_hyperparams(optax.adamw)( + learning_rate=0.0, + b1=0., + b2=hps.opt_hparams['beta2'], + eps=hps.opt_hparams['epsilon'], + weight_decay=hps.opt_hparams['weight_decay'], + ) + opt_init, opt_update = utils.static_inject_hyperparams( + optax.contrib.schedule_free + )( + learning_rate=0.0, + base_optimizer=base_opt, + b1=hps.opt_hparams['beta1'], + weight_lr_power=hps.opt_hparams['weight_lr_power'], + ) + elif hps.optimizer == 'schedule_free_nadam': + base_opt = utils.static_inject_hyperparams(kitchen_sink.nadamw)( + learning_rate=0.0, + b1=0.0, + b2=hps.opt_hparams['beta2'], + eps=hps.opt_hparams['epsilon'], + eps_root=hps.opt_hparams.get('epsilon_root', 0.0), + debias=hps.opt_hparams.get('debias', True), + weight_decay=weight_decay, + # NOTE(dsuo): we provide this wiring, but specifying a weight decay + # mask in a config file / serializing properly is not completely + # straightforward. + weight_decay_mask=hps.opt_hparams.get('weight_decay_mask', None), + ) + opt_init, opt_update = utils.static_inject_hyperparams( + optax.contrib.schedule_free + )( + learning_rate=0.0, + base_optimizer=base_opt, + b1=hps.opt_hparams['beta1'], + weight_lr_power=hps.opt_hparams['weight_lr_power'], + ) elif hps.optimizer == 'kitchen_sink': opt_init, opt_update = utils.static_inject_hyperparams( diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 088b8030..83d6638f 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -549,7 +549,15 @@ def _eval( ) if self._eval_use_ema: - if isinstance( + if ( + schedule_free_state := optax.tree_utils.tree_get( + self._optimizer_state, 'ScheduleFreeState' + ) + ) is not None: + eval_params = optax.contrib.schedule_free_eval_params( + schedule_free_state, self._params + ) + elif isinstance( self._optimizer_state, optax.InjectStatefulHyperparamsState ): eval_params = self._optimizer_state.inner_state[0][0].ema