|
10 | 10 | #' control prediction. |
11 | 11 | #' |
12 | 12 | #' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs. |
13 | | -#' @param X_0 Covariates used for prediction in the "control" case. |
14 | | -#' @param X_1 Covariates used for prediction in the "treatment" case. |
15 | | -#' @param Z_0 Treatments used for prediction in the "control" case. |
16 | | -#' @param Z_1 Treatments used for prediction in the "treatment" case. |
17 | | -#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. |
18 | | -#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. |
| 13 | +#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. |
| 14 | +#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. |
| 15 | +#' @param Z_0 Treatments used for prediction in the "control" case. Must be a matrix or vector. |
| 16 | +#' @param Z_1 Treatments used for prediction in the "treatment" case. Must be a matrix or vector. |
| 17 | +#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector. |
| 18 | +#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector. |
19 | 19 | #' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects |
20 | 20 | #' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation |
21 | | -#' for group labels that were not in the training set. |
| 21 | +#' for group labels that were not in the training set. Must be a vector. |
22 | 22 | #' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects |
23 | 23 | #' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation |
24 | | -#' for group labels that were not in the training set. |
25 | | -#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. |
26 | | -#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. |
27 | | -#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". |
| 24 | +#' for group labels that were not in the training set. Must be a vector. |
| 25 | +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. |
| 26 | +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. |
| 27 | +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". |
28 | 28 | #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". |
29 | 29 | #' @param ... (Optional) Other prediction parameters. |
30 | 30 | #' |
@@ -247,6 +247,219 @@ compute_contrast_bcf_model <- function( |
247 | 247 | } |
248 | 248 | } |
249 | 249 |
|
| 250 | +#' Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference. |
| 251 | +#' This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects |
| 252 | +#' bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or |
| 253 | +#' `Y0` term and the minuend of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" |
| 254 | +#' terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the `predict.bartmodel` |
| 255 | +#' function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and |
| 256 | +#' `0` to denote inclusion in the control prediction. |
| 257 | +#' |
| 258 | +#' Only valid when there is either a mean forest or a random effects term in the BART model. |
| 259 | +#' |
| 260 | +#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. |
| 261 | +#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. |
| 262 | +#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. |
| 263 | +#' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`. |
| 264 | +#' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`. |
| 265 | +#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects |
| 266 | +#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation |
| 267 | +#' for group labels that were not in the training set. Must be a vector. |
| 268 | +#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects |
| 269 | +#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation |
| 270 | +#' for group labels that were not in the training set. Must be a vector. |
| 271 | +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. |
| 272 | +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. |
| 273 | +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". |
| 274 | +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". |
| 275 | +#' |
| 276 | +#' @return Contrast matrix or vector, depending on whether type = "mean" or "posterior". |
| 277 | +#' @export |
| 278 | +#' |
| 279 | +#' @examples |
| 280 | +#' n <- 100 |
| 281 | +#' p <- 5 |
| 282 | +#' X <- matrix(runif(n*p), ncol = p) |
| 283 | +#' W <- matrix(runif(n*1), ncol = 1) |
| 284 | +#' f_XW <- ( |
| 285 | +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + |
| 286 | +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + |
| 287 | +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + |
| 288 | +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) |
| 289 | +#' ) |
| 290 | +#' noise_sd <- 1 |
| 291 | +#' y <- f_XW + rnorm(n, 0, noise_sd) |
| 292 | +#' test_set_pct <- 0.2 |
| 293 | +#' n_test <- round(test_set_pct*n) |
| 294 | +#' n_train <- n - n_test |
| 295 | +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) |
| 296 | +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] |
| 297 | +#' X_test <- X[test_inds,] |
| 298 | +#' X_train <- X[train_inds,] |
| 299 | +#' W_test <- W[test_inds,] |
| 300 | +#' W_train <- W[train_inds,] |
| 301 | +#' y_test <- y[test_inds] |
| 302 | +#' y_train <- y[train_inds] |
| 303 | +#' bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, |
| 304 | +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10) |
| 305 | +#' contrast <- compute_contrast_bart_model( |
| 306 | +#' bart_model, |
| 307 | +#' covariates_0 = X_test, |
| 308 | +#' covariates_1 = X_test, |
| 309 | +#' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), |
| 310 | +#' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), |
| 311 | +#' type = "posterior", |
| 312 | +#' scale = "linear" |
| 313 | +#' ) |
| 314 | +compute_contrast_bart_model <- function( |
| 315 | + object, |
| 316 | + covariates_0, |
| 317 | + covariates_1, |
| 318 | + leaf_basis_0 = NULL, |
| 319 | + leaf_basis_1 = NULL, |
| 320 | + rfx_group_ids_0 = NULL, |
| 321 | + rfx_group_ids_1 = NULL, |
| 322 | + rfx_basis_0 = NULL, |
| 323 | + rfx_basis_1 = NULL, |
| 324 | + type = "posterior", |
| 325 | + scale = "linear", |
| 326 | + ... |
| 327 | +) { |
| 328 | + # Handle mean function scale |
| 329 | + if (!is.character(scale)) { |
| 330 | + stop("scale must be a string or character vector") |
| 331 | + } |
| 332 | + if (!(scale %in% c("linear", "probability"))) { |
| 333 | + stop("scale must either be 'linear' or 'probability'") |
| 334 | + } |
| 335 | + is_probit <- object$model_params$probit_outcome_model |
| 336 | + if ((scale == "probability") && (!is_probit)) { |
| 337 | + stop( |
| 338 | + "scale cannot be 'probability' for models not fit with a probit outcome model" |
| 339 | + ) |
| 340 | + } |
| 341 | + probability_scale <- scale == "probability" |
| 342 | + |
| 343 | + # Handle prediction type |
| 344 | + if (!is.character(type)) { |
| 345 | + stop("type must be a string or character vector") |
| 346 | + } |
| 347 | + if (!(type %in% c("mean", "posterior"))) { |
| 348 | + stop("type must either be 'mean' or 'posterior'") |
| 349 | + } |
| 350 | + predict_mean <- type == "mean" |
| 351 | + |
| 352 | + # Handle prediction terms |
| 353 | + has_mean_forest <- object$model_params$include_mean_forest |
| 354 | + has_rfx <- object$model_params$has_rfx |
| 355 | + if ((!has_mean_forest) && (!has_rfx)) { |
| 356 | + stop( |
| 357 | + "Model must have either or both of mean forest or random effects terms to compute the requested contrast." |
| 358 | + ) |
| 359 | + } |
| 360 | + |
| 361 | + # Check that covariates are matrix or data frame |
| 362 | + if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) { |
| 363 | + stop("covariates_0 must be a matrix or dataframe") |
| 364 | + } |
| 365 | + if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) { |
| 366 | + stop("covariates_1 must be a matrix or dataframe") |
| 367 | + } |
| 368 | + |
| 369 | + # Convert all input data to matrices if not already converted |
| 370 | + if ((is.null(dim(leaf_basis_0))) && (!is.null(leaf_basis_0))) { |
| 371 | + leaf_basis_0 <- as.matrix(leaf_basis_0) |
| 372 | + } |
| 373 | + if ((is.null(dim(leaf_basis_1))) && (!is.null(leaf_basis_1))) { |
| 374 | + leaf_basis_1 <- as.matrix(leaf_basis_1) |
| 375 | + } |
| 376 | + if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) { |
| 377 | + rfx_basis_0 <- as.matrix(rfx_basis_0) |
| 378 | + } |
| 379 | + if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) { |
| 380 | + rfx_basis_1 <- as.matrix(rfx_basis_1) |
| 381 | + } |
| 382 | + |
| 383 | + # Data checks |
| 384 | + if ( |
| 385 | + (object$model_params$requires_basis) && |
| 386 | + (is.null(leaf_basis_0) || is.null(leaf_basis_1)) |
| 387 | + ) { |
| 388 | + stop("leaf_basis_0 and leaf_basis_1 must be provided for this model") |
| 389 | + } |
| 390 | + if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) { |
| 391 | + stop("covariates_0 and leaf_basis_0 must have the same number of rows") |
| 392 | + } |
| 393 | + if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) { |
| 394 | + stop("covariates_1 and leaf_basis_1 must have the same number of rows") |
| 395 | + } |
| 396 | + if (object$model_params$num_covariates != ncol(covariates_0)) { |
| 397 | + stop( |
| 398 | + "covariates_0 must contain the same number of columns as the BART model's training dataset" |
| 399 | + ) |
| 400 | + } |
| 401 | + if (object$model_params$num_covariates != ncol(covariates_1)) { |
| 402 | + stop( |
| 403 | + "covariates_1 must contain the same number of columns as the BART model's training dataset" |
| 404 | + ) |
| 405 | + } |
| 406 | + if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) { |
| 407 | + stop( |
| 408 | + "rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model" |
| 409 | + ) |
| 410 | + } |
| 411 | + if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) { |
| 412 | + stop( |
| 413 | + "rfx_basis_0 and rfx_basis_1 must be provided for this model" |
| 414 | + ) |
| 415 | + } |
| 416 | + if ( |
| 417 | + (object$model_params$num_rfx_basis > 0) && |
| 418 | + ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || |
| 419 | + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) |
| 420 | + ) { |
| 421 | + stop( |
| 422 | + "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" |
| 423 | + ) |
| 424 | + } |
| 425 | + |
| 426 | + # Predict for the control arm |
| 427 | + control_preds <- predict( |
| 428 | + object = object, |
| 429 | + covariates = covariates_0, |
| 430 | + leaf_basis = leaf_basis_0, |
| 431 | + rfx_group_ids = rfx_group_ids_0, |
| 432 | + rfx_basis = rfx_basis_0, |
| 433 | + type = "posterior", |
| 434 | + term = "y_hat", |
| 435 | + scale = "linear" |
| 436 | + ) |
| 437 | + |
| 438 | + # Predict for the treatment arm |
| 439 | + treatment_preds <- predict( |
| 440 | + object = object, |
| 441 | + covariates = covariates_1, |
| 442 | + leaf_basis = leaf_basis_1, |
| 443 | + rfx_group_ids = rfx_group_ids_1, |
| 444 | + rfx_basis = rfx_basis_1, |
| 445 | + type = "posterior", |
| 446 | + term = "y_hat", |
| 447 | + scale = "linear" |
| 448 | + ) |
| 449 | + |
| 450 | + # Transform to probability scale if requested |
| 451 | + if (probability_scale) { |
| 452 | + treatment_preds <- pnorm(treatment_preds) |
| 453 | + control_preds <- pnorm(control_preds) |
| 454 | + } |
| 455 | + |
| 456 | + # Compute and return contrast |
| 457 | + if (predict_mean) { |
| 458 | + return(rowMeans(treatment_preds - control_preds)) |
| 459 | + } else { |
| 460 | + return(treatment_preds - control_preds) |
| 461 | + } |
| 462 | +} |
250 | 463 |
|
251 | 464 | #' Sample from the posterior predictive distribution for outcomes modeled by BCF |
252 | 465 | #' |
|
0 commit comments