Skip to content

Commit e995300

Browse files
committed
Added python demo and benchmark script
1 parent 21f2897 commit e995300

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Supervised Learning Demo Script
2+
3+
# Load necessary libraries
4+
import numpy as np
5+
from stochtree import BCFModel
6+
from sklearn.model_selection import train_test_split
7+
import timeit
8+
9+
# Generate sample data
10+
# RNG
11+
random_seed = 1234
12+
rng = np.random.default_rng(random_seed)
13+
14+
# Generate covariates and basis
15+
n = 1000
16+
p = 100
17+
X = rng.normal(0, 1, (n, p))
18+
19+
# Generate outcome
20+
snr = 2
21+
mu_x = 1 + 2*X[:,0] + np.where(X[:,1] < 0, -4, 4) + 3*(np.abs(X[:,2]) - np.sqrt(2/np.pi))
22+
tau_x = 1 + 2*X[:,3]
23+
u = rng.uniform(0, 1, n)
24+
pi_x = ((mu_x-1.)/4.) + 4*(u-0.5)
25+
Z = pi_x + rng.normal(0, 1, n)
26+
E_XZ = mu_x + Z*tau_x
27+
noise_sd = np.std(E_XZ) / snr
28+
y = E_XZ + rng.normal(0, 1, n)*noise_sd
29+
30+
# Test-train split
31+
sample_inds = np.arange(n)
32+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.2)
33+
X_train = X[train_inds,:]
34+
X_test = X[test_inds,:]
35+
pi_x_train = pi_x[train_inds]
36+
pi_x_test = pi_x[test_inds]
37+
Z_train = Z[train_inds]
38+
Z_test = Z[test_inds]
39+
y_train = y[train_inds]
40+
y_test = y[test_inds]
41+
42+
# Run XBART with the full feature set
43+
s = """\
44+
bcf_model_a = BCFModel()
45+
prog_forest_config_a = {"num_trees": 100}
46+
trt_forest_config_a = {"num_trees": 50}
47+
bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a)
48+
"""
49+
timing_no_subsampling = timeit.timeit(stmt=s, number=5, globals=globals())
50+
print(f"Average runtime, without feature subsampling (p = {p:d}): {timing_no_subsampling:.2f}")
51+
52+
# Run XBART with each tree considering random subsets of 5 features
53+
s = """\
54+
bcf_model_b = BCFModel()
55+
prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5}
56+
trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5}
57+
bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b)
58+
"""
59+
timing_subsampling = timeit.timeit(stmt=s, number=5, globals=globals())
60+
print(f"Average runtime, subsampling 5 out of {p:d} features: {timing_subsampling:.2f}")
61+
62+
# Compare RMSEs of each model
63+
bcf_model_a = BCFModel()
64+
prog_forest_config_a = {"num_trees": 100}
65+
trt_forest_config_a = {"num_trees": 50}
66+
bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a)
67+
bcf_model_b = BCFModel()
68+
prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5}
69+
trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5}
70+
bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b)
71+
y_hat_test_a = np.squeeze(bcf_model_a.y_hat_test).mean(axis = 1)
72+
rmse_no_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_a,2)))
73+
print(f"Test set RMSE, no subsampling (p = {p:d}): {rmse_no_subsampling:.2f}")
74+
y_hat_test_b = np.squeeze(bcf_model_b.y_hat_test).mean(axis = 1)
75+
rmse_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_b,2)))
76+
print(f"Test set RMSE, subsampling 5 out of {p:d} features: {rmse_subsampling:.2f}")

0 commit comments

Comments
 (0)