Skip to content

Commit c0fb82e

Browse files
authored
Fix 361 (#367)
* check if N==0, and handle this case * change position of comment * Address comments from shuhei
1 parent f612f46 commit c0fb82e

File tree

4 files changed

+89
-1
lines changed

4 files changed

+89
-1
lines changed

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
293293
writer=writer,
294294
)
295295

296+
# its fine if train_loss is None due to `is_max_time_reached()`
297+
if train_loss is None:
298+
if self.budget_tracker.is_max_time_reached():
299+
break
300+
else:
301+
raise RuntimeError("Got an unexpected None in `train_loss`.")
302+
296303
val_loss, val_metrics, test_loss, test_metrics = None, {}, None, {}
297304
if self.eval_valid_each_epoch(X):
298305
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
@@ -334,6 +341,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
334341
if 'cuda' in X['device']:
335342
torch.cuda.empty_cache()
336343

344+
if self.run_summary.is_empty():
345+
raise RuntimeError("Budget exhausted without finishing an epoch.")
346+
337347
# wrap up -- add score if not evaluating every epoch
338348
if not self.eval_valid_each_epoch(X):
339349
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,16 @@ def repr_last_epoch(self) -> str:
179179
string += '=' * 40
180180
return string
181181

182+
def is_empty(self) -> bool:
183+
"""
184+
Checks if the object is empty or not
185+
186+
Returns:
187+
bool
188+
"""
189+
# if train_loss is empty, we can be sure that RunSummary is empty.
190+
return not bool(self.performance_tracker['train_loss'])
191+
182192

183193
class BaseTrainerComponent(autoPyTorchTrainingComponent):
184194

@@ -277,7 +287,7 @@ def _scheduler_step(
277287

278288
def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
279289
writer: Optional[SummaryWriter],
280-
) -> Tuple[float, Dict[str, float]]:
290+
) -> Tuple[Optional[float], Dict[str, float]]:
281291
"""
282292
Train the model for a single epoch.
283293
@@ -317,6 +327,9 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
317327
epoch * len(train_loader) + step,
318328
)
319329

330+
if N == 0:
331+
return None, {}
332+
320333
self._scheduler_step(step_interval=StepIntervalUnit.epoch, loss=loss_sum / N)
321334

322335
if self.metrics_during_training:

test/test_pipeline/components/training/test_training.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,43 @@ def test_train_step(self):
236236
lr = optimizer.param_groups[0]['lr']
237237
assert lr == target_lr
238238

239+
def test_train_epoch_no_step(self):
240+
"""
241+
This test checks if max runtime is reached
242+
for an epoch before any train_step has been
243+
completed. In this case we would like to
244+
return None for train_loss and an empty
245+
dictionary for the metrics.
246+
"""
247+
device = torch.device('cpu')
248+
model = torch.nn.Linear(1, 1).to(device)
249+
optimizer = torch.optim.Adam(model.parameters(), lr=1)
250+
data_loader = unittest.mock.MagicMock(spec=torch.utils.data.DataLoader)
251+
ms = [3, 5, 6]
252+
params = {
253+
'metrics': [],
254+
'device': device,
255+
'task_type': constants.TABULAR_REGRESSION,
256+
'labels': torch.Tensor([]),
257+
'metrics_during_training': False,
258+
'budget_tracker': BudgetTracker(budget_type='runtime', max_runtime=0),
259+
'criterion': torch.nn.MSELoss,
260+
'optimizer': optimizer,
261+
'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=ms, gamma=2),
262+
'model': model,
263+
'step_interval': StepIntervalUnit.epoch
264+
}
265+
trainer = StandardTrainer()
266+
trainer.prepare(**params)
267+
268+
loss, metrics = trainer.train_epoch(
269+
train_loader=data_loader,
270+
epoch=0,
271+
writer=None
272+
)
273+
assert loss is None
274+
assert metrics == {}
275+
239276

240277
class TestStandardTrainer(BaseTraining):
241278
def test_regression_epoch_training(self, n_samples):

test/test_pipeline/test_tabular_classification.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import re
33
import unittest
4+
import unittest.mock
45

56
from ConfigSpace.hyperparameters import (
67
CategoricalHyperparameter,
@@ -491,3 +492,30 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy):
491492

492493
# More than 200 epochs would have pass in 5 seconds for this dataset
493494
assert len(run_summary.performance_tracker['start_time']) > 100
495+
496+
497+
@pytest.mark.parametrize("fit_dictionary_tabular_dummy", ["classification"], indirect=True)
498+
def test_train_pipeline_with_runtime_max_reached(fit_dictionary_tabular_dummy):
499+
"""
500+
This test makes sure that the pipeline raises an
501+
error in case no epoch has finished successfully
502+
due to max runtime reached
503+
"""
504+
505+
# Convert the training to runtime
506+
fit_dictionary_tabular_dummy.pop('epochs', None)
507+
fit_dictionary_tabular_dummy['budget_type'] = 'runtime'
508+
fit_dictionary_tabular_dummy['runtime'] = 5
509+
fit_dictionary_tabular_dummy['early_stopping'] = -1
510+
511+
pipeline = TabularClassificationPipeline(
512+
dataset_properties=fit_dictionary_tabular_dummy['dataset_properties'])
513+
514+
cs = pipeline.get_hyperparameter_search_space()
515+
config = cs.get_default_configuration()
516+
pipeline.set_hyperparameters(config)
517+
518+
with unittest.mock.patch('autoPyTorch.pipeline.components.training.trainer.BudgetTracker') as patch:
519+
patch.is_max_time_reached.return_value = True
520+
with pytest.raises(RuntimeError):
521+
pipeline.fit(fit_dictionary_tabular_dummy)

0 commit comments

Comments
 (0)