@@ -336,3 +336,98 @@ test_that("BART Predictions", {
336336 expect_equal(train_preds_mean_cached , train_preds_mean_recomputed )
337337 expect_equal(train_preds_variance_cached , train_preds_variance_recomputed )
338338})
339+
340+ test_that(" Random Effects BART" , {
341+ skip_on_cran()
342+
343+ # Generate simulated data
344+ n <- 100
345+ p <- 5
346+ p_w <- 2
347+ X <- matrix (runif(n * p ), ncol = p )
348+ W <- matrix (runif(n * p_w ), ncol = p_w )
349+ f_XW <- (
350+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (- 7.5 * W [,1 ]) +
351+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (- 2.5 * W [,1 ]) +
352+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (2.5 * W [,1 ]) +
353+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (7.5 * W [,1 ])
354+ )
355+ rfx_group_ids <- sample(1 : 2 , size = n , replace = TRUE )
356+ rfx_basis <- cbind(rep(1 , n ), runif(n ))
357+ num_rfx_components <- ncol(rfx_basis )
358+ num_rfx_groups <- length(unique(rfx_group_ids ))
359+ rfx_coefs <- matrix (c(- 5 , 5 , 1 , - 1 ), ncol = 2 , byrow = T )
360+ rfx_term <- rowSums(rfx_coefs [rfx_group_ids ,] * rfx_basis )
361+ noise_sd <- 1
362+ y <- f_XW + rfx_term + rnorm(n , 0 , noise_sd )
363+ test_set_pct <- 0.2
364+ n_test <- round(test_set_pct * n )
365+ n_train <- n - n_test
366+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
367+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
368+ X_test <- X [test_inds ,]
369+ X_train <- X [train_inds ,]
370+ W_test <- W [test_inds ,]
371+ W_train <- W [train_inds ,]
372+ rfx_group_ids_test <- rfx_group_ids [test_inds ]
373+ rfx_group_ids_train <- rfx_group_ids [train_inds ]
374+ rfx_basis_test <- rfx_basis [test_inds ,]
375+ rfx_basis_train <- rfx_basis [train_inds ,]
376+ y_test <- y [test_inds ]
377+ y_train <- y [train_inds ]
378+
379+ # Specify no rfx parameters
380+ general_param_list <- list ()
381+ mean_forest_param_list <- list (sample_sigma2_leaf = FALSE )
382+ expect_no_error(
383+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
384+ leaf_basis_train = W_train , leaf_basis_test = W_test ,
385+ rfx_group_ids_train = rfx_group_ids_train ,
386+ rfx_group_ids_test = rfx_group_ids_test ,
387+ rfx_basis_train = rfx_basis_train ,
388+ rfx_basis_test = rfx_basis_test ,
389+ num_gfr = 0 , num_burnin = 10 , num_mcmc = 10 ,
390+ general_params = general_param_list ,
391+ mean_forest_params = mean_forest_param_list )
392+ )
393+
394+ # Specify all rfx parameters as scalars
395+ general_param_list <- list (rfx_working_parameter_prior_mean = 1 . ,
396+ rfx_group_parameter_prior_mean = 1 . ,
397+ rfx_working_parameter_prior_cov = 1 . ,
398+ rfx_group_parameter_prior_cov = 1 . ,
399+ rfx_variance_prior_shape = 1 ,
400+ rfx_variance_prior_scale = 1 )
401+ mean_forest_param_list <- list (sample_sigma2_leaf = FALSE )
402+ expect_no_error(
403+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
404+ leaf_basis_train = W_train , leaf_basis_test = W_test ,
405+ rfx_group_ids_train = rfx_group_ids_train ,
406+ rfx_group_ids_test = rfx_group_ids_test ,
407+ rfx_basis_train = rfx_basis_train ,
408+ rfx_basis_test = rfx_basis_test ,
409+ num_gfr = 0 , num_burnin = 10 , num_mcmc = 10 ,
410+ general_params = general_param_list ,
411+ mean_forest_params = mean_forest_param_list )
412+ )
413+
414+ # Specify all relevant rfx parameters as vectors
415+ general_param_list <- list (rfx_working_parameter_prior_mean = c(1 . ,1 . ),
416+ rfx_group_parameter_prior_mean = c(1 . ,1 . ),
417+ rfx_working_parameter_prior_cov = diag(1 . ,2 ),
418+ rfx_group_parameter_prior_cov = diag(1 . ,2 ),
419+ rfx_variance_prior_shape = 1 ,
420+ rfx_variance_prior_scale = 1 )
421+ mean_forest_param_list <- list (sample_sigma2_leaf = FALSE )
422+ expect_no_error(
423+ bart_model <- bart(X_train = X_train , y_train = y_train , X_test = X_test ,
424+ leaf_basis_train = W_train , leaf_basis_test = W_test ,
425+ rfx_group_ids_train = rfx_group_ids_train ,
426+ rfx_group_ids_test = rfx_group_ids_test ,
427+ rfx_basis_train = rfx_basis_train ,
428+ rfx_basis_test = rfx_basis_test ,
429+ num_gfr = 0 , num_burnin = 10 , num_mcmc = 10 ,
430+ general_params = general_param_list ,
431+ mean_forest_params = mean_forest_param_list )
432+ )
433+ })
0 commit comments