Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,6 @@ def _run(
training_metrics_config=training_metrics_config,
callback_configs=callback_configs,
external_checkpoint_path=external_checkpoint_path,
dataset_meta_data=dataset_meta_data,
loss_name=loss_name,
metrics_name=metrics_name,
data_selector=data_selector,
).train()
)
Expand Down
15 changes: 0 additions & 15 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def __init__(
training_metrics_config=None,
callback_configs=None,
external_checkpoint_path=None,
dataset_meta_data=None,
loss_name=None,
metrics_name=None,
data_selector=None,
training_algorithm_class=training_algorithm.OptaxTrainingAlgorithm,
):
Expand Down Expand Up @@ -125,22 +122,10 @@ def __init__(
external_checkpoint_path: (str) If this argument is set, we will load the
optimizer_state, params, batch_stats, and training_metrics from the
checkpoint at this location.
dataset_meta_data: meta_data about the dataset. It is not directly used in
the base trainer. Users are expected to overwrite the initialization
method in a customimzed trainer to access it.
loss_name: name of the loss function. Not directly used in base trainer.
Users are expected to overwrite the initialization method in a
customimzed trainer to access it.
metrics_name: Not directly used in the base trainer. Users are expected to
overwrite the initialization method in a customimzed trainer to access
it.
data_selector: data selection function returned by
datasets.get_data_selector.
training_algorithm_class: Class of training algorithm to use.
"""
del dataset_meta_data
del loss_name
del metrics_name
self._train_dir = train_dir
self._model = model
self._dataset_builder = dataset_builder
Expand Down