Skip to content

Commit cf03f4d

Browse files
committed
Updated Python BCF interface
1 parent 98ebce1 commit cf03f4d

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

stochtree/bcf.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def sample(
168168
* `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
169169
* `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the prognostic (`mu(X)`) forest. Defaults to `None`.
170170
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the prognostic (`mu(X)`) forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
171+
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
171172
172173
treatment_effect_forest_params : dict, optional
173174
Dictionary of treatment effect forest model parameters, each of which has a default value processed internally, so this argument is optional.
@@ -184,6 +185,7 @@ def sample(
184185
* `delta_max` (`float`): Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Defaults to `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`.
185186
* `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the treatment effect (`tau(X)`) forest. Defaults to `None`.
186187
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the treatment effect (`tau(X)`) forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
188+
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
187189
188190
variance_forest_params : dict, optional
189191
Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional.
@@ -199,6 +201,7 @@ def sample(
199201
* `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set here.
200202
* `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the variance forest. Defaults to `None`.
201203
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
204+
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
202205
203206
Returns
204207
-------
@@ -242,6 +245,7 @@ def sample(
242245
"sigma2_leaf_scale": None,
243246
"keep_vars": None,
244247
"drop_vars": None,
248+
"num_features_subsample": None,
245249
}
246250
prognostic_forest_params_updated = _preprocess_params(
247251
prognostic_forest_params_default, prognostic_forest_params
@@ -261,6 +265,7 @@ def sample(
261265
"delta_max": 0.9,
262266
"keep_vars": None,
263267
"drop_vars": None,
268+
"num_features_subsample": None,
264269
}
265270
treatment_effect_forest_params_updated = _preprocess_params(
266271
treatment_effect_forest_params_default, treatment_effect_forest_params
@@ -279,6 +284,7 @@ def sample(
279284
"var_forest_prior_scale": None,
280285
"keep_vars": None,
281286
"drop_vars": None,
287+
"num_features_subsample": None,
282288
}
283289
variance_forest_params_updated = _preprocess_params(
284290
variance_forest_params_default, variance_forest_params
@@ -316,6 +322,7 @@ def sample(
316322
b_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_scale"]
317323
keep_vars_mu = prognostic_forest_params_updated["keep_vars"]
318324
drop_vars_mu = prognostic_forest_params_updated["drop_vars"]
325+
num_features_subsample_mu = prognostic_forest_params_updated["num_features_subsample"]
319326

320327
# 3. Tau forest parameters
321328
num_trees_tau = treatment_effect_forest_params_updated["num_trees"]
@@ -334,6 +341,7 @@ def sample(
334341
delta_max = treatment_effect_forest_params_updated["delta_max"]
335342
keep_vars_tau = treatment_effect_forest_params_updated["keep_vars"]
336343
drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"]
344+
num_features_subsample_tau = treatment_effect_forest_params_updated["num_features_subsample"]
337345

338346
# 4. Variance forest parameters
339347
num_trees_variance = variance_forest_params_updated["num_trees"]
@@ -349,6 +357,7 @@ def sample(
349357
b_forest = variance_forest_params_updated["var_forest_prior_scale"]
350358
keep_vars_variance = variance_forest_params_updated["keep_vars"]
351359
drop_vars_variance = variance_forest_params_updated["drop_vars"]
360+
num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"]
352361

353362
# Override keep_gfr if there are no MCMC samples
354363
if num_mcmc == 0:
@@ -744,6 +753,19 @@ def sample(
744753
if not isinstance(keep_gfr, bool):
745754
raise ValueError("keep_gfr must be a bool")
746755

756+
# Covariate preprocessing
757+
self._covariate_preprocessor = CovariatePreprocessor()
758+
self._covariate_preprocessor.fit(X_train)
759+
X_train_processed = self._covariate_preprocessor.transform(X_train)
760+
if X_test is not None:
761+
X_test_processed = self._covariate_preprocessor.transform(X_test)
762+
feature_types = np.asarray(
763+
self._covariate_preprocessor._processed_feature_types
764+
)
765+
original_var_indices = (
766+
self._covariate_preprocessor.fetch_original_feature_indices()
767+
)
768+
747769
# Standardize the keep variable lists to numeric indices
748770
if keep_vars_mu is not None:
749771
if isinstance(keep_vars_mu, list):
@@ -1052,18 +1074,13 @@ def sample(
10521074
else:
10531075
variable_subset_variance = [i for i in range(X_train.shape[1])]
10541076

1055-
# Covariate preprocessing
1056-
self._covariate_preprocessor = CovariatePreprocessor()
1057-
self._covariate_preprocessor.fit(X_train)
1058-
X_train_processed = self._covariate_preprocessor.transform(X_train)
1059-
if X_test is not None:
1060-
X_test_processed = self._covariate_preprocessor.transform(X_test)
1061-
feature_types = np.asarray(
1062-
self._covariate_preprocessor._processed_feature_types
1063-
)
1064-
original_var_indices = (
1065-
self._covariate_preprocessor.fetch_original_feature_indices()
1066-
)
1077+
# Set num_features_subsample to default, ncol(X_train), if not already set
1078+
if num_features_subsample_mu is None:
1079+
num_features_subsample_mu = X_train.shape[1]
1080+
if num_features_subsample_tau is None:
1081+
num_features_subsample_tau = X_train.shape[1]
1082+
if num_features_subsample_variance is None:
1083+
num_features_subsample_variance = X_train.shape[1]
10671084

10681085
# Determine whether a test set is provided
10691086
self.has_test = X_test is not None
@@ -1519,6 +1536,7 @@ def sample(
15191536
leaf_model_type=leaf_model_mu,
15201537
leaf_model_scale=current_leaf_scale_mu,
15211538
cutpoint_grid_size=cutpoint_grid_size,
1539+
num_features_subsample=num_features_subsample_mu,
15221540
)
15231541
forest_sampler_mu = ForestSampler(
15241542
forest_dataset_train,
@@ -1539,6 +1557,7 @@ def sample(
15391557
leaf_model_type=leaf_model_tau,
15401558
leaf_model_scale=current_leaf_scale_tau,
15411559
cutpoint_grid_size=cutpoint_grid_size,
1560+
num_features_subsample=num_features_subsample_tau,
15421561
)
15431562
forest_sampler_tau = ForestSampler(
15441563
forest_dataset_train,
@@ -1561,6 +1580,7 @@ def sample(
15611580
cutpoint_grid_size=cutpoint_grid_size,
15621581
variance_forest_shape=a_forest,
15631582
variance_forest_scale=b_forest,
1583+
num_features_subsample=num_features_subsample_variance,
15641584
)
15651585
forest_sampler_variance = ForestSampler(
15661586
forest_dataset_train, global_model_config, forest_model_config_variance

0 commit comments

Comments
 (0)