7979# '
8080# ' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional.
8181# '
82+ # ' - `model_spec` Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
8283# ' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
8384# ' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
8485# ' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
@@ -198,6 +199,7 @@ bart <- function(
198199
199200 # Update rfx parameters
200201 rfx_params_default <- list (
202+ model_spec = " custom" ,
201203 working_parameter_prior_mean = NULL ,
202204 group_parameter_prior_mean = NULL ,
203205 working_parameter_prior_cov = NULL ,
@@ -257,6 +259,7 @@ bart <- function(
257259 num_features_subsample_variance <- variance_forest_params_updated $ num_features_subsample
258260
259261 # 4. RFX parameters
262+ rfx_model_spec <- rfx_params_updated $ model_spec
260263 rfx_working_parameter_prior_mean <- rfx_params_updated $ working_parameter_prior_mean
261264 rfx_group_parameter_prior_mean <- rfx_params_updated $ group_parameter_prior_mean
262265 rfx_working_parameter_prior_cov <- rfx_params_updated $ working_parameter_prior_cov
@@ -614,35 +617,43 @@ bart <- function(
614617 }
615618 }
616619
617- # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
620+ # Handle the rfx basis matrices
618621 has_basis_rfx <- FALSE
619622 num_basis_rfx <- 0
620623 if (has_rfx ) {
621- if (is.null(rfx_basis_train )) {
624+ if (rfx_model_spec == " custom" ) {
625+ if (is.null(rfx_basis_train )) {
626+ stop(
627+ " A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'"
628+ )
629+ }
630+ has_basis_rfx <- TRUE
631+ num_basis_rfx <- ncol(rfx_basis_train )
632+ } else if (rfx_model_spec == " intercept_only" ) {
622633 rfx_basis_train <- matrix (
623634 rep(1 , nrow(X_train )),
624635 nrow = nrow(X_train ),
625636 ncol = 1
626637 )
627- } else {
628638 has_basis_rfx <- TRUE
629- num_basis_rfx <- ncol( rfx_basis_train )
639+ num_basis_rfx <- 1
630640 }
631641 num_rfx_groups <- length(unique(rfx_group_ids_train ))
632642 num_rfx_components <- ncol(rfx_basis_train )
633643 if (num_rfx_groups == 1 ) {
634644 warning(
635- " Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill"
645+ " Only one group was provided for random effect sampling, so the random effects model is likely overkill"
636646 )
637647 }
638648 }
639649 if (has_rfx_test ) {
640- if (is.null( rfx_basis_test ) ) {
641- if (has_basis_rfx ) {
650+ if (rfx_model_spec == " custom " ) {
651+ if (is.null( rfx_basis_test ) ) {
642652 stop(
643- " Random effects basis provided for training set, must also be provided for the test set "
653+ " A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom' "
644654 )
645655 }
656+ } else if (rfx_model_spec == " intercept_only" ) {
646657 rfx_basis_test <- matrix (
647658 rep(1 , nrow(X_test )),
648659 nrow = nrow(X_test ),
@@ -1744,7 +1755,8 @@ bart <- function(
17441755 " sample_sigma2_leaf" = sample_sigma2_leaf ,
17451756 " include_mean_forest" = include_mean_forest ,
17461757 " include_variance_forest" = include_variance_forest ,
1747- " probit_outcome_model" = probit_outcome_model
1758+ " probit_outcome_model" = probit_outcome_model ,
1759+ " rfx_model_spec" = rfx_model_spec
17481760 )
17491761 result <- list (
17501762 " model_params" = model_params ,
@@ -1878,6 +1890,8 @@ predict.bartmodel <- function(
18781890 predict_mean <- type == " mean"
18791891
18801892 # Handle prediction terms
1893+ rfx_model_spec <- object $ model_params $ rfx_model_spec
1894+ rfx_intercept <- rfx_model_spec == " intercept_only"
18811895 if (! is.character(terms )) {
18821896 stop(" type must be a string or character vector" )
18831897 }
@@ -1954,16 +1968,17 @@ predict.bartmodel <- function(
19541968 " Random effect group labels (rfx_group_ids) must be provided for this model"
19551969 )
19561970 }
1957- if ((predict_rfx ) && (is.null(rfx_basis ))) {
1971+ if ((predict_rfx ) && (is.null(rfx_basis )) && ( ! rfx_intercept ) ) {
19581972 stop(" Random effects basis (rfx_basis) must be provided for this model" )
19591973 }
19601974 if (
1961- (object $ model_params $ num_rfx_basis > 0 ) &&
1962- (ncol(rfx_basis ) != object $ model_params $ num_rfx_basis )
1975+ (object $ model_params $ num_rfx_basis > 0 ) && (! rfx_intercept )
19631976 ) {
1964- stop(
1965- " Random effects basis has a different dimension than the basis used to train this model"
1966- )
1977+ if (ncol(rfx_basis ) != object $ model_params $ num_rfx_basis ) {
1978+ stop(
1979+ " Random effects basis has a different dimension than the basis used to train this model"
1980+ )
1981+ }
19671982 }
19681983
19691984 # Preprocess covariates
@@ -1986,11 +2001,26 @@ predict.bartmodel <- function(
19862001 }
19872002 }
19882003
1989- # Produce basis for the "intercept-only" random effects case
1990- if ((predict_rfx ) && (is.null(rfx_basis ))) {
1991- rfx_basis <- matrix (rep(1 , nrow(covariates )), ncol = 1 )
2004+ # Handle RFX model specification
2005+ if (has_rfx ) {
2006+ if (object $ model_params $ rfx_model_spec == " custom" ) {
2007+ if (is.null(rfx_basis )) {
2008+ stop(
2009+ " A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2010+ )
2011+ }
2012+ } else if (object $ model_params $ rfx_model_spec == " intercept_only" ) {
2013+ # Only construct a basis if user-provided basis missing
2014+ if (is.null(rfx_basis )) {
2015+ rfx_basis <- matrix (
2016+ rep(1 , nrow(covariates )),
2017+ nrow = nrow(covariates ),
2018+ ncol = 1
2019+ )
2020+ }
2021+ }
19922022 }
1993-
2023+
19942024 # Create prediction dataset
19952025 if (! is.null(leaf_basis )) {
19962026 prediction_dataset <- createForestDataset(covariates , leaf_basis )
@@ -2033,11 +2063,40 @@ predict.bartmodel <- function(
20332063
20342064 # Compute rfx predictions (if needed)
20352065 if (predict_rfx || predict_rfx_intermediate ) {
2036- rfx_predictions <- object $ rfx_samples $ predict(
2037- rfx_group_ids ,
2038- rfx_basis
2039- ) *
2040- y_std
2066+ if (! is.null(rfx_basis )) {
2067+ rfx_predictions <- object $ rfx_samples $ predict(
2068+ rfx_group_ids ,
2069+ rfx_basis
2070+ ) *
2071+ y_std
2072+ } else {
2073+ # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only"
2074+ if (! rfx_intercept ) {
2075+ stop(" rfx_basis must be provided for random effects models with random slopes" )
2076+ }
2077+
2078+ # Extract the raw RFX samples and scale by train set outcome standard deviation
2079+ rfx_param_list <- object $ rfx_samples $ extract_parameter_samples()
2080+ rfx_beta_draws <- rfx_param_list $ beta_samples * y_std
2081+
2082+ # Construct a matrix with the appropriate group random effects arranged for each observation
2083+ rfx_predictions_raw <- array (
2084+ NA ,
2085+ dim = c(
2086+ nrow(X ),
2087+ ncol(rfx_basis ),
2088+ object $ model_params $ num_samples
2089+ )
2090+ )
2091+ for (i in 1 : nrow(X )) {
2092+ rfx_predictions_raw [i , , ] <-
2093+ rfx_beta_draws [, rfx_group_ids [i ], ]
2094+ }
2095+
2096+ # Intercept-only model, so the random effect prediction is simply the
2097+ # value of the respective group's intercept coefficient for each observation
2098+ rfx_predictions = rfx_predictions_raw [, 1 , ]
2099+ }
20412100 }
20422101
20432102 # Combine into y hat predictions
@@ -2310,6 +2369,10 @@ saveBARTModelToJson <- function(object) {
23102369 " probit_outcome_model" ,
23112370 object $ model_params $ probit_outcome_model
23122371 )
2372+ jsonobj $ add_string(
2373+ " rfx_model_spec" ,
2374+ object $ model_params $ rfx_model_spec
2375+ )
23132376 if (object $ model_params $ sample_sigma2_global ) {
23142377 jsonobj $ add_vector(
23152378 " sigma2_global_samples" ,
@@ -2554,6 +2617,9 @@ createBARTModelFromJson <- function(json_object) {
25542617 model_params [[" probit_outcome_model" ]] <- json_object $ get_boolean(
25552618 " probit_outcome_model"
25562619 )
2620+ model_params [[" rfx_model_spec" ]] <- json_object $ get_string(
2621+ " rfx_model_spec"
2622+ )
25572623
25582624 output [[" model_params" ]] <- model_params
25592625
@@ -2825,6 +2891,9 @@ createBARTModelFromCombinedJson <- function(json_object_list) {
28252891 model_params [[" probit_outcome_model" ]] <- json_object_default $ get_boolean(
28262892 " probit_outcome_model"
28272893 )
2894+ model_params [[" rfx_model_spec" ]] <- json_object_default $ get_string(
2895+ " rfx_model_spec"
2896+ )
28282897 model_params [[" num_chains" ]] <- json_object_default $ get_scalar(" num_chains" )
28292898 model_params [[" keep_every" ]] <- json_object_default $ get_scalar(" keep_every" )
28302899
@@ -3066,6 +3135,9 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) {
30663135 model_params [[" probit_outcome_model" ]] <- json_object_default $ get_boolean(
30673136 " probit_outcome_model"
30683137 )
3138+ model_params [[" rfx_model_spec" ]] <- json_object_default $ get_string(
3139+ " rfx_model_spec"
3140+ )
30693141
30703142 # Combine values that are sample-specific
30713143 for (i in 1 : length(json_object_list )) {
0 commit comments