Skip to content

Commit bb09de5

Browse files
committed
Handling RFX specifications in the predict BCF method
1 parent d6e4480 commit bb09de5

File tree

1 file changed

+64
-13
lines changed

1 file changed

+64
-13
lines changed

R/bcf.R

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2720,6 +2720,10 @@ predict.bcfmodel <- function(
27202720
predict_mean <- type == "mean"
27212721

27222722
# Handle prediction terms
2723+
rfx_model_spec = object$model_params$rfx_model_spec
2724+
rfx_intercept_only <- rfx_model_spec == "intercept_only"
2725+
rfx_intercept_plus_treatment <- (rfx_model_spec == "intercept_plus_treatment")
2726+
rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment
27232727
if (!is.character(terms)) {
27242728
stop("type must be a string or character vector")
27252729
}
@@ -2756,7 +2760,10 @@ predict.bcfmodel <- function(
27562760
))
27572761
return(NULL)
27582762
}
2759-
predict_rfx_intermediate <- (predict_y_hat && has_rfx)
2763+
predict_rfx_intermediate <- ((predict_y_hat && has_rfx))
2764+
predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept_only) ||
2765+
(predict_mu_forest && has_rfx && rfx_intercept_plus_treatment) ||
2766+
(predict_tau_forest && has_rfx && rfx_intercept_plus_treatment))
27602767
predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest)
27612768
predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest)
27622769

@@ -2887,30 +2894,33 @@ predict.bcfmodel <- function(
28872894

28882895
# Compute mu forest predictions
28892896
if (predict_mu_forest || predict_mu_forest_intermediate) {
2890-
mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar
2897+
mu_hat_forest <- object$forests_mu$predict(forest_dataset_pred) *
2898+
y_std +
2899+
y_bar
28912900
}
28922901

28932902
# Compute CATE forest predictions
28942903
if (predict_tau_forest || predict_tau_forest_intermediate) {
28952904
if (object$model_params$adaptive_coding) {
28962905
tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred)
2897-
tau_hat <- t(
2906+
tau_hat_forest <- t(
28982907
t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples)
28992908
) *
29002909
y_std
29012910
} else {
2902-
tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std
2911+
tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) *
2912+
y_std
29032913
}
29042914
if (object$model_params$multivariate_treatment) {
2905-
tau_dim <- dim(tau_hat)
2915+
tau_dim <- dim(tau_hat_forest)
29062916
tau_num_obs <- tau_dim[1]
29072917
tau_num_samples <- tau_dim[3]
29082918
treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
29092919
for (i in 1:nrow(Z)) {
2910-
treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ])
2920+
treatment_term[i, ] <- colSums(tau_hat_forest[i, , ] * Z[i, ])
29112921
}
29122922
} else {
2913-
treatment_term <- tau_hat * as.numeric(Z)
2923+
treatment_term <- tau_hat_forest * as.numeric(Z)
29142924
}
29152925
}
29162926

@@ -2923,6 +2933,41 @@ predict.bcfmodel <- function(
29232933
y_std
29242934
}
29252935

2936+
# Extract "raw" rfx coefficients for each rfx basis term if needed
2937+
if (predict_rfx_raw) {
2938+
# Extract the raw RFX samples and scale by train set outcome sd
2939+
rfx_param_list <- object$rfx_samples$extract_parameter_samples()
2940+
rfx_beta_draws <- rfx_param_list$beta_samples *
2941+
object$model_params$outcome_scale
2942+
2943+
# Construct a matrix with the correct random effects
2944+
rfx_predictions_raw <- array(
2945+
NA,
2946+
dim = c(
2947+
nrow(X),
2948+
ncol(rfx_basis),
2949+
object$model_params$num_samples
2950+
)
2951+
)
2952+
for (i in 1:nrow(X)) {
2953+
rfx_predictions_raw[i, , ] <-
2954+
rfx_beta_draws[, rfx_group_ids[i], ]
2955+
}
2956+
2957+
# Add these RFX predictions to mu and tau if warranted by the RFX model spec
2958+
if (predict_mu_forest && rfx_intercept) {
2959+
mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ]
2960+
} else {
2961+
mu_hat_final <- mu_hat_forest
2962+
}
2963+
if (predict_tau_forest && rfx_intercept_plus_treatment) {
2964+
tau_hat_final <- (tau_hat_forest +
2965+
rfx_predictions_raw[, 2:ncol(rfx_basis), ])
2966+
} else {
2967+
tau_hat_final <- tau_hat_forest
2968+
}
2969+
}
2970+
29262971
# Combine into y hat predictions
29272972
needs_mean_term_preds <- predict_y_hat ||
29282973
predict_mu_forest ||
@@ -2932,32 +2977,38 @@ predict.bcfmodel <- function(
29322977
if (probability_scale) {
29332978
if (has_rfx) {
29342979
if (predict_y_hat) {
2935-
y_hat <- pnorm(mu_hat + treatment_term + rfx_predictions)
2980+
y_hat <- pnorm(mu_hat_forest + treatment_term + rfx_predictions)
29362981
}
29372982
if (predict_rfx) {
29382983
rfx_predictions <- pnorm(rfx_predictions)
29392984
}
29402985
} else {
29412986
if (predict_y_hat) {
2942-
y_hat <- pnorm(mu_hat + treatment_term)
2987+
y_hat <- pnorm(mu_hat_forest + treatment_term)
29432988
}
29442989
}
29452990
if (predict_mu_forest) {
2946-
mu_hat <- pnorm(mu_hat)
2991+
mu_hat <- pnorm(mu_hat_final)
29472992
}
29482993
if (predict_tau_forest) {
2949-
tau_hat <- pnorm(tau_hat)
2994+
tau_hat <- pnorm(tau_hat_final)
29502995
}
29512996
} else {
29522997
if (has_rfx) {
29532998
if (predict_y_hat) {
2954-
y_hat <- mu_hat + treatment_term + rfx_predictions
2999+
y_hat <- mu_hat_forest + treatment_term + rfx_predictions
29553000
}
29563001
} else {
29573002
if (predict_y_hat) {
2958-
y_hat <- mu_hat + treatment_term
3003+
y_hat <- mu_hat_forest + treatment_term
29593004
}
29603005
}
3006+
if (predict_mu_forest) {
3007+
mu_hat <- mu_hat_final
3008+
}
3009+
if (predict_tau_forest) {
3010+
tau_hat <- tau_hat_final
3011+
}
29613012
}
29623013
}
29633014

0 commit comments

Comments
 (0)