Skip to content

Commit 6afe604

Browse files
committed
add custom loss, optim, metrics for model_sweep
1 parent 023db27 commit 6afe604

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

src/pytorch_tabular/tabular_model_sweep.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _validate_args(
9292
experiment_config: Optional[Union[ExperimentConfig, str]] = None,
9393
common_model_args: Optional[dict] = {},
9494
rank_metric: Optional[str] = "loss",
95+
custom_fit_params: Optional[dict] = {},
9596
):
9697
assert task in [
9798
"classification",
@@ -149,6 +150,8 @@ def _validate_args(
149150
"lower_is_better",
150151
"higher_is_better",
151152
], "rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}"
153+
if "metrics" in custom_fit_params.keys():
154+
assert rank_metric[0] == "loss", "only loss is supported as the rank_metric when using custom metrics"
152155

153156

154157
def model_sweep(
@@ -172,6 +175,7 @@ def model_sweep(
172175
progress_bar: bool = True,
173176
verbose: bool = True,
174177
suppress_lightning_logger: bool = True,
178+
custom_fit_params: Optional[dict] = {},
175179
):
176180
"""Compare multiple models on the same dataset.
177181
@@ -231,6 +235,10 @@ def model_sweep(
231235
232236
suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.
233237
238+
custom_fit_params (dict, optional): A dict specifying custom loss, metrics and optimizer.
239+
The behviour of these custom parameters is similar to those passed through the `fit` method
240+
of `TabularModel`.
241+
234242
Returns:
235243
results: Training results.
236244
@@ -252,6 +260,7 @@ def model_sweep(
252260
experiment_config=experiment_config,
253261
common_model_args=common_model_args,
254262
rank_metric=rank_metric,
263+
custom_fit_params=custom_fit_params,
255264
)
256265
if suppress_lightning_logger:
257266
suppress_lightning_logs()
@@ -326,7 +335,7 @@ def _init_tabular_model(m):
326335
name = tabular_model.name
327336
if verbose:
328337
logger.info(f"Training {name}")
329-
model = tabular_model.prepare_model(datamodule)
338+
model = tabular_model.prepare_model(datamodule, **custom_fit_params)
330339
if progress_bar:
331340
progress.update(task_p, description=f"Training {name}", advance=1)
332341
with OutOfMemoryHandler(handle_oom=True) as handler:

tests/test_common.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ def test_tta_regression(
11431143

11441144

11451145
def _run_model_compare(
1146-
task, model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric
1146+
task, model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric, custom_fit_params={},
11471147
):
11481148
model_list = copy.deepcopy(model_list)
11491149
if isinstance(model_list, list):
@@ -1161,6 +1161,7 @@ def _run_model_compare(
11611161
metrics_params=metric[1],
11621162
metrics_prob_input=metric[2],
11631163
rank_metric=rank_metric,
1164+
custom_fit_params=custom_fit_params,
11641165
)
11651166

11661167

@@ -1248,6 +1249,50 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12481249
# best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
12491250
# assert best_model.model._get_name() in best_models
12501251

1252+
@pytest.mark.parametrize("model_list", ["lite", MODEL_CONFIG_MODEL_SWEEP_TEST])
1253+
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
1254+
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
1255+
@pytest.mark.parametrize("metric", [
1256+
(["mean_squared_error"], [{}], [False]),
1257+
])
1258+
@pytest.mark.parametrize("rank_metric", [("loss", "lower_is_better")])
1259+
@pytest.mark.parametrize(
1260+
"custom_fit_params",
1261+
[
1262+
{
1263+
"loss": torch.nn.L1Loss(),
1264+
"metrics": [fake_metric],
1265+
"metrics_prob_inputs": [True],
1266+
"optimizer": torch.optim.Adagrad,
1267+
},
1268+
]
1269+
)
1270+
def test_model_compare_custom(regression_data, model_list, continuous_cols, categorical_cols, metric, rank_metric, custom_fit_params):
1271+
(train, test, target) = regression_data
1272+
data_config = DataConfig(
1273+
target=target,
1274+
continuous_cols=continuous_cols,
1275+
categorical_cols=categorical_cols,
1276+
handle_missing_values=True,
1277+
handle_unknown_categories=True,
1278+
)
1279+
trainer_config = TrainerConfig(
1280+
max_epochs=3,
1281+
checkpoints=None,
1282+
early_stopping=None,
1283+
accelerator="cpu",
1284+
fast_dev_run=True,
1285+
)
1286+
optimizer_config = OptimizerConfig()
1287+
comp_df, best_model = _run_model_compare(
1288+
"regression", model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric, custom_fit_params=custom_fit_params
1289+
)
1290+
if model_list == "lite":
1291+
assert len(comp_df) == 3
1292+
else:
1293+
assert len(comp_df) == len(model_list)
1294+
if custom_fit_params.get("metric", None) == fake_metric:
1295+
assert "test_fake_metric" in comp_df.columns()
12511296

12521297
@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
12531298
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])

0 commit comments

Comments
 (0)