Skip to content

Commit 2cd3775

Browse files
committed
Added contrast function for BART
1 parent c03b2ea commit 2cd3775

File tree

7 files changed

+505
-24
lines changed

7 files changed

+505
-24
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export(computeForestLeafVariances)
1212
export(computeForestMaxLeafIndex)
1313
export(compute_bart_posterior_interval)
1414
export(compute_bcf_posterior_interval)
15+
export(compute_contrast_bart_model)
1516
export(compute_contrast_bcf_model)
1617
export(convertPreprocessorToJson)
1718
export(createBARTModelFromCombinedJson)

R/bcf.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2591,7 +2591,7 @@ bcf <- function(
25912591
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
25922592
#' that were not in the training set.
25932593
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model.
2594-
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
2594+
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
25952595
#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
25962596
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
25972597
#' @param ... (Optional) Other prediction parameters.

R/posterior_transformation.R

Lines changed: 224 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010
#' control prediction.
1111
#'
1212
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
13-
#' @param X_0 Covariates used for prediction in the "control" case.
14-
#' @param X_1 Covariates used for prediction in the "treatment" case.
15-
#' @param Z_0 Treatments used for prediction in the "control" case.
16-
#' @param Z_1 Treatments used for prediction in the "treatment" case.
17-
#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case.
18-
#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case.
13+
#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
14+
#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
15+
#' @param Z_0 Treatments used for prediction in the "control" case. Must be a matrix or vector.
16+
#' @param Z_1 Treatments used for prediction in the "treatment" case. Must be a matrix or vector.
17+
#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector.
18+
#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector.
1919
#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
2020
#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation
21-
#' for group labels that were not in the training set.
21+
#' for group labels that were not in the training set. Must be a vector.
2222
#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects
2323
#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation
24-
#' for group labels that were not in the training set.
25-
#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case.
26-
#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case.
27-
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
24+
#' for group labels that were not in the training set. Must be a vector.
25+
#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.
26+
#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.
27+
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".
2828
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
2929
#' @param ... (Optional) Other prediction parameters.
3030
#'
@@ -247,6 +247,219 @@ compute_contrast_bcf_model <- function(
247247
}
248248
}
249249

