Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion init2winit/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def build_hparams(model_name,
hparam_file,
hparam_overrides,
input_pipeline_hps=None,
allowed_unrecognized_hparams=None):
allowed_unrecognized_hparams=None,
algoperf_submission_name=None):
"""Build experiment hyperparameters.

Args:
Expand All @@ -121,6 +122,7 @@ def build_hparams(model_name,
hparams from an error to a warning can be useful when trying to tune using
a shared search space over multiple workloads that don't all support the
same set of hyperparameters.
algoperf_submission_name: The name of the algoperf submission.

Returns:
A ConfigDict of experiment hyperparameters.
Expand Down Expand Up @@ -163,6 +165,10 @@ def build_hparams(model_name,
for key in ['opt_hparams', 'lr_hparams']:
merged[key].unlock()

if algoperf_submission_name:
with merged.unlocked():
merged['algoperf_submission_name'] = algoperf_submission_name

if hparam_file:
logging.info('Loading hparams from %s', hparam_file)
with gfile.GFile(hparam_file, 'r') as f:
Expand Down
98 changes: 68 additions & 30 deletions init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,66 +57,101 @@
flags.DEFINE_string('trainer', 'standard', 'Name of the trainer to use.')
flags.DEFINE_string('model', 'fully_connected', 'Name of the model to train.')
flags.DEFINE_string('loss', 'cross_entropy', 'Loss function.')
flags.DEFINE_string('metrics', 'classification_metrics',
'Metrics to be used for evaluation.')
flags.DEFINE_string(
'algoperf_submission_name', '', 'AlgoPerf submission module lookup name.'
)
flags.DEFINE_string(
'metrics', 'classification_metrics', 'Metrics to be used for evaluation.'
)
flags.DEFINE_string('initializer', 'noop', 'Must be in [noop, meta_init].')
flags.DEFINE_string('experiment_dir', None,
'Path to save weights and other results. Each trial '
'directory will have path experiment_dir/worker_id/.')
flags.DEFINE_string(
'experiment_dir',
None,
'Path to save weights and other results. Each trial '
'directory will have path experiment_dir/worker_id/.',
)
flags.DEFINE_string('dataset', 'mnist', 'Which dataset to train on.')
flags.DEFINE_string('data_selector', 'noop', 'Which data selector to use.')
flags.DEFINE_integer('num_train_steps', None, 'The number of steps to train.')
flags.DEFINE_integer(
'num_tf_data_prefetches', -1, 'The number of batches to to prefetch from '
'network to host at each step. Set to -1 for tf.data.AUTOTUNE.')
'num_tf_data_prefetches',
-1,
'The number of batches to to prefetch from '
'network to host at each step. Set to -1 for tf.data.AUTOTUNE.',
)
flags.DEFINE_integer(
'num_device_prefetches', 0, 'The number of batches to to prefetch from '
'host to device at each step.')
'num_device_prefetches',
0,
'The number of batches to to prefetch from host to device at each step.',
)
flags.DEFINE_integer(
'num_tf_data_map_parallel_calls', -1, 'The number of parallel calls to '
'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.'
'num_tf_data_map_parallel_calls',
-1,
'The number of parallel calls to '
'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.',
)
flags.DEFINE_integer('eval_batch_size', None, 'Batch size for evaluation.')
flags.DEFINE_bool('eval_use_ema', None, 'If True evals will use ema of params.')
flags.DEFINE_integer(
'eval_num_batches', None,
'eval_num_batches',
None,
'Number of batches for evaluation. Leave None to evaluate '
'on the entire validation and test set.')
'on the entire validation and test set.',
)
flags.DEFINE_integer(
'test_num_batches', None,
'test_num_batches',
None,
'Number of batches for eval on test set. Leave None to evaluate '
'on the entire test set.')
flags.DEFINE_integer('eval_train_num_batches', None,
'Number of batches when evaluating on the training set.')
'on the entire test set.',
)
flags.DEFINE_integer(
'eval_train_num_batches',
None,
'Number of batches when evaluating on the training set.',
)
flags.DEFINE_integer('eval_frequency', 1000, 'Evaluate every k steps.')
flags.DEFINE_string(
'hparam_overrides', '', 'JSON representation of a flattened dict of hparam '
'hparam_overrides',
'',
'JSON representation of a flattened dict of hparam '
'overrides. For nested dictionaries, the override key '
'should be specified as lr_hparams.base_lr.')
'should be specified as lr_hparams.base_lr.',
)
flags.DEFINE_string(
'callback_configs', '', 'JSON representation of a list of dictionaries '
'which specify general callbacks to be run during eval of training.')
'callback_configs',
'',
'JSON representation of a list of dictionaries '
'which specify general callbacks to be run during eval of training.',
)
flags.DEFINE_list(
'checkpoint_steps', [], 'List of steps to checkpoint the'
'checkpoint_steps',
[],
'List of steps to checkpoint the'
' model. The checkpoints will be saved in a separate'
'directory train_dir/checkpoints. Note these checkpoints'
'will be in addition to the normal checkpointing that'
'occurs during training for preemption purposes.')
flags.DEFINE_string('external_checkpoint_path', None,
'If this argument is set, the trainer will initialize'
'the parameters, batch stats, optimizer state, and training'
'metrics by loading them from the checkpoint at this path.')
'occurs during training for preemption purposes.',
)
flags.DEFINE_string(
'external_checkpoint_path',
None,
'If this argument is set, the trainer will initialize'
'the parameters, batch stats, optimizer state, and training'
'metrics by loading them from the checkpoint at this path.',
)

flags.DEFINE_string(
'early_stopping_target_name',
None,
'A string naming the metric to use to perform early stopping. If this '
'metric reaches the value `early_stopping_target_value`, training will '
'stop. Must include the dataset split (ex: validation/error_rate).')
'stop. Must include the dataset split (ex: validation/error_rate).',
)
flags.DEFINE_float(
'early_stopping_target_value',
None,
'A float indicating the value at which to stop training.')
'A float indicating the value at which to stop training.',
)
flags.DEFINE_enum(
'early_stopping_mode',
None,
Expand Down Expand Up @@ -198,6 +233,7 @@ def _run(
initializer_name,
model_name,
loss_name,
algoperf_submission_name,
metrics_name,
num_train_steps,
experiment_dir,
Expand Down Expand Up @@ -225,7 +261,8 @@ def _run(
hparam_file=hparam_file,
hparam_overrides=hparam_overrides,
input_pipeline_hps=input_pipeline_hps,
allowed_unrecognized_hparams=allowed_unrecognized_hparams)
allowed_unrecognized_hparams=allowed_unrecognized_hparams,
algoperf_submission_name=algoperf_submission_name)

# Note that one should never tune an RNG seed!!! The seed is only included in
# the hparams for convenience of running hparam trials with multiple seeds per
Expand Down Expand Up @@ -358,6 +395,7 @@ def main(unused_argv):
initializer_name=FLAGS.initializer,
model_name=FLAGS.model,
loss_name=FLAGS.loss,
algoperf_submission_name=FLAGS.algoperf_submission_name,
metrics_name=FLAGS.metrics,
num_train_steps=FLAGS.num_train_steps,
experiment_dir=experiment_dir,
Expand Down
60 changes: 60 additions & 0 deletions init2winit/trainer_lib/i2w_workload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2024 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Init2winit workload."""

from init2winit.experiments.mlcommons.workloads import mlcommons_targets
from init2winit.experiments.mlcommons.workloads import mlcommons_workload_info
from init2winit.trainer_lib import spec


class Init2winitWorkload(spec.Workload):
"""Init2winit workload."""

def initialize(self, model, hps):
self._model = model
self._hps = hps

@property
def workload_name(self):
if not self._hps.workload_name:
self._workload_name = self._hps.dataset + '_' + self._hps.model
else:
self._workload_name = self._hps.workload_name

return self._workload_name

@property
def param_shapes(self):
return self._model.param_shapes

@property
def model_params_types(self):
return self._model.param_types

@property
def step_hint(self):
if self.workload_name not in mlcommons_workload_info.num_train_steps:
raise ValueError(
f'Workload {self.workload_name} not found in num_train_steps.')
return mlcommons_workload_info.num_train_steps[self.workload_name]

@property
def target_metric_name(self):
if self.workload_name not in mlcommons_targets.validation_targets:
raise ValueError(
f'Workload {self.workload_name} not found in validation targets.')

return mlcommons_targets.validation_targets[self.workload_name]['metric']
Loading