Skip to content

Commit 21f2897

Browse files
committed
Updated benchmarking scripts
1 parent 3a576e1 commit 21f2897

File tree

2 files changed

+116
-18
lines changed

2 files changed

+116
-18
lines changed

demo/debug/supervised_learning_feature_subsets.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
# Load necessary libraries
44
import numpy as np
5-
import pandas as pd
6-
import seaborn as sns
7-
import matplotlib.pyplot as plt
85
from stochtree import BARTModel
96
from sklearn.model_selection import train_test_split
107
import timeit
@@ -16,25 +13,21 @@
1613

1714
# Generate covariates and basis
1815
n = 1000
19-
p_X = 100
20-
X = rng.uniform(0, 1, (n, p_X))
16+
p = 100
17+
X = rng.uniform(0, 1, (n, p))
2118

2219
# Define the outcome mean function
2320
def outcome_mean(X):
24-
return np.where(
25-
(X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5,
26-
np.where(
27-
(X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5,
28-
np.where(
29-
(X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5,
30-
7.5
31-
)
32-
)
21+
return (
22+
np.sin(4*np.pi*X[:,0]) + np.cos(4*np.pi*X[:,1]) + np.sin(4*np.pi*X[:,2]) + np.cos(4*np.pi*X[:,3])
3323
)
3424

3525
# Generate outcome
36-
epsilon = rng.normal(0, 1, n)
37-
y = outcome_mean(X) + epsilon
26+
snr = 2
27+
f_X = outcome_mean(X)
28+
noise_sd = np.std(f_X) / snr
29+
epsilon = rng.normal(0, 1, n) * noise_sd
30+
y = f_X + epsilon
3831

3932
# Test-train split
4033
sample_inds = np.arange(n)
@@ -50,12 +43,28 @@ def outcome_mean(X):
5043
forest_config_a = {"num_trees": 100}
5144
bart_model_a.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_a)
5245
"""
53-
print(timeit.timeit(stmt=s, number=5, globals=globals()))
46+
timing_no_subsampling = timeit.timeit(stmt=s, number=5, globals=globals())
47+
print(f"Average runtime, without feature subsampling (p = {p:d}): {timing_no_subsampling:.2f}")
5448

5549
# Run XBART with each tree considering random subsets of 5 features
5650
s = """\
5751
bart_model_b = BARTModel()
5852
forest_config_b = {"num_trees": 100, "num_features_subsample": 5}
5953
bart_model_b.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_b)
6054
"""
61-
print(timeit.timeit(stmt=s, number=5, globals=globals()))
55+
timing_subsampling = timeit.timeit(stmt=s, number=5, globals=globals())
56+
print(f"Average runtime, subsampling 5 out of {p:d} features: {timing_subsampling:.2f}")
57+
58+
# Compare RMSEs of each model
59+
bart_model_a = BARTModel()
60+
forest_config_a = {"num_trees": 100}
61+
bart_model_a.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_a)
62+
bart_model_b = BARTModel()
63+
forest_config_b = {"num_trees": 100, "num_features_subsample": 5}
64+
bart_model_b.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_b)
65+
y_hat_test_a = np.squeeze(bart_model_a.y_hat_test).mean(axis = 1)
66+
rmse_no_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_a,2)))
67+
print(f"Test set RMSE, no subsampling (p = {p:d}): {rmse_no_subsampling:.2f}")
68+
y_hat_test_b = np.squeeze(bart_model_b.y_hat_test).mean(axis = 1)
69+
rmse_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_b,2)))
70+
print(f"Test set RMSE, subsampling 5 out of {p:d} features: {rmse_subsampling:.2f}")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Load libraries
2+
library(stochtree)
3+
library(microbenchmark)
4+
5+
# Generate the data
6+
n <- 1000
7+
p <- 100
8+
snr <- 2
9+
X <- matrix(rnorm(n*p), ncol = p)
10+
mu_x <- 1 + 2*X[,1] - 4*(X[,2] < 0) + 4*(X[,2] >= 0) + 3*(abs(X[,3]) - sqrt(2/pi))
11+
tau_x <- 1 + 2*X[,4]
12+
u <- runif(n)
13+
pi_x <- ((mu_x-1)/4) + 4*(u-0.5)
14+
Z <- pi_x + rnorm(n,0,1)
15+
E_XZ <- mu_x + Z*tau_x
16+
noise_sd <- sd(E_XZ) / snr
17+
y <- E_XZ + rnorm(n, 0, 1)*noise_sd
18+
19+
# Split data into test and train sets
20+
test_set_pct <- 0.2
21+
n_test <- round(test_set_pct*n)
22+
n_train <- n - n_test
23+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
24+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
25+
X_test <- X[test_inds,]
26+
X_train <- X[train_inds,]
27+
Z_test <- Z[test_inds]
28+
Z_train <- Z[train_inds]
29+
pi_x_test <- pi_x[test_inds]
30+
pi_x_train <- pi_x[train_inds]
31+
y_test <- y[test_inds]
32+
y_train <- y[train_inds]
33+
34+
# Sampler settings
35+
num_gfr <- 100
36+
num_burnin <- 0
37+
num_mcmc <- 0
38+
general_params <- list(sample_sigma2_global = T)
39+
prog_params_a <- list(num_trees = 100, num_features_subsample = 5)
40+
trt_params_a <- list(num_trees = 100, num_features_subsample = 5)
41+
prog_params_b <- list(num_trees = 50)
42+
trt_params_b <- list(num_trees = 50)
43+
44+
# Benchmark sampler with and without feature subsampling
45+
microbenchmark::microbenchmark(
46+
stochtree::bcf(
47+
X_train = X, Z_train = Z, propensity_train = pi_x, y_train = y,
48+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
49+
general_params = general_params, prognostic_forest_params = prog_params_a,
50+
treatment_effect_forest_params = trt_params_a
51+
),
52+
stochtree::bcf(
53+
X_train = X, Z_train = Z, propensity_train = pi_x, y_train = y,
54+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
55+
general_params = general_params, prognostic_forest_params = prog_params_b,
56+
treatment_effect_forest_params = trt_params_b
57+
),
58+
times = 5
59+
)
60+
61+
Rprof()
62+
model_subsampling <- stochtree::bcf(
63+
X_train = X_train, Z_train = Z_train, propensity_train = pi_x_train, y_train = y_train,
64+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
65+
general_params = general_params, prognostic_forest_params = prog_params_a,
66+
treatment_effect_forest_params = trt_params_a
67+
)
68+
Rprof(NULL)
69+
summaryRprof()
70+
71+
Rprof()
72+
model_no_subsampling <- stochtree::bcf(
73+
X_train = X_train, Z_train = Z_train, propensity_train = pi_x_train, y_train = y_train,
74+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
75+
general_params = general_params, prognostic_forest_params = prog_params_b,
76+
treatment_effect_forest_params = trt_params_b
77+
)
78+
Rprof(NULL)
79+
summaryRprof()
80+
81+
# Compare out of sample RMSE of the two models
82+
y_hat_test_subsampling <- rowMeans(predict(model_subsampling, X = X_test, Z = Z_test, propensity = pi_x_test)$y_hat)
83+
rmse_subsampling <- (
84+
sqrt(mean((y_hat_test_subsampling - y_test)^2))
85+
)
86+
y_hat_test_no_subsampling <- rowMeans(predict(model_no_subsampling, X = X_test, Z = Z_test, propensity = pi_x_test)$y_hat)
87+
rmse_no_subsampling <- (
88+
sqrt(mean((y_hat_test_no_subsampling - y_test)^2))
89+
)

0 commit comments

Comments
 (0)