@@ -2720,6 +2720,10 @@ predict.bcfmodel <- function(
27202720 predict_mean <- type == " mean"
27212721
27222722 # Handle prediction terms
2723+ rfx_model_spec = object $ model_params $ rfx_model_spec
2724+ rfx_intercept_only <- rfx_model_spec == " intercept_only"
2725+ rfx_intercept_plus_treatment <- (rfx_model_spec == " intercept_plus_treatment" )
2726+ rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment
27232727 if (! is.character(terms )) {
27242728 stop(" type must be a string or character vector" )
27252729 }
@@ -2756,7 +2760,10 @@ predict.bcfmodel <- function(
27562760 ))
27572761 return (NULL )
27582762 }
2759- predict_rfx_intermediate <- (predict_y_hat && has_rfx )
2763+ predict_rfx_intermediate <- ((predict_y_hat && has_rfx ))
2764+ predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept_only ) ||
2765+ (predict_mu_forest && has_rfx && rfx_intercept_plus_treatment ) ||
2766+ (predict_tau_forest && has_rfx && rfx_intercept_plus_treatment ))
27602767 predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest )
27612768 predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest )
27622769
@@ -2887,30 +2894,33 @@ predict.bcfmodel <- function(
28872894
28882895 # Compute mu forest predictions
28892896 if (predict_mu_forest || predict_mu_forest_intermediate ) {
2890- mu_hat <- object $ forests_mu $ predict(forest_dataset_pred ) * y_std + y_bar
2897+ mu_hat_forest <- object $ forests_mu $ predict(forest_dataset_pred ) *
2898+ y_std +
2899+ y_bar
28912900 }
28922901
28932902 # Compute CATE forest predictions
28942903 if (predict_tau_forest || predict_tau_forest_intermediate ) {
28952904 if (object $ model_params $ adaptive_coding ) {
28962905 tau_hat_raw <- object $ forests_tau $ predict_raw(forest_dataset_pred )
2897- tau_hat <- t(
2906+ tau_hat_forest <- t(
28982907 t(tau_hat_raw ) * (object $ b_1_samples - object $ b_0_samples )
28992908 ) *
29002909 y_std
29012910 } else {
2902- tau_hat <- object $ forests_tau $ predict_raw(forest_dataset_pred ) * y_std
2911+ tau_hat_forest <- object $ forests_tau $ predict_raw(forest_dataset_pred ) *
2912+ y_std
29032913 }
29042914 if (object $ model_params $ multivariate_treatment ) {
2905- tau_dim <- dim(tau_hat )
2915+ tau_dim <- dim(tau_hat_forest )
29062916 tau_num_obs <- tau_dim [1 ]
29072917 tau_num_samples <- tau_dim [3 ]
29082918 treatment_term <- matrix (NA_real_ , nrow = tau_num_obs , tau_num_samples )
29092919 for (i in 1 : nrow(Z )) {
2910- treatment_term [i , ] <- colSums(tau_hat [i , , ] * Z [i , ])
2920+ treatment_term [i , ] <- colSums(tau_hat_forest [i , , ] * Z [i , ])
29112921 }
29122922 } else {
2913- treatment_term <- tau_hat * as.numeric(Z )
2923+ treatment_term <- tau_hat_forest * as.numeric(Z )
29142924 }
29152925 }
29162926
@@ -2923,6 +2933,41 @@ predict.bcfmodel <- function(
29232933 y_std
29242934 }
29252935
2936+ # Extract "raw" rfx coefficients for each rfx basis term if needed
2937+ if (predict_rfx_raw ) {
2938+ # Extract the raw RFX samples and scale by train set outcome sd
2939+ rfx_param_list <- object $ rfx_samples $ extract_parameter_samples()
2940+ rfx_beta_draws <- rfx_param_list $ beta_samples *
2941+ object $ model_params $ outcome_scale
2942+
2943+ # Construct a matrix with the correct random effects
2944+ rfx_predictions_raw <- array (
2945+ NA ,
2946+ dim = c(
2947+ nrow(X ),
2948+ ncol(rfx_basis ),
2949+ object $ model_params $ num_samples
2950+ )
2951+ )
2952+ for (i in 1 : nrow(X )) {
2953+ rfx_predictions_raw [i , , ] <-
2954+ rfx_beta_draws [, rfx_group_ids [i ], ]
2955+ }
2956+
2957+ # Add these RFX predictions to mu and tau if warranted by the RFX model spec
2958+ if (predict_mu_forest && rfx_intercept ) {
2959+ mu_hat_final <- mu_hat_forest + rfx_predictions_raw [, 1 , ]
2960+ } else {
2961+ mu_hat_final <- mu_hat_forest
2962+ }
2963+ if (predict_tau_forest && rfx_intercept_plus_treatment ) {
2964+ tau_hat_final <- (tau_hat_forest +
2965+ rfx_predictions_raw [, 2 : ncol(rfx_basis ), ])
2966+ } else {
2967+ tau_hat_final <- tau_hat_forest
2968+ }
2969+ }
2970+
29262971 # Combine into y hat predictions
29272972 needs_mean_term_preds <- predict_y_hat ||
29282973 predict_mu_forest ||
@@ -2932,32 +2977,38 @@ predict.bcfmodel <- function(
29322977 if (probability_scale ) {
29332978 if (has_rfx ) {
29342979 if (predict_y_hat ) {
2935- y_hat <- pnorm(mu_hat + treatment_term + rfx_predictions )
2980+ y_hat <- pnorm(mu_hat_forest + treatment_term + rfx_predictions )
29362981 }
29372982 if (predict_rfx ) {
29382983 rfx_predictions <- pnorm(rfx_predictions )
29392984 }
29402985 } else {
29412986 if (predict_y_hat ) {
2942- y_hat <- pnorm(mu_hat + treatment_term )
2987+ y_hat <- pnorm(mu_hat_forest + treatment_term )
29432988 }
29442989 }
29452990 if (predict_mu_forest ) {
2946- mu_hat <- pnorm(mu_hat )
2991+ mu_hat <- pnorm(mu_hat_final )
29472992 }
29482993 if (predict_tau_forest ) {
2949- tau_hat <- pnorm(tau_hat )
2994+ tau_hat <- pnorm(tau_hat_final )
29502995 }
29512996 } else {
29522997 if (has_rfx ) {
29532998 if (predict_y_hat ) {
2954- y_hat <- mu_hat + treatment_term + rfx_predictions
2999+ y_hat <- mu_hat_forest + treatment_term + rfx_predictions
29553000 }
29563001 } else {
29573002 if (predict_y_hat ) {
2958- y_hat <- mu_hat + treatment_term
3003+ y_hat <- mu_hat_forest + treatment_term
29593004 }
29603005 }
3006+ if (predict_mu_forest ) {
3007+ mu_hat <- mu_hat_final
3008+ }
3009+ if (predict_tau_forest ) {
3010+ tau_hat <- tau_hat_final
3011+ }
29613012 }
29623013 }
29633014
0 commit comments