Skip to content

Commit 8b2d62e

Browse files
committed
Added support for random intercept specification in python BART
1 parent 784fac9 commit 8b2d62e

File tree

2 files changed

+105
-33
lines changed

2 files changed

+105
-33
lines changed

stochtree/bart.py

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def sample(
8585
general_params: Optional[Dict[str, Any]] = None,
8686
mean_forest_params: Optional[Dict[str, Any]] = None,
8787
variance_forest_params: Optional[Dict[str, Any]] = None,
88-
rfx_params: Optional[Dict[str, Any]] = None,
88+
random_effects_params: Optional[Dict[str, Any]] = None,
8989
previous_model_json: Optional[str] = None,
9090
previous_model_warmstart_sample_num: Optional[int] = None,
9191
) -> None:
@@ -170,9 +170,10 @@ def sample(
170170
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
171171
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
172172
173-
rfx_params : dict, optional
173+
random_effects_params : dict, optional
174174
Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional.
175175
176+
* `model_spec`: Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
176177
* `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
177178
* `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
178179
* `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
@@ -251,14 +252,15 @@ def sample(
251252

252253
# Update random effects parameters
253254
rfx_params_default = {
255+
"model_spec": "custom",
254256
"working_parameter_prior_mean": None,
255257
"group_parameter_prior_mean": None,
256258
"working_parameter_prior_cov": None,
257259
"group_parameter_prior_cov": None,
258260
"variance_prior_shape": 1.0,
259261
"variance_prior_scale": 1.0,
260262
}
261-
rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params)
263+
rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params)
262264

263265
### Unpack all parameter values
264266
# 1. General parameters
@@ -312,6 +314,7 @@ def sample(
312314
]
313315

314316
# 4. Random effects parameters
317+
self.rfx_model_spec = rfx_params_updated["model_spec"]
315318
rfx_working_parameter_prior_mean = rfx_params_updated[
316319
"working_parameter_prior_mean"
317320
]
@@ -325,6 +328,12 @@ def sample(
325328
rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"]
326329
rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"]
327330

331+
# Check random effects specification
332+
if not isinstance(self.rfx_model_spec, str):
333+
raise ValueError("rfx_model_spec must be a string")
334+
if self.rfx_model_spec not in ["custom", "intercept_only"]:
335+
raise ValueError("type must either be 'custom' or 'intercept_only'")
336+
328337
# Override keep_gfr if there are no MCMC samples
329338
if num_mcmc == 0:
330339
keep_gfr = True
@@ -980,24 +989,35 @@ def sample(
980989
"All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train"
981990
)
982991

983-
# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
984-
has_basis_rfx = False
992+
# Handle the rfx basis matrices
993+
self.has_rfx_basis = False
994+
self.num_rfx_basis = 0
985995
if self.has_rfx:
986-
if rfx_basis_train is None:
987-
rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1))
988-
else:
989-
has_basis_rfx = True
996+
if self.rfx_model_spec == "custom":
997+
if rfx_basis_train is None:
998+
raise ValueError(
999+
"rfx_basis_train must be provided when rfx_model_spec = 'custom'"
1000+
)
1001+
elif self.rfx_model_spec == "intercept_only":
1002+
if rfx_basis_train is None:
1003+
rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1))
1004+
self.has_rfx_basis = True
1005+
self.num_rfx_basis = rfx_basis_train.shape[1]
9901006
num_rfx_groups = np.unique(rfx_group_ids_train).shape[0]
9911007
num_rfx_components = rfx_basis_train.shape[1]
992-
# TODO warn if num_rfx_groups is 1
1008+
if num_rfx_groups == 1:
1009+
warnings.warn(
1010+
"Only one group was provided for random effect sampling, so the random effects model is likely overkill"
1011+
)
9931012
if has_rfx_test:
994-
if rfx_basis_test is None:
995-
if has_basis_rfx:
1013+
if self.rfx_model_spec == "custom":
1014+
if rfx_basis_test is None:
9961015
raise ValueError(
997-
"Random effects basis provided for training set, must also be provided for the test set"
1016+
"rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided"
9981017
)
999-
rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1))
1000-
1018+
elif self.rfx_model_spec == "intercept_only":
1019+
if rfx_basis_test is None:
1020+
rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1))
10011021
# Set up random effects structures
10021022
if self.has_rfx:
10031023
# Prior parameters
@@ -1676,6 +1696,8 @@ def predict(
16761696
predict_mean = type == "mean"
16771697

16781698
# Handle prediction terms
1699+
rfx_model_spec = self.rfx_model_spec
1700+
rfx_intercept = rfx_model_spec == "intercept_only"
16791701
if not isinstance(terms, str) and not isinstance(terms, list):
16801702
raise ValueError("type must be a string or list of strings")
16811703
num_terms = 1 if isinstance(terms, str) else len(terms)
@@ -1801,12 +1823,51 @@ def predict(
18011823
)
18021824
mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar
18031825

1826+
# Random effects data checks
1827+
if has_rfx:
1828+
if rfx_group_ids is None:
1829+
raise ValueError(
1830+
"rfx_group_ids must be provided if rfx_basis is provided"
1831+
)
1832+
if rfx_basis is not None:
1833+
if rfx_basis.ndim == 1:
1834+
rfx_basis = np.expand_dims(rfx_basis, 1)
1835+
if rfx_basis.shape[0] != covariates.shape[0]:
1836+
raise ValueError("X and rfx_basis must have the same number of rows")
1837+
if rfx_basis.shape[1] != self.num_rfx_basis:
1838+
raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
1839+
18041840
# Random effects predictions
18051841
if predict_rfx or predict_rfx_intermediate:
1806-
rfx_predictions = (
1807-
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
1808-
)
1809-
1842+
if rfx_basis is not None:
1843+
rfx_predictions = (
1844+
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
1845+
)
1846+
else:
1847+
# Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only"
1848+
if not rfx_intercept:
1849+
raise ValueError(
1850+
"rfx_basis must be provided for random effects models with random slopes"
1851+
)
1852+
1853+
# Extract the raw RFX samples and scale by train set outcome standard deviation
1854+
rfx_samples_raw = self.rfx_container.extract_parameter_samples()
1855+
rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std
1856+
1857+
# Construct an array with the appropriate group random effects arranged for each observation
1858+
n_train = covariates.shape[0]
1859+
if rfx_beta_draws.ndim != 2:
1860+
raise ValueError(
1861+
"BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices"
1862+
)
1863+
else:
1864+
rfx_predictions_raw = np.empty(shape=(n_train, 1, rfx_beta_draws.shape[1]))
1865+
for i in range(n_train):
1866+
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[
1867+
rfx_group_ids[i], :
1868+
]
1869+
rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :])
1870+
18101871
# Combine into y hat predictions
18111872
if probability_scale:
18121873
if predict_y_hat and has_mean_forest and has_rfx:

stochtree/bcf.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def sample(
208208
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
209209
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
210210
211-
rfx_params : dict, optional
211+
random_effects_params : dict, optional
212212
Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional.
213213
214214
* `model_spec`: Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
@@ -2504,13 +2504,13 @@ def predict(
25042504
raise ValueError(
25052505
"rfx_group_ids must be provided if rfx_basis is provided"
25062506
)
2507-
if rfx_basis is not None:
2508-
if rfx_basis.ndim == 1:
2509-
rfx_basis = np.expand_dims(rfx_basis, 1)
2510-
if rfx_basis.shape[0] != X.shape[0]:
2511-
raise ValueError("X and rfx_basis must have the same number of rows")
2512-
if rfx_basis.shape[1] != self.num_rfx_basis:
2513-
raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
2507+
if rfx_basis is not None:
2508+
if rfx_basis.ndim == 1:
2509+
rfx_basis = np.expand_dims(rfx_basis, 1)
2510+
if rfx_basis.shape[0] != X.shape[0]:
2511+
raise ValueError("X and rfx_basis must have the same number of rows")
2512+
if rfx_basis.shape[1] != self.num_rfx_basis:
2513+
raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
25142514

25152515
# Random effects predictions
25162516
if predict_rfx or predict_rfx_intermediate:
@@ -2524,12 +2524,23 @@ def predict(
25242524
rfx_samples_raw = self.rfx_container.extract_parameter_samples()
25252525
rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std
25262526

2527-
# Construct a matrix with the appropriate group random effects arranged for each observation
2528-
rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2]))
2529-
for i in range(X.shape[0]):
2530-
rfx_predictions_raw[i, :, :] = rfx_beta_draws[
2531-
:, rfx_group_ids[i], :
2532-
]
2527+
# Construct an array with the appropriate group random effects arranged for each observation
2528+
if rfx_beta_draws.ndim == 3:
2529+
rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2]))
2530+
for i in range(X.shape[0]):
2531+
rfx_predictions_raw[i, :, :] = rfx_beta_draws[
2532+
:, rfx_group_ids[i], :
2533+
]
2534+
elif rfx_beta_draws.ndim == 2:
2535+
rfx_predictions_raw = np.empty(shape=(X.shape[0], 1, rfx_beta_draws.shape[1]))
2536+
for i in range(X.shape[0]):
2537+
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[
2538+
rfx_group_ids[i], :
2539+
]
2540+
else:
2541+
raise ValueError(
2542+
"Unexpected number of dimensions in extracted random effects samples"
2543+
)
25332544

25342545
# Add raw RFX predictions to mu and tau if warranted by the RFX model spec
25352546
if predict_mu_forest or predict_mu_forest_intermediate:

0 commit comments

Comments
 (0)