22
33# Load necessary libraries
44import numpy as np
5- import pandas as pd
6- import seaborn as sns
7- import matplotlib .pyplot as plt
85from stochtree import BARTModel
96from sklearn .model_selection import train_test_split
107import timeit
1613
1714# Generate covariates and basis
1815n = 1000
19- p_X = 100
20- X = rng .uniform (0 , 1 , (n , p_X ))
16+ p = 100
17+ X = rng .uniform (0 , 1 , (n , p ))
2118
2219# Define the outcome mean function
2320def outcome_mean (X ):
24- return np .where (
25- (X [:,0 ] >= 0.0 ) & (X [:,0 ] < 0.25 ), - 7.5 ,
26- np .where (
27- (X [:,0 ] >= 0.25 ) & (X [:,0 ] < 0.5 ), - 2.5 ,
28- np .where (
29- (X [:,0 ] >= 0.5 ) & (X [:,0 ] < 0.75 ), 2.5 ,
30- 7.5
31- )
32- )
21+ return (
22+ np .sin (4 * np .pi * X [:,0 ]) + np .cos (4 * np .pi * X [:,1 ]) + np .sin (4 * np .pi * X [:,2 ]) + np .cos (4 * np .pi * X [:,3 ])
3323 )
3424
3525# Generate outcome
36- epsilon = rng .normal (0 , 1 , n )
37- y = outcome_mean (X ) + epsilon
26+ snr = 2
27+ f_X = outcome_mean (X )
28+ noise_sd = np .std (f_X ) / snr
29+ epsilon = rng .normal (0 , 1 , n ) * noise_sd
30+ y = f_X + epsilon
3831
3932# Test-train split
4033sample_inds = np .arange (n )
@@ -50,12 +43,28 @@ def outcome_mean(X):
5043forest_config_a = {"num_trees": 100}
5144bart_model_a.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_a)
5245"""
53- print (timeit .timeit (stmt = s , number = 5 , globals = globals ()))
46+ timing_no_subsampling = timeit .timeit (stmt = s , number = 5 , globals = globals ())
47+ print (f"Average runtime, without feature subsampling (p = { p :d} ): { timing_no_subsampling :.2f} " )
5448
5549# Run XBART with each tree considering random subsets of 5 features
5650s = """\
5751 bart_model_b = BARTModel()
5852forest_config_b = {"num_trees": 100, "num_features_subsample": 5}
5953bart_model_b.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=100, num_mcmc=0, mean_forest_params=forest_config_b)
6054"""
61- print (timeit .timeit (stmt = s , number = 5 , globals = globals ()))
55+ timing_subsampling = timeit .timeit (stmt = s , number = 5 , globals = globals ())
56+ print (f"Average runtime, subsampling 5 out of { p :d} features: { timing_subsampling :.2f} " )
57+
58+ # Compare RMSEs of each model
59+ bart_model_a = BARTModel ()
60+ forest_config_a = {"num_trees" : 100 }
61+ bart_model_a .sample (X_train = X_train , y_train = y_train , X_test = X_test , num_gfr = 100 , num_mcmc = 0 , mean_forest_params = forest_config_a )
62+ bart_model_b = BARTModel ()
63+ forest_config_b = {"num_trees" : 100 , "num_features_subsample" : 5 }
64+ bart_model_b .sample (X_train = X_train , y_train = y_train , X_test = X_test , num_gfr = 100 , num_mcmc = 0 , mean_forest_params = forest_config_b )
65+ y_hat_test_a = np .squeeze (bart_model_a .y_hat_test ).mean (axis = 1 )
66+ rmse_no_subsampling = np .sqrt (np .mean (np .power (y_test - y_hat_test_a ,2 )))
67+ print (f"Test set RMSE, no subsampling (p = { p :d} ): { rmse_no_subsampling :.2f} " )
68+ y_hat_test_b = np .squeeze (bart_model_b .y_hat_test ).mean (axis = 1 )
69+ rmse_subsampling = np .sqrt (np .mean (np .power (y_test - y_hat_test_b ,2 )))
70+ print (f"Test set RMSE, subsampling 5 out of { p :d} features: { rmse_subsampling :.2f} " )
0 commit comments