|
| 1 | +#' Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. |
| 2 | +#' For simple BCF models with binary treatment, this will yield the same prediction as requesting `terms = "cate"` |
| 3 | +#' in the `predict.bcfmodel` function. For more general models, such as models with continuous / multivariate treatments or |
| 4 | +#' an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a |
| 5 | +#' any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term |
| 6 | +#' contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend of the |
| 7 | +#' contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" terminology of a classic |
| 8 | +#' two-arm experiment. We mirror the function calls and terminology of the `predict.bcfmodel` function, labeling each prediction |
| 9 | +#' data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the |
| 10 | +#' control prediction. |
| 11 | +#' |
| 12 | +#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs. |
| 13 | +#' @param X_0 Covariates used for prediction in the "control" case. |
| 14 | +#' @param X_1 Covariates used for prediction in the "treatment" case. |
| 15 | +#' @param Z_0 Treatments used for prediction in the "control" case. |
| 16 | +#' @param Z_1 Treatments used for prediction in the "treatment" case. |
| 17 | +#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. |
| 18 | +#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. |
| 19 | +#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects |
| 20 | +#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation |
| 21 | +#' for group labels that were not in the training set. |
| 22 | +#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects |
| 23 | +#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation |
| 24 | +#' for group labels that were not in the training set. |
| 25 | +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. |
| 26 | +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. |
| 27 | +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". |
| 28 | +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". |
| 29 | +#' @param ... (Optional) Other prediction parameters. |
| 30 | +#' |
| 31 | +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. |
| 32 | +#' @export |
| 33 | +#' |
| 34 | +#' @examples |
| 35 | +#' n <- 500 |
| 36 | +#' p <- 5 |
| 37 | +#' X <- matrix(runif(n*p), ncol = p) |
| 38 | +#' mu_x <- ( |
| 39 | +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + |
| 40 | +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + |
| 41 | +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + |
| 42 | +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) |
| 43 | +#' ) |
| 44 | +#' pi_x <- ( |
| 45 | +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + |
| 46 | +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + |
| 47 | +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + |
| 48 | +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) |
| 49 | +#' ) |
| 50 | +#' tau_x <- ( |
| 51 | +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + |
| 52 | +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + |
| 53 | +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + |
| 54 | +#' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) |
| 55 | +#' ) |
| 56 | +#' Z <- rbinom(n, 1, pi_x) |
| 57 | +#' noise_sd <- 1 |
| 58 | +#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) |
| 59 | +#' test_set_pct <- 0.2 |
| 60 | +#' n_test <- round(test_set_pct*n) |
| 61 | +#' n_train <- n - n_test |
| 62 | +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) |
| 63 | +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] |
| 64 | +#' X_test <- X[test_inds,] |
| 65 | +#' X_train <- X[train_inds,] |
| 66 | +#' pi_test <- pi_x[test_inds] |
| 67 | +#' pi_train <- pi_x[train_inds] |
| 68 | +#' Z_test <- Z[test_inds] |
| 69 | +#' Z_train <- Z[train_inds] |
| 70 | +#' y_test <- y[test_inds] |
| 71 | +#' y_train <- y[train_inds] |
| 72 | +#' mu_test <- mu_x[test_inds] |
| 73 | +#' mu_train <- mu_x[train_inds] |
| 74 | +#' tau_test <- tau_x[test_inds] |
| 75 | +#' tau_train <- tau_x[train_inds] |
| 76 | +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, |
| 77 | +#' propensity_train = pi_train, num_gfr = 10, |
| 78 | +#' num_burnin = 0, num_mcmc = 10) |
| 79 | +#' preds <- compute_posterior_contrast_bcf_model( |
| 80 | +#' bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), |
| 81 | +#' propensity_0 = pi_test, propensity_1 = pi_test |
| 82 | +#' ) |
| 83 | +compute_contrast_bcf_model <- function( |
| 84 | + object, |
| 85 | + X_0, |
| 86 | + X_1, |
| 87 | + Z_0, |
| 88 | + Z_1, |
| 89 | + propensity_0 = NULL, |
| 90 | + propensity_1 = NULL, |
| 91 | + rfx_group_ids_0 = NULL, |
| 92 | + rfx_group_ids_1 = NULL, |
| 93 | + rfx_basis_0 = NULL, |
| 94 | + rfx_basis_1 = NULL, |
| 95 | + type = "posterior", |
| 96 | + scale = "linear", |
| 97 | + ... |
| 98 | +) { |
| 99 | + # Handle mean function scale |
| 100 | + if (!is.character(scale)) { |
| 101 | + stop("scale must be a string or character vector") |
| 102 | + } |
| 103 | + if (!(scale %in% c("linear", "probability"))) { |
| 104 | + stop("scale must either be 'linear' or 'probability'") |
| 105 | + } |
| 106 | + is_probit <- object$model_params$probit_outcome_model |
| 107 | + if ((scale == "probability") && (!is_probit)) { |
| 108 | + stop( |
| 109 | + "scale cannot be 'probability' for models not fit with a probit outcome model" |
| 110 | + ) |
| 111 | + } |
| 112 | + probability_scale <- scale == "probability" |
| 113 | + |
| 114 | + # Handle prediction type |
| 115 | + if (!is.character(type)) { |
| 116 | + stop("type must be a string or character vector") |
| 117 | + } |
| 118 | + if (!(type %in% c("mean", "posterior"))) { |
| 119 | + stop("type must either be 'mean' or 'posterior") |
| 120 | + } |
| 121 | + predict_mean <- type == "mean" |
| 122 | + |
| 123 | + # Make sure covariates are matrix or data frame |
| 124 | + if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) { |
| 125 | + stop("X_0 must be a matrix or dataframe") |
| 126 | + } |
| 127 | + if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) { |
| 128 | + stop("X_1 must be a matrix or dataframe") |
| 129 | + } |
| 130 | + |
| 131 | + # Convert all input data to matrices if not already converted |
| 132 | + if ((is.null(dim(Z_0))) && (!is.null(Z_0))) { |
| 133 | + Z_0 <- as.matrix(as.numeric(Z_0)) |
| 134 | + } |
| 135 | + if ((is.null(dim(Z_1))) && (!is.null(Z_1))) { |
| 136 | + Z_1 <- as.matrix(as.numeric(Z_1)) |
| 137 | + } |
| 138 | + if ((is.null(dim(propensity_0))) && (!is.null(propensity_0))) { |
| 139 | + propensity_0 <- as.matrix(propensity_0) |
| 140 | + } |
| 141 | + if ((is.null(dim(propensity_1))) && (!is.null(propensity_1))) { |
| 142 | + propensity_1 <- as.matrix(propensity_1) |
| 143 | + } |
| 144 | + if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) { |
| 145 | + rfx_basis_0 <- as.matrix(rfx_basis_0) |
| 146 | + } |
| 147 | + if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) { |
| 148 | + rfx_basis_1 <- as.matrix(rfx_basis_1) |
| 149 | + } |
| 150 | + |
| 151 | + # Data checks |
| 152 | + if ( |
| 153 | + (object$model_params$propensity_covariate != "none") && |
| 154 | + ((is.null(propensity_0)) || |
| 155 | + (is.null(propensity_1))) |
| 156 | + ) { |
| 157 | + if (!object$model_params$internal_propensity_model) { |
| 158 | + stop("propensity_0 and propensity_1 must be provided for this model") |
| 159 | + } |
| 160 | + } |
| 161 | + if (nrow(X_0) != nrow(Z_0)) { |
| 162 | + stop("X_0 and Z_0 must have the same number of rows") |
| 163 | + } |
| 164 | + if (nrow(X_1) != nrow(Z_1)) { |
| 165 | + stop("X_1 and Z_1 must have the same number of rows") |
| 166 | + } |
| 167 | + if (object$model_params$num_covariates != ncol(X_0)) { |
| 168 | + stop( |
| 169 | + "X_0 and must have the same number of columns as the covariates used to train the model" |
| 170 | + ) |
| 171 | + } |
| 172 | + if (object$model_params$num_covariates != ncol(X_1)) { |
| 173 | + stop( |
| 174 | + "X_1 and must have the same number of columns as the covariates used to train the model" |
| 175 | + ) |
| 176 | + } |
| 177 | + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_0))) { |
| 178 | + stop( |
| 179 | + "Random effect group labels (rfx_group_ids_0) must be provided for this model" |
| 180 | + ) |
| 181 | + } |
| 182 | + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_1))) { |
| 183 | + stop( |
| 184 | + "Random effect group labels (rfx_group_ids_1) must be provided for this model" |
| 185 | + ) |
| 186 | + } |
| 187 | + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_0))) { |
| 188 | + stop("Random effects basis (rfx_basis_0) must be provided for this model") |
| 189 | + } |
| 190 | + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_1))) { |
| 191 | + stop("Random effects basis (rfx_basis_1) must be provided for this model") |
| 192 | + } |
| 193 | + if ( |
| 194 | + (object$model_params$num_rfx_basis > 0) && |
| 195 | + (ncol(rfx_basis_0) != object$model_params$num_rfx_basis) |
| 196 | + ) { |
| 197 | + stop( |
| 198 | + "Random effects basis has a different dimension than the basis used to train this model" |
| 199 | + ) |
| 200 | + } |
| 201 | + if ( |
| 202 | + (object$model_params$num_rfx_basis > 0) && |
| 203 | + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis) |
| 204 | + ) { |
| 205 | + stop( |
| 206 | + "Random effects basis has a different dimension than the basis used to train this model" |
| 207 | + ) |
| 208 | + } |
| 209 | + |
| 210 | + # Predict for the control arm |
| 211 | + control_preds <- predict( |
| 212 | + object = object, |
| 213 | + X = X_0, |
| 214 | + Z = Z_0, |
| 215 | + propensity = propensity_0, |
| 216 | + rfx_group_ids = rfx_group_ids_0, |
| 217 | + rfx_basis = rfx_basis_0, |
| 218 | + type = "posterior", |
| 219 | + term = "y_hat", |
| 220 | + scale = "linear" |
| 221 | + ) |
| 222 | + |
| 223 | + # Predict for the treatment arm |
| 224 | + treatment_preds <- predict( |
| 225 | + object = object, |
| 226 | + X = X_1, |
| 227 | + Z = Z_1, |
| 228 | + propensity = propensity_1, |
| 229 | + rfx_group_ids = rfx_group_ids_1, |
| 230 | + rfx_basis = rfx_basis_1, |
| 231 | + type = "posterior", |
| 232 | + term = "y_hat", |
| 233 | + scale = "linear" |
| 234 | + ) |
| 235 | + |
| 236 | + # Transform to probability scale if requested |
| 237 | + if (probability_scale) { |
| 238 | + treatment_preds <- pnorm(treatment_preds) |
| 239 | + control_preds <- pnorm(control_preds) |
| 240 | + } |
| 241 | + |
| 242 | + # Compute and return contrast |
| 243 | + if (predict_mean) { |
| 244 | + return(rowMeans(treatment_preds - control_preds)) |
| 245 | + } else { |
| 246 | + return(treatment_preds - control_preds) |
| 247 | + } |
| 248 | +} |
| 249 | + |
| 250 | + |
1 | 251 | #' Sample from the posterior predictive distribution for outcomes modeled by BCF |
2 | 252 | #' |
3 | 253 | #' @param model_object A fitted BCF model object of class `bcfmodel`. |
|
0 commit comments