From 99956e1871a4e764164787be7d113f9e464b8bbc Mon Sep 17 00:00:00 2001 From: Ran Tian Date: Tue, 14 May 2024 15:57:36 -0700 Subject: [PATCH] Proposal to make 'eval_use_ema' a hyper-parameter visible to optimizers. PiperOrigin-RevId: 633733339 --- init2winit/hyperparameters.py | 1 + init2winit/optimizer_lib/optimizers.py | 1 + 2 files changed, 2 insertions(+) diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index 081bb2e9..119ba9de 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -152,6 +152,7 @@ def build_hparams(model_name, merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1) merged = config_dict.ConfigDict(merged_dict) + merged['eval_use_ema'] = 'False' merged.lock() # Subconfig "opt_hparams" and "lr_hparams" are allowed to add new fields. diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index c98b4d44..fd0efb04 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -27,6 +27,7 @@ from init2winit.optimizer_lib import utils from init2winit.optimizer_lib.hessian_free import CGIterationTrackingMethod from init2winit.optimizer_lib.hessian_free import hessian_free +from init2winit.optimizer_lib.kitchen_sink._src.transform import compute_params_ema_for_eval import jax import optax