Skip to content

Commit b0b59f2

Browse files
committed
Reformatting R and Python unit test code
1 parent c00e480 commit b0b59f2

File tree

12 files changed

+801
-574
lines changed

12 files changed

+801
-574
lines changed

test/R/testthat/test-bart.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,13 @@ test_that("Random Effects BART", {
582582
mean_forest_params = mean_forest_param_list,
583583
random_effects_params = rfx_param_list
584584
)
585-
preds <- predict(bart_model, covariates = X_test, leaf_basis = W_test, rfx_group_ids = rfx_group_ids_test, type = "posterior", terms = "rfx")
585+
preds <- predict(
586+
bart_model,
587+
covariates = X_test,
588+
leaf_basis = W_test,
589+
rfx_group_ids = rfx_group_ids_test,
590+
type = "posterior",
591+
terms = "rfx"
592+
)
586593
})
587594
})

test/python/test_bart.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def outcome_mean(X):
8484

8585
# Assertions
8686
bart_preds_combined = bart_model_3.predict(covariates=X_train)
87-
y_hat_train_combined = bart_preds_combined['y_hat']
87+
y_hat_train_combined = bart_preds_combined["y_hat"]
8888
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
8989
np.testing.assert_allclose(
9090
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -192,7 +192,7 @@ def outcome_mean(X, W):
192192
bart_preds_combined = bart_model_3.predict(
193193
covariates=X_train, basis=basis_train
194194
)
195-
y_hat_train_combined = bart_preds_combined['y_hat']
195+
y_hat_train_combined = bart_preds_combined["y_hat"]
196196
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
197197
np.testing.assert_allclose(
198198
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -300,7 +300,7 @@ def outcome_mean(X, W):
300300
bart_preds_combined = bart_model_3.predict(
301301
covariates=X_train, basis=basis_train
302302
)
303-
y_hat_train_combined = bart_preds_combined['y_hat']
303+
y_hat_train_combined = bart_preds_combined["y_hat"]
304304
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
305305
np.testing.assert_allclose(
306306
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -411,7 +411,10 @@ def conditional_stddev(X):
411411

412412
# Assertions
413413
bart_preds_combined = bart_model_3.predict(covariates=X_train)
414-
y_hat_train_combined, sigma2_x_train_combined = bart_preds_combined['y_hat'], bart_preds_combined['variance_forest_predictions']
414+
y_hat_train_combined, sigma2_x_train_combined = (
415+
bart_preds_combined["y_hat"],
416+
bart_preds_combined["variance_forest_predictions"],
417+
)
415418
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
416419
assert sigma2_x_train_combined.shape == (n_train, num_mcmc * 2)
417420
np.testing.assert_allclose(
@@ -424,7 +427,8 @@ def conditional_stddev(X):
424427
sigma2_x_train_combined[:, 0:num_mcmc], bart_model.sigma2_x_train
425428
)
426429
np.testing.assert_allclose(
427-
sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train
430+
sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)],
431+
bart_model_2.sigma2_x_train,
428432
)
429433
np.testing.assert_allclose(
430434
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
@@ -543,7 +547,7 @@ def conditional_stddev(X):
543547
bart_preds_combined = bart_model_3.predict(
544548
covariates=X_train, basis=basis_train
545549
)
546-
y_hat_train_combined = bart_preds_combined['y_hat']
550+
y_hat_train_combined = bart_preds_combined["y_hat"]
547551
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
548552
np.testing.assert_allclose(
549553
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -668,7 +672,7 @@ def conditional_stddev(X):
668672
bart_preds_combined = bart_model_3.predict(
669673
covariates=X_train, basis=basis_train
670674
)
671-
y_hat_train_combined = bart_preds_combined['y_hat']
675+
y_hat_train_combined = bart_preds_combined["y_hat"]
672676
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
673677
np.testing.assert_allclose(
674678
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -825,7 +829,7 @@ def rfx_term(group_labels, basis):
825829
rfx_group_ids=group_labels_train,
826830
rfx_basis=rfx_basis_train,
827831
)
828-
y_hat_train_combined = bart_preds_combined['y_hat']
832+
y_hat_train_combined = bart_preds_combined["y_hat"]
829833
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
830834
np.testing.assert_allclose(
831835
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -999,7 +1003,7 @@ def conditional_stddev(X):
9991003
rfx_group_ids=group_labels_train,
10001004
rfx_basis=rfx_basis_train,
10011005
)
1002-
y_hat_train_combined = bart_preds_combined['y_hat']
1006+
y_hat_train_combined = bart_preds_combined["y_hat"]
10031007
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
10041008
np.testing.assert_allclose(
10051009
y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train
@@ -1059,7 +1063,7 @@ def outcome_mean(X, W):
10591063
# Define the group rfx function
10601064
def rfx_term(group_labels, basis):
10611065
return np.where(
1062-
group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1]
1066+
group_labels == 0, -5 + 1.0 * basis[:, 1], 5 - 1.0 * basis[:, 1]
10631067
)
10641068

10651069
# Define the conditional standard deviation function
@@ -1124,12 +1128,12 @@ def conditional_stddev(X):
11241128
# Specify scalar rfx parameters
11251129
rfx_params = {
11261130
"model_spec": "custom",
1127-
"working_parameter_prior_mean": 1.,
1128-
"group_parameter_prior_mean": 1.,
1129-
"working_parameter_prior_cov": 1.,
1130-
"group_parameter_prior_cov": 1.,
1131+
"working_parameter_prior_mean": 1.0,
1132+
"group_parameter_prior_mean": 1.0,
1133+
"working_parameter_prior_cov": 1.0,
1134+
"group_parameter_prior_cov": 1.0,
11311135
"variance_prior_shape": 1,
1132-
"variance_prior_scale": 1
1136+
"variance_prior_scale": 1,
11331137
}
11341138
bart_model_2 = BARTModel()
11351139
bart_model_2.sample(
@@ -1151,12 +1155,12 @@ def conditional_stddev(X):
11511155
# Specify all relevant rfx parameters as vectors
11521156
rfx_params = {
11531157
"model_spec": "custom",
1154-
"working_parameter_prior_mean": np.repeat(1., num_rfx_basis),
1155-
"group_parameter_prior_mean": np.repeat(1., num_rfx_basis),
1158+
"working_parameter_prior_mean": np.repeat(1.0, num_rfx_basis),
1159+
"group_parameter_prior_mean": np.repeat(1.0, num_rfx_basis),
11561160
"working_parameter_prior_cov": np.identity(num_rfx_basis),
11571161
"group_parameter_prior_cov": np.identity(num_rfx_basis),
11581162
"variance_prior_shape": 1,
1159-
"variance_prior_scale": 1
1163+
"variance_prior_scale": 1,
11601164
}
11611165
bart_model_3 = BARTModel()
11621166
bart_model_3.sample(
@@ -1176,9 +1180,7 @@ def conditional_stddev(X):
11761180
)
11771181

11781182
# Fit a simpler intercept-only RFX model
1179-
rfx_params = {
1180-
"model_spec": "intercept_only"
1181-
}
1183+
rfx_params = {"model_spec": "intercept_only"}
11821184
bart_model_4 = BARTModel()
11831185
bart_model_4.sample(
11841186
X_train=X_train,
@@ -1198,6 +1200,6 @@ def conditional_stddev(X):
11981200
basis=basis_test,
11991201
rfx_group_ids=group_labels_test,
12001202
type="posterior",
1201-
terms="rfx"
1203+
terms="rfx",
12021204
)
12031205
assert preds.shape == (n_test, num_mcmc)

0 commit comments

Comments
 (0)