@@ -76,8 +76,8 @@ def test_binary_bcf(self):
7676 assert mu_hat .shape == (n_test , num_mcmc )
7777 assert y_hat .shape == (n_test , num_mcmc )
7878
79- # Check treatment effect prediction method
80- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
79+ # Check that we can predict just treatment effects
80+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
8181 assert tau_hat .shape == (n_test , num_mcmc )
8282
8383 # Run BCF without test set and with propensity score
@@ -106,7 +106,7 @@ def test_binary_bcf(self):
106106 assert mu_hat .shape == (n_test , num_mcmc )
107107 assert y_hat .shape == (n_test , num_mcmc )
108108 # Check treatment effect prediction method
109- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
109+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
110110 assert tau_hat .shape == (n_test , num_mcmc )
111111
112112 # Run BCF with test set and without propensity score
@@ -142,7 +142,7 @@ def test_binary_bcf(self):
142142 assert y_hat .shape == (n_test , num_mcmc )
143143
144144 # Check treatment effect prediction method
145- tau_hat = bcf_model .predict_tau ( X_test , Z_test )
145+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , terms = "cate" )
146146 assert tau_hat .shape == (n_test , num_mcmc )
147147
148148 # Run BCF without test set and without propensity score
@@ -172,7 +172,7 @@ def test_binary_bcf(self):
172172 assert y_hat .shape == (n_test , num_mcmc )
173173
174174 # Check treatment effect prediction method
175- tau_hat = bcf_model .predict_tau ( X_test , Z_test )
175+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , terms = "cate" )
176176
177177 def test_continuous_univariate_bcf (self ):
178178 # RNG
@@ -245,7 +245,7 @@ def test_continuous_univariate_bcf(self):
245245 assert y_hat .shape == (n_test , num_mcmc )
246246
247247 # Check treatment effect prediction method
248- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
248+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
249249 assert tau_hat .shape == (n_test , num_mcmc )
250250
251251 # Run second BCF model with test set and propensity score
@@ -281,7 +281,7 @@ def test_continuous_univariate_bcf(self):
281281 assert y_hat_2 .shape == (n_test , num_mcmc )
282282
283283 # Check treatment effect prediction method
284- tau_hat_2 = bcf_model_2 .predict_tau ( X_test , Z_test , pi_test )
284+ tau_hat_2 = bcf_model_2 .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
285285 assert tau_hat_2 .shape == (n_test , num_mcmc )
286286
287287 # Combine into a single model
@@ -336,7 +336,7 @@ def test_continuous_univariate_bcf(self):
336336 assert y_hat .shape == (n_test , num_mcmc )
337337
338338 # Check treatment effect prediction method
339- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
339+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
340340 assert tau_hat .shape == (n_test , num_mcmc )
341341
342342 # Run BCF with test set and without propensity score
@@ -372,7 +372,7 @@ def test_continuous_univariate_bcf(self):
372372 assert y_hat .shape == (n_test , num_mcmc )
373373
374374 # Check treatment effect prediction method
375- tau_hat = bcf_model .predict_tau ( X_test , Z_test )
375+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , terms = "cate" )
376376 assert tau_hat .shape == (n_test , num_mcmc )
377377
378378 # Run BCF without test set and without propensity score
@@ -402,7 +402,7 @@ def test_continuous_univariate_bcf(self):
402402 assert y_hat .shape == (n_test , num_mcmc )
403403
404404 # Check treatment effect prediction method
405- tau_hat = bcf_model .predict_tau ( X_test , Z_test )
405+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , terms = "cate" )
406406
407407 # Run second BCF model with test set and propensity score
408408 bcf_model_2 = BCFModel ()
@@ -430,7 +430,7 @@ def test_continuous_univariate_bcf(self):
430430 assert y_hat_2 .shape == (n_test , num_mcmc )
431431
432432 # Check treatment effect prediction method
433- tau_hat_2 = bcf_model_2 .predict_tau ( X_test , Z_test )
433+ tau_hat_2 = bcf_model_2 .predict ( X = X_test , Z = Z_test , terms = "cate" )
434434 assert tau_hat_2 .shape == (n_test , num_mcmc )
435435
436436 # Combine into a single model
@@ -528,7 +528,7 @@ def test_multivariate_bcf(self):
528528 assert y_hat .shape == (n_test , num_mcmc )
529529
530530 # Check treatment effect prediction method
531- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
531+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
532532 assert tau_hat .shape == (n_test , num_mcmc , treatment_dim )
533533
534534 # Run BCF without test set and with propensity score
@@ -558,7 +558,7 @@ def test_multivariate_bcf(self):
558558 assert y_hat .shape == (n_test , num_mcmc )
559559
560560 # Check treatment effect prediction method
561- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
561+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
562562 assert tau_hat .shape == (n_test , num_mcmc , treatment_dim )
563563
564564 # Run BCF with test set and without propensity score
@@ -665,7 +665,7 @@ def test_binary_bcf_heteroskedastic(self):
665665 assert sigma2_x_hat .shape == (n_test , num_mcmc )
666666
667667 # Check treatment effect prediction method
668- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
668+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
669669 assert tau_hat .shape == (n_test , num_mcmc )
670670
671671 # Run BCF without test set and with propensity score
@@ -715,7 +715,7 @@ def test_binary_bcf_heteroskedastic(self):
715715 )
716716
717717 # Check treatment effect prediction method
718- tau_hat = bcf_model .predict_tau ( X_test , Z_test , pi_test )
718+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , propensity = pi_test , terms = "cate" )
719719 assert tau_hat .shape == (n_test , num_mcmc )
720720
721721 # Run BCF with test set and without propensity score
@@ -752,7 +752,7 @@ def test_binary_bcf_heteroskedastic(self):
752752 assert bcf_preds ['variance_forest_predictions' ].shape == (n_test , num_mcmc )
753753
754754 # Check treatment effect prediction method
755- tau_hat = bcf_model .predict_tau ( X_test , Z_test )
755+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , terms = "cate" )
756756 assert tau_hat .shape == (n_test , num_mcmc )
757757
758758 # Run BCF without test set and without propensity score
@@ -781,7 +781,7 @@ def test_binary_bcf_heteroskedastic(self):
781781 assert bcf_preds ['y_hat' ].shape == (n_test , num_mcmc )
782782
783783 # Check treatment effect prediction method
784- tau_hat = bcf_model .predict_tau ( X_test , Z_test )
784+ tau_hat = bcf_model .predict ( X = X_test , Z = Z_test , terms = "cate" )
785785
786786 def test_bcf_rfx_parameters (self ):
787787 # RNG
0 commit comments