Skip to content

Commit c03b2ea

Browse files
committed
Fixed bug in probit + RFX and added contrast computation function for BCF
1 parent f11a0f4 commit c03b2ea

File tree

8 files changed

+756
-19
lines changed

8 files changed

+756
-19
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Description: Flexible stochastic tree ensemble software.
2929
License: MIT + file LICENSE
3030
Encoding: UTF-8
3131
Roxygen: list(markdown = TRUE)
32-
RoxygenNote: 7.3.2
32+
RoxygenNote: 7.3.3
3333
LinkingTo:
3434
cpp11, BH
3535
Suggests:

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export(computeForestLeafVariances)
1212
export(computeForestMaxLeafIndex)
1313
export(compute_bart_posterior_interval)
1414
export(compute_bcf_posterior_interval)
15+
export(compute_contrast_bcf_model)
1516
export(convertPreprocessorToJson)
1617
export(createBARTModelFromCombinedJson)
1718
export(createBARTModelFromCombinedJsonString)

R/bart.R

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,18 +1106,25 @@ bart <- function(
11061106
if (include_mean_forest) {
11071107
if (probit_outcome_model) {
11081108
# Sample latent probit variable, z | -
1109-
forest_pred <- active_forest_mean$predict(
1109+
outcome_pred <- active_forest_mean$predict(
11101110
forest_dataset_train
11111111
)
1112-
mu0 <- forest_pred[y_train == 0]
1113-
mu1 <- forest_pred[y_train == 1]
1112+
if (has_rfx) {
1113+
rfx_pred <- rfx_model$predict(
1114+
rfx_dataset_train,
1115+
rfx_tracker_train
1116+
)
1117+
outcome_pred <- outcome_pred + rfx_pred
1118+
}
1119+
mu0 <- outcome_pred[y_train == 0]
1120+
mu1 <- outcome_pred[y_train == 1]
11141121
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
11151122
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
11161123
resid_train[y_train == 0] <- mu0 + qnorm(u0)
11171124
resid_train[y_train == 1] <- mu1 + qnorm(u1)
11181125

11191126
# Update outcome
1120-
outcome_train$update_data(resid_train - forest_pred)
1127+
outcome_train$update_data(resid_train - outcome_pred)
11211128
}
11221129

11231130
# Sample mean forest
@@ -1467,18 +1474,25 @@ bart <- function(
14671474
if (include_mean_forest) {
14681475
if (probit_outcome_model) {
14691476
# Sample latent probit variable, z | -
1470-
forest_pred <- active_forest_mean$predict(
1477+
outcome_pred <- active_forest_mean$predict(
14711478
forest_dataset_train
14721479
)
1473-
mu0 <- forest_pred[y_train == 0]
1474-
mu1 <- forest_pred[y_train == 1]
1480+
if (has_rfx) {
1481+
rfx_pred <- rfx_model$predict(
1482+
rfx_dataset_train,
1483+
rfx_tracker_train
1484+
)
1485+
outcome_pred <- outcome_pred + rfx_pred
1486+
}
1487+
mu0 <- outcome_pred[y_train == 0]
1488+
mu1 <- outcome_pred[y_train == 1]
14751489
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
14761490
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
14771491
resid_train[y_train == 0] <- mu0 + qnorm(u0)
14781492
resid_train[y_train == 1] <- mu1 + qnorm(u1)
14791493

14801494
# Update outcome
1481-
outcome_train$update_data(resid_train - forest_pred)
1495+
outcome_train$update_data(resid_train - outcome_pred)
14821496
}
14831497

14841498
forest_model_mean$sample_one_iteration(

R/bcf.R

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,16 +1460,23 @@ bcf <- function(
14601460
tau_forest_pred <- active_forest_tau$predict(
14611461
forest_dataset_train
14621462
)
1463-
forest_pred <- mu_forest_pred + tau_forest_pred
1464-
mu0 <- forest_pred[y_train == 0]
1465-
mu1 <- forest_pred[y_train == 1]
1463+
outcome_pred <- mu_forest_pred + tau_forest_pred
1464+
if (has_rfx) {
1465+
rfx_pred <- rfx_model$predict(
1466+
rfx_dataset_train,
1467+
rfx_tracker_train
1468+
)
1469+
outcome_pred <- outcome_pred + rfx_pred
1470+
}
1471+
mu0 <- outcome_pred[y_train == 0]
1472+
mu1 <- outcome_pred[y_train == 1]
14661473
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
14671474
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
14681475
resid_train[y_train == 0] <- mu0 + qnorm(u0)
14691476
resid_train[y_train == 1] <- mu1 + qnorm(u1)
14701477

14711478
# Update outcome
1472-
outcome_train$update_data(resid_train - forest_pred)
1479+
outcome_train$update_data(resid_train - outcome_pred)
14731480
}
14741481

14751482
# Sample the prognostic forest
@@ -2057,16 +2064,23 @@ bcf <- function(
20572064
tau_forest_pred <- active_forest_tau$predict(
20582065
forest_dataset_train
20592066
)
2060-
forest_pred <- mu_forest_pred + tau_forest_pred
2061-
mu0 <- forest_pred[y_train == 0]
2062-
mu1 <- forest_pred[y_train == 1]
2067+
outcome_pred <- mu_forest_pred + tau_forest_pred
2068+
if (has_rfx) {
2069+
rfx_pred <- rfx_model$predict(
2070+
rfx_dataset_train,
2071+
rfx_tracker_train
2072+
)
2073+
outcome_pred <- outcome_pred + rfx_pred
2074+
}
2075+
mu0 <- outcome_pred[y_train == 0]
2076+
mu1 <- outcome_pred[y_train == 1]
20632077
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
20642078
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
20652079
resid_train[y_train == 0] <- mu0 + qnorm(u0)
20662080
resid_train[y_train == 1] <- mu1 + qnorm(u1)
20672081

20682082
# Update outcome
2069-
outcome_train$update_data(resid_train - forest_pred)
2083+
outcome_train$update_data(resid_train - outcome_pred)
20702084
}
20712085

20722086
# Sample the prognostic forest
@@ -2771,7 +2785,7 @@ predict.bcfmodel <- function(
27712785
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
27722786
if (sum(is.na(group_ids_factor)) > 0) {
27732787
stop(
2774-
"All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train"
2788+
"All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids"
27752789
)
27762790
}
27772791
rfx_group_ids <- as.integer(group_ids_factor)

R/posterior_transformation.R

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,253 @@
1+
#' Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference.
2+
#' For simple BCF models with binary treatment, this will yield the same prediction as requesting `terms = "cate"`
3+
#' in the `predict.bcfmodel` function. For more general models, such as models with continuous / multivariate treatments or
4+
#' an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a
5+
#' any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term
6+
#' contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend of the
7+
#' contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" terminology of a classic
8+
#' two-arm experiment. We mirror the function calls and terminology of the `predict.bcfmodel` function, labeling each prediction
9+
#' data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the
10+
#' control prediction.
11+
#'
12+
#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs.
13+
#' @param X_0 Covariates used for prediction in the "control" case.
14+
#' @param X_1 Covariates used for prediction in the "treatment" case.
15+
#' @param Z_0 Treatments used for prediction in the "control" case.
16+
#' @param Z_1 Treatments used for prediction in the "treatment" case.
17+
#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case.
18+
#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case.
19+
#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
20+
#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation
21+
#' for group labels that were not in the training set.
22+
#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects
23+
#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation
24+
#' for group labels that were not in the training set.
25+
#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case.
26+
#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case.
27+
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
28+
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
29+
#' @param ... (Optional) Other prediction parameters.
30+
#'
31+
#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
32+
#' @export
33+
#'
34+
#' @examples
35+
#' n <- 500
36+
#' p <- 5
37+
#' X <- matrix(runif(n*p), ncol = p)
38+
#' mu_x <- (
39+
#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
40+
#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
41+
#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
42+
#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
43+
#' )
44+
#' pi_x <- (
45+
#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
46+
#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
47+
#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
48+
#' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
49+
#' )
50+
#' tau_x <- (
51+
#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
52+
#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
53+
#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
54+
#' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
55+
#' )
56+
#' Z <- rbinom(n, 1, pi_x)
57+
#' noise_sd <- 1
58+
#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd)
59+
#' test_set_pct <- 0.2
60+
#' n_test <- round(test_set_pct*n)
61+
#' n_train <- n - n_test
62+
#' test_inds <- sort(sample(1:n, n_test, replace = FALSE))
63+
#' train_inds <- (1:n)[!((1:n) %in% test_inds)]
64+
#' X_test <- X[test_inds,]
65+
#' X_train <- X[train_inds,]
66+
#' pi_test <- pi_x[test_inds]
67+
#' pi_train <- pi_x[train_inds]
68+
#' Z_test <- Z[test_inds]
69+
#' Z_train <- Z[train_inds]
70+
#' y_test <- y[test_inds]
71+
#' y_train <- y[train_inds]
72+
#' mu_test <- mu_x[test_inds]
73+
#' mu_train <- mu_x[train_inds]
74+
#' tau_test <- tau_x[test_inds]
75+
#' tau_train <- tau_x[train_inds]
76+
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
77+
#' propensity_train = pi_train, num_gfr = 10,
78+
#' num_burnin = 0, num_mcmc = 10)
79+
#' preds <- compute_posterior_contrast_bcf_model(
80+
#' bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test),
81+
#' propensity_0 = pi_test, propensity_1 = pi_test
82+
#' )
83+
compute_contrast_bcf_model <- function(
84+
object,
85+
X_0,
86+
X_1,
87+
Z_0,
88+
Z_1,
89+
propensity_0 = NULL,
90+
propensity_1 = NULL,
91+
rfx_group_ids_0 = NULL,
92+
rfx_group_ids_1 = NULL,
93+
rfx_basis_0 = NULL,
94+
rfx_basis_1 = NULL,
95+
type = "posterior",
96+
scale = "linear",
97+
...
98+
) {
99+
# Handle mean function scale
100+
if (!is.character(scale)) {
101+
stop("scale must be a string or character vector")
102+
}
103+
if (!(scale %in% c("linear", "probability"))) {
104+
stop("scale must either be 'linear' or 'probability'")
105+
}
106+
is_probit <- object$model_params$probit_outcome_model
107+
if ((scale == "probability") && (!is_probit)) {
108+
stop(
109+
"scale cannot be 'probability' for models not fit with a probit outcome model"
110+
)
111+
}
112+
probability_scale <- scale == "probability"
113+
114+
# Handle prediction type
115+
if (!is.character(type)) {
116+
stop("type must be a string or character vector")
117+
}
118+
if (!(type %in% c("mean", "posterior"))) {
119+
stop("type must either be 'mean' or 'posterior")
120+
}
121+
predict_mean <- type == "mean"
122+
123+
# Make sure covariates are matrix or data frame
124+
if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) {
125+
stop("X_0 must be a matrix or dataframe")
126+
}
127+
if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) {
128+
stop("X_1 must be a matrix or dataframe")
129+
}
130+
131+
# Convert all input data to matrices if not already converted
132+
if ((is.null(dim(Z_0))) && (!is.null(Z_0))) {
133+
Z_0 <- as.matrix(as.numeric(Z_0))
134+
}
135+
if ((is.null(dim(Z_1))) && (!is.null(Z_1))) {
136+
Z_1 <- as.matrix(as.numeric(Z_1))
137+
}
138+
if ((is.null(dim(propensity_0))) && (!is.null(propensity_0))) {
139+
propensity_0 <- as.matrix(propensity_0)
140+
}
141+
if ((is.null(dim(propensity_1))) && (!is.null(propensity_1))) {
142+
propensity_1 <- as.matrix(propensity_1)
143+
}
144+
if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) {
145+
rfx_basis_0 <- as.matrix(rfx_basis_0)
146+
}
147+
if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) {
148+
rfx_basis_1 <- as.matrix(rfx_basis_1)
149+
}
150+
151+
# Data checks
152+
if (
153+
(object$model_params$propensity_covariate != "none") &&
154+
((is.null(propensity_0)) ||
155+
(is.null(propensity_1)))
156+
) {
157+
if (!object$model_params$internal_propensity_model) {
158+
stop("propensity_0 and propensity_1 must be provided for this model")
159+
}
160+
}
161+
if (nrow(X_0) != nrow(Z_0)) {
162+
stop("X_0 and Z_0 must have the same number of rows")
163+
}
164+
if (nrow(X_1) != nrow(Z_1)) {
165+
stop("X_1 and Z_1 must have the same number of rows")
166+
}
167+
if (object$model_params$num_covariates != ncol(X_0)) {
168+
stop(
169+
"X_0 and must have the same number of columns as the covariates used to train the model"
170+
)
171+
}
172+
if (object$model_params$num_covariates != ncol(X_1)) {
173+
stop(
174+
"X_1 and must have the same number of columns as the covariates used to train the model"
175+
)
176+
}
177+
if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_0))) {
178+
stop(
179+
"Random effect group labels (rfx_group_ids_0) must be provided for this model"
180+
)
181+
}
182+
if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_1))) {
183+
stop(
184+
"Random effect group labels (rfx_group_ids_1) must be provided for this model"
185+
)
186+
}
187+
if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_0))) {
188+
stop("Random effects basis (rfx_basis_0) must be provided for this model")
189+
}
190+
if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_1))) {
191+
stop("Random effects basis (rfx_basis_1) must be provided for this model")
192+
}
193+
if (
194+
(object$model_params$num_rfx_basis > 0) &&
195+
(ncol(rfx_basis_0) != object$model_params$num_rfx_basis)
196+
) {
197+
stop(
198+
"Random effects basis has a different dimension than the basis used to train this model"
199+
)
200+
}
201+
if (
202+
(object$model_params$num_rfx_basis > 0) &&
203+
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis)
204+
) {
205+
stop(
206+
"Random effects basis has a different dimension than the basis used to train this model"
207+
)
208+
}
209+
210+
# Predict for the control arm
211+
control_preds <- predict(
212+
object = object,
213+
X = X_0,
214+
Z = Z_0,
215+
propensity = propensity_0,
216+
rfx_group_ids = rfx_group_ids_0,
217+
rfx_basis = rfx_basis_0,
218+
type = "posterior",
219+
term = "y_hat",
220+
scale = "linear"
221+
)
222+
223+
# Predict for the treatment arm
224+
treatment_preds <- predict(
225+
object = object,
226+
X = X_1,
227+
Z = Z_1,
228+
propensity = propensity_1,
229+
rfx_group_ids = rfx_group_ids_1,
230+
rfx_basis = rfx_basis_1,
231+
type = "posterior",
232+
term = "y_hat",
233+
scale = "linear"
234+
)
235+
236+
# Transform to probability scale if requested
237+
if (probability_scale) {
238+
treatment_preds <- pnorm(treatment_preds)
239+
control_preds <- pnorm(control_preds)
240+
}
241+
242+
# Compute and return contrast
243+
if (predict_mean) {
244+
return(rowMeans(treatment_preds - control_preds))
245+
} else {
246+
return(treatment_preds - control_preds)
247+
}
248+
}
249+
250+
1251
#' Sample from the posterior predictive distribution for outcomes modeled by BCF
2252
#'
3253
#' @param model_object A fitted BCF model object of class `bcfmodel`.

0 commit comments

Comments
 (0)