250+
#' Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference.
251+
#' This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects
252+
#' bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or
253+
#' `Y0` term and the minuend of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
254+
#' terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the `predict.bartmodel`
255+
#' function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and
256+
#' `0` to denote inclusion in the control prediction.
257+
#'
258+
#' Only valid when there is either a mean forest or a random effects term in the BART model.
259+
#'
260+
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
261+
#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
262+
#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
263+
#' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`.
264+
#' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`.
265+
#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
266+
#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation
267+
#' for group labels that were not in the training set. Must be a vector.
268+
#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects
269+
#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation
270+
#' for group labels that were not in the training set. Must be a vector.
271+
#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.
272+
#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.
273+
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".
274+
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
275+
#'
276+
#' @return Contrast matrix or vector, depending on whether type = "mean" or "posterior".
277+
#' @export
278+
#'
279+
#' @examples
280+
#' n <- 100
281+
#' p <- 5
282+
#' X <- matrix(runif(n*p), ncol = p)
283+
#' W <- matrix(runif(n*1), ncol = 1)
284+
#' f_XW <- (
285+
#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
286+
#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
287+
#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
288+
#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
289+
#' )
290+
#' noise_sd <- 1
291+
#' y <- f_XW + rnorm(n, 0, noise_sd)
292+
#' test_set_pct <- 0.2
293+
#' n_test <- round(test_set_pct*n)
294+
#' n_train <- n - n_test
295+
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
296+
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
297+
#' X_test <- X[test_inds,]
298+
#' X_train <- X[train_inds,]
299+
#' W_test <- W[test_inds,]
300+
#' W_train <- W[train_inds,]
301+
#' y_test <- y[test_inds]
302+
#' y_train <- y[train_inds]
303+
#' bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train,
304+
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
305+
#' contrast <- compute_contrast_bart_model(
306+
#' bart_model,
307+
#' covariates_0 = X_test,
308+
#' covariates_1 = X_test,
309+
#' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
310+
#' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
311+
#' type = "posterior",
312+
#' scale = "linear"
313+
#' )
314+
compute_contrast_bart_model <- function(
315+
object,
316+
covariates_0,
317+
covariates_1,
318+
leaf_basis_0 = NULL,
319+
leaf_basis_1 = NULL,
320+
rfx_group_ids_0 = NULL,
321+
rfx_group_ids_1 = NULL,
322+
rfx_basis_0 = NULL,
323+
rfx_basis_1 = NULL,
324+
type = "posterior",
325+
scale = "linear",
326+
...
327+
) {
328+
# Handle mean function scale
329+
if (!is.character(scale)) {
330+
stop("scale must be a string or character vector")
331+
}
332+
if (!(scale %in% c("linear", "probability"))) {
333+
stop("scale must either be 'linear' or 'probability'")
334+
}
335+
is_probit <- object$model_params$probit_outcome_model
336+
if ((scale == "probability") && (!is_probit)) {
337+
stop(
338+
"scale cannot be 'probability' for models not fit with a probit outcome model"
339+
)
340+
}
341+
probability_scale <- scale == "probability"
342+
343+
# Handle prediction type
344+
if (!is.character(type)) {
345+
stop("type must be a string or character vector")
346+
}
347+
if (!(type %in% c("mean", "posterior"))) {
348+
stop("type must either be 'mean' or 'posterior'")
349+
}
350+
predict_mean <- type == "mean"
351+
352+
# Handle prediction terms
353+
has_mean_forest <- object$model_params$include_mean_forest
354+
has_rfx <- object$model_params$has_rfx
355+
if ((!has_mean_forest) && (!has_rfx)) {
356+
stop(
357+
"Model must have either or both of mean forest or random effects terms to compute the requested contrast."
358+
)
359+
}
360+
361+
# Check that covariates are matrix or data frame
362+
if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) {
363+
stop("covariates_0 must be a matrix or dataframe")
364+
}
365+
if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) {
366+
stop("covariates_1 must be a matrix or dataframe")
367+
}
368+
369+
# Convert all input data to matrices if not already converted
370+
if ((is.null(dim(leaf_basis_0))) && (!is.null(leaf_basis_0))) {
371+
leaf_basis_0 <- as.matrix(leaf_basis_0)
372+
}
373+
if ((is.null(dim(leaf_basis_1))) && (!is.null(leaf_basis_1))) {
374+
leaf_basis_1 <- as.matrix(leaf_basis_1)
375+
}
376+
if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) {
377+
rfx_basis_0 <- as.matrix(rfx_basis_0)
378+
}
379+
if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) {
380+
rfx_basis_1 <- as.matrix(rfx_basis_1)
381+
}
382+
383+
# Data checks
384+
if (
385+
(object$model_params$requires_basis) &&
386+
(is.null(leaf_basis_0) || is.null(leaf_basis_1))
387+
) {
388+
stop("leaf_basis_0 and leaf_basis_1 must be provided for this model")
389+
}
390+
if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) {
391+
stop("covariates_0 and leaf_basis_0 must have the same number of rows")
392+
}
393+
if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) {
394+
stop("covariates_1 and leaf_basis_1 must have the same number of rows")
395+
}
396+
if (object$model_params$num_covariates != ncol(covariates_0)) {
397+
stop(
398+
"covariates_0 must contain the same number of columns as the BART model's training dataset"
399+
)
400+
}
401+
if (object$model_params$num_covariates != ncol(covariates_1)) {
402+
stop(
403+
"covariates_1 must contain the same number of columns as the BART model's training dataset"
404+
)
405+
}
406+
if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) {
407+
stop(
408+
"rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model"
409+
)
410+
}
411+
if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
412+
stop(
413+
"rfx_basis_0 and rfx_basis_1 must be provided for this model"
414+
)
415+
}
416+
if (
417+
(object$model_params$num_rfx_basis > 0) &&
418+
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
419+
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
420+
) {
421+
stop(
422+
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
423+
)
424+
}
425+
426+
# Predict for the control arm
427+
control_preds <- predict(
428+
object = object,
429+
covariates = covariates_0,
430+
leaf_basis = leaf_basis_0,
431+
rfx_group_ids = rfx_group_ids_0,
432+
rfx_basis = rfx_basis_0,
433+
type = "posterior",
434+
term = "y_hat",
435+
scale = "linear"
436+
)
437+
438+
# Predict for the treatment arm
439+
treatment_preds <- predict(
440+
object = object,
441+
covariates = covariates_1,
442+
leaf_basis = leaf_basis_1,
443+
rfx_group_ids = rfx_group_ids_1,
444+
rfx_basis = rfx_basis_1,
445+
type = "posterior",
446+
term = "y_hat",
447+
scale = "linear"
448+
)
449+
450+
# Transform to probability scale if requested
451+
if (probability_scale) {
452+
treatment_preds <- pnorm(treatment_preds)
453+
control_preds <- pnorm(control_preds)
454+
}
455+
456+
# Compute and return contrast
457+
if (predict_mean) {
458+
return(rowMeans(treatment_preds - control_preds))
459+
} else {
460+
return(treatment_preds - control_preds)
461+
}
462+
}
250463

251464
#' Sample from the posterior predictive distribution for outcomes modeled by BCF
252465
#'

man/compute_contrast_bart_model.Rd

Lines changed: 96 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)