diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index 081bb2e9..119ba9de 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -152,6 +152,7 @@ def build_hparams(model_name, merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1) merged = config_dict.ConfigDict(merged_dict) + merged['eval_use_ema'] = 'False' merged.lock() # Subconfig "opt_hparams" and "lr_hparams" are allowed to add new fields. diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index c98b4d44..44571c6b 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -15,6 +15,7 @@ """Getter function for selecting optimizers.""" +import functools from absl import logging import flax from init2winit.model_lib.model_utils import ParameterType # pylint: disable=g-importing-member @@ -27,6 +28,7 @@ from init2winit.optimizer_lib import utils from init2winit.optimizer_lib.hessian_free import CGIterationTrackingMethod from init2winit.optimizer_lib.hessian_free import hessian_free +from init2winit.optimizer_lib.kitchen_sink._src.transform import compute_params_ema_for_eval import jax import optax