|
| 1 | +# Supervised Learning Demo Script |
| 2 | + |
| 3 | +# Load necessary libraries |
| 4 | +import numpy as np |
| 5 | +from stochtree import BCFModel |
| 6 | +from sklearn.model_selection import train_test_split |
| 7 | +import timeit |
| 8 | + |
| 9 | +# Generate sample data |
| 10 | +# RNG |
| 11 | +random_seed = 1234 |
| 12 | +rng = np.random.default_rng(random_seed) |
| 13 | + |
| 14 | +# Generate covariates and basis |
| 15 | +n = 1000 |
| 16 | +p = 100 |
| 17 | +X = rng.normal(0, 1, (n, p)) |
| 18 | + |
| 19 | +# Generate outcome |
| 20 | +snr = 2 |
| 21 | +mu_x = 1 + 2*X[:,0] + np.where(X[:,1] < 0, -4, 4) + 3*(np.abs(X[:,2]) - np.sqrt(2/np.pi)) |
| 22 | +tau_x = 1 + 2*X[:,3] |
| 23 | +u = rng.uniform(0, 1, n) |
| 24 | +pi_x = ((mu_x-1.)/4.) + 4*(u-0.5) |
| 25 | +Z = pi_x + rng.normal(0, 1, n) |
| 26 | +E_XZ = mu_x + Z*tau_x |
| 27 | +noise_sd = np.std(E_XZ) / snr |
| 28 | +y = E_XZ + rng.normal(0, 1, n)*noise_sd |
| 29 | + |
| 30 | +# Test-train split |
| 31 | +sample_inds = np.arange(n) |
| 32 | +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) |
| 33 | +X_train = X[train_inds,:] |
| 34 | +X_test = X[test_inds,:] |
| 35 | +pi_x_train = pi_x[train_inds] |
| 36 | +pi_x_test = pi_x[test_inds] |
| 37 | +Z_train = Z[train_inds] |
| 38 | +Z_test = Z[test_inds] |
| 39 | +y_train = y[train_inds] |
| 40 | +y_test = y[test_inds] |
| 41 | + |
| 42 | +# Run XBART with the full feature set |
| 43 | +s = """\ |
| 44 | +bcf_model_a = BCFModel() |
| 45 | +prog_forest_config_a = {"num_trees": 100} |
| 46 | +trt_forest_config_a = {"num_trees": 50} |
| 47 | +bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) |
| 48 | +""" |
| 49 | +timing_no_subsampling = timeit.timeit(stmt=s, number=5, globals=globals()) |
| 50 | +print(f"Average runtime, without feature subsampling (p = {p:d}): {timing_no_subsampling:.2f}") |
| 51 | + |
| 52 | +# Run XBART with each tree considering random subsets of 5 features |
| 53 | +s = """\ |
| 54 | +bcf_model_b = BCFModel() |
| 55 | +prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5} |
| 56 | +trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5} |
| 57 | +bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) |
| 58 | +""" |
| 59 | +timing_subsampling = timeit.timeit(stmt=s, number=5, globals=globals()) |
| 60 | +print(f"Average runtime, subsampling 5 out of {p:d} features: {timing_subsampling:.2f}") |
| 61 | + |
| 62 | +# Compare RMSEs of each model |
| 63 | +bcf_model_a = BCFModel() |
| 64 | +prog_forest_config_a = {"num_trees": 100} |
| 65 | +trt_forest_config_a = {"num_trees": 50} |
| 66 | +bcf_model_a.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_a, treatment_effect_forest_params=trt_forest_config_a) |
| 67 | +bcf_model_b = BCFModel() |
| 68 | +prog_forest_config_b = {"num_trees": 100, "num_features_subsample": 5} |
| 69 | +trt_forest_config_b = {"num_trees": 50, "num_features_subsample": 5} |
| 70 | +bcf_model_b.sample(X_train=X_train, Z_train=Z_train, pi_train=pi_x_train, y_train=y_train, X_test=X_test, Z_test=Z_test, pi_test=pi_x_test, num_gfr=100, num_mcmc=0, prognostic_forest_params=prog_forest_config_b, treatment_effect_forest_params=trt_forest_config_b) |
| 71 | +y_hat_test_a = np.squeeze(bcf_model_a.y_hat_test).mean(axis = 1) |
| 72 | +rmse_no_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_a,2))) |
| 73 | +print(f"Test set RMSE, no subsampling (p = {p:d}): {rmse_no_subsampling:.2f}") |
| 74 | +y_hat_test_b = np.squeeze(bcf_model_b.y_hat_test).mean(axis = 1) |
| 75 | +rmse_subsampling = np.sqrt(np.mean(np.power(y_test - y_hat_test_b,2))) |
| 76 | +print(f"Test set RMSE, subsampling 5 out of {p:d} features: {rmse_subsampling:.2f}") |
0 commit comments