Skip to content

Commit c00e480

Browse files
committed
Reformatting R and Python code
1 parent 5d6535c commit c00e480

File tree

3 files changed

+54
-35
lines changed

3 files changed

+54
-35
lines changed

R/bart.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,9 +1971,7 @@ predict.bartmodel <- function(
19711971
if ((predict_rfx) && (is.null(rfx_basis)) && (!rfx_intercept)) {
19721972
stop("Random effects basis (rfx_basis) must be provided for this model")
19731973
}
1974-
if (
1975-
(object$model_params$num_rfx_basis > 0) && (!rfx_intercept)
1976-
) {
1974+
if ((object$model_params$num_rfx_basis > 0) && (!rfx_intercept)) {
19771975
if (ncol(rfx_basis) != object$model_params$num_rfx_basis) {
19781976
stop(
19791977
"Random effects basis has a different dimension than the basis used to train this model"
@@ -2020,7 +2018,7 @@ predict.bartmodel <- function(
20202018
}
20212019
}
20222020
}
2023-
2021+
20242022
# Create prediction dataset
20252023
if (!is.null(leaf_basis)) {
20262024
prediction_dataset <- createForestDataset(covariates, leaf_basis)
@@ -2072,7 +2070,9 @@ predict.bartmodel <- function(
20722070
} else {
20732071
# Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only"
20742072
if (!rfx_intercept) {
2075-
stop("rfx_basis must be provided for random effects models with random slopes")
2073+
stop(
2074+
"rfx_basis must be provided for random effects models with random slopes"
2075+
)
20762076
}
20772077

20782078
# Extract the raw RFX samples and scale by train set outcome standard deviation
@@ -2093,7 +2093,7 @@ predict.bartmodel <- function(
20932093
rfx_beta_draws[, rfx_group_ids[i], ]
20942094
}
20952095

2096-
# Intercept-only model, so the random effect prediction is simply the
2096+
# Intercept-only model, so the random effect prediction is simply the
20972097
# value of the respective group's intercept coefficient for each observation
20982098
rfx_predictions = rfx_predictions_raw[, 1, ]
20992099
}

stochtree/bart.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ def sample(
260260
"variance_prior_shape": 1.0,
261261
"variance_prior_scale": 1.0,
262262
}
263-
rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params)
263+
rfx_params_updated = _preprocess_params(
264+
rfx_params_default, random_effects_params
265+
)
264266

265267
### Unpack all parameter values
266268
# 1. General parameters
@@ -1459,7 +1461,9 @@ def sample(
14591461
forest_dataset_train
14601462
)
14611463
if self.has_rfx:
1462-
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1464+
rfx_pred = rfx_model.predict(
1465+
rfx_dataset_train, rfx_tracker
1466+
)
14631467
outcome_pred = outcome_pred + rfx_pred
14641468
mu0 = outcome_pred[y_train[:, 0] == 0]
14651469
mu1 = outcome_pred[y_train[:, 0] == 1]
@@ -1835,8 +1839,10 @@ def predict(
18351839
if rfx_basis.shape[0] != covariates.shape[0]:
18361840
raise ValueError("X and rfx_basis must have the same number of rows")
18371841
if rfx_basis.shape[1] != self.num_rfx_basis:
1838-
raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
1839-
1842+
raise ValueError(
1843+
"rfx_basis must have the same number of columns as the random effects basis used to sample this model"
1844+
)
1845+
18401846
# Random effects predictions
18411847
if predict_rfx or predict_rfx_intermediate:
18421848
if rfx_basis is not None:
@@ -1849,10 +1855,10 @@ def predict(
18491855
raise ValueError(
18501856
"rfx_basis must be provided for random effects models with random slopes"
18511857
)
1852-
1858+
18531859
# Extract the raw RFX samples and scale by train set outcome standard deviation
18541860
rfx_samples_raw = self.rfx_container.extract_parameter_samples()
1855-
rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std
1861+
rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std
18561862

18571863
# Construct an array with the appropriate group random effects arranged for each observation
18581864
n_train = covariates.shape[0]
@@ -1861,13 +1867,15 @@ def predict(
18611867
"BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices"
18621868
)
18631869
else:
1864-
rfx_predictions_raw = np.empty(shape=(n_train, 1, rfx_beta_draws.shape[1]))
1870+
rfx_predictions_raw = np.empty(
1871+
shape=(n_train, 1, rfx_beta_draws.shape[1])
1872+
)
18651873
for i in range(n_train):
18661874
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[
18671875
rfx_group_ids[i], :
18681876
]
18691877
rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :])
1870-
1878+
18711879
# Combine into y hat predictions
18721880
if probability_scale:
18731881
if predict_y_hat and has_mean_forest and has_rfx:
@@ -2583,9 +2591,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
25832591
self.probit_outcome_model = json_object_default.get_boolean(
25842592
"probit_outcome_model"
25852593
)
2586-
self.rfx_model_spec = json_object_default.get_string(
2587-
"rfx_model_spec"
2588-
)
2594+
self.rfx_model_spec = json_object_default.get_string("rfx_model_spec")
25892595

25902596
# Unpack number of samples
25912597
for i in range(len(json_object_list)):

stochtree/bcf.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ def sample(
317317
"variance_prior_shape": 1.0,
318318
"variance_prior_scale": 1.0,
319319
}
320-
rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params)
320+
rfx_params_updated = _preprocess_params(
321+
rfx_params_default, random_effects_params
322+
)
321323

322324
### Unpack all parameter values
323325
# 1. General parameters
@@ -413,8 +415,14 @@ def sample(
413415
# Check random effects specification
414416
if not isinstance(self.rfx_model_spec, str):
415417
raise ValueError("rfx_model_spec must be a string")
416-
if self.rfx_model_spec not in ["custom", "intercept_only", "intercept_plus_treatment"]:
417-
raise ValueError("type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'")
418+
if self.rfx_model_spec not in [
419+
"custom",
420+
"intercept_only",
421+
"intercept_plus_treatment",
422+
]:
423+
raise ValueError(
424+
"type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'"
425+
)
418426

419427
# Override keep_gfr if there are no MCMC samples
420428
if num_mcmc == 0:
@@ -2295,7 +2303,7 @@ def predict(
22952303
) -> Union[dict[str, np.array], np.array]:
22962304
"""Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation.
22972305
Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function.
2298-
When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or
2306+
When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or
22992307
partially included in mu_x and partially included in tau_x `rfx_model_spec == "intercept_plus_treatment"`.
23002308
23012309
Parameters
@@ -2508,9 +2516,13 @@ def predict(
25082516
if rfx_basis.ndim == 1:
25092517
rfx_basis = np.expand_dims(rfx_basis, 1)
25102518
if rfx_basis.shape[0] != X.shape[0]:
2511-
raise ValueError("X and rfx_basis must have the same number of rows")
2519+
raise ValueError(
2520+
"X and rfx_basis must have the same number of rows"
2521+
)
25122522
if rfx_basis.shape[1] != self.num_rfx_basis:
2513-
raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
2523+
raise ValueError(
2524+
"rfx_basis must have the same number of columns as the random effects basis used to sample this model"
2525+
)
25142526

25152527
# Random effects predictions
25162528
if predict_rfx or predict_rfx_intermediate:
@@ -2522,26 +2534,28 @@ def predict(
25222534
if predict_rfx_raw:
25232535
# Extract the raw RFX samples and scale by train set outcome standard deviation
25242536
rfx_samples_raw = self.rfx_container.extract_parameter_samples()
2525-
rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std
2537+
rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std
25262538

25272539
# Construct an array with the appropriate group random effects arranged for each observation
25282540
if rfx_beta_draws.ndim == 3:
2529-
rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2]))
2541+
rfx_predictions_raw = np.empty(
2542+
shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2])
2543+
)
25302544
for i in range(X.shape[0]):
25312545
rfx_predictions_raw[i, :, :] = rfx_beta_draws[
25322546
:, rfx_group_ids[i], :
25332547
]
25342548
elif rfx_beta_draws.ndim == 2:
2535-
rfx_predictions_raw = np.empty(shape=(X.shape[0], 1, rfx_beta_draws.shape[1]))
2549+
rfx_predictions_raw = np.empty(
2550+
shape=(X.shape[0], 1, rfx_beta_draws.shape[1])
2551+
)
25362552
for i in range(X.shape[0]):
2537-
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[
2538-
rfx_group_ids[i], :
2539-
]
2553+
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_ids[i], :]
25402554
else:
25412555
raise ValueError(
25422556
"Unexpected number of dimensions in extracted random effects samples"
25432557
)
2544-
2558+
25452559
# Add raw RFX predictions to mu and tau if warranted by the RFX model spec
25462560
if predict_mu_forest or predict_mu_forest_intermediate:
25472561
if rfx_intercept and predict_rfx_raw:
@@ -2553,7 +2567,6 @@ def predict(
25532567
tau_x = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :])
25542568
else:
25552569
tau_x = tau_x_forest
2556-
25572570

25582571
# Combine into y hat predictions
25592572
needs_mean_term_preds = (
@@ -3308,7 +3321,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
33083321
self.has_rfx = json_object_default.get_boolean("has_rfx")
33093322
self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis")
33103323
self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis")
3311-
self.multivariate_treatment = json_object_default.get_boolean("multivariate_treatment")
3324+
self.multivariate_treatment = json_object_default.get_boolean(
3325+
"multivariate_treatment"
3326+
)
33123327
if self.has_rfx:
33133328
self.rfx_container = RandomEffectsContainer()
33143329
for i in range(len(json_object_list)):
@@ -3344,9 +3359,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
33443359
self.probit_outcome_model = json_object_default.get_boolean(
33453360
"probit_outcome_model"
33463361
)
3347-
self.rfx_model_spec = json_object_default.get_string(
3348-
"rfx_model_spec"
3349-
)
3362+
self.rfx_model_spec = json_object_default.get_string("rfx_model_spec")
33503363

33513364
# Unpack number of samples
33523365
for i in range(len(json_object_list)):

0 commit comments

Comments
 (0)