Skip to content

Commit d6e4480

Browse files
committed
Added RFX spec argument to R bcf
1 parent 97cb983 commit d6e4480

File tree

1 file changed

+74
-26
lines changed

1 file changed

+74
-26
lines changed

R/bcf.R

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
#'
100100
#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional.
101101
#'
102+
#' - `model_spec` Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
102103
#' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
103104
#' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
104105
#' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
@@ -269,6 +270,7 @@ bcf <- function(
269270

270271
# Update random effects parameters
271272
rfx_params_default <- list(
273+
model_spec = "custom",
272274
working_parameter_prior_mean = NULL,
273275
group_parameter_prior_mean = NULL,
274276
working_parameter_prior_cov = NULL,
@@ -348,6 +350,7 @@ bcf <- function(
348350
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample
349351

350352
# 5. Random effects parameters
353+
rfx_model_spec <- rfx_params_updated$model_spec
351354
rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean
352355
rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean
353356
rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov
@@ -769,20 +772,6 @@ bcf <- function(
769772
}
770773
}
771774

772-
# Random effects covariance prior
773-
if (has_rfx) {
774-
if (is.null(rfx_prior_var)) {
775-
rfx_prior_var <- rep(1, ncol(rfx_basis_train))
776-
} else {
777-
if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) {
778-
stop("rfx_prior_var must be a numeric vector")
779-
}
780-
if (length(rfx_prior_var) != ncol(rfx_basis_train)) {
781-
stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)")
782-
}
783-
}
784-
}
785-
786775
# Update variable weights
787776
variable_weights_adj <- 1 /
788777
sapply(original_var_indices, function(x) sum(original_var_indices == x))
@@ -799,40 +788,74 @@ bcf <- function(
799788
] <- 0
800789
}
801790

802-
# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
791+
# Handle the rfx basis matrices
803792
has_basis_rfx <- FALSE
804793
num_basis_rfx <- 0
805794
if (has_rfx) {
806-
if (is.null(rfx_basis_train)) {
795+
if (rfx_model_spec == "custom") {
796+
if (is.null(rfx_basis_train)) {
797+
stop(
798+
"A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'"
799+
)
800+
}
801+
has_basis_rfx <- TRUE
802+
num_basis_rfx <- ncol(rfx_basis_train)
803+
} else if (rfx_model_spec == "intercept_only") {
807804
rfx_basis_train <- matrix(
808805
rep(1, nrow(X_train)),
809806
nrow = nrow(X_train),
810807
ncol = 1
811808
)
812-
} else {
813809
has_basis_rfx <- TRUE
814-
num_basis_rfx <- ncol(rfx_basis_train)
810+
num_basis_rfx <- 1
811+
} else if (rfx_model_spec == "intercept_plus_treatment") {
812+
rfx_basis_train <- cbind(
813+
rep(1, nrow(X_train)),
814+
Z_train
815+
)
816+
has_basis_rfx <- TRUE
817+
num_basis_rfx <- 1 + ncol(Z_train)
815818
}
816819
num_rfx_groups <- length(unique(rfx_group_ids_train))
817820
num_rfx_components <- ncol(rfx_basis_train)
818821
if (num_rfx_groups == 1) {
819822
warning(
820-
"Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill"
823+
"Only one group was provided for random effect sampling, so the random effects model is likely overkill"
821824
)
822825
}
823826
}
824827
if (has_rfx_test) {
825-
if (is.null(rfx_basis_test)) {
826-
if (!is.null(rfx_basis_train)) {
828+
if (rfx_model_spec == "custom") {
829+
if (is.null(rfx_basis_test)) {
827830
stop(
828-
"Random effects basis provided for training set, must also be provided for the test set"
831+
"A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'"
829832
)
830833
}
834+
} else if (rfx_model_spec == "intercept_only") {
831835
rfx_basis_test <- matrix(
832836
rep(1, nrow(X_test)),
833837
nrow = nrow(X_test),
834838
ncol = 1
835839
)
840+
} else if (rfx_model_spec == "intercept_plus_treatment") {
841+
rfx_basis_test <- cbind(
842+
rep(1, nrow(X_test)),
843+
Z_test
844+
)
845+
}
846+
}
847+
848+
# Random effects covariance prior
849+
if (has_rfx) {
850+
if (is.null(rfx_prior_var)) {
851+
rfx_prior_var <- rep(1, ncol(rfx_basis_train))
852+
} else {
853+
if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) {
854+
stop("rfx_prior_var must be a numeric vector")
855+
}
856+
if (length(rfx_prior_var) != ncol(rfx_basis_train)) {
857+
stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)")
858+
}
836859
}
837860
}
838861

@@ -2536,7 +2559,8 @@ bcf <- function(
25362559
"sample_sigma2_global" = sample_sigma2_global,
25372560
"sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu,
25382561
"sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau,
2539-
"probit_outcome_model" = probit_outcome_model
2562+
"probit_outcome_model" = probit_outcome_model,
2563+
"rfx_model_spec" = rfx_model_spec
25402564
)
25412565
result <- list(
25422566
"forests_mu" = forest_samples_mu,
@@ -2806,9 +2830,24 @@ predict.bcfmodel <- function(
28062830
has_rfx <- TRUE
28072831
}
28082832

2809-
# Produce basis for the "intercept-only" random effects case
2810-
if ((object$model_params$has_rfx) && (is.null(rfx_basis))) {
2811-
rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1)
2833+
# Handle RFX model specification
2834+
if (object$model_params$rfx_model_spec == "custom") {
2835+
if (is.null(rfx_basis)) {
2836+
stop(
2837+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2838+
)
2839+
}
2840+
} else if (object$model_params$rfx_model_spec == "intercept_only") {
2841+
rfx_basis <- matrix(
2842+
rep(1, nrow(X)),
2843+
nrow = nrow(X),
2844+
ncol = 1
2845+
)
2846+
} else if (object$model_params$rfx_model_spec == "intercept_plus_treatment") {
2847+
rfx_basis <- cbind(
2848+
rep(1, nrow(X)),
2849+
Z
2850+
)
28122851
}
28132852

28142853
# Add propensities to covariate set if necessary
@@ -3650,6 +3689,9 @@ createBCFModelFromJson <- function(json_object) {
36503689
model_params[["probit_outcome_model"]] <- json_object$get_boolean(
36513690
"probit_outcome_model"
36523691
)
3692+
model_params[["rfx_model_spec"]] <- json_object$get_string(
3693+
"rfx_model_spec"
3694+
)
36533695
output[["model_params"]] <- model_params
36543696

36553697
# Unpack sampled parameters
@@ -4069,6 +4111,9 @@ createBCFModelFromCombinedJson <- function(json_object_list) {
40694111
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean(
40704112
"probit_outcome_model"
40714113
)
4114+
model_params[["rfx_model_spec"]] <- json_object_default$get_string(
4115+
"rfx_model_spec"
4116+
)
40724117

40734118
# Combine values that are sample-specific
40744119
for (i in 1:length(json_object_list)) {
@@ -4423,6 +4468,9 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) {
44234468
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean(
44244469
"probit_outcome_model"
44254470
)
4471+
model_params[["rfx_model_spec"]] <- json_object_default$get_string(
4472+
"rfx_model_spec"
4473+
)
44264474

44274475
# Combine values that are sample-specific
44284476
for (i in 1:length(json_object_list)) {

0 commit comments

Comments
 (0)