diff --git a/init2winit/dataset_lib/fake_dataset.py b/init2winit/dataset_lib/fake_dataset.py index bbcf039a..059156aa 100644 --- a/init2winit/dataset_lib/fake_dataset.py +++ b/init2winit/dataset_lib/fake_dataset.py @@ -79,12 +79,8 @@ def train_iterator_fn(): while True: yield fake_train_batch - def valid_epoch(epoch, num_batches=None): - del num_batches - del epoch - # Note that we do // beacuse we do not support partial batching for the fake - # dataset. - for _ in range(hps.valid_size // eval_batch_size): + def valid_epoch(num_batches): + for _ in range(num_batches): yield fake_test_batch # pylint: disable=unreachable @@ -105,4 +101,3 @@ def test_epoch(*args, **kwargs): return data_utils.Dataset( train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) -