From 10b0f2a810e0d1ffc54b60ba0f25e1b4782976a7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 15 Dec 2025 15:27:47 -0600 Subject: [PATCH 1/3] Fixed bug in R BART when predicting y_hat only for models with non-default RFX IDs --- R/bart.R | 18 +- R/bcf.R | 2 - test/R/testthat/test-random-effects.R | 711 ++++++++++++++++++++++++++ 3 files changed, 719 insertions(+), 12 deletions(-) create mode 100644 test/R/testthat/test-random-effects.R diff --git a/R/bart.R b/R/bart.R index 23fee012..8fcde545 100644 --- a/R/bart.R +++ b/R/bart.R @@ -2134,17 +2134,15 @@ predict.bartmodel <- function( X <- preprocessPredictionData(X, train_set_metadata) # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - if (predict_rfx) { - if (!is.null(rfx_group_ids)) { - rfx_unique_group_ids <- object$rfx_unique_group_ids - group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) - if (sum(is.na(group_ids_factor)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train" - ) - } - rfx_group_ids <- as.integer(group_ids_factor) + if (!is.null(rfx_group_ids)) { + rfx_unique_group_ids <- object$rfx_unique_group_ids + group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) + if (sum(is.na(group_ids_factor)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train" + ) } + rfx_group_ids <- as.integer(group_ids_factor) } # Handle RFX model specification diff --git a/R/bcf.R b/R/bcf.R index c634295b..d8fc8e66 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -3040,7 +3040,6 @@ predict.bcfmodel <- function( X <- preprocessPredictionData(X, train_set_metadata) # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE if (!is.null(rfx_group_ids)) { rfx_unique_group_ids <- object$rfx_unique_group_ids group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) @@ -3050,7 +3049,6 @@ predict.bcfmodel <- function( ) } rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE } # Handle RFX model specification diff --git a/test/R/testthat/test-random-effects.R b/test/R/testthat/test-random-effects.R new file mode 100644 index 00000000..f0c1830c --- /dev/null +++ b/test/R/testthat/test-random-effects.R @@ -0,0 +1,711 @@ +test_that("Random Effects BART with Default Numbering", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx parameters + general_param_list <- list() + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list + ) + ) + + # Specify all rfx parameters as scalars + rfx_param_list <- list( + working_parameter_prior_mean = 1., + group_parameter_prior_mean = 1., + working_parameter_prior_cov = 1., + group_parameter_prior_cov = 1., + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list + ) + ) + + # Specify all relevant rfx parameters as vectors + rfx_param_list <- list( + working_parameter_prior_mean = c(1., 1.), + group_parameter_prior_mean = c(1., 1.), + working_parameter_prior_cov = diag(1., 2), + group_parameter_prior_cov = diag(1., 2), + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list + ) + ) + + # Specify simpler intercept-only RFX model + rfx_param_list <- list( + model_spec = "intercept_only" + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list + ) + preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + }) +}) + +test_that("Random Effects BART with Default Numbering", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx basis directly + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list + ) + preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "rfx" + ) + }) + + # Intercept-only RFX model + rfx_param_list <- list( + model_spec = "intercept_only" + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list + ) + preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + }) +}) + +test_that("Random Effects BART with Offset Numbering", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + 2 + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids - 2, ] * rfx_basis) + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx basis directly + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list + ) + rfx_preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "y_hat" + ) + }) + + # Intercept-only RFX model + rfx_param_list <- list( + model_spec = "intercept_only" + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list + ) + rfx_preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bart_model, + X = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "y_hat" + ) + }) +}) + +test_that("Random Effects BCF with Default Numbering", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rfx_term + rnorm(n, 0, noise_sd) + + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a BCF model with random effects by passing the basis directly + expect_no_error({ + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + rfx_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "y_hat" + ) + }) + + # Fit a BCF model with random effects by specifying an "intercept only" model + expect_no_error({ + rfx_param_list <- list( + model_spec = "intercept_only" + ) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + random_effects_params = rfx_param_list + ) + rfx_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "y_hat" + ) + }) + + # Fit a BCF model with random effects by specifying an "intercept plus treatment" model + expect_no_error({ + rfx_param_list <- list( + model_spec = "intercept_plus_treatment" + ) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + random_effects_params = rfx_param_list + ) + rfx_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "y_hat" + ) + }) +}) + +test_that("Random Effects BCF with Offset Numbering", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + 2 + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids - 2, ] * rfx_basis) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rfx_term + rnorm(n, 0, noise_sd) + + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a BCF model with random effects by passing the basis directly + expect_no_error({ + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + rfx_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "y_hat" + ) + }) + + # Fit a BCF model with random effects by specifying an "intercept only" model + expect_no_error({ + rfx_param_list <- list( + model_spec = "intercept_only" + ) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + random_effects_params = rfx_param_list + ) + rfx_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "y_hat" + ) + }) + + # Fit a BCF model with random effects by specifying an "intercept plus treatment" model + expect_no_error({ + rfx_param_list <- list( + model_spec = "intercept_plus_treatment" + ) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + random_effects_params = rfx_param_list + ) + rfx_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + yhat_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "y_hat" + ) + }) +}) From f555813e2ae4f8f2af1a8cea45e99cc6fbfe2cfc Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 15 Dec 2025 15:52:33 -0600 Subject: [PATCH 2/3] Fixed prediction bug for Python random effects models with "non-default" random effects IDs --- src/py_stochtree.cpp | 16 +- stochtree/bart.py | 8 +- stochtree/bcf.py | 10 +- stochtree/random_effects.py | 32 ++ test/python/test_random_effects.py | 875 +++++++++++++++++++++++++++++ 5 files changed, 937 insertions(+), 4 deletions(-) diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 05bcc876..a8a31bc3 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1573,6 +1573,18 @@ class RandomEffectsLabelMapperCpp { StochTree::LabelMapper* GetLabelMapper() { return rfx_label_mapper_.get(); } + int MapGroupIdToArrayIndex(int original_label) { + return rfx_label_mapper_->CategoryNumber(original_label); + } + py::array_t MapMultipleGroupIdsToArrayIndices(py::array_t original_labels) { + int output_size = original_labels.size(); + auto result = py::array_t(py::detail::any_container({output_size})); + auto accessor = result.mutable_unchecked<1>(); + for (int i = 0; i < output_size; i++) { + accessor(i) = rfx_label_mapper_->CategoryNumber(original_labels.at(i)); + } + return result; + } private: std::unique_ptr rfx_label_mapper_; @@ -2410,7 +2422,9 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("DumpJsonString", &RandomEffectsLabelMapperCpp::DumpJsonString) .def("LoadFromJsonString", &RandomEffectsLabelMapperCpp::LoadFromJsonString) .def("LoadFromJson", &RandomEffectsLabelMapperCpp::LoadFromJson) - .def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper); + .def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper) + .def("MapGroupIdToArrayIndex", &RandomEffectsLabelMapperCpp::MapGroupIdToArrayIndex) + .def("MapMultipleGroupIdsToArrayIndices", &RandomEffectsLabelMapperCpp::MapMultipleGroupIdsToArrayIndices); py::class_(m, "RandomEffectsModelCpp") .def(py::init()) diff --git a/stochtree/bart.py b/stochtree/bart.py index 832d9d00..4697d41d 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1988,6 +1988,12 @@ def predict( "Random effects basis has a different dimension than the basis used to train this model" ) + # Convert rfx_group_ids to their corresponding array position indices in the random effects parameter sample arrays + if rfx_group_ids is not None: + rfx_group_id_indices = self.rfx_container.map_group_ids_to_array_indices( + rfx_group_ids + ) + # Random effects predictions if predict_rfx or predict_rfx_intermediate: if rfx_basis is not None: @@ -2017,7 +2023,7 @@ def predict( ) for i in range(n_train): rfx_predictions_raw[i, 0, :] = rfx_beta_draws[ - rfx_group_ids[i], : + rfx_group_id_indices[i], : ] rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :]) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index dfe0610e..28752c73 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -3053,6 +3053,12 @@ def predict( raise ValueError( "rfx_basis must have the same number of columns as the random effects basis used to sample this model" ) + + # Convert rfx_group_ids to their corresponding array position indices in the random effects parameter sample arrays + if rfx_group_ids is not None: + rfx_group_id_indices = self.rfx_container.map_group_ids_to_array_indices( + rfx_group_ids + ) # Random effects predictions if predict_rfx or predict_rfx_intermediate: @@ -3073,14 +3079,14 @@ def predict( ) for i in range(X.shape[0]): rfx_predictions_raw[i, :, :] = rfx_beta_draws[ - :, rfx_group_ids[i], : + :, rfx_group_id_indices[i], : ] elif rfx_beta_draws.ndim == 2: rfx_predictions_raw = np.empty( shape=(X.shape[0], 1, rfx_beta_draws.shape[1]) ) for i in range(X.shape[0]): - rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_ids[i], :] + rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_id_indices[i], :] else: raise ValueError( "Unexpected number of dimensions in extracted random effects samples" diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 404da725..ea028b3e 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -418,6 +418,38 @@ def extract_parameter_samples(self) -> dict[str, np.ndarray]: "sigma_samples": sigma_samples, } return output + + def map_group_id_to_array_index(self, group_id: int) -> int: + """ + Map an integer-valued random effects group ID to its group's corresponding position in the arrays that store random effects parameter samples. + + Parameters + ---------- + group_id : int + Group identifier to be converted to an array position. + + Returns + ------- + int + The position of `group_id` in the parameter sample arrays underlying the random effects container. + """ + return self.rfx_label_mapper_cpp.MapGroupIdToArrayIndex(group_id) + + def map_group_ids_to_array_indices(self, group_ids: np.ndarray) -> np.ndarray: + """ + Map an array of integer-valued random effects group IDs to their groups' corresponding positions in the arrays that store random effects parameter samples. + + Parameters + ---------- + group_ids : np.ndarray + Array of group identifiers (integer-valued) to be converted to an array position. + + Returns + ------- + np.ndarray + Numpy array of the position of `group_id` in the parameter sample arrays underlying the random effects container. + """ + return self.rfx_label_mapper_cpp.MapMultipleGroupIdsToArrayIndices(group_ids) class RandomEffectsModel: diff --git a/test/python/test_random_effects.py b/test/python/test_random_effects.py index f6616240..ce0b7c09 100644 --- a/test/python/test_random_effects.py +++ b/test/python/test_random_effects.py @@ -1,6 +1,8 @@ import numpy as np from stochtree import ( + BARTModel, + BCFModel, RandomEffectsContainer, RandomEffectsDataset, RandomEffectsModel, @@ -136,3 +138,876 @@ def outcome_mean(group_labels, basis): # Inspect the samples rfx_preds = rfx_container.predict(group_labels, basis) * y_std + y_bar assert rfx_preds.shape == (num_observations, num_mcmc) + + def test_bart_rfx_default_numbering(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the covariate-dependent function + def covariate_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 0, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 1, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 2, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + mean_term = covariate_fn(X) + rfx_term = rfx_fn(group_labels, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X, + y_train=y, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis + ) + + # Check all of the prediction / summary computation methods + + # Predict + mean_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="mean", + terms="all", + scale="linear", + ) + posterior_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="all", + scale="linear", + ) + + # Compute intervals + posterior_interval = bart_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + ) + + # Sample posterior predictive + posterior_predictive_draws = bart_model.sample_posterior_predictive( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + num_draws_per_sample=5, + ) + + def test_bart_rfx_offset_numbering(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(range(2, 2 + num_rfx_groups), size=n) + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the covariate-dependent function + def covariate_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 2, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 3, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 4, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + mean_term = covariate_fn(X) + rfx_term = rfx_fn(group_labels, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X, + y_train=y, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis + ) + + # Check all of the prediction / summary computation methods + + # Predict + mean_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="mean", + terms="all", + scale="linear", + ) + posterior_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="all", + scale="linear", + ) + + # Compute intervals + posterior_interval = bart_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + ) + + # Sample posterior predictive + posterior_predictive_draws = bart_model.sample_posterior_predictive( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + num_draws_per_sample=5, + ) + + def test_bart_rfx_default_numbering_model_spec(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the covariate-dependent function + def covariate_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 0, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 1, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 2, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + mean_term = covariate_fn(X) + rfx_term = rfx_fn(group_labels, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X, + y_train=y, + rfx_group_ids_train=group_labels, + random_effects_params={"model_spec": "intercept_only"} + ) + + # Check all of the prediction / summary computation methods + + # Predict + mean_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="mean", + terms="all", + scale="linear", + ) + posterior_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="all", + scale="linear", + ) + + # Compute intervals + posterior_interval = bart_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + ) + + # Sample posterior predictive + posterior_predictive_draws = bart_model.sample_posterior_predictive( + X=X, + rfx_group_ids=group_labels, + rfx_basis=basis, + num_draws_per_sample=5, + ) + + def test_bart_rfx_offset_numbering_model_spec(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(range(2, 2 + num_rfx_groups), size=n) + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the covariate-dependent function + def covariate_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 2, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 3, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 4, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + mean_term = covariate_fn(X) + rfx_term = rfx_fn(group_labels, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X, + y_train=y, + rfx_group_ids_train=group_labels, + random_effects_params={"model_spec": "intercept_only"} + ) + + # Check all of the prediction / summary computation methods + + # Predict + mean_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + type="mean", + terms="all", + scale="linear", + ) + posterior_preds = bart_model.predict( + X=X, + rfx_group_ids=group_labels, + type="posterior", + terms="all", + scale="linear", + ) + + # Compute intervals + posterior_interval = bart_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + rfx_group_ids=group_labels, + ) + + # Sample posterior predictive + posterior_predictive_draws = bart_model.sample_posterior_predictive( + X=X, + rfx_group_ids=group_labels, + num_draws_per_sample=5, + ) + + def test_bcf_rfx_default_numbering(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the prognostic function + def prog_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the treatment effect function + def cate_fn(X): + return 1.0 * X[:, 1] + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 0, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 1, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 2, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + propensity = rng.uniform(0.4, 0.6, n) + Z = rng.binomial(1, propensity, n) + mean_term = prog_fn(X) + Z * cate_fn(X) + rfx_term = rfx_fn(group_labels, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BCF model with intercept plus treatment random effects model specification + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + Z_train=Z, + propensity_train=propensity, + y_train=y, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis, + ) + + # Check all of the prediction / summary computation methods + + # Predict + rfx_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="rfx", + scale="linear", + ) + yhat_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + num_draws_per_sample=5, + ) + + def test_bcf_rfx_default_numbering_model_spec(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the prognostic function + def prog_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the treatment effect function + def cate_fn(X): + return 1.0 * X[:, 1] + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 0, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 1, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 2, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + propensity = rng.uniform(0.4, 0.6, n) + Z = rng.binomial(1, propensity, n) + mean_term = prog_fn(X) + Z * cate_fn(X) + rfx_term = rfx_fn(group_labels, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BCF model with intercept plus treatment random effects model specification + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + Z_train=Z, + propensity_train=propensity, + y_train=y, + rfx_group_ids_train=group_labels, + random_effects_params={"model_spec": "intercept_plus_treatment"} + ) + + # Check all of the prediction / summary computation methods + + # Predict + rfx_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="rfx", + scale="linear", + ) + yhat_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + num_draws_per_sample=5, + ) + + # Fit a BCF model with intercept-only random effects model specification + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + Z_train=Z, + propensity_train=propensity, + y_train=y, + rfx_group_ids_train=group_labels, + random_effects_params={"model_spec": "intercept_only"} + ) + + # Check all of the prediction / summary computation methods + + # Predict + rfx_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="rfx", + scale="linear", + ) + yhat_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + num_draws_per_sample=5, + ) + + def test_bcf_rfx_offset_numbering(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + 2 + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the prognostic function + def prog_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the treatment effect function + def cate_fn(X): + return 1.0 * X[:, 1] + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 0, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 1, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 2, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + propensity = rng.uniform(0.4, 0.6, n) + Z = rng.binomial(1, propensity, n) + mean_term = prog_fn(X) + Z * cate_fn(X) + rfx_term = rfx_fn(group_labels - 2, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BCF model with intercept plus treatment random effects model specification + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + Z_train=Z, + propensity_train=propensity, + y_train=y, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis, + ) + + # Check all of the prediction / summary computation methods + + # Predict + rfx_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="rfx", + scale="linear", + ) + yhat_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + rfx_basis=basis, + num_draws_per_sample=5, + ) + + def test_bcf_rfx_offset_numbering_model_spec(self): + # RNG + rng = np.random.default_rng() + + # Generate covariates + n = 100 + p = 10 + X = rng.uniform(0, 1, (n, p)) + + # Generate group labels and random effects basis + num_rfx_basis = 2 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + 2 + basis = np.empty((n, num_rfx_basis)) + basis[:, 0] = 1.0 + if num_rfx_basis > 1: + basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the prognostic function + def prog_fn(X): + return 5 * np.sin(2 * np.pi * X[:, 0]) + + # Define the treatment effect function + def cate_fn(X): + return 1.0 * X[:, 1] + + # Define the group rfx function + def rfx_fn(group_labels, basis): + return np.where( + group_labels == 0, + 0 - 1 * basis[:, 1], + np.where( + group_labels == 1, + 4 + 1 * basis[:, 1], + np.where( + group_labels == 2, 8 + 3 * basis[:, 1], 12 + 5 * basis[:, 1] + ), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + propensity = rng.uniform(0.4, 0.6, n) + Z = rng.binomial(1, propensity, n) + mean_term = prog_fn(X) + Z * cate_fn(X) + rfx_term = rfx_fn(group_labels - 2, basis) + y = mean_term + rfx_term + epsilon + + # Fit a BCF model with intercept plus treatment random effects model specification + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + Z_train=Z, + propensity_train=propensity, + y_train=y, + rfx_group_ids_train=group_labels, + random_effects_params={"model_spec": "intercept_plus_treatment"} + ) + + # Check all of the prediction / summary computation methods + + # Predict + rfx_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="rfx", + scale="linear", + ) + yhat_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + num_draws_per_sample=5, + ) + + # Fit a BCF model with intercept-only random effects model specification + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + Z_train=Z, + propensity_train=propensity, + y_train=y, + rfx_group_ids_train=group_labels, + random_effects_params={"model_spec": "intercept_only"} + ) + + # Check all of the prediction / summary computation methods + + # Predict + rfx_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="rfx", + scale="linear", + ) + yhat_preds = bcf_model.predict( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="linear", + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X, + Z=Z, + propensity=propensity, + rfx_group_ids=group_labels, + num_draws_per_sample=5, + ) + + + From 174fc276a16b4c324018f42846b6f5b4edf2d279 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 15 Dec 2025 15:54:43 -0600 Subject: [PATCH 3/3] Updated NEWS and CHANGELOG --- CHANGELOG.md | 7 +++++++ NEWS.md | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index af065869..08bed573 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,17 @@ # Changelog +# stochtree 0.2.1.9000 + +## Bug Fixes + +* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256)) + # stochtree 0.2.1 ## Bug Fixes * Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248)) +* Fix prediction bug for Python BART and BCF models with random effects with labels that aren't straightforward `0:(num_groups-1)` integers ([#256](https://github.com/StochasticTree/stochtree/pull/256)) ## Other Changes diff --git a/NEWS.md b/NEWS.md index 644e6f72..a9e06e93 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,15 @@ +# stochtree 0.2.1.9000 + +## Bug Fixes + +* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256)) + # stochtree 0.2.1 ## Bug Fixes * Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248)) +* Fix prediction bug for Python BART and BCF models with random effects with labels that aren't straightforward `0:(num_groups-1)` integers ([#256](https://github.com/StochasticTree/stochtree/pull/256)) ## Other Changes