Skip to content

Commit 5d6535c

Browse files
committed
Updated R BART to include intercept only RFX specification
1 parent 3c2d4ce commit 5d6535c

File tree

6 files changed

+123
-26
lines changed

6 files changed

+123
-26
lines changed

R/bart.R

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#'
8080
#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional.
8181
#'
82+
#' - `model_spec` Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
8283
#' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
8384
#' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
8485
#' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
@@ -198,6 +199,7 @@ bart <- function(
198199

199200
# Update rfx parameters
200201
rfx_params_default <- list(
202+
model_spec = "custom",
201203
working_parameter_prior_mean = NULL,
202204
group_parameter_prior_mean = NULL,
203205
working_parameter_prior_cov = NULL,
@@ -257,6 +259,7 @@ bart <- function(
257259
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample
258260

259261
# 4. RFX parameters
262+
rfx_model_spec <- rfx_params_updated$model_spec
260263
rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean
261264
rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean
262265
rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov
@@ -614,35 +617,43 @@ bart <- function(
614617
}
615618
}
616619

617-
# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
620+
# Handle the rfx basis matrices
618621
has_basis_rfx <- FALSE
619622
num_basis_rfx <- 0
620623
if (has_rfx) {
621-
if (is.null(rfx_basis_train)) {
624+
if (rfx_model_spec == "custom") {
625+
if (is.null(rfx_basis_train)) {
626+
stop(
627+
"A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'"
628+
)
629+
}
630+
has_basis_rfx <- TRUE
631+
num_basis_rfx <- ncol(rfx_basis_train)
632+
} else if (rfx_model_spec == "intercept_only") {
622633
rfx_basis_train <- matrix(
623634
rep(1, nrow(X_train)),
624635
nrow = nrow(X_train),
625636
ncol = 1
626637
)
627-
} else {
628638
has_basis_rfx <- TRUE
629-
num_basis_rfx <- ncol(rfx_basis_train)
639+
num_basis_rfx <- 1
630640
}
631641
num_rfx_groups <- length(unique(rfx_group_ids_train))
632642
num_rfx_components <- ncol(rfx_basis_train)
633643
if (num_rfx_groups == 1) {
634644
warning(
635-
"Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill"
645+
"Only one group was provided for random effect sampling, so the random effects model is likely overkill"
636646
)
637647
}
638648
}
639649
if (has_rfx_test) {
640-
if (is.null(rfx_basis_test)) {
641-
if (has_basis_rfx) {
650+
if (rfx_model_spec == "custom") {
651+
if (is.null(rfx_basis_test)) {
642652
stop(
643-
"Random effects basis provided for training set, must also be provided for the test set"
653+
"A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'"
644654
)
645655
}
656+
} else if (rfx_model_spec == "intercept_only") {
646657
rfx_basis_test <- matrix(
647658
rep(1, nrow(X_test)),
648659
nrow = nrow(X_test),
@@ -1744,7 +1755,8 @@ bart <- function(
17441755
"sample_sigma2_leaf" = sample_sigma2_leaf,
17451756
"include_mean_forest" = include_mean_forest,
17461757
"include_variance_forest" = include_variance_forest,
1747-
"probit_outcome_model" = probit_outcome_model
1758+
"probit_outcome_model" = probit_outcome_model,
1759+
"rfx_model_spec" = rfx_model_spec
17481760
)
17491761
result <- list(
17501762
"model_params" = model_params,
@@ -1878,6 +1890,8 @@ predict.bartmodel <- function(
18781890
predict_mean <- type == "mean"
18791891

18801892
# Handle prediction terms
1893+
rfx_model_spec <- object$model_params$rfx_model_spec
1894+
rfx_intercept <- rfx_model_spec == "intercept_only"
18811895
if (!is.character(terms)) {
18821896
stop("type must be a string or character vector")
18831897
}
@@ -1954,16 +1968,17 @@ predict.bartmodel <- function(
19541968
"Random effect group labels (rfx_group_ids) must be provided for this model"
19551969
)
19561970
}
1957-
if ((predict_rfx) && (is.null(rfx_basis))) {
1971+
if ((predict_rfx) && (is.null(rfx_basis)) && (!rfx_intercept)) {
19581972
stop("Random effects basis (rfx_basis) must be provided for this model")
19591973
}
19601974
if (
1961-
(object$model_params$num_rfx_basis > 0) &&
1962-
(ncol(rfx_basis) != object$model_params$num_rfx_basis)
1975+
(object$model_params$num_rfx_basis > 0) && (!rfx_intercept)
19631976
) {
1964-
stop(
1965-
"Random effects basis has a different dimension than the basis used to train this model"
1966-
)
1977+
if (ncol(rfx_basis) != object$model_params$num_rfx_basis) {
1978+
stop(
1979+
"Random effects basis has a different dimension than the basis used to train this model"
1980+
)
1981+
}
19671982
}
19681983

19691984
# Preprocess covariates
@@ -1986,11 +2001,26 @@ predict.bartmodel <- function(
19862001
}
19872002
}
19882003

1989-
# Produce basis for the "intercept-only" random effects case
1990-
if ((predict_rfx) && (is.null(rfx_basis))) {
1991-
rfx_basis <- matrix(rep(1, nrow(covariates)), ncol = 1)
2004+
# Handle RFX model specification
2005+
if (has_rfx) {
2006+
if (object$model_params$rfx_model_spec == "custom") {
2007+
if (is.null(rfx_basis)) {
2008+
stop(
2009+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2010+
)
2011+
}
2012+
} else if (object$model_params$rfx_model_spec == "intercept_only") {
2013+
# Only construct a basis if user-provided basis missing
2014+
if (is.null(rfx_basis)) {
2015+
rfx_basis <- matrix(
2016+
rep(1, nrow(covariates)),
2017+
nrow = nrow(covariates),
2018+
ncol = 1
2019+
)
2020+
}
2021+
}
19922022
}
1993-
2023+
19942024
# Create prediction dataset
19952025
if (!is.null(leaf_basis)) {
19962026
prediction_dataset <- createForestDataset(covariates, leaf_basis)
@@ -2033,11 +2063,40 @@ predict.bartmodel <- function(
20332063

20342064
# Compute rfx predictions (if needed)
20352065
if (predict_rfx || predict_rfx_intermediate) {
2036-
rfx_predictions <- object$rfx_samples$predict(
2037-
rfx_group_ids,
2038-
rfx_basis
2039-
) *
2040-
y_std
2066+
if (!is.null(rfx_basis)) {
2067+
rfx_predictions <- object$rfx_samples$predict(
2068+
rfx_group_ids,
2069+
rfx_basis
2070+
) *
2071+
y_std
2072+
} else {
2073+
# Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only"
2074+
if (!rfx_intercept) {
2075+
stop("rfx_basis must be provided for random effects models with random slopes")
2076+
}
2077+
2078+
# Extract the raw RFX samples and scale by train set outcome standard deviation
2079+
rfx_param_list <- object$rfx_samples$extract_parameter_samples()
2080+
rfx_beta_draws <- rfx_param_list$beta_samples * y_std
2081+
2082+
# Construct a matrix with the appropriate group random effects arranged for each observation
2083+
rfx_predictions_raw <- array(
2084+
NA,
2085+
dim = c(
2086+
nrow(X),
2087+
ncol(rfx_basis),
2088+
object$model_params$num_samples
2089+
)
2090+
)
2091+
for (i in 1:nrow(X)) {
2092+
rfx_predictions_raw[i, , ] <-
2093+
rfx_beta_draws[, rfx_group_ids[i], ]
2094+
}
2095+
2096+
# Intercept-only model, so the random effect prediction is simply the
2097+
# value of the respective group's intercept coefficient for each observation
2098+
rfx_predictions = rfx_predictions_raw[, 1, ]
2099+
}
20412100
}
20422101

20432102
# Combine into y hat predictions
@@ -2310,6 +2369,10 @@ saveBARTModelToJson <- function(object) {
23102369
"probit_outcome_model",
23112370
object$model_params$probit_outcome_model
23122371
)
2372+
jsonobj$add_string(
2373+
"rfx_model_spec",
2374+
object$model_params$rfx_model_spec
2375+
)
23132376
if (object$model_params$sample_sigma2_global) {
23142377
jsonobj$add_vector(
23152378
"sigma2_global_samples",
@@ -2554,6 +2617,9 @@ createBARTModelFromJson <- function(json_object) {
25542617
model_params[["probit_outcome_model"]] <- json_object$get_boolean(
25552618
"probit_outcome_model"
25562619
)
2620+
model_params[["rfx_model_spec"]] <- json_object$get_string(
2621+
"rfx_model_spec"
2622+
)
25572623

25582624
output[["model_params"]] <- model_params
25592625

@@ -2825,6 +2891,9 @@ createBARTModelFromCombinedJson <- function(json_object_list) {
28252891
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean(
28262892
"probit_outcome_model"
28272893
)
2894+
model_params[["rfx_model_spec"]] <- json_object_default$get_string(
2895+
"rfx_model_spec"
2896+
)
28282897
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
28292898
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
28302899

@@ -3066,6 +3135,9 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) {
30663135
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean(
30673136
"probit_outcome_model"
30683137
)
3138+
model_params[["rfx_model_spec"]] <- json_object_default$get_string(
3139+
"rfx_model_spec"
3140+
)
30693141

30703142
# Combine values that are sample-specific
30713143
for (i in 1:length(json_object_list)) {

man/RandomEffectSamples.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/bart.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/bcf.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict.bcfmodel.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

test/R/testthat/test-bart.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,4 +561,27 @@ test_that("Random Effects BART", {
561561
random_effects_params = rfx_param_list
562562
)
563563
)
564+
565+
# Specify simpler intercept-only RFX model
566+
rfx_param_list <- list(
567+
model_spec = "intercept_only"
568+
)
569+
mean_forest_param_list <- list(sample_sigma2_leaf = FALSE)
570+
expect_no_error({
571+
bart_model <- bart(
572+
X_train = X_train,
573+
y_train = y_train,
574+
X_test = X_test,
575+
leaf_basis_train = W_train,
576+
leaf_basis_test = W_test,
577+
rfx_group_ids_train = rfx_group_ids_train,
578+
rfx_group_ids_test = rfx_group_ids_test,
579+
num_gfr = 0,
580+
num_burnin = 10,
581+
num_mcmc = 10,
582+
mean_forest_params = mean_forest_param_list,
583+
random_effects_params = rfx_param_list
584+
)
585+
preds <- predict(bart_model, covariates = X_test, leaf_basis = W_test, rfx_group_ids = rfx_group_ids_test, type = "posterior", terms = "rfx")
586+
})
564587
})

0 commit comments

Comments
 (0)