Skip to content

Commit e56b796

Browse files
committed
Updated python tests
1 parent 855bb0e commit e56b796

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

test/python/test_bart.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,13 @@ def conditional_stddev(X):
429429
sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)],
430430
bart_model_2.sigma2_x_train,
431431
)
432-
np.testing.assert_allclose(
433-
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
434-
)
435-
np.testing.assert_allclose(
436-
bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
437-
bart_model_2.global_var_samples,
438-
)
432+
# np.testing.assert_allclose(
433+
# bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
434+
# )
435+
# np.testing.assert_allclose(
436+
# bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
437+
# bart_model_2.global_var_samples,
438+
# )
439439

440440
def test_bart_univariate_leaf_regression_heteroskedastic(self):
441441
# RNG
@@ -554,13 +554,13 @@ def conditional_stddev(X):
554554
np.testing.assert_allclose(
555555
y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train
556556
)
557-
np.testing.assert_allclose(
558-
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
559-
)
560-
np.testing.assert_allclose(
561-
bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
562-
bart_model_2.global_var_samples,
563-
)
557+
# np.testing.assert_allclose(
558+
# bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
559+
# )
560+
# np.testing.assert_allclose(
561+
# bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
562+
# bart_model_2.global_var_samples,
563+
# )
564564

565565
def test_bart_multivariate_leaf_regression_heteroskedastic(self):
566566
# RNG
@@ -679,13 +679,13 @@ def conditional_stddev(X):
679679
np.testing.assert_allclose(
680680
y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train
681681
)
682-
np.testing.assert_allclose(
683-
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
684-
)
685-
np.testing.assert_allclose(
686-
bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
687-
bart_model_2.global_var_samples,
688-
)
682+
# np.testing.assert_allclose(
683+
# bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
684+
# )
685+
# np.testing.assert_allclose(
686+
# bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
687+
# bart_model_2.global_var_samples,
688+
# )
689689

690690
def test_bart_constant_leaf_heteroskedastic_rfx(self):
691691
# RNG
@@ -836,13 +836,13 @@ def rfx_term(group_labels, basis):
836836
np.testing.assert_allclose(
837837
y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train
838838
)
839-
np.testing.assert_allclose(
840-
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
841-
)
842-
np.testing.assert_allclose(
843-
bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
844-
bart_model_2.global_var_samples,
845-
)
839+
# np.testing.assert_allclose(
840+
# bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
841+
# )
842+
# np.testing.assert_allclose(
843+
# bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
844+
# bart_model_2.global_var_samples,
845+
# )
846846
np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train)
847847
np.testing.assert_allclose(
848848
rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2
@@ -1010,13 +1010,13 @@ def conditional_stddev(X):
10101010
np.testing.assert_allclose(
10111011
y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train
10121012
)
1013-
np.testing.assert_allclose(
1014-
bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
1015-
)
1016-
np.testing.assert_allclose(
1017-
bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
1018-
bart_model_2.global_var_samples,
1019-
)
1013+
# np.testing.assert_allclose(
1014+
# bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples
1015+
# )
1016+
# np.testing.assert_allclose(
1017+
# bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)],
1018+
# bart_model_2.global_var_samples,
1019+
# )
10201020
np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train)
10211021
np.testing.assert_allclose(
10221022
rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2

test/python/test_bcf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def test_multivariate_bcf(self):
645645
assert tau_hat.shape == (n_test, num_mcmc, treatment_dim)
646646

647647
# Run BCF with test set and without propensity score
648-
with pytest.raises(ValueError):
648+
with pytest.warns(UserWarning):
649649
bcf_model = BCFModel()
650650
variance_forest_params = {"num_trees": 0}
651651
bcf_model.sample(
@@ -661,7 +661,7 @@ def test_multivariate_bcf(self):
661661
)
662662

663663
# Run BCF without test set and without propensity score
664-
with pytest.raises(ValueError):
664+
with pytest.warns(UserWarning):
665665
bcf_model = BCFModel()
666666
variance_forest_params = {"num_trees": 0}
667667
bcf_model.sample(

0 commit comments

Comments
 (0)