Skip to content

Commit 19302ac

Browse files
committed
Fixed bug in BCF probit RFX in Python
1 parent 8426d98 commit 19302ac

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

stochtree/bart.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,16 +1800,12 @@ def predict(
18001800
pred_dataset.dataset_cpp
18011801
)
18021802
mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar
1803-
# if predict_mean:
1804-
# mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1)
18051803

18061804
# Random effects predictions
18071805
if predict_rfx or predict_rfx_intermediate:
18081806
rfx_predictions = (
18091807
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
18101808
)
1811-
# if predict_mean:
1812-
# rfx_predictions = np.mean(rfx_predictions, axis = 1)
18131809

18141810
# Combine into y hat predictions
18151811
if probability_scale:

stochtree/bcf.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,21 +1815,21 @@ def sample(
18151815

18161816
# Sample coding parameters (if requested)
18171817
if self.adaptive_coding:
1818-
partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train)
1818+
mu_x = active_forest_mu.predict_raw(forest_dataset_train)
18191819
tau_x = np.squeeze(
18201820
active_forest_tau.predict_raw(forest_dataset_train)
18211821
)
1822+
partial_resid_train = np.squeeze(resid_train - mu_x)
18221823
if self.has_rfx:
1823-
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1824-
partial_outcome_pred = partial_outcome_pred + rfx_pred
1824+
rfx_pred = np.squeeze(rfx_model.predict(rfx_dataset_train, rfx_tracker))
1825+
partial_resid_train = partial_resid_train - rfx_pred
18251826
s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0))
18261827
s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1))
1827-
partial_resid = np.squeeze(resid_train - partial_outcome_pred)
18281828
s_ty0 = np.sum(
1829-
tau_x * partial_resid * (np.squeeze(Z_train) == 0)
1829+
tau_x * partial_resid_train * (np.squeeze(Z_train) == 0)
18301830
)
18311831
s_ty1 = np.sum(
1832-
tau_x * partial_resid * (np.squeeze(Z_train) == 1)
1832+
tau_x * partial_resid_train * (np.squeeze(Z_train) == 1)
18331833
)
18341834
current_b_0 = self.rng.normal(
18351835
loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),
@@ -2021,21 +2021,21 @@ def sample(
20212021

20222022
# Sample coding parameters (if requested)
20232023
if self.adaptive_coding:
2024-
partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train)
2024+
mu_x = active_forest_mu.predict_raw(forest_dataset_train)
20252025
tau_x = np.squeeze(
20262026
active_forest_tau.predict_raw(forest_dataset_train)
20272027
)
2028+
partial_resid_train = np.squeeze(resid_train - mu_x)
20282029
if self.has_rfx:
2029-
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
2030-
partial_outcome_pred = partial_outcome_pred + rfx_pred
2030+
rfx_pred = np.squeeze(rfx_model.predict(rfx_dataset_train, rfx_tracker))
2031+
partial_resid_train = partial_resid_train - rfx_pred
20312032
s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0))
20322033
s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1))
2033-
partial_resid = np.squeeze(resid_train - partial_outcome_pred)
20342034
s_ty0 = np.sum(
2035-
tau_x * partial_resid * (np.squeeze(Z_train) == 0)
2035+
tau_x * partial_resid_train * (np.squeeze(Z_train) == 0)
20362036
)
20372037
s_ty1 = np.sum(
2038-
tau_x * partial_resid * (np.squeeze(Z_train) == 1)
2038+
tau_x * partial_resid_train * (np.squeeze(Z_train) == 1)
20392039
)
20402040
current_b_0 = self.rng.normal(
20412041
loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),
@@ -2459,8 +2459,6 @@ def predict(
24592459
rfx_preds = (
24602460
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
24612461
)
2462-
if predict_mean:
2463-
rfx_preds = np.mean(rfx_preds, axis=1)
24642462

24652463
# Combine into y hat predictions
24662464
if predict_y_hat and has_mu_forest and has_rfx:

0 commit comments

Comments
 (0)