Skip to content

Commit 707b275

Browse files
committed
fix flake and issue #299
1 parent 23075d6 commit 707b275

File tree

3 files changed

+63
-69
lines changed

3 files changed

+63
-69
lines changed

autoPyTorch/api/base_task.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def __init__(
204204
self._scoring_functions: Optional[List[autoPyTorchMetric]] = None
205205
self._logger: Optional[PicklableClientLogger] = None
206206
self.dataset_name: Optional[str] = None
207+
self.dataset = Optional[BaseDataset]
207208
self.cv_models_: Dict = {}
208209

209210
self._results_manager = ResultsManager()
@@ -632,20 +633,7 @@ def _load_best_individual_model(self) -> SingleBest:
632633
run_history=self.run_history,
633634
backend=self._backend,
634635
)
635-
if self._logger is None:
636-
warnings.warn(
637-
"No valid ensemble was created. Please check the log"
638-
"file for errors. Default to the best individual estimator:{}".format(
639-
ensemble.identifiers_
640-
)
641-
)
642-
else:
643-
self._logger.exception(
644-
"No valid ensemble was created. Please check the log"
645-
"file for errors. Default to the best individual estimator:{}".format(
646-
ensemble.identifiers_
647-
)
648-
)
636+
649637

650638
return ensemble
651639

@@ -1273,7 +1261,6 @@ def _search(
12731261
if proc_ensemble is not None:
12741262
self._collect_results_ensemble(proc_ensemble)
12751263

1276-
12771264
self._logger.info("Closing the dask infrastructure")
12781265
self._close_dask_client()
12791266
self._logger.info("Finished closing the dask infrastructure")
@@ -1283,6 +1270,14 @@ def _search(
12831270
self._load_models()
12841271
self._logger.info("Finished loading models...")
12851272

1273+
if isinstance(self.ensemble_, SingleBest) and ensemble_size > 0:
1274+
self._logger.exception(
1275+
"No valid ensemble was created. Please check the log"
1276+
"file for errors. Default to the best individual estimator:{}".format(
1277+
self.ensemble_.identifiers_
1278+
)
1279+
)
1280+
12861281
self._cleanup()
12871282

12881283
return self

examples/40_advanced/example_posthoc_ensemble_fit.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,59 @@
2424
from autoPyTorch.api.tabular_classification import TabularClassificationTask
2525

2626

27-
if __name__ == '__main__':
27+
############################################################################
28+
# Data Loading
29+
# ============
30+
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
31+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
32+
X,
33+
y,
34+
random_state=42,
35+
)
2836

29-
############################################################################
30-
# Data Loading
31-
# ============
32-
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
33-
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
34-
X,
35-
y,
36-
random_state=42,
37-
)
37+
############################################################################
38+
# Build and fit a classifier
39+
# ==========================
40+
api = TabularClassificationTask(
41+
seed=42,
42+
)
3843

39-
############################################################################
40-
# Build and fit a classifier
41-
# ==========================
42-
api = TabularClassificationTask(
43-
ensemble_size=0,
44-
seed=42,
45-
)
44+
############################################################################
45+
# Search for the best neural network
46+
# ==================================
47+
api.search(
48+
X_train=X_train,
49+
y_train=y_train,
50+
X_test=X_test.copy(),
51+
y_test=y_test.copy(),
52+
optimize_metric='accuracy',
53+
total_walltime_limit=100,
54+
func_eval_time_limit_secs=50,
55+
ensemble_size=0,
56+
)
4657

47-
############################################################################
48-
# Search for the best neural network
49-
# ==================================
50-
api.search(
51-
X_train=X_train,
52-
y_train=y_train,
53-
X_test=X_test.copy(),
54-
y_test=y_test.copy(),
55-
optimize_metric='accuracy',
56-
total_walltime_limit=100,
57-
func_eval_time_limit_secs=50
58-
)
58+
############################################################################
59+
# Print the final performance of the incumbent neural network
60+
# ===========================================================
61+
print(api.run_history, api.trajectory)
62+
y_pred = api.predict(X_test)
63+
score = api.score(y_pred, y_test)
64+
print(score)
5965

60-
############################################################################
61-
# Print the final performance of the incumbent neural network
62-
# ===========================================================
63-
print(api.run_history, api.trajectory)
64-
y_pred = api.predict(X_test)
65-
score = api.score(y_pred, y_test)
66-
print(score)
66+
############################################################################
67+
# Fit an ensemble with the neural networks fitted during the search
68+
# =================================================================
6769

68-
############################################################################
69-
# Fit an ensemble with the neural networks fitted during the search
70-
# =================================================================
70+
api.fit_ensemble(ensemble_size=5,
71+
# Set the enable_traditional_pipeline=True
72+
# to also include traditional models
73+
# in the ensemble
74+
enable_traditional_pipeline=False)
75+
# Print the final ensemble built by AutoPyTorch
76+
y_pred = api.predict(X_test)
77+
score = api.score(y_pred, y_test)
78+
print(score)
79+
print(api.show_models())
7180

72-
api.fit_ensemble(ensemble_size=5,
73-
# Set the enable_traditional_pipeline=True
74-
# to also include traditional models
75-
# in the ensemble
76-
enable_traditional_pipeline=False)
77-
# Print the final ensemble built by AutoPyTorch
78-
y_pred = api.predict(X_test)
79-
score = api.score(y_pred, y_test)
80-
print(score)
81-
print(api.show_models())
82-
api._cleanup()
81+
# Print statistics from search
82+
print(api.sprint_statistics())

test/test_api/test_base_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from autoPyTorch.datasets.base_dataset import BaseDataset
1616
from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes
1717
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
18-
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1918
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
19+
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
2020

2121

2222
# ====
@@ -182,13 +182,12 @@ def test_init_ensemble_builder(backend):
182182
time_left_for_ensembles=60,
183183
optimize_metric='accuracy',
184184
ensemble_nbest=10,
185-
ensemble_size=5
186-
)
185+
ensemble_size=5)
187186

188187
assert isinstance(proc_ensemble, EnsembleBuilderManager)
189188
assert proc_ensemble.opt_metric == 'accuracy'
190189
assert proc_ensemble.metrics[0] == accuracy
191190

192191
estimator._cleanup()
193192

194-
del estimator
193+
del estimator

0 commit comments

Comments
 (0)