Skip to content

Commit fdac86c

Browse files
committed
add tests for ensemble init
1 parent 0e15972 commit fdac86c

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

autoPyTorch/api/base_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,8 +1836,9 @@ def _init_ensemble_builder(
18361836
dataset_name=str(self.dataset.dataset_name),
18371837
output_type=STRING_TO_OUTPUT_TYPES[self.dataset.output_type],
18381838
task_type=STRING_TO_TASK_TYPES[self.task_type],
1839-
metrics=[self._metric] if self._metric is not None else get_metrics(
1840-
dataset_properties=required_dataset_properties, names=[optimize_metric]),
1839+
metrics=get_metrics(
1840+
dataset_properties=required_dataset_properties,
1841+
names=[optimize_metric]),
18411842
opt_metric=optimize_metric,
18421843
ensemble_size=ensemble_size,
18431844
ensemble_nbest=ensemble_nbest,

test/test_api/test_base_api.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
from autoPyTorch.api.base_task import BaseTask, _pipeline_predict
1414
from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION
15+
from autoPyTorch.datasets.base_dataset import BaseDataset
1516
from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes
17+
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
1618
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
19+
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
1720

1821

1922
# ====
@@ -160,3 +163,33 @@ def test_no_resampling_error(backend):
160163
seed=42,
161164
ensemble_size=1
162165
)
166+
def test_init_ensemble_builder(backend):
167+
BaseTask.__abstractmethods__ = set()
168+
estimator = BaseTask(
169+
backend=backend,
170+
ensemble_size=0,
171+
)
172+
173+
# Setup pre-requisites normally set by search()
174+
estimator._logger = estimator._get_logger('test')
175+
estimator.task_type = "tabular_classification"
176+
estimator._memory_limit = 60
177+
estimator.dataset = MagicMock(spec=BaseDataset)
178+
estimator.dataset.output_type = 'binary'
179+
estimator.dataset.dataset_name = 'dummy'
180+
181+
proc_ensemble = estimator._init_ensemble_builder(
182+
time_left_for_ensembles=60,
183+
optimize_metric='accuracy',
184+
ensemble_nbest=10,
185+
ensemble_size=5
186+
)
187+
188+
assert isinstance(proc_ensemble, EnsembleBuilderManager)
189+
assert proc_ensemble.opt_metric == 'accuracy'
190+
assert proc_ensemble.metrics[0] == accuracy
191+
192+
estimator._close_dask_client()
193+
estimator._clean_logger()
194+
195+
del estimator

0 commit comments

Comments
 (0)