Skip to content

Commit 169e3fa

Browse files
committed
Updated python predict methods and tests
1 parent e96268f commit 169e3fa

File tree

4 files changed

+79
-51
lines changed

4 files changed

+79
-51
lines changed

stochtree/bart.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,13 +1803,6 @@ def predict(
18031803
# rfx_predictions = np.mean(rfx_predictions, axis = 1)
18041804

18051805
# Combine into y hat predictions
1806-
if predict_y_hat and has_mean_forest and has_rfx:
1807-
y_hat = mean_forest_predictions + rfx_predictions
1808-
elif predict_y_hat and has_mean_forest:
1809-
y_hat = mean_forest_predictions
1810-
elif predict_y_hat and has_rfx:
1811-
y_hat = rfx_predictions
1812-
18131806
if probability_scale:
18141807
if predict_y_hat and has_mean_forest and has_rfx:
18151808
y_hat = norm.ppf(mean_forest_predictions + rfx_predictions)

stochtree/bcf.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,14 +2400,28 @@ def predict(
24002400
forest_dataset_test.add_covariates(X_combined)
24012401
forest_dataset_test.add_basis(Z)
24022402

2403-
# Compute predicted outcome and decomposed outcome model terms
2403+
# Compute predictions from the variance forest (if included)
2404+
if predict_variance_forest:
2405+
sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict(
2406+
forest_dataset_test.dataset_cpp
2407+
)
2408+
if self.sample_sigma2_global:
2409+
sigma2_x = np.empty_like(sigma2_x_raw)
2410+
for i in range(self.num_samples):
2411+
sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i]
2412+
else:
2413+
sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std
2414+
if predict_mean:
2415+
sigma2_x = np.mean(sigma2_x, axis=1)
2416+
2417+
# Prognostic forest predictions
24042418
if predict_mu_forest or predict_mu_forest_intermediate:
24052419
mu_raw = self.forest_container_mu.forest_container_cpp.Predict(
24062420
forest_dataset_test.dataset_cpp
24072421
)
24082422
mu_x = mu_raw * self.y_std + self.y_bar
2409-
if predict_mean:
2410-
mu_x = np.mean(mu_x, axis=1)
2423+
2424+
# Treatment effect forest predictions
24112425
if predict_tau_forest or predict_tau_forest_intermediate:
24122426
tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw(
24132427
forest_dataset_test.dataset_cpp
@@ -2422,43 +2436,65 @@ def predict(
24222436
treatment_term = np.multiply(
24232437
np.atleast_3d(Z).swapaxes(1, 2), tau_x
24242438
).sum(axis=2)
2425-
if predict_mean:
2426-
treatment_term = np.mean(treatment_term, axis=1)
2427-
tau_x = np.mean(tau_x, axis=2)
24282439
else:
24292440
treatment_term = Z * np.squeeze(tau_x)
2430-
if predict_mean:
2431-
treatment_term = np.mean(treatment_term, axis=1)
2432-
tau_x = np.mean(tau_x, axis=1)
24332441

2442+
# Random effects predictions
24342443
if predict_rfx or predict_rfx_intermediate:
24352444
rfx_preds = (
24362445
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
24372446
)
24382447
if predict_mean:
24392448
rfx_preds = np.mean(rfx_preds, axis=1)
24402449

2450+
# Combine into y hat predictions
24412451
if predict_y_hat and has_mu_forest and has_rfx:
24422452
y_hat = mu_x + treatment_term + rfx_preds
24432453
elif predict_y_hat and has_mu_forest:
24442454
y_hat = mu_x + treatment_term
24452455
elif predict_y_hat and has_rfx:
24462456
y_hat = rfx_preds
2447-
2448-
# Compute predictions from the variance forest (if included)
2449-
if predict_variance_forest:
2450-
sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict(
2451-
forest_dataset_test.dataset_cpp
2452-
)
2453-
if self.sample_sigma2_global:
2454-
sigma2_x = np.empty_like(sigma2_x_raw)
2455-
for i in range(self.num_samples):
2456-
sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i]
2457+
2458+
needs_mean_term_preds = predict_y_hat or \
2459+
predict_mu_forest or \
2460+
predict_tau_forest or \
2461+
predict_rfx
2462+
if needs_mean_term_preds:
2463+
if probability_scale:
2464+
if has_rfx:
2465+
if predict_y_hat:
2466+
y_hat = norm.cdf(mu_x + treatment_term + rfx_preds)
2467+
if predict_rfx:
2468+
rfx_preds = norm.cdf(rfx_preds)
2469+
else:
2470+
if predict_y_hat:
2471+
y_hat = norm.cdf(mu_x + treatment_term)
2472+
if predict_mu_forest:
2473+
mu_x = norm.cdf(mu_x)
2474+
if predict_tau_forest:
2475+
tau_x = norm.cdf(tau_x)
24572476
else:
2458-
sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std
2459-
if predict_mean:
2460-
sigma2_x = np.mean(sigma2_x, axis=1)
2477+
if has_rfx:
2478+
if predict_y_hat:
2479+
y_hat = mu_x + treatment_term + rfx_preds
2480+
else:
2481+
if predict_y_hat:
2482+
y_hat = mu_x + treatment_term
24612483

2484+
# Collapse to posterior mean predictions if requested
2485+
if predict_mean:
2486+
if predict_mu_forest:
2487+
mu_x = np.mean(mu_x, axis=1)
2488+
if predict_tau_forest:
2489+
if Z.shape[1] > 1:
2490+
tau_x = np.mean(tau_x, axis=2)
2491+
else:
2492+
tau_x = np.mean(tau_x, axis=1)
2493+
if predict_rfx:
2494+
rfx_preds = np.mean(rfx_preds, axis=1)
2495+
if predict_y_hat:
2496+
y_hat = np.mean(y_hat, axis=1)
2497+
24622498
if predict_count == 1:
24632499
if predict_y_hat:
24642500
return y_hat
@@ -2754,6 +2790,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
27542790
self.internal_propensity_model = json_object_default.get_boolean(
27552791
"internal_propensity_model"
27562792
)
2793+
self.probit_outcome_model = json_object_default.get_boolean(
2794+
"probit_outcome_model"
2795+
)
27572796

27582797
# Unpack number of samples
27592798
for i in range(len(json_object_list)):

test/python/test_bcf.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def test_binary_bcf(self):
7676
assert mu_hat.shape == (n_test, num_mcmc)
7777
assert y_hat.shape == (n_test, num_mcmc)
7878

79-
# Check treatment effect prediction method
80-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
79+
# Check that we can predict just treatment effects
80+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
8181
assert tau_hat.shape == (n_test, num_mcmc)
8282

8383
# Run BCF without test set and with propensity score
@@ -106,7 +106,7 @@ def test_binary_bcf(self):
106106
assert mu_hat.shape == (n_test, num_mcmc)
107107
assert y_hat.shape == (n_test, num_mcmc)
108108
# Check treatment effect prediction method
109-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
109+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
110110
assert tau_hat.shape == (n_test, num_mcmc)
111111

112112
# Run BCF with test set and without propensity score
@@ -142,7 +142,7 @@ def test_binary_bcf(self):
142142
assert y_hat.shape == (n_test, num_mcmc)
143143

144144
# Check treatment effect prediction method
145-
tau_hat = bcf_model.predict_tau(X_test, Z_test)
145+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate")
146146
assert tau_hat.shape == (n_test, num_mcmc)
147147

148148
# Run BCF without test set and without propensity score
@@ -172,7 +172,7 @@ def test_binary_bcf(self):
172172
assert y_hat.shape == (n_test, num_mcmc)
173173

174174
# Check treatment effect prediction method
175-
tau_hat = bcf_model.predict_tau(X_test, Z_test)
175+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate")
176176

177177
def test_continuous_univariate_bcf(self):
178178
# RNG
@@ -245,7 +245,7 @@ def test_continuous_univariate_bcf(self):
245245
assert y_hat.shape == (n_test, num_mcmc)
246246

247247
# Check treatment effect prediction method
248-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
248+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
249249
assert tau_hat.shape == (n_test, num_mcmc)
250250

251251
# Run second BCF model with test set and propensity score
@@ -281,7 +281,7 @@ def test_continuous_univariate_bcf(self):
281281
assert y_hat_2.shape == (n_test, num_mcmc)
282282

283283
# Check treatment effect prediction method
284-
tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test, pi_test)
284+
tau_hat_2 = bcf_model_2.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
285285
assert tau_hat_2.shape == (n_test, num_mcmc)
286286

287287
# Combine into a single model
@@ -336,7 +336,7 @@ def test_continuous_univariate_bcf(self):
336336
assert y_hat.shape == (n_test, num_mcmc)
337337

338338
# Check treatment effect prediction method
339-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
339+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
340340
assert tau_hat.shape == (n_test, num_mcmc)
341341

342342
# Run BCF with test set and without propensity score
@@ -372,7 +372,7 @@ def test_continuous_univariate_bcf(self):
372372
assert y_hat.shape == (n_test, num_mcmc)
373373

374374
# Check treatment effect prediction method
375-
tau_hat = bcf_model.predict_tau(X_test, Z_test)
375+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate")
376376
assert tau_hat.shape == (n_test, num_mcmc)
377377

378378
# Run BCF without test set and without propensity score
@@ -402,7 +402,7 @@ def test_continuous_univariate_bcf(self):
402402
assert y_hat.shape == (n_test, num_mcmc)
403403

404404
# Check treatment effect prediction method
405-
tau_hat = bcf_model.predict_tau(X_test, Z_test)
405+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate")
406406

407407
# Run second BCF model with test set and propensity score
408408
bcf_model_2 = BCFModel()
@@ -430,7 +430,7 @@ def test_continuous_univariate_bcf(self):
430430
assert y_hat_2.shape == (n_test, num_mcmc)
431431

432432
# Check treatment effect prediction method
433-
tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test)
433+
tau_hat_2 = bcf_model_2.predict(X = X_test, Z = Z_test, terms = "cate")
434434
assert tau_hat_2.shape == (n_test, num_mcmc)
435435

