@@ -1143,7 +1143,7 @@ def test_tta_regression(
11431143
11441144
11451145def _run_model_compare (
1146- task , model_list , data_config , trainer_config , optimizer_config , train , test , metric , rank_metric
1146+ task , model_list , data_config , trainer_config , optimizer_config , train , test , metric , rank_metric , custom_fit_params = {},
11471147):
11481148 model_list = copy .deepcopy (model_list )
11491149 if isinstance (model_list , list ):
@@ -1161,6 +1161,7 @@ def _run_model_compare(
11611161 metrics_params = metric [1 ],
11621162 metrics_prob_input = metric [2 ],
11631163 rank_metric = rank_metric ,
1164+ custom_fit_params = custom_fit_params ,
11641165 )
11651166
11661167
@@ -1248,6 +1249,50 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12481249 # best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
12491250 # assert best_model.model._get_name() in best_models
12501251
1252+ @pytest .mark .parametrize ("model_list" , ["lite" , MODEL_CONFIG_MODEL_SWEEP_TEST ])
1253+ @pytest .mark .parametrize ("continuous_cols" , [list (DATASET_CONTINUOUS_COLUMNS )])
1254+ @pytest .mark .parametrize ("categorical_cols" , [["HouseAgeBin" ]])
1255+ @pytest .mark .parametrize ("metric" , [
1256+ (["mean_squared_error" ], [{}], [False ]),
1257+ ])
1258+ @pytest .mark .parametrize ("rank_metric" , [("loss" , "lower_is_better" )])
1259+ @pytest .mark .parametrize (
1260+ "custom_fit_params" ,
1261+ [
1262+ {
1263+ "loss" : torch .nn .L1Loss (),
1264+ "metrics" : [fake_metric ],
1265+ "metrics_prob_inputs" : [True ],
1266+ "optimizer" : torch .optim .Adagrad ,
1267+ },
1268+ ]
1269+ )
1270+ def test_model_compare_custom (regression_data , model_list , continuous_cols , categorical_cols , metric , rank_metric , custom_fit_params ):
1271+ (train , test , target ) = regression_data
1272+ data_config = DataConfig (
1273+ target = target ,
1274+ continuous_cols = continuous_cols ,
1275+ categorical_cols = categorical_cols ,
1276+ handle_missing_values = True ,
1277+ handle_unknown_categories = True ,
1278+ )
1279+ trainer_config = TrainerConfig (
1280+ max_epochs = 3 ,
1281+ checkpoints = None ,
1282+ early_stopping = None ,
1283+ accelerator = "cpu" ,
1284+ fast_dev_run = True ,
1285+ )
1286+ optimizer_config = OptimizerConfig ()
1287+ comp_df , best_model = _run_model_compare (
1288+ "regression" , model_list , data_config , trainer_config , optimizer_config , train , test , metric , rank_metric , custom_fit_params = custom_fit_params
1289+ )
1290+ if model_list == "lite" :
1291+ assert len (comp_df ) == 3
1292+ else :
1293+ assert len (comp_df ) == len (model_list )
1294+ if custom_fit_params .get ("metric" , None ) == fake_metric :
1295+ assert "test_fake_metric" in comp_df .columns ()
12511296
12521297@pytest .mark .parametrize ("model_config_class" , MODEL_CONFIG_SAVE_TEST )
12531298@pytest .mark .parametrize ("continuous_cols" , [list (DATASET_CONTINUOUS_COLUMNS )])
0 commit comments