From 7ea5bc88d6f9111daef9336fddde54c1846441c1 Mon Sep 17 00:00:00 2001 From: Sourabh Medapati Date: Tue, 22 Apr 2025 14:04:56 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 750320849 --- init2winit/hyperparameters.py | 8 +- init2winit/main.py | 98 ++++-- init2winit/trainer_lib/i2w_workload.py | 60 ++++ init2winit/trainer_lib/spec.py | 282 +++++++++++++++++ .../adamw_jax_paper_baseline.py | 298 ++++++++++++++++++ .../submissions_lib/submissions.py | 35 ++ init2winit/trainer_lib/trainer.py | 61 ++-- 7 files changed, 793 insertions(+), 49 deletions(-) create mode 100644 init2winit/trainer_lib/i2w_workload.py create mode 100644 init2winit/trainer_lib/spec.py create mode 100644 init2winit/trainer_lib/submissions_lib/adamw_jax_paper_baseline.py create mode 100644 init2winit/trainer_lib/submissions_lib/submissions.py diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index 081bb2e9..14a9ae83 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -102,7 +102,8 @@ def build_hparams(model_name, hparam_file, hparam_overrides, input_pipeline_hps=None, - allowed_unrecognized_hparams=None): + allowed_unrecognized_hparams=None, + algoperf_submission_name=None): """Build experiment hyperparameters. Args: @@ -121,6 +122,7 @@ def build_hparams(model_name, hparams from an error to a warning can be useful when trying to tune using a shared search space over multiple workloads that don't all support the same set of hyperparameters. + algoperf_submission_name: The name of the algoperf submission. Returns: A ConfigDict of experiment hyperparameters. @@ -163,6 +165,10 @@ def build_hparams(model_name, for key in ['opt_hparams', 'lr_hparams']: merged[key].unlock() + if algoperf_submission_name: + with merged.unlocked(): + merged['algoperf_submission_name'] = algoperf_submission_name + if hparam_file: logging.info('Loading hparams from %s', hparam_file) with gfile.GFile(hparam_file, 'r') as f: diff --git a/init2winit/main.py b/init2winit/main.py index f899896b..f5014446 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -57,66 +57,101 @@ flags.DEFINE_string('trainer', 'standard', 'Name of the trainer to use.') flags.DEFINE_string('model', 'fully_connected', 'Name of the model to train.') flags.DEFINE_string('loss', 'cross_entropy', 'Loss function.') -flags.DEFINE_string('metrics', 'classification_metrics', - 'Metrics to be used for evaluation.') +flags.DEFINE_string( + 'algoperf_submission_name', '', 'AlgoPerf submission module lookup name.' +) +flags.DEFINE_string( + 'metrics', 'classification_metrics', 'Metrics to be used for evaluation.' +) flags.DEFINE_string('initializer', 'noop', 'Must be in [noop, meta_init].') -flags.DEFINE_string('experiment_dir', None, - 'Path to save weights and other results. Each trial ' - 'directory will have path experiment_dir/worker_id/.') +flags.DEFINE_string( + 'experiment_dir', + None, + 'Path to save weights and other results. Each trial ' + 'directory will have path experiment_dir/worker_id/.', +) flags.DEFINE_string('dataset', 'mnist', 'Which dataset to train on.') flags.DEFINE_string('data_selector', 'noop', 'Which data selector to use.') flags.DEFINE_integer('num_train_steps', None, 'The number of steps to train.') flags.DEFINE_integer( - 'num_tf_data_prefetches', -1, 'The number of batches to to prefetch from ' - 'network to host at each step. Set to -1 for tf.data.AUTOTUNE.') + 'num_tf_data_prefetches', + -1, + 'The number of batches to to prefetch from ' + 'network to host at each step. Set to -1 for tf.data.AUTOTUNE.', +) flags.DEFINE_integer( - 'num_device_prefetches', 0, 'The number of batches to to prefetch from ' - 'host to device at each step.') + 'num_device_prefetches', + 0, + 'The number of batches to to prefetch from host to device at each step.', +) flags.DEFINE_integer( - 'num_tf_data_map_parallel_calls', -1, 'The number of parallel calls to ' - 'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.' + 'num_tf_data_map_parallel_calls', + -1, + 'The number of parallel calls to ' + 'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.', ) flags.DEFINE_integer('eval_batch_size', None, 'Batch size for evaluation.') flags.DEFINE_bool('eval_use_ema', None, 'If True evals will use ema of params.') flags.DEFINE_integer( - 'eval_num_batches', None, + 'eval_num_batches', + None, 'Number of batches for evaluation. Leave None to evaluate ' - 'on the entire validation and test set.') + 'on the entire validation and test set.', +) flags.DEFINE_integer( - 'test_num_batches', None, + 'test_num_batches', + None, 'Number of batches for eval on test set. Leave None to evaluate ' - 'on the entire test set.') -flags.DEFINE_integer('eval_train_num_batches', None, - 'Number of batches when evaluating on the training set.') + 'on the entire test set.', +) +flags.DEFINE_integer( + 'eval_train_num_batches', + None, + 'Number of batches when evaluating on the training set.', +) flags.DEFINE_integer('eval_frequency', 1000, 'Evaluate every k steps.') flags.DEFINE_string( - 'hparam_overrides', '', 'JSON representation of a flattened dict of hparam ' + 'hparam_overrides', + '', + 'JSON representation of a flattened dict of hparam ' 'overrides. For nested dictionaries, the override key ' - 'should be specified as lr_hparams.base_lr.') + 'should be specified as lr_hparams.base_lr.', +) flags.DEFINE_string( - 'callback_configs', '', 'JSON representation of a list of dictionaries ' - 'which specify general callbacks to be run during eval of training.') + 'callback_configs', + '', + 'JSON representation of a list of dictionaries ' + 'which specify general callbacks to be run during eval of training.', +) flags.DEFINE_list( - 'checkpoint_steps', [], 'List of steps to checkpoint the' + 'checkpoint_steps', + [], + 'List of steps to checkpoint the' ' model. The checkpoints will be saved in a separate' 'directory train_dir/checkpoints. Note these checkpoints' 'will be in addition to the normal checkpointing that' - 'occurs during training for preemption purposes.') -flags.DEFINE_string('external_checkpoint_path', None, - 'If this argument is set, the trainer will initialize' - 'the parameters, batch stats, optimizer state, and training' - 'metrics by loading them from the checkpoint at this path.') + 'occurs during training for preemption purposes.', +) +flags.DEFINE_string( + 'external_checkpoint_path', + None, + 'If this argument is set, the trainer will initialize' + 'the parameters, batch stats, optimizer state, and training' + 'metrics by loading them from the checkpoint at this path.', +) flags.DEFINE_string( 'early_stopping_target_name', None, 'A string naming the metric to use to perform early stopping. If this ' 'metric reaches the value `early_stopping_target_value`, training will ' - 'stop. Must include the dataset split (ex: validation/error_rate).') + 'stop. Must include the dataset split (ex: validation/error_rate).', +) flags.DEFINE_float( 'early_stopping_target_value', None, - 'A float indicating the value at which to stop training.') + 'A float indicating the value at which to stop training.', +) flags.DEFINE_enum( 'early_stopping_mode', None, @@ -198,6 +233,7 @@ def _run( initializer_name, model_name, loss_name, + algoperf_submission_name, metrics_name, num_train_steps, experiment_dir, @@ -225,7 +261,8 @@ def _run( hparam_file=hparam_file, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps, - allowed_unrecognized_hparams=allowed_unrecognized_hparams) + allowed_unrecognized_hparams=allowed_unrecognized_hparams, + algoperf_submission_name=algoperf_submission_name) # Note that one should never tune an RNG seed!!! The seed is only included in # the hparams for convenience of running hparam trials with multiple seeds per @@ -358,6 +395,7 @@ def main(unused_argv): initializer_name=FLAGS.initializer, model_name=FLAGS.model, loss_name=FLAGS.loss, + algoperf_submission_name=FLAGS.algoperf_submission_name, metrics_name=FLAGS.metrics, num_train_steps=FLAGS.num_train_steps, experiment_dir=experiment_dir, diff --git a/init2winit/trainer_lib/i2w_workload.py b/init2winit/trainer_lib/i2w_workload.py new file mode 100644 index 00000000..57b28e72 --- /dev/null +++ b/init2winit/trainer_lib/i2w_workload.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2024 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Init2winit workload.""" + +from init2winit.experiments.mlcommons.workloads import mlcommons_targets +from init2winit.experiments.mlcommons.workloads import mlcommons_workload_info +from init2winit.trainer_lib import spec + + +class Init2winitWorkload(spec.Workload): + """Init2winit workload.""" + + def initialize(self, model, hps): + self._model = model + self._hps = hps + + @property + def workload_name(self): + if not self._hps.workload_name: + self._workload_name = self._hps.dataset + '_' + self._hps.model + else: + self._workload_name = self._hps.workload_name + + return self._workload_name + + @property + def param_shapes(self): + return self._model.param_shapes + + @property + def model_params_types(self): + return self._model.param_types + + @property + def step_hint(self): + if self.workload_name not in mlcommons_workload_info.num_train_steps: + raise ValueError( + f'Workload {self.workload_name} not found in num_train_steps.') + return mlcommons_workload_info.num_train_steps[self.workload_name] + + @property + def target_metric_name(self): + if self.workload_name not in mlcommons_targets.validation_targets: + raise ValueError( + f'Workload {self.workload_name} not found in validation targets.') + + return mlcommons_targets.validation_targets[self.workload_name]['metric'] diff --git a/init2winit/trainer_lib/spec.py b/init2winit/trainer_lib/spec.py new file mode 100644 index 00000000..a2de365b --- /dev/null +++ b/init2winit/trainer_lib/spec.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2024 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLPerf™ Algorithmic Efficiency API.""" + +import abc +import enum +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + + +class LossType(enum.Enum): + SOFTMAX_CROSS_ENTROPY = 0 + SIGMOID_CROSS_ENTROPY = 1 + MEAN_SQUARED_ERROR = 2 + CTC_LOSS = 3 + MEAN_ABSOLUTE_ERROR = 4 + + +class ForwardPassMode(enum.Enum): + TRAIN = 0 + EVAL = 1 + + +class ParameterType(enum.Enum): + """Types of model parameters.""" + WEIGHT = 0 + BIAS = 1 + CONV_WEIGHT = 2 + BATCH_NORM_SCALE = 3 + BATCH_NORM_BIAS = 4 + LAYER_NORM_SCALE = 5 + LAYER_NORM_BIAS = 6 + EMBEDDING = 7 + ATTENTION_Q = 8 + ATTENTION_K = 9 + ATTENTION_V = 10 + ATTENTION_OUT = 11 + ATTENTION_QKV = 12 # This is used for implementations that fuse QKV together. + ATTENTION_KV = 13 # This is used for implementations that fuse KV together. + # We sometimes need to split this out because otherwise fused models will have + # a different number of biases. + ATTENTION_BIAS = 14 + + +# Of course, Tensor knows its shape and dtype. +# Tensor = Union[jnp.array, np.array, tf.Tensor, torch.Tensor, ...] +Tensor = Any + + +# Define this so that if using pytree iteration utilities, can iterate over the +# model shapes pytree without iterating over the shape tuples. +class ShapeTuple: + + def __init__(self, shape_tuple): + self.shape_tuple = shape_tuple + + def __repr__(self): + return f'ShapeTuple({self.shape_tuple})' + + def __eq__(self, other): + return self.shape_tuple == other.shape_tuple + + +Shape = Union[Tuple[int], + Tuple[int, int], + Tuple[int, int, int], + Tuple[int, int, int, int], + ShapeTuple] +ParameterShapeTree = Dict[str, Dict[str, Shape]] + +# If necessary, these can be zipped together easily given they have the same +# structure, to get an iterator over pairs of leaves. +ParameterKey = str +# Dicts can be arbitrarily nested. +ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]]] +ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]] + +RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...] + +OptimizerState = Union[Dict[str, Any], Tuple[Any, Any]] +Hyperparameters = Any +Timing = int +Steps = int + +# BN EMAs. +ModelAuxiliaryState = Any +ModelInitState = Tuple[ParameterContainer, ModelAuxiliaryState] + + +class Workload(metaclass=abc.ABCMeta): + """Base class for workloads.""" + + def __init__(self, *args, **kwargs) -> None: + del args + del kwargs + self._param_shapes: Optional[ParameterShapeTree] = None + self._param_types: Optional[ParameterTypeTree] = None + self._eval_iters: Dict[str, Iterator[Dict[str, Any]]] = {} + self.metrics_logger = None + + @property + @abc.abstractmethod + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + + @property + @abc.abstractmethod + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + + @property + def param_shapes(self): + """The shapes of the parameters in the workload model.""" + if self._param_shapes is None: + raise ValueError( + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_shapes!') + return self._param_shapes + + @property + def model_params_types(self): + """The types of the parameters in the workload model.""" + if self._param_types is None: + raise ValueError( + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_types!') + return self._param_types + + +class TrainingCompleteError(Exception): + pass + + +# Training algorithm track submission functions, to be filled in by the +# submitter. + +InitOptimizerFn = Callable[[ + Workload, + ParameterContainer, + ModelAuxiliaryState, + Hyperparameters, + RandomState +], OptimizerState] + + +# pylint: disable=unused-argument +def init_optimizer_state(workload: Workload, + model_params: ParameterContainer, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + rng: RandomState) -> OptimizerState: + # return initial_optimizer_state + pass + + +UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState] +UpdateParamsFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + Dict[str, Tensor], + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState, + Optional[Dict[str, Any]] +], + UpdateReturn] + + +# Each call to this function is considered a "step". +# Can raise a TrainingCompleteError if it believes it has achieved the goal and +# wants to end the run and receive a final free eval. It will not be restarted, +# and if has not actually achieved the goal then it will be considered as not +# achieved the goal and get an infinite time score. Most submissions will likely +# wait until the next free eval and not use this functionality. +def update_params(workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + batch: Dict[str, Tensor], + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + pass + + +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], UpdateReturn] + + +# Prepare model and optimizer for evaluation. +def prepare_for_eval(workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState) -> UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + pass + + +DataSelectionFn = Callable[[ + Workload, + Iterator[Dict[str, Any]], + OptimizerState, + ParameterContainer, + LossType, + Hyperparameters, + int, + RandomState +], Tuple[Tensor, Tensor]] + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: Workload, + input_queue: Iterator[Dict[str, Any]], + optimizer_state: OptimizerState, + current_param_container: ParameterContainer, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + global_step: int, + rng: RandomState) -> Dict[str, Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + + Args: + workload: The workload being trained. + input_queue: An iterator over the training data. + optimizer_state: The current optimizer state. + current_param_container: The current model parameters. + model_state: The current model state (e.g., for batch norm). + hyperparameters: The current hyperparameters. + global_step: The current training step. + rng: The current random number generator state. + + Returns: + A batch of training data. + """ + # return next(input_queue) + pass + + +def get_batch_size(workload_name: str) -> int: + """Return the global batch size to use for a given workload.""" + pass diff --git a/init2winit/trainer_lib/submissions_lib/adamw_jax_paper_baseline.py b/init2winit/trainer_lib/submissions_lib/adamw_jax_paper_baseline.py new file mode 100644 index 00000000..a84c165a --- /dev/null +++ b/init2winit/trainer_lib/submissions_lib/adamw_jax_paper_baseline.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Copyright 2024 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from init2winit.trainer_lib import spec +import jax +import jax.numpy as jnp +# pylint: disable=g-importing-member +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +import optax + + +_GRAD_CLIP_EPS = 1e-6 + + +def map_i2w_hparams_to_algoperf_hparams(merged_hps): + submission_hps = merged_hps + submission_hps.unlock() + + submission_hps.learning_rate = merged_hps.lr_hparams.base_lr + submission_hps.warmup_factor = merged_hps.lr_hparams.warmup_steps_fraction + submission_hps.weight_decay = merged_hps.opt_hparams.weight_decay + submission_hps.one_minus_beta1 = 1.0 - merged_hps.opt_hparams.beta1 + submission_hps.beta2 = merged_hps.opt_hparams.beta2 + + submission_hps.lock() + return submission_hps + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates an AdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = optax.adamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_init_fn, opt_update_fn + + +# pylint: disable=missing-function-docstring +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + + # Compute local loss and gradients + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + # Set up mesh and sharding + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + ), + out_shardings=( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ), + ) + # print(batch) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + ) + ) + + # Log loss, grad_norm. + if global_step % 1 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + """Return the global batch size.""" + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + elif workload_name == 'cifar': + return 32 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +# pylint: disable=g-doc-args +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + Returns: + A batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch + diff --git a/init2winit/trainer_lib/submissions_lib/submissions.py b/init2winit/trainer_lib/submissions_lib/submissions.py new file mode 100644 index 00000000..ac840de5 --- /dev/null +++ b/init2winit/trainer_lib/submissions_lib/submissions.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2024 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Submissions library for the init2winit project.""" + +import collections +from init2winit.trainer_lib.submissions_lib import adamw_jax_paper_baseline + +_Submission = collections.namedtuple( + 'Submission', ( + 'map_i2w_hparams_to_algoperf_hparams', + 'init_optimizer_state')) + + +def get_submission_module(submission_name: str) -> _Submission: + """Returns the submission module for the given submission name.""" + if submission_name == 'adamw_jax_paper_baseline': + return _Submission( + adamw_jax_paper_baseline.map_i2w_hparams_to_algoperf_hparams, + adamw_jax_paper_baseline.init_optimizer_state, + ) + else: + raise ValueError(f'Unknown submission name: {submission_name}') diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index 716f30b5..1ebd2232 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -22,7 +22,9 @@ from init2winit.model_lib import model_utils from init2winit.optimizer_lib import optimizers from init2winit.trainer_lib import base_trainer +from init2winit.trainer_lib import i2w_workload from init2winit.trainer_lib import trainer_utils +from init2winit.trainer_lib.submissions_lib import submissions import jax import jax.numpy as jnp import optax @@ -137,24 +139,47 @@ def __init__(self, *args, **kwargs): self._update_jitted = None def init_optimizer_state(self, model, params, batch_stats, hps, rng): - del batch_stats - del rng - - stretch_factor = 1 - if hps.get('total_accumulated_batch_size') is not None: - stretch_factor = (hps.total_accumulated_batch_size // hps.batch_size) - - self._lr_fn = schedules.get_schedule_fn( - self._hps.lr_hparams, - max_training_updates=self._num_train_steps // stretch_factor, - stretch_factor=stretch_factor) - - self._optimizer_init_fn, self._optimizer_update_fn = ( - optimizers.get_optimizer( - hps, model, batch_axis_name='batch' - ) - ) - unreplicated_optimizer_state = self._optimizer_init_fn(params) + if hps.get('algoperf_submission_name', None): + print('inside algoperf code path ') + + workload = i2w_workload.Init2winitWorkload(model, hps) + submission_module = submissions.get_submission_module( + hps.algoperf_submission_name + ) + + submission_hps = submission_module.map_i2w_hparams_to_algoperf_hparams( + hps + ) + + ( + unreplicated_optimizer_state, + self._optimizer_init_fn, + self._optimizer_update_fn, + ) = submission_module.init_optimizer_state( + workload, params, batch_stats, submission_hps, rng + ) + + else: + # Init2winit optimizer_lib code path. + del batch_stats + del rng + + stretch_factor = 1 + if hps.get('total_accumulated_batch_size') is not None: + stretch_factor = (hps.total_accumulated_batch_size // hps.batch_size) + + self._lr_fn = schedules.get_schedule_fn( + self._hps.lr_hparams, + max_training_updates=self._num_train_steps // stretch_factor, + stretch_factor=stretch_factor) + + self._optimizer_init_fn, self._optimizer_update_fn = ( + optimizers.get_optimizer( + hps, model, batch_axis_name='batch' + ) + ) + unreplicated_optimizer_state = self._optimizer_init_fn(params) + return unreplicated_optimizer_state, self._optimizer_update_fn def update_params(self,