@@ -3,7 +3,7 @@ library(stochtree)
33library(microbenchmark )
44
55# Generate the data
6- n <- 5000
6+ n <- 1000
77p <- 100
88snr <- 2
99X <- matrix (runif(n * p ), ncol = p )
@@ -46,19 +46,29 @@ microbenchmark::microbenchmark(
4646)
4747
4848Rprof()
49- stochtree :: bart(
50- X_train = X , y_train = y ,
49+ model_subsampling <- stochtree :: bart(
50+ X_train = X_train , y_train = y_train ,
5151 num_gfr = num_gfr , num_burnin = num_burnin , num_mcmc = num_mcmc ,
5252 general_params = general_params , mean_forest_params = mean_params_a
5353)
5454Rprof(NULL )
5555summaryRprof()
5656
5757Rprof()
58- stochtree :: bart(
59- X_train = X , y_train = y ,
58+ model_no_subsampling <- stochtree :: bart(
59+ X_train = X_train , y_train = y_train ,
6060 num_gfr = num_gfr , num_burnin = num_burnin , num_mcmc = num_mcmc ,
6161 general_params = general_params , mean_forest_params = mean_params_b
6262)
6363Rprof(NULL )
6464summaryRprof()
65+
66+ # Compare out of sample RMSE of the two models
67+ y_hat_test_subsampling <- rowMeans(predict(model_subsampling , X = X_test )$ y_hat )
68+ rmse_subsampling <- (
69+ sqrt(mean((y_hat_test_subsampling - y_test )^ 2 ))
70+ )
71+ y_hat_test_no_subsampling <- rowMeans(predict(model_no_subsampling , X = X_test )$ y_hat )
72+ rmse_no_subsampling <- (
73+ sqrt(mean((y_hat_test_no_subsampling - y_test )^ 2 ))
74+ )
0 commit comments