@@ -1815,21 +1815,21 @@ def sample(
18151815
18161816 # Sample coding parameters (if requested)
18171817 if self .adaptive_coding :
1818- partial_outcome_pred = active_forest_mu .predict_raw (forest_dataset_train )
1818+ mu_x = active_forest_mu .predict_raw (forest_dataset_train )
18191819 tau_x = np .squeeze (
18201820 active_forest_tau .predict_raw (forest_dataset_train )
18211821 )
1822+ partial_resid_train = np .squeeze (resid_train - mu_x )
18221823 if self .has_rfx :
1823- rfx_pred = rfx_model .predict (rfx_dataset_train , rfx_tracker )
1824- partial_outcome_pred = partial_outcome_pred + rfx_pred
1824+ rfx_pred = np . squeeze ( rfx_model .predict (rfx_dataset_train , rfx_tracker ) )
1825+ partial_resid_train = partial_resid_train - rfx_pred
18251826 s_tt0 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 0 ))
18261827 s_tt1 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 1 ))
1827- partial_resid = np .squeeze (resid_train - partial_outcome_pred )
18281828 s_ty0 = np .sum (
1829- tau_x * partial_resid * (np .squeeze (Z_train ) == 0 )
1829+ tau_x * partial_resid_train * (np .squeeze (Z_train ) == 0 )
18301830 )
18311831 s_ty1 = np .sum (
1832- tau_x * partial_resid * (np .squeeze (Z_train ) == 1 )
1832+ tau_x * partial_resid_train * (np .squeeze (Z_train ) == 1 )
18331833 )
18341834 current_b_0 = self .rng .normal (
18351835 loc = (s_ty0 / (s_tt0 + 2 * current_sigma2 )),
@@ -2021,21 +2021,21 @@ def sample(
20212021
20222022 # Sample coding parameters (if requested)
20232023 if self .adaptive_coding :
2024- partial_outcome_pred = active_forest_mu .predict_raw (forest_dataset_train )
2024+ mu_x = active_forest_mu .predict_raw (forest_dataset_train )
20252025 tau_x = np .squeeze (
20262026 active_forest_tau .predict_raw (forest_dataset_train )
20272027 )
2028+ partial_resid_train = np .squeeze (resid_train - mu_x )
20282029 if self .has_rfx :
2029- rfx_pred = rfx_model .predict (rfx_dataset_train , rfx_tracker )
2030- partial_outcome_pred = partial_outcome_pred + rfx_pred
2030+ rfx_pred = np . squeeze ( rfx_model .predict (rfx_dataset_train , rfx_tracker ) )
2031+ partial_resid_train = partial_resid_train - rfx_pred
20312032 s_tt0 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 0 ))
20322033 s_tt1 = np .sum (tau_x * tau_x * (np .squeeze (Z_train ) == 1 ))
2033- partial_resid = np .squeeze (resid_train - partial_outcome_pred )
20342034 s_ty0 = np .sum (
2035- tau_x * partial_resid * (np .squeeze (Z_train ) == 0 )
2035+ tau_x * partial_resid_train * (np .squeeze (Z_train ) == 0 )
20362036 )
20372037 s_ty1 = np .sum (
2038- tau_x * partial_resid * (np .squeeze (Z_train ) == 1 )
2038+ tau_x * partial_resid_train * (np .squeeze (Z_train ) == 1 )
20392039 )
20402040 current_b_0 = self .rng .normal (
20412041 loc = (s_ty0 / (s_tt0 + 2 * current_sigma2 )),
@@ -2459,8 +2459,6 @@ def predict(
24592459 rfx_preds = (
24602460 self .rfx_container .predict (rfx_group_ids , rfx_basis ) * self .y_std
24612461 )
2462- if predict_mean :
2463- rfx_preds = np .mean (rfx_preds , axis = 1 )
24642462
24652463 # Combine into y hat predictions
24662464 if predict_y_hat and has_mu_forest and has_rfx :
0 commit comments