Skip to content

Commit c83e7d6

Browse files
committed
autoPyTorch/api/
1 parent 2569602 commit c83e7d6

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

test/test_api/test_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@ def test_tabular_input_support(openml_id, backend):
428428
estimator = TabularClassificationTask(
429429
backend=backend,
430430
resampling_strategy=HoldoutValTypes.holdout_validation,
431-
ensemble_size=0,
432431
)
433432

434433
estimator._do_dummy_prediction = unittest.mock.MagicMock()
@@ -443,6 +442,7 @@ def test_tabular_input_support(openml_id, backend):
443442
func_eval_time_limit_secs=50,
444443
enable_traditional_pipeline=False,
445444
load_models=False,
445+
ensemble_size=0,
446446
)
447447

448448

@@ -452,7 +452,6 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
452452
estimator = TabularClassificationTask(
453453
backend=backend,
454454
resampling_strategy=HoldoutValTypes.holdout_validation,
455-
ensemble_size=0,
456455
)
457456

458457
# Setup pre-requisites normally set by search()

test/test_api/test_base_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_set_pipeline_config():
117117
])
118118
def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, budget_type, expected):
119119
BaseTask.__abstractmethods__ = set()
120-
estimator = BaseTask(task_type='tabular_classification', ensemble_size=0)
120+
estimator = BaseTask(task_type='tabular_classification')
121121

122122
# Fixture pipeline config
123123
default_pipeline_config = {
@@ -140,7 +140,7 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud
140140
smac_mock.return_value = smac
141141
estimator._search(optimize_metric='accuracy', dataset=dataset, tae_func=pipeline_fit,
142142
min_budget=min_budget, max_budget=max_budget, budget_type=budget_type,
143-
enable_traditional_pipeline=False,
143+
ensemble_size=0, enable_traditional_pipeline=False,
144144
total_walltime_limit=20, func_eval_time_limit_secs=10,
145145
load_models=False)
146146
assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config
@@ -152,7 +152,6 @@ def test_init_ensemble_builder(backend):
152152
BaseTask.__abstractmethods__ = set()
153153
estimator = BaseTask(
154154
backend=backend,
155-
ensemble_size=0,
156155
)
157156

158157
# Setup pre-requisites normally set by search()

0 commit comments

Comments
 (0)