436436
# Combine into a single model
@@ -528,7 +528,7 @@ def test_multivariate_bcf(self):
528528
assert y_hat.shape == (n_test, num_mcmc)
529529

530530
# Check treatment effect prediction method
531-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
531+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
532532
assert tau_hat.shape == (n_test, num_mcmc, treatment_dim)
533533

534534
# Run BCF without test set and with propensity score
@@ -558,7 +558,7 @@ def test_multivariate_bcf(self):
558558
assert y_hat.shape == (n_test, num_mcmc)
559559

560560
# Check treatment effect prediction method
561-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
561+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
562562
assert tau_hat.shape == (n_test, num_mcmc, treatment_dim)
563563

564564
# Run BCF with test set and without propensity score
@@ -665,7 +665,7 @@ def test_binary_bcf_heteroskedastic(self):
665665
assert sigma2_x_hat.shape == (n_test, num_mcmc)
666666

667667
# Check treatment effect prediction method
668-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
668+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
669669
assert tau_hat.shape == (n_test, num_mcmc)
670670

671671
# Run BCF without test set and with propensity score
@@ -715,7 +715,7 @@ def test_binary_bcf_heteroskedastic(self):
715715
)
716716

717717
# Check treatment effect prediction method
718-
tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test)
718+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate")
719719
assert tau_hat.shape == (n_test, num_mcmc)
720720

721721
# Run BCF with test set and without propensity score
@@ -752,7 +752,7 @@ def test_binary_bcf_heteroskedastic(self):
752752
assert bcf_preds['variance_forest_predictions'].shape == (n_test, num_mcmc)
753753

754754
# Check treatment effect prediction method
755-
tau_hat = bcf_model.predict_tau(X_test, Z_test)
755+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate")
756756
assert tau_hat.shape == (n_test, num_mcmc)
757757

758758
# Run BCF without test set and without propensity score
@@ -781,7 +781,7 @@ def test_binary_bcf_heteroskedastic(self):
781781
assert bcf_preds['y_hat'].shape == (n_test, num_mcmc)
782782

783783
# Check treatment effect prediction method
784-
tau_hat = bcf_model.predict_tau(X_test, Z_test)
784+
tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate")
785785

786786
def test_bcf_rfx_parameters(self):
787787
# RNG

test/python/test_predict.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,6 @@ def test_bart_prediction(self):
263263
def test_bcf_prediction(self):
264264
# Generate data and test/train split
265265
rng = np.random.default_rng(1234)
266-
267-
268-
# Convert the R code down below to Python
269-
rng = np.random.default_rng(1234)
270266
n = 100
271267
g = lambda x: np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4))
272268
x1 = rng.normal(size=n)
@@ -328,7 +324,7 @@ def g(x5):
328324
num_mcmc = 10
329325
)
330326

331-
# Check that the default predict method returns a list
327+
# Check that the default predict method returns a dictionary
332328
pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test)
333329
y_hat_posterior_test = pred['y_hat']
334330
assert y_hat_posterior_test.shape == (20, 10)

0 commit comments

Comments
 (0)