Skip to content

Commit 691eed7

Browse files
committed
Updated predict method and added demo script to check the behavior
1 parent bb09de5 commit 691eed7

File tree

2 files changed

+90
-21
lines changed

2 files changed

+90
-21
lines changed

R/bcf.R

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2628,7 +2628,7 @@ bcf <- function(
26282628
#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
26292629
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
26302630
#' that were not in the training set.
2631-
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model.
2631+
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.
26322632
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
26332633
#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
26342634
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
@@ -2830,31 +2830,41 @@ predict.bcfmodel <- function(
28302830
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
28312831
if (sum(is.na(group_ids_factor)) > 0) {
28322832
stop(
2833-
"All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids"
2833+
"All random effect group labels provided in rfx_group_ids must have been present at sampling time"
28342834
)
28352835
}
28362836
rfx_group_ids <- as.integer(group_ids_factor)
28372837
has_rfx <- TRUE
28382838
}
28392839

28402840
# Handle RFX model specification
2841-
if (object$model_params$rfx_model_spec == "custom") {
2842-
if (is.null(rfx_basis)) {
2843-
stop(
2844-
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2845-
)
2841+
if (has_rfx) {
2842+
if (object$model_params$rfx_model_spec == "custom") {
2843+
if (is.null(rfx_basis)) {
2844+
stop(
2845+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2846+
)
2847+
}
2848+
} else if (object$model_params$rfx_model_spec == "intercept_only") {
2849+
# Only construct a basis if user-provided basis missing
2850+
if (is.null(rfx_basis)) {
2851+
rfx_basis <- matrix(
2852+
rep(1, nrow(X)),
2853+
nrow = nrow(X),
2854+
ncol = 1
2855+
)
2856+
}
2857+
} else if (
2858+
object$model_params$rfx_model_spec == "intercept_plus_treatment"
2859+
) {
2860+
# Only construct a basis if user-provided basis missing
2861+
if (is.null(rfx_basis)) {
2862+
rfx_basis <- cbind(
2863+
rep(1, nrow(X)),
2864+
Z
2865+
)
2866+
}
28462867
}
2847-
} else if (object$model_params$rfx_model_spec == "intercept_only") {
2848-
rfx_basis <- matrix(
2849-
rep(1, nrow(X)),
2850-
nrow = nrow(X),
2851-
ncol = 1
2852-
)
2853-
} else if (object$model_params$rfx_model_spec == "intercept_plus_treatment") {
2854-
rfx_basis <- cbind(
2855-
rep(1, nrow(X)),
2856-
Z
2857-
)
28582868
}
28592869

28602870
# Add propensities to covariate set if necessary
@@ -2953,14 +2963,18 @@ predict.bcfmodel <- function(
29532963
rfx_predictions_raw[i, , ] <-
29542964
rfx_beta_draws[, rfx_group_ids[i], ]
29552965
}
2966+
}
29562967

2957-
# Add these RFX predictions to mu and tau if warranted by the RFX model spec
2958-
if (predict_mu_forest && rfx_intercept) {
2968+
# Add raw RFX predictions to mu and tau if warranted by the RFX model spec
2969+
if (predict_mu_forest || predict_mu_forest_intermediate) {
2970+
if (rfx_intercept && predict_rfx_raw) {
29592971
mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ]
29602972
} else {
29612973
mu_hat_final <- mu_hat_forest
29622974
}
2963-
if (predict_tau_forest && rfx_intercept_plus_treatment) {
2975+
}
2976+
if (predict_tau_forest || predict_tau_forest_intermediate) {
2977+
if (rfx_intercept_plus_treatment && predict_rfx_raw) {
29642978
tau_hat_final <- (tau_hat_forest +
29652979
rfx_predictions_raw[, 2:ncol(rfx_basis), ])
29662980
} else {

tools/debug/bcf_cate_debug.R

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,61 @@ all(
193193
abs(tau_diff) < 0.001
194194
)
195195

196+
# Now repeat the same process but via random effects model spec
197+
bcf_model <- bcf(
198+
X_train = X_train,
199+
Z_train = Z_train,
200+
y_train = y_train,
201+
propensity_train = pi_train,
202+
rfx_group_ids_train = group_ids_train,
203+
X_test = X_test,
204+
Z_test = Z_test,
205+
propensity_test = pi_test,
206+
rfx_group_ids_test = group_ids_test,
207+
num_gfr = 10,
208+
num_burnin = 0,
209+
num_mcmc = 1000,
210+
random_effects_params = list(
211+
model_spec = "intercept_plus_treatment"
212+
)
213+
)
214+
215+
# Compute CATE posterior
216+
tau_hat_posterior_test <- compute_contrast_bcf_model(
217+
bcf_model,
218+
X_0 = X_test,
219+
X_1 = X_test,
220+
Z_0 = rep(0, n_test),
221+
Z_1 = rep(1, n_test),
222+
propensity_0 = pi_test,
223+
propensity_1 = pi_test,
224+
rfx_group_ids_0 = group_ids_test,
225+
rfx_group_ids_1 = group_ids_test,
226+
rfx_basis_0 = cbind(1, rep(0, n_test)),
227+
rfx_basis_1 = cbind(1, rep(1, n_test)),
228+
type = "posterior",
229+
scale = "linear"
230+
)
231+
232+
# Compute the same quantity via predict
233+
tau_hat_posterior_test_comparison <- predict(
234+
bcf_model,
235+
X = X_test,
236+
Z = Z_test,
237+
propensity = pi_test,
238+
rfx_group_ids = group_ids_test,
239+
rfx_basis = rfx_basis_test,
240+
type = "posterior",
241+
terms = "cate",
242+
scale = "linear"
243+
)
244+
245+
# Compare results
246+
tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test
247+
all(
248+
abs(tau_diff) < 0.001
249+
)
250+
196251
# Generate data for a probit BCF model with random effects
197252
X <- matrix(rnorm(n * p), ncol = p)
198253
mu_x <- X[, 1]

0 commit comments

Comments
 (0)