Skip to content

Commit 2569602

Browse files
committed
fix flake and issue #299
1 parent 4d16352 commit 2569602

File tree

3 files changed

+63
-69
lines changed

3 files changed

+63
-69
lines changed

autoPyTorch/api/base_task.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def __init__(
188188
self._scoring_functions: Optional[List[autoPyTorchMetric]] = None
189189
self._logger: Optional[PicklableClientLogger] = None
190190
self.dataset_name: Optional[str] = None
191+
self.dataset = Optional[BaseDataset]
191192
self.cv_models_: Dict = {}
192193

193194
self._results_manager = ResultsManager()
@@ -616,20 +617,7 @@ def _load_best_individual_model(self) -> SingleBest:
616617
run_history=self.run_history,
617618
backend=self._backend,
618619
)
619-
if self._logger is None:
620-
warnings.warn(
621-
"No valid ensemble was created. Please check the log"
622-
"file for errors. Default to the best individual estimator:{}".format(
623-
ensemble.identifiers_
624-
)
625-
)
626-
else:
627-
self._logger.exception(
628-
"No valid ensemble was created. Please check the log"
629-
"file for errors. Default to the best individual estimator:{}".format(
630-
ensemble.identifiers_
631-
)
632-
)
620+
633621

634622
return ensemble
635623

@@ -1257,7 +1245,6 @@ def _search(
12571245
if proc_ensemble is not None:
12581246
self._collect_results_ensemble(proc_ensemble)
12591247

1260-
12611248
self._logger.info("Closing the dask infrastructure")
12621249
self._close_dask_client()
12631250
self._logger.info("Finished closing the dask infrastructure")
@@ -1267,6 +1254,14 @@ def _search(
12671254
self._load_models()
12681255
self._logger.info("Finished loading models...")
12691256

1257+
if isinstance(self.ensemble_, SingleBest) and ensemble_size > 0:
1258+
self._logger.exception(
1259+
"No valid ensemble was created. Please check the log"
1260+
"file for errors. Default to the best individual estimator:{}".format(
1261+
self.ensemble_.identifiers_
1262+
)
1263+
)
1264+
12701265
self._cleanup()
12711266

12721267
return self

examples/40_advanced/example_posthoc_ensemble_fit.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,59 @@
2424
from autoPyTorch.api.tabular_classification import TabularClassificationTask
2525

2626

27-
if __name__ == '__main__':
27+
############################################################################
28+
# Data Loading
29+
# ============
30+
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
31+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
32+
X,
33+
y,
34+
random_state=42,
35+
)
2836

29-
############################################################################
30-
# Data Loading
31-
# ============
32-
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
33-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
34-
X,
35-
y,
36-
random_state=42,
37-
)
37+
############################################################################
38+
# Build and fit a classifier
39+
# ==========================
40+
api = TabularClassificationTask(
41+
seed=42,
42+
)
3843

39-
############################################################################
40-
# Build and fit a classifier
41-
# ==========================
42-
api = TabularClassificationTask(
43-
ensemble_size=0,
44-
seed=42,
45-
)
44+
############################################################################
45+
# Search for the best neural network
46+
# ==================================
47+
api.search(
48+
X_train=X_train,
49+
y_train=y_train,
50+
X_test=X_test.copy(),
51+
y_test=y_test.copy(),
52+
optimize_metric='accuracy',
53+
total_walltime_limit=100,
54+
func_eval_time_limit_secs=50,
55+
ensemble_size=0,
56+
)
4657

47-
############################################################################
48-
# Search for the best neural network
49-
# ==================================
50-
api.search(
51-
X_train=X_train,
52-
y_train=y_train,
53-
X_test=X_test.copy(),
54-
y_test=y_test.copy(),
55-
optimize_metric='accuracy',
56-
total_walltime_limit=100,
57-
func_eval_time_limit_secs=50
58-
)
58+
############################################################################
59+
# Print the final performance of the incumbent neural network
60+
# ===========================================================
61+
print(api.run_history, api.trajectory)
62+
y_pred = api.predict(X_test)
63+
score = api.score(y_pred, y_test)
64+
print(score)
5965

60-
############################################################################
61-
# Print the final performance of the incumbent neural network
62-
# ===========================================================
63-
print(api.run_history, api.trajectory)
64-
y_pred = api.predict(X_test)
65-
score = api.score(y_pred, y_test)
66-
print(score)
66+
############################################################################
67+
# Fit an ensemble with the neural networks fitted during the search
68+
# =================================================================
6769

68-
############################################################################
69-
# Fit an ensemble with the neural networks fitted during the search
70-
# =================================================================
70+
api.fit_ensemble(ensemble_size=5,
71+
# Set the enable_traditional_pipeline=True
72+
# to also include traditional models
73+
# in the ensemble
74+
enable_traditional_pipeline=False)
75+
# Print the final ensemble built by AutoPyTorch
76+
y_pred = api.predict(X_test)
77+
score = api.score(y_pred, y_test)
78+
print(score)
79+
print(api.show_models())
7180

72-
api.fit_ensemble(ensemble_size=5,
73-
# Set the enable_traditional_pipeline=True
74-
# to also include traditional models
75-
# in the ensemble
76-
enable_traditional_pipeline=False)
77-
# Print the final ensemble built by AutoPyTorch
78-
y_pred = api.predict(X_test)
79-
score = api.score(y_pred, y_test)
80-
print(score)
81-
print(api.show_models())
82-
api._cleanup()
81+
# Print statistics from search
82+
print(api.sprint_statistics())

test/test_api/test_base_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION
1515
from autoPyTorch.datasets.base_dataset import BaseDataset
1616
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
17-
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1817
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
18+
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1919

2020

2121
# ====
@@ -167,13 +167,12 @@ def test_init_ensemble_builder(backend):
167167
time_left_for_ensembles=60,
168168
optimize_metric='accuracy',
169169
ensemble_nbest=10,
170-
ensemble_size=5
171-
)
170+
ensemble_size=5)
172171

173172
assert isinstance(proc_ensemble, EnsembleBuilderManager)
174173
assert proc_ensemble.opt_metric == 'accuracy'
175174
assert proc_ensemble.metrics[0] == accuracy
176175

177176
estimator._cleanup()
178177

179-
del estimator
178+
del estimator

0 commit comments

Comments
 (0)