|
12 | 12 |
|
13 | 13 | from autoPyTorch.api.base_task import BaseTask, _pipeline_predict |
14 | 14 | from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION |
| 15 | +from autoPyTorch.datasets.base_dataset import BaseDataset |
15 | 16 | from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes |
| 17 | +from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager |
16 | 18 | from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline |
| 19 | +from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy |
17 | 20 |
|
18 | 21 |
|
19 | 22 | # ==== |
@@ -160,3 +163,33 @@ def test_no_resampling_error(backend): |
160 | 163 | seed=42, |
161 | 164 | ensemble_size=1 |
162 | 165 | ) |
| 166 | +def test_init_ensemble_builder(backend): |
| 167 | + BaseTask.__abstractmethods__ = set() |
| 168 | + estimator = BaseTask( |
| 169 | + backend=backend, |
| 170 | + ensemble_size=0, |
| 171 | + ) |
| 172 | + |
| 173 | + # Setup pre-requisites normally set by search() |
| 174 | + estimator._logger = estimator._get_logger('test') |
| 175 | + estimator.task_type = "tabular_classification" |
| 176 | + estimator._memory_limit = 60 |
| 177 | + estimator.dataset = MagicMock(spec=BaseDataset) |
| 178 | + estimator.dataset.output_type = 'binary' |
| 179 | + estimator.dataset.dataset_name = 'dummy' |
| 180 | + |
| 181 | + proc_ensemble = estimator._init_ensemble_builder( |
| 182 | + time_left_for_ensembles=60, |
| 183 | + optimize_metric='accuracy', |
| 184 | + ensemble_nbest=10, |
| 185 | + ensemble_size=5 |
| 186 | + ) |
| 187 | + |
| 188 | + assert isinstance(proc_ensemble, EnsembleBuilderManager) |
| 189 | + assert proc_ensemble.opt_metric == 'accuracy' |
| 190 | + assert proc_ensemble.metrics[0] == accuracy |
| 191 | + |
| 192 | + estimator._close_dask_client() |
| 193 | + estimator._clean_logger() |
| 194 | + |
| 195 | + del estimator |
0 commit comments