Skip to content

Commit f9e6206

Browse files
committed
Updated R BART use of user-provided rfx parameters
1 parent 47faa3d commit f9e6206

File tree

2 files changed

+96
-5
lines changed

2 files changed

+96
-5
lines changed

R/bart.R

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
706706
if (is.null(rfx_group_parameter_prior_mean)) {
707707
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
708708
} else {
709-
xi_init <- expand_dims_1d(rfx_group_parameter_prior_mean, num_rfx_components)
710-
# If it's a vector, expand to matrix
711-
if (is.vector(xi_init)) {
712-
xi_init <- matrix(rep(xi_init, num_rfx_groups), num_rfx_components, num_rfx_groups)
713-
}
709+
xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
714710
}
715711

716712
if (is.null(rfx_working_parameter_prior_cov)) {

test/R/testthat/test-bart.R

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,98 @@ test_that("BART Predictions", {
336336
expect_equal(train_preds_mean_cached, train_preds_mean_recomputed)
337337
expect_equal(train_preds_variance_cached, train_preds_variance_recomputed)
338338
})
339+
340+
test_that("Random Effects BART", {
341+
skip_on_cran()
342+
343+
# Generate simulated data
344+
n <- 100
345+
p <- 5
346+
p_w <- 2
347+
X <- matrix(runif(n*p), ncol = p)
348+
W <- matrix(runif(n*p_w), ncol = p_w)
349+
f_XW <- (
350+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
351+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
352+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
353+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
354+
)
355+
rfx_group_ids <- sample(1:2, size = n, replace = TRUE)
356+
rfx_basis <- cbind(rep(1, n), runif(n))
357+
num_rfx_components <- ncol(rfx_basis)
358+
num_rfx_groups <- length(unique(rfx_group_ids))
359+
rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T)
360+
rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
361+
noise_sd <- 1
362+
y <- f_XW + rfx_term + rnorm(n, 0, noise_sd)
363+
test_set_pct <- 0.2
364+
n_test <- round(test_set_pct*n)
365+
n_train <- n - n_test
366+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
367+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
368+
X_test <- X[test_inds,]
369+
X_train <- X[train_inds,]
370+
W_test <- W[test_inds,]
371+
W_train <- W[train_inds,]
372+
rfx_group_ids_test <- rfx_group_ids[test_inds]
373+
rfx_group_ids_train <- rfx_group_ids[train_inds]
374+
rfx_basis_test <- rfx_basis[test_inds,]
375+
rfx_basis_train <- rfx_basis[train_inds,]
376+
y_test <- y[test_inds]
377+
y_train <- y[train_inds]
378+
379+
# Specify no rfx parameters
380+
general_param_list <- list()
381+
mean_forest_param_list <- list(sample_sigma2_leaf = FALSE)
382+
expect_no_error(
383+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
384+
leaf_basis_train = W_train, leaf_basis_test = W_test,
385+
rfx_group_ids_train = rfx_group_ids_train,
386+
rfx_group_ids_test = rfx_group_ids_test,
387+
rfx_basis_train = rfx_basis_train,
388+
rfx_basis_test = rfx_basis_test,
389+
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
390+
general_params = general_param_list,
391+
mean_forest_params = mean_forest_param_list)
392+
)
393+
394+
# Specify all rfx parameters as scalars
395+
general_param_list <- list(rfx_working_parameter_prior_mean = 1.,
396+
rfx_group_parameter_prior_mean = 1.,
397+
rfx_working_parameter_prior_cov = 1.,
398+
rfx_group_parameter_prior_cov = 1.,
399+
rfx_variance_prior_shape = 1,
400+
rfx_variance_prior_scale = 1)
401+
mean_forest_param_list <- list(sample_sigma2_leaf = FALSE)
402+
expect_no_error(
403+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
404+
leaf_basis_train = W_train, leaf_basis_test = W_test,
405+
rfx_group_ids_train = rfx_group_ids_train,
406+
rfx_group_ids_test = rfx_group_ids_test,
407+
rfx_basis_train = rfx_basis_train,
408+
rfx_basis_test = rfx_basis_test,
409+
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
410+
general_params = general_param_list,
411+
mean_forest_params = mean_forest_param_list)
412+
)
413+
414+
# Specify all relevant rfx parameters as vectors
415+
general_param_list <- list(rfx_working_parameter_prior_mean = c(1.,1.),
416+
rfx_group_parameter_prior_mean = c(1.,1.),
417+
rfx_working_parameter_prior_cov = diag(1.,2),
418+
rfx_group_parameter_prior_cov = diag(1.,2),
419+
rfx_variance_prior_shape = 1,
420+
rfx_variance_prior_scale = 1)
421+
mean_forest_param_list <- list(sample_sigma2_leaf = FALSE)
422+
expect_no_error(
423+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
424+
leaf_basis_train = W_train, leaf_basis_test = W_test,
425+
rfx_group_ids_train = rfx_group_ids_train,
426+
rfx_group_ids_test = rfx_group_ids_test,
427+
rfx_basis_train = rfx_basis_train,
428+
rfx_basis_test = rfx_basis_test,
429+
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
430+
general_params = general_param_list,
431+
mean_forest_params = mean_forest_param_list)
432+
)
433+
})

0 commit comments

Comments
 (0)