Skip to content

Commit 20514cb

Browse files
committed
add tests for ensemble init
1 parent 497f8f7 commit 20514cb

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

autoPyTorch/api/base_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,8 +1914,9 @@ def _init_ensemble_builder(
19141914
dataset_name=str(self.dataset.dataset_name),
19151915
output_type=STRING_TO_OUTPUT_TYPES[self.dataset.output_type],
19161916
task_type=STRING_TO_TASK_TYPES[self.task_type],
1917-
metrics=[self._metric] if self._metric is not None else get_metrics(
1918-
dataset_properties=required_dataset_properties, names=[optimize_metric]),
1917+
metrics=get_metrics(
1918+
dataset_properties=required_dataset_properties,
1919+
names=[optimize_metric]),
19191920
opt_metric=optimize_metric,
19201921
ensemble_size=ensemble_size,
19211922
ensemble_nbest=ensemble_nbest,

test/test_api/test_base_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
from autoPyTorch.api.base_task import BaseTask, _pipeline_predict
1414
from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION
15+
from autoPyTorch.datasets.base_dataset import BaseDataset
1516
from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes
17+
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
1618
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
19+
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
1720

1821

1922
# ====
@@ -201,3 +204,35 @@ def test_pipeline_get_budget_forecasting(fit_dictionary_forecasting, min_budget,
201204
assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config
202205
assert list(smac_mock.call_args)[1]['max_budget'] == max_budget
203206
assert list(smac_mock.call_args)[1]['initial_budget'] == min_budget
207+
208+
209+
def test_init_ensemble_builder(backend):
210+
BaseTask.__abstractmethods__ = set()
211+
estimator = BaseTask(
212+
backend=backend,
213+
ensemble_size=0,
214+
)
215+
216+
# Setup pre-requisites normally set by search()
217+
estimator._logger = estimator._get_logger('test')
218+
estimator.task_type = "tabular_classification"
219+
estimator._memory_limit = 60
220+
estimator.dataset = MagicMock(spec=BaseDataset)
221+
estimator.dataset.output_type = 'binary'
222+
estimator.dataset.dataset_name = 'dummy'
223+
224+
proc_ensemble = estimator._init_ensemble_builder(
225+
time_left_for_ensembles=60,
226+
optimize_metric='accuracy',
227+
ensemble_nbest=10,
228+
ensemble_size=5
229+
)
230+
231+
assert isinstance(proc_ensemble, EnsembleBuilderManager)
232+
assert proc_ensemble.opt_metric == 'accuracy'
233+
assert proc_ensemble.metrics[0] == accuracy
234+
235+
estimator._close_dask_client()
236+
estimator._clean_logger()
237+
238+
del estimator

0 commit comments

Comments
 (0)