@@ -426,4 +426,71 @@ test_that("Multivariate Treatment MCMC BCF", {
426426 propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
427427 num_mcmc = 10 , general_params = general_param_list )
428428 )
429- })
429+ })
430+
431+ test_that(" BCF Predictions" , {
432+ skip_on_cran()
433+
434+ # Generate simulated data
435+ n <- 100
436+ p <- 5
437+ X <- matrix (runif(n * p ), ncol = p )
438+ mu_X <- (
439+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (- 7.5 ) +
440+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (- 2.5 ) +
441+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (2.5 ) +
442+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (7.5 )
443+ )
444+ pi_X <- (
445+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (0.2 ) +
446+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (0.4 ) +
447+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (0.6 ) +
448+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (0.8 )
449+ )
450+ tau_X <- (
451+ ((0 < = X [,2 ]) & (0.25 > X [,2 ])) * (0.5 ) +
452+ ((0.25 < = X [,2 ]) & (0.5 > X [,2 ])) * (1.0 ) +
453+ ((0.5 < = X [,2 ]) & (0.75 > X [,2 ])) * (1.5 ) +
454+ ((0.75 < = X [,2 ]) & (1 > X [,2 ])) * (2.0 )
455+ )
456+ Z <- rbinom(n , 1 , pi_X )
457+ noise_sd <- 1
458+ y <- mu_X + tau_X * Z + rnorm(n , 0 , noise_sd )
459+ test_set_pct <- 0.2
460+ n_test <- round(test_set_pct * n )
461+ n_train <- n - n_test
462+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
463+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
464+ X_test <- X [test_inds ,]
465+ X_train <- X [train_inds ,]
466+ Z_test <- Z [test_inds ]
467+ Z_train <- Z [train_inds ]
468+ pi_test <- pi_X [test_inds ]
469+ pi_train <- pi_X [train_inds ]
470+ mu_test <- mu_X [test_inds ]
471+ mu_train <- mu_X [train_inds ]
472+ tau_test <- tau_X [test_inds ]
473+ tau_train <- tau_X [train_inds ]
474+ y_test <- y [test_inds ]
475+ y_train <- y [train_inds ]
476+
477+ # Run a BCF model with only GFR
478+ general_params <- list (num_chains = 1 , keep_every = 1 )
479+ variance_forest_params <- list (num_trees = 50 )
480+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
481+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
482+ propensity_test = pi_test , num_gfr = 10 , num_burnin = 0 ,
483+ num_mcmc = 10 , general_params = general_params ,
484+ variance_forest_params = variance_forest_params )
485+
486+ # Check that cached predictions agree with results of predict() function
487+ train_preds <- predict(bcf_model , X = X_train , Z = Z_train , propensity = pi_train )
488+ train_preds_mean_cached <- bcf_model $ y_hat_train
489+ train_preds_mean_recomputed <- train_preds $ y_hat
490+ train_preds_variance_cached <- bcf_model $ sigma2_x_hat_train
491+ train_preds_variance_recomputed <- train_preds $ variance_forest_predictions
492+
493+ # Assertion
494+ expect_equal(train_preds_mean_cached , train_preds_mean_recomputed )
495+ expect_equal(train_preds_variance_cached , train_preds_variance_recomputed )
496+ })
0 commit comments