@@ -1735,9 +1735,12 @@ def sample(
17351735 # Sample latent probit variable z | -
17361736 forest_pred_mu = active_forest_mu .predict (forest_dataset_train )
17371737 forest_pred_tau = active_forest_tau .predict (forest_dataset_train )
1738- forest_pred = forest_pred_mu + forest_pred_tau
1739- mu0 = forest_pred [y_train [:, 0 ] == 0 ]
1740- mu1 = forest_pred [y_train [:, 0 ] == 1 ]
1738+ outcome_pred = forest_pred_mu + forest_pred_tau
1739+ if self .has_rfx :
1740+ rfx_pred = rfx_model .predict (rfx_dataset_train , rfx_tracker )
1741+ outcome_pred = outcome_pred + rfx_pred
1742+ mu0 = outcome_pred [y_train [:, 0 ] == 0 ]
1743+ mu1 = outcome_pred [y_train [:, 0 ] == 1 ]
17411744 n0 = np .sum (y_train [:, 0 ] == 0 )
17421745 n1 = np .sum (y_train [:, 0 ] == 1 )
17431746 u0 = self .rng .uniform (
@@ -1754,7 +1757,7 @@ def sample(
17541757 resid_train [y_train [:, 0 ] == 1 , 0 ] = mu1 + norm .ppf (u1 )
17551758
17561759 # Update outcome
1757- new_outcome = np .squeeze (resid_train ) - forest_pred
1760+ new_outcome = np .squeeze (resid_train ) - outcome_pred
17581761 residual_train .update_data (new_outcome )
17591762
17601763 # Sample the prognostic forest
@@ -1817,18 +1820,21 @@ def sample(
18171820
18181821 # Sample coding parameters (if requested)
18191822 if self .adaptive_coding :
1820- mu_x = active_forest_mu .predict_raw (forest_dataset_train )
1823+ partial_outcome_pred = active_forest_mu .predict_raw (forest_dataset_train )
18211824 tau_x = np .squeeze (
18221825 active_forest_tau .predict_raw (forest_dataset_train )
18231826 )
1827+ if self .has_rfx :
1828+ rfx_pred = rfx_model .predict (rfx_dataset_train , rfx_tracker )
1829+ partial_outcome_pred = partial_outcome_pred + rfx_pred
18241830 s_tt0 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 0 ))
18251831 s_tt1 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 1 ))
1826- partial_resid_mu = np .squeeze (resid_train - mu_x )
1832+ partial_resid = np .squeeze (resid_train - partial_outcome_pred )
18271833 s_ty0 = np .sum (
1828- tau_x * partial_resid_mu * (np .squeeze (Z_train ) == 0 )
1834+ tau_x * partial_resid * (np .squeeze (Z_train ) == 0 )
18291835 )
18301836 s_ty1 = np .sum (
1831- tau_x * partial_resid_mu * (np .squeeze (Z_train ) == 1 )
1837+ tau_x * partial_resid * (np .squeeze (Z_train ) == 1 )
18321838 )
18331839 current_b_0 = self .rng .normal (
18341840 loc = (s_ty0 / (s_tt0 + 2 * current_sigma2 )),
@@ -1935,9 +1941,12 @@ def sample(
19351941 # Sample latent probit variable z | -
19361942 forest_pred_mu = active_forest_mu .predict (forest_dataset_train )
19371943 forest_pred_tau = active_forest_tau .predict (forest_dataset_train )
1938- forest_pred = forest_pred_mu + forest_pred_tau
1939- mu0 = forest_pred [y_train [:, 0 ] == 0 ]
1940- mu1 = forest_pred [y_train [:, 0 ] == 1 ]
1944+ outcome_pred = forest_pred_mu + forest_pred_tau
1945+ if self .has_rfx :
1946+ rfx_pred = rfx_model .predict (rfx_dataset_train , rfx_tracker )
1947+ outcome_pred = outcome_pred + rfx_pred
1948+ mu0 = outcome_pred [y_train [:, 0 ] == 0 ]
1949+ mu1 = outcome_pred [y_train [:, 0 ] == 1 ]
19411950 n0 = np .sum (y_train [:, 0 ] == 0 )
19421951 n1 = np .sum (y_train [:, 0 ] == 1 )
19431952 u0 = self .rng .uniform (
@@ -1954,7 +1963,7 @@ def sample(
19541963 resid_train [y_train [:, 0 ] == 1 , 0 ] = mu1 + norm .ppf (u1 )
19551964
19561965 # Update outcome
1957- new_outcome = np .squeeze (resid_train ) - forest_pred
1966+ new_outcome = np .squeeze (resid_train ) - outcome_pred
19581967 residual_train .update_data (new_outcome )
19591968
19601969 # Sample the prognostic forest
@@ -2017,18 +2026,21 @@ def sample(
20172026
20182027 # Sample coding parameters (if requested)
20192028 if self .adaptive_coding :
2020- mu_x = active_forest_mu .predict_raw (forest_dataset_train )
2029+ partial_outcome_pred = active_forest_mu .predict_raw (forest_dataset_train )
20212030 tau_x = np .squeeze (
20222031 active_forest_tau .predict_raw (forest_dataset_train )
20232032 )
2033+ if self .has_rfx :
2034+ rfx_pred = rfx_model .predict (rfx_dataset_train , rfx_tracker )
2035+ partial_outcome_pred = partial_outcome_pred + rfx_pred
20242036 s_tt0 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 0 ))
20252037 s_tt1 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 1 ))
2026- partial_resid_mu = np .squeeze (resid_train - mu_x )
2038+ partial_resid = np .squeeze (resid_train - partial_outcome_pred )
20272039 s_ty0 = np .sum (
2028- tau_x * partial_resid_mu * (np .squeeze (Z_train ) == 0 )
2040+ tau_x * partial_resid * (np .squeeze (Z_train ) == 0 )
20292041 )
20302042 s_ty1 = np .sum (
2031- tau_x * partial_resid_mu * (np .squeeze (Z_train ) == 1 )
2043+ tau_x * partial_resid * (np .squeeze (Z_train ) == 1 )
20322044 )
20332045 current_b_0 = self .rng .normal (
20342046 loc = (s_ty0 / (s_tt0 + 2 * current_sigma2 )),
@@ -2655,8 +2667,8 @@ def compute_contrast(
26552667
26562668 # Transform to probability scale if requested
26572669 if probability_scale :
2658- treatment_preds = norm .ppf (treatment_preds )
2659- control_preds = norm .ppf (control_preds )
2670+ treatment_preds = norm .cdf (treatment_preds )
2671+ control_preds = norm .cdf (control_preds )
26602672
26612673 # Compute and return contrast
26622674 if predict_mean :
0 commit comments