@@ -84,7 +84,7 @@ def outcome_mean(X):
8484
8585 # Assertions
8686 bart_preds_combined = bart_model_3 .predict (covariates = X_train )
87- y_hat_train_combined = bart_preds_combined [' y_hat' ]
87+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
8888 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
8989 np .testing .assert_allclose (
9090 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -192,7 +192,7 @@ def outcome_mean(X, W):
192192 bart_preds_combined = bart_model_3 .predict (
193193 covariates = X_train , basis = basis_train
194194 )
195- y_hat_train_combined = bart_preds_combined [' y_hat' ]
195+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
196196 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
197197 np .testing .assert_allclose (
198198 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -300,7 +300,7 @@ def outcome_mean(X, W):
300300 bart_preds_combined = bart_model_3 .predict (
301301 covariates = X_train , basis = basis_train
302302 )
303- y_hat_train_combined = bart_preds_combined [' y_hat' ]
303+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
304304 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
305305 np .testing .assert_allclose (
306306 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -411,7 +411,10 @@ def conditional_stddev(X):
411411
412412 # Assertions
413413 bart_preds_combined = bart_model_3 .predict (covariates = X_train )
414- y_hat_train_combined , sigma2_x_train_combined = bart_preds_combined ['y_hat' ], bart_preds_combined ['variance_forest_predictions' ]
414+ y_hat_train_combined , sigma2_x_train_combined = (
415+ bart_preds_combined ["y_hat" ],
416+ bart_preds_combined ["variance_forest_predictions" ],
417+ )
415418 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
416419 assert sigma2_x_train_combined .shape == (n_train , num_mcmc * 2 )
417420 np .testing .assert_allclose (
@@ -424,7 +427,8 @@ def conditional_stddev(X):
424427 sigma2_x_train_combined [:, 0 :num_mcmc ], bart_model .sigma2_x_train
425428 )
426429 np .testing .assert_allclose (
427- sigma2_x_train_combined [:, num_mcmc : (2 * num_mcmc )], bart_model_2 .sigma2_x_train
430+ sigma2_x_train_combined [:, num_mcmc : (2 * num_mcmc )],
431+ bart_model_2 .sigma2_x_train ,
428432 )
429433 np .testing .assert_allclose (
430434 bart_model_3 .global_var_samples [0 :num_mcmc ], bart_model .global_var_samples
@@ -543,7 +547,7 @@ def conditional_stddev(X):
543547 bart_preds_combined = bart_model_3 .predict (
544548 covariates = X_train , basis = basis_train
545549 )
546- y_hat_train_combined = bart_preds_combined [' y_hat' ]
550+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
547551 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
548552 np .testing .assert_allclose (
549553 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -668,7 +672,7 @@ def conditional_stddev(X):
668672 bart_preds_combined = bart_model_3 .predict (
669673 covariates = X_train , basis = basis_train
670674 )
671- y_hat_train_combined = bart_preds_combined [' y_hat' ]
675+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
672676 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
673677 np .testing .assert_allclose (
674678 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -825,7 +829,7 @@ def rfx_term(group_labels, basis):
825829 rfx_group_ids = group_labels_train ,
826830 rfx_basis = rfx_basis_train ,
827831 )
828- y_hat_train_combined = bart_preds_combined [' y_hat' ]
832+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
829833 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
830834 np .testing .assert_allclose (
831835 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -999,7 +1003,7 @@ def conditional_stddev(X):
9991003 rfx_group_ids = group_labels_train ,
10001004 rfx_basis = rfx_basis_train ,
10011005 )
1002- y_hat_train_combined = bart_preds_combined [' y_hat' ]
1006+ y_hat_train_combined = bart_preds_combined [" y_hat" ]
10031007 assert y_hat_train_combined .shape == (n_train , num_mcmc * 2 )
10041008 np .testing .assert_allclose (
10051009 y_hat_train_combined [:, 0 :num_mcmc ], bart_model .y_hat_train
@@ -1059,7 +1063,7 @@ def outcome_mean(X, W):
10591063 # Define the group rfx function
10601064 def rfx_term (group_labels , basis ):
10611065 return np .where (
1062- group_labels == 0 , - 5 + 1. * basis [:,1 ], 5 - 1. * basis [:,1 ]
1066+ group_labels == 0 , - 5 + 1.0 * basis [:, 1 ], 5 - 1.0 * basis [:, 1 ]
10631067 )
10641068
10651069 # Define the conditional standard deviation function
@@ -1124,12 +1128,12 @@ def conditional_stddev(X):
11241128 # Specify scalar rfx parameters
11251129 rfx_params = {
11261130 "model_spec" : "custom" ,
1127- "working_parameter_prior_mean" : 1. ,
1128- "group_parameter_prior_mean" : 1. ,
1129- "working_parameter_prior_cov" : 1. ,
1130- "group_parameter_prior_cov" : 1. ,
1131+ "working_parameter_prior_mean" : 1.0 ,
1132+ "group_parameter_prior_mean" : 1.0 ,
1133+ "working_parameter_prior_cov" : 1.0 ,
1134+ "group_parameter_prior_cov" : 1.0 ,
11311135 "variance_prior_shape" : 1 ,
1132- "variance_prior_scale" : 1
1136+ "variance_prior_scale" : 1 ,
11331137 }
11341138 bart_model_2 = BARTModel ()
11351139 bart_model_2 .sample (
@@ -1151,12 +1155,12 @@ def conditional_stddev(X):
11511155 # Specify all relevant rfx parameters as vectors
11521156 rfx_params = {
11531157 "model_spec" : "custom" ,
1154- "working_parameter_prior_mean" : np .repeat (1. , num_rfx_basis ),
1155- "group_parameter_prior_mean" : np .repeat (1. , num_rfx_basis ),
1158+ "working_parameter_prior_mean" : np .repeat (1.0 , num_rfx_basis ),
1159+ "group_parameter_prior_mean" : np .repeat (1.0 , num_rfx_basis ),
11561160 "working_parameter_prior_cov" : np .identity (num_rfx_basis ),
11571161 "group_parameter_prior_cov" : np .identity (num_rfx_basis ),
11581162 "variance_prior_shape" : 1 ,
1159- "variance_prior_scale" : 1
1163+ "variance_prior_scale" : 1 ,
11601164 }
11611165 bart_model_3 = BARTModel ()
11621166 bart_model_3 .sample (
@@ -1176,9 +1180,7 @@ def conditional_stddev(X):
11761180 )
11771181
11781182 # Fit a simpler intercept-only RFX model
1179- rfx_params = {
1180- "model_spec" : "intercept_only"
1181- }
1183+ rfx_params = {"model_spec" : "intercept_only" }
11821184 bart_model_4 = BARTModel ()
11831185 bart_model_4 .sample (
11841186 X_train = X_train ,
@@ -1198,6 +1200,6 @@ def conditional_stddev(X):
11981200 basis = basis_test ,
11991201 rfx_group_ids = group_labels_test ,
12001202 type = "posterior" ,
1201- terms = "rfx"
1203+ terms = "rfx" ,
12021204 )
12031205 assert preds .shape == (n_test , num_mcmc )
0 commit comments