Skip to content

Commit 6ee1a9b

Browse files
committed
Fixed bug at the interface of probit / adaptive coding / RFX in Python
1 parent f572cb9 commit 6ee1a9b

File tree

3 files changed

+75
-35
lines changed

3 files changed

+75
-35
lines changed

stochtree/bart.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,9 +1233,12 @@ def sample(
12331233
if self.include_mean_forest:
12341234
if self.probit_outcome_model:
12351235
# Sample latent probit variable z | -
1236-
forest_pred = active_forest_mean.predict(forest_dataset_train)
1237-
mu0 = forest_pred[y_train[:, 0] == 0]
1238-
mu1 = forest_pred[y_train[:, 0] == 1]
1236+
outcome_pred = active_forest_mean.predict(forest_dataset_train)
1237+
if self.has_rfx:
1238+
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1239+
outcome_pred = outcome_pred + rfx_pred
1240+
mu0 = outcome_pred[y_train[:, 0] == 0]
1241+
mu1 = outcome_pred[y_train[:, 0] == 1]
12391242
n0 = np.sum(y_train[:, 0] == 0)
12401243
n1 = np.sum(y_train[:, 0] == 1)
12411244
u0 = self.rng.uniform(
@@ -1252,7 +1255,7 @@ def sample(
12521255
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)
12531256

12541257
# Update outcome
1255-
new_outcome = np.squeeze(resid_train) - forest_pred
1258+
new_outcome = np.squeeze(resid_train) - outcome_pred
12561259
residual_train.update_data(new_outcome)
12571260

12581261
# Sample the mean forest
@@ -1437,11 +1440,14 @@ def sample(
14371440
if self.include_mean_forest:
14381441
if self.probit_outcome_model:
14391442
# Sample latent probit variable z | -
1440-
forest_pred = active_forest_mean.predict(
1443+
outcome_pred = active_forest_mean.predict(
14411444
forest_dataset_train
14421445
)
1443-
mu0 = forest_pred[y_train[:, 0] == 0]
1444-
mu1 = forest_pred[y_train[:, 0] == 1]
1446+
if self.has_rfx:
1447+
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1448+
outcome_pred = outcome_pred + rfx_pred
1449+
mu0 = outcome_pred[y_train[:, 0] == 0]
1450+
mu1 = outcome_pred[y_train[:, 0] == 1]
14451451
n0 = np.sum(y_train[:, 0] == 0)
14461452
n1 = np.sum(y_train[:, 0] == 1)
14471453
u0 = self.rng.uniform(
@@ -1458,7 +1464,7 @@ def sample(
14581464
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)
14591465

14601466
# Update outcome
1461-
new_outcome = np.squeeze(resid_train) - forest_pred
1467+
new_outcome = np.squeeze(resid_train) - outcome_pred
14621468
residual_train.update_data(new_outcome)
14631469

14641470
# Sample the mean forest
@@ -1813,15 +1819,15 @@ def predict(
18131819
# Combine into y hat predictions
18141820
if probability_scale:
18151821
if predict_y_hat and has_mean_forest and has_rfx:
1816-
y_hat = norm.ppf(mean_forest_predictions + rfx_predictions)
1817-
mean_forest_predictions = norm.ppf(mean_forest_predictions)
1818-
rfx_predictions = norm.ppf(rfx_predictions)
1822+
y_hat = norm.cdf(mean_forest_predictions + rfx_predictions)
1823+
mean_forest_predictions = norm.cdf(mean_forest_predictions)
1824+
rfx_predictions = norm.cdf(rfx_predictions)
18191825
elif predict_y_hat and has_mean_forest:
1820-
y_hat = norm.ppf(mean_forest_predictions)
1821-
mean_forest_predictions = norm.ppf(mean_forest_predictions)
1826+
y_hat = norm.cdf(mean_forest_predictions)
1827+
mean_forest_predictions = norm.cdf(mean_forest_predictions)
18221828
elif predict_y_hat and has_rfx:
1823-
y_hat = norm.ppf(rfx_predictions)
1824-
rfx_predictions = norm.ppf(rfx_predictions)
1829+
y_hat = norm.cdf(rfx_predictions)
1830+
rfx_predictions = norm.cdf(rfx_predictions)
18251831
else:
18261832
if predict_y_hat and has_mean_forest and has_rfx:
18271833
y_hat = mean_forest_predictions + rfx_predictions
@@ -2006,8 +2012,8 @@ def compute_contrast(
20062012

20072013
# Transform to probability scale if requested
20082014
if probability_scale:
2009-
treatment_preds = norm.ppf(treatment_preds)
2010-
control_preds = norm.ppf(control_preds)
2015+
treatment_preds = norm.cdf(treatment_preds)
2016+
control_preds = norm.cdf(control_preds)
20112017

20122018
# Compute and return contrast
20132019
if predict_mean:

stochtree/bcf.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,9 +1735,12 @@ def sample(
17351735
# Sample latent probit variable z | -
17361736
forest_pred_mu = active_forest_mu.predict(forest_dataset_train)
17371737
forest_pred_tau = active_forest_tau.predict(forest_dataset_train)
1738-
forest_pred = forest_pred_mu + forest_pred_tau
1739-
mu0 = forest_pred[y_train[:, 0] == 0]
1740-
mu1 = forest_pred[y_train[:, 0] == 1]
1738+
outcome_pred = forest_pred_mu + forest_pred_tau
1739+
if self.has_rfx:
1740+
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1741+
outcome_pred = outcome_pred + rfx_pred
1742+
mu0 = outcome_pred[y_train[:, 0] == 0]
1743+
mu1 = outcome_pred[y_train[:, 0] == 1]
17411744
n0 = np.sum(y_train[:, 0] == 0)
17421745
n1 = np.sum(y_train[:, 0] == 1)
17431746
u0 = self.rng.uniform(
@@ -1754,7 +1757,7 @@ def sample(
17541757
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)
17551758

17561759
# Update outcome
1757-
new_outcome = np.squeeze(resid_train) - forest_pred
1760+
new_outcome = np.squeeze(resid_train) - outcome_pred
17581761
residual_train.update_data(new_outcome)
17591762

17601763
# Sample the prognostic forest
@@ -1817,18 +1820,21 @@ def sample(
18171820

18181821
# Sample coding parameters (if requested)
18191822
if self.adaptive_coding:
1820-
mu_x = active_forest_mu.predict_raw(forest_dataset_train)
1823+
partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train)
18211824
tau_x = np.squeeze(
18221825
active_forest_tau.predict_raw(forest_dataset_train)
18231826
)
1827+
if self.has_rfx:
1828+
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1829+
partial_outcome_pred = partial_outcome_pred + rfx_pred
18241830
s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0))
18251831
s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1))
1826-
partial_resid_mu = np.squeeze(resid_train - mu_x)
1832+
partial_resid = np.squeeze(resid_train - partial_outcome_pred)
18271833
s_ty0 = np.sum(
1828-
tau_x * partial_resid_mu * (np.squeeze(Z_train) == 0)
1834+
tau_x * partial_resid * (np.squeeze(Z_train) == 0)
18291835
)
18301836
s_ty1 = np.sum(
1831-
tau_x * partial_resid_mu * (np.squeeze(Z_train) == 1)
1837+
tau_x * partial_resid * (np.squeeze(Z_train) == 1)
18321838
)
18331839
current_b_0 = self.rng.normal(
18341840
loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),
@@ -1935,9 +1941,12 @@ def sample(
19351941
# Sample latent probit variable z | -
19361942
forest_pred_mu = active_forest_mu.predict(forest_dataset_train)
19371943
forest_pred_tau = active_forest_tau.predict(forest_dataset_train)
1938-
forest_pred = forest_pred_mu + forest_pred_tau
1939-
mu0 = forest_pred[y_train[:, 0] == 0]
1940-
mu1 = forest_pred[y_train[:, 0] == 1]
1944+
outcome_pred = forest_pred_mu + forest_pred_tau
1945+
if self.has_rfx:
1946+
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
1947+
outcome_pred = outcome_pred + rfx_pred
1948+
mu0 = outcome_pred[y_train[:, 0] == 0]
1949+
mu1 = outcome_pred[y_train[:, 0] == 1]
19411950
n0 = np.sum(y_train[:, 0] == 0)
19421951
n1 = np.sum(y_train[:, 0] == 1)
19431952
u0 = self.rng.uniform(
@@ -1954,7 +1963,7 @@ def sample(
19541963
resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1)
19551964

19561965
# Update outcome
1957-
new_outcome = np.squeeze(resid_train) - forest_pred
1966+
new_outcome = np.squeeze(resid_train) - outcome_pred
19581967
residual_train.update_data(new_outcome)
19591968

19601969
# Sample the prognostic forest
@@ -2017,18 +2026,21 @@ def sample(
20172026

20182027
# Sample coding parameters (if requested)
20192028
if self.adaptive_coding:
2020-
mu_x = active_forest_mu.predict_raw(forest_dataset_train)
2029+
partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train)
20212030
tau_x = np.squeeze(
20222031
active_forest_tau.predict_raw(forest_dataset_train)
20232032
)
2033+
if self.has_rfx:
2034+
rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker)
2035+
partial_outcome_pred = partial_outcome_pred + rfx_pred
20242036
s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0))
20252037
s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1))
2026-
partial_resid_mu = np.squeeze(resid_train - mu_x)
2038+
partial_resid = np.squeeze(resid_train - partial_outcome_pred)
20272039
s_ty0 = np.sum(
2028-
tau_x * partial_resid_mu * (np.squeeze(Z_train) == 0)
2040+
tau_x * partial_resid * (np.squeeze(Z_train) == 0)
20292041
)
20302042
s_ty1 = np.sum(
2031-
tau_x * partial_resid_mu * (np.squeeze(Z_train) == 1)
2043+
tau_x * partial_resid * (np.squeeze(Z_train) == 1)
20322044
)
20332045
current_b_0 = self.rng.normal(
20342046
loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)),
@@ -2655,8 +2667,8 @@ def compute_contrast(
26552667

26562668
# Transform to probability scale if requested
26572669
if probability_scale:
2658-
treatment_preds = norm.ppf(treatment_preds)
2659-
control_preds = norm.ppf(control_preds)
2670+
treatment_preds = norm.cdf(treatment_preds)
2671+
control_preds = norm.cdf(control_preds)
26602672

26612673
# Compute and return contrast
26622674
if predict_mean:

stochtree/random_effects.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,28 @@ def sample(
426426
rng.rng_cpp,
427427
)
428428

429+
def predict(
430+
self, rfx_dataset: RandomEffectsDataset, rfx_tracker: RandomEffectsTracker
431+
) -> np.ndarray:
432+
"""
433+
Predict random effects for each observation in `rfx_dataset`
434+
435+
Parameters
436+
----------
437+
rfx_dataset: RandomEffectsDataset
438+
Object of type `RandomEffectsDataset`
439+
rfx_tracker: RandomEffectsTracker
440+
Object of type `RandomEffectsTracker`
441+
442+
Returns
443+
-------
444+
np.ndarray
445+
Numpy array with as many rows as observations in `rfx_dataset` and as many columns as samples in the container
446+
"""
447+
return self.rfx_model_cpp.Predict(
448+
rfx_dataset.rfx_dataset_cpp, rfx_tracker.rfx_tracker_cpp
449+
)
450+
429451
def set_working_parameter(self, working_parameter: np.ndarray) -> None:
430452
"""
431453
Set values for the "working parameter." This is typically used for initialization,

0 commit comments

Comments
 (0)