diff --git a/init2winit/main.py b/init2winit/main.py index 48b15982..fa9132be 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -284,9 +284,6 @@ def _run( training_metrics_config=training_metrics_config, callback_configs=callback_configs, external_checkpoint_path=external_checkpoint_path, - dataset_meta_data=dataset_meta_data, - loss_name=loss_name, - metrics_name=metrics_name, data_selector=data_selector, ).train() ) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index c8a12064..691cba95 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -63,9 +63,6 @@ def __init__( training_metrics_config=None, callback_configs=None, external_checkpoint_path=None, - dataset_meta_data=None, - loss_name=None, - metrics_name=None, data_selector=None, training_algorithm_class=training_algorithm.OptaxTrainingAlgorithm, ): @@ -125,22 +122,10 @@ def __init__( external_checkpoint_path: (str) If this argument is set, we will load the optimizer_state, params, batch_stats, and training_metrics from the checkpoint at this location. - dataset_meta_data: meta_data about the dataset. It is not directly used in - the base trainer. Users are expected to overwrite the initialization - method in a customimzed trainer to access it. - loss_name: name of the loss function. Not directly used in base trainer. - Users are expected to overwrite the initialization method in a - customimzed trainer to access it. - metrics_name: Not directly used in the base trainer. Users are expected to - overwrite the initialization method in a customimzed trainer to access - it. data_selector: data selection function returned by datasets.get_data_selector. training_algorithm_class: Class of training algorithm to use. """ - del dataset_meta_data - del loss_name - del metrics_name self._train_dir = train_dir self._model = model self._dataset_builder = dataset_builder