|
12 | 12 |
|
13 | 13 | from autoPyTorch.api.base_task import BaseTask, _pipeline_predict |
14 | 14 | from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION |
| 15 | +from autoPyTorch.datasets.base_dataset import BaseDataset |
15 | 16 | from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes |
| 17 | +from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager |
16 | 18 | from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline |
| 19 | +from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy |
17 | 20 |
|
18 | 21 |
|
19 | 22 | # ==== |
@@ -201,3 +204,35 @@ def test_pipeline_get_budget_forecasting(fit_dictionary_forecasting, min_budget, |
201 | 204 | assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config |
202 | 205 | assert list(smac_mock.call_args)[1]['max_budget'] == max_budget |
203 | 206 | assert list(smac_mock.call_args)[1]['initial_budget'] == min_budget |
| 207 | + |
| 208 | + |
| 209 | +def test_init_ensemble_builder(backend): |
| 210 | + BaseTask.__abstractmethods__ = set() |
| 211 | + estimator = BaseTask( |
| 212 | + backend=backend, |
| 213 | + ensemble_size=0, |
| 214 | + ) |
| 215 | + |
| 216 | + # Setup pre-requisites normally set by search() |
| 217 | + estimator._logger = estimator._get_logger('test') |
| 218 | + estimator.task_type = "tabular_classification" |
| 219 | + estimator._memory_limit = 60 |
| 220 | + estimator.dataset = MagicMock(spec=BaseDataset) |
| 221 | + estimator.dataset.output_type = 'binary' |
| 222 | + estimator.dataset.dataset_name = 'dummy' |
| 223 | + |
| 224 | + proc_ensemble = estimator._init_ensemble_builder( |
| 225 | + time_left_for_ensembles=60, |
| 226 | + optimize_metric='accuracy', |
| 227 | + ensemble_nbest=10, |
| 228 | + ensemble_size=5 |
| 229 | + ) |
| 230 | + |
| 231 | + assert isinstance(proc_ensemble, EnsembleBuilderManager) |
| 232 | + assert proc_ensemble.opt_metric == 'accuracy' |
| 233 | + assert proc_ensemble.metrics[0] == accuracy |
| 234 | + |
| 235 | + estimator._close_dask_client() |
| 236 | + estimator._clean_logger() |
| 237 | + |
| 238 | + del estimator |
0 commit comments