From 08aab14f4ff0d6f4694ba19f55208119c4990669 Mon Sep 17 00:00:00 2001 From: Ran Tian Date: Tue, 14 May 2024 15:57:36 -0700 Subject: [PATCH] Internal. PiperOrigin-RevId: 633733338 --- init2winit/hyperparameters.py | 1 + init2winit/optimizer_lib/optimizers.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index 081bb2e9..119ba9de 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -152,6 +152,7 @@ def build_hparams(model_name, merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1) merged = config_dict.ConfigDict(merged_dict) + merged['eval_use_ema'] = 'False' merged.lock() # Subconfig "opt_hparams" and "lr_hparams" are allowed to add new fields. diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index c98b4d44..44571c6b 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -15,6 +15,7 @@ """Getter function for selecting optimizers.""" +import functools from absl import logging import flax from init2winit.model_lib.model_utils import ParameterType # pylint: disable=g-importing-member @@ -27,6 +28,7 @@ from init2winit.optimizer_lib import utils from init2winit.optimizer_lib.hessian_free import CGIterationTrackingMethod from init2winit.optimizer_lib.hessian_free import hessian_free +from init2winit.optimizer_lib.kitchen_sink._src.transform import compute_params_ema_for_eval import jax import optax