Skip to content

Commit 3a576e1

Browse files
committed
Updated feature subsampling benchmark script
1 parent cf03f4d commit 3a576e1

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

tools/perf/gfr_feature_subsample_microbenchmark.R

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ library(stochtree)
33
library(microbenchmark)
44

55
# Generate the data
6-
n <- 5000
6+
n <- 1000
77
p <- 100
88
snr <- 2
99
X <- matrix(runif(n*p), ncol = p)
@@ -46,19 +46,29 @@ microbenchmark::microbenchmark(
4646
)
4747

4848
Rprof()
49-
stochtree::bart(
50-
X_train = X, y_train = y,
49+
model_subsampling <- stochtree::bart(
50+
X_train = X_train, y_train = y_train,
5151
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
5252
general_params = general_params, mean_forest_params = mean_params_a
5353
)
5454
Rprof(NULL)
5555
summaryRprof()
5656

5757
Rprof()
58-
stochtree::bart(
59-
X_train = X, y_train = y,
58+
model_no_subsampling <- stochtree::bart(
59+
X_train = X_train, y_train = y_train,
6060
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
6161
general_params = general_params, mean_forest_params = mean_params_b
6262
)
6363
Rprof(NULL)
6464
summaryRprof()
65+
66+
# Compare out of sample RMSE of the two models
67+
y_hat_test_subsampling <- rowMeans(predict(model_subsampling, X = X_test)$y_hat)
68+
rmse_subsampling <- (
69+
sqrt(mean((y_hat_test_subsampling - y_test)^2))
70+
)
71+
y_hat_test_no_subsampling <- rowMeans(predict(model_no_subsampling, X = X_test)$y_hat)
72+
rmse_no_subsampling <- (
73+
sqrt(mean((y_hat_test_no_subsampling - y_test)^2))
74+
)

0 commit comments

Comments
 (0)