Skip to content

Commit 8588864

Browse files
committed
Reformatted python code
1 parent 4f1beed commit 8588864

File tree

4 files changed

+189
-121
lines changed

4 files changed

+189
-121
lines changed

demo/debug/bart_predict_debug.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,50 +47,42 @@
4747

4848
# # Check several predict approaches
4949
bart_preds = bart_model.predict(covariates=X_test)
50-
y_hat_posterior_test = bart_model.predict(covariates=X_test)['y_hat']
51-
y_hat_mean_test = bart_model.predict(
52-
covariates=X_test,
53-
type = "mean",
54-
terms = ["y_hat"]
55-
)
50+
y_hat_posterior_test = bart_model.predict(covariates=X_test)["y_hat"]
51+
y_hat_mean_test = bart_model.predict(covariates=X_test, type="mean", terms=["y_hat"])
5652
y_hat_test = bart_model.predict(
57-
covariates=X_test,
58-
type = "mean",
59-
terms = ["rfx", "variance"]
53+
covariates=X_test, type="mean", terms=["rfx", "variance"]
6054
)
6155

6256
# Plot predicted versus actual
6357
plt.scatter(y_hat_mean_test, y_test, color="black")
64-
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
58+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
6559
plt.xlabel("Predicted")
6660
plt.ylabel("Actual")
6761
plt.title("Y hat")
6862
plt.show()
6963

7064
# Compute posterior interval
7165
intervals = bart_model.compute_posterior_interval(
72-
terms = "all",
73-
scale = "linear",
74-
level = 0.95,
75-
covariates = X_test
66+
terms="all", scale="linear", level=0.95, covariates=X_test
7667
)
7768

7869
# Check coverage
7970
mean_coverage = np.mean(
80-
(intervals["y_hat"]["lower"] <= f_X_test) & (f_X_test <= intervals["y_hat"]["upper"])
71+
(intervals["y_hat"]["lower"] <= f_X_test)
72+
& (f_X_test <= intervals["y_hat"]["upper"])
8173
)
8274
print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}")
8375

8476
# Sample from the posterior predictive distribution
8577
bart_ppd_samples = bart_model.sample_posterior_predictive(
86-
covariates = X_test, num_draws_per_sample = 10
78+
covariates=X_test, num_draws_per_sample=10
8779
)
8880

8981
# Plot PPD mean vs actual
9082
ppd_mean = np.mean(bart_ppd_samples, axis=(0, 2))
9183
plt.clf()
9284
plt.scatter(ppd_mean, y_test, color="blue")
93-
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
85+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
9486
plt.xlabel("Predicted")
9587
plt.ylabel("Actual")
9688
plt.title("Posterior Predictive Mean Comparison")

demo/debug/bcf_predict_debug.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
n = 1000
1313
p = 5
1414
X = rng.normal(loc=0.0, scale=1.0, size=(n, p))
15-
mu_X = X[:,0]
16-
tau_X = 0.25 * X[:,1]
17-
pi_X = norm.cdf(0.5 * X[:,1])
15+
mu_X = X[:, 0]
16+
tau_X = 0.25 * X[:, 1]
17+
pi_X = norm.cdf(0.5 * X[:, 1])
1818
Z = rng.binomial(n=1, p=pi_X, size=(n,))
1919
E_XZ = mu_X + tau_X * Z
2020
snr = 2.0
@@ -54,27 +54,23 @@
5454

5555
# Check several predict approaches
5656
bcf_preds = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)
57-
y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)['y_hat']
57+
y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)[
58+
"y_hat"
59+
]
5860
y_hat_mean_test = bcf_model.predict(
59-
X=X_test, Z=Z_test, propensity=pi_test,
60-
type = "mean",
61-
terms = ["y_hat"]
61+
X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["y_hat"]
6262
)
6363
tau_hat_mean_test = bcf_model.predict(
64-
X=X_test, Z=Z_test, propensity=pi_test,
65-
type = "mean",
66-
terms = ["cate"]
64+
X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["cate"]
6765
)
6866
# Check that this raises a warning
6967
y_hat_test = bcf_model.predict(
70-
X=X_test, Z=Z_test, propensity=pi_test,
71-
type = "mean",
72-
terms = ["rfx", "variance"]
68+
X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["rfx", "variance"]
7369
)
7470

7571
# Plot predicted versus actual
7672
plt.scatter(y_hat_mean_test, y_test, color="black")
77-
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
73+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
7874
plt.xlabel("Predicted")
7975
plt.ylabel("Actual")
8076
plt.title("Y hat")
@@ -83,50 +79,53 @@
8379
# Plot predicted versus actual
8480
plt.clf()
8581
plt.scatter(tau_hat_mean_test, tau_test, color="black")
86-
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
82+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
8783
plt.xlabel("Predicted")
8884
plt.ylabel("Actual")
8985
plt.title("CATE function")
9086
plt.show()
9187

9288
# Compute posterior interval
9389
intervals = bcf_model.compute_posterior_interval(
94-
terms = "all",
95-
scale = "linear",
96-
level = 0.95,
97-
covariates = X_test,
98-
treatment = Z_test,
99-
propensity = pi_test
90+
terms="all",
91+
scale="linear",
92+
level=0.95,
93+
covariates=X_test,
94+
treatment=Z_test,
95+
propensity=pi_test,
10096
)
10197

10298
# Check coverage of E[Y | X, Z]
10399
mean_coverage = np.mean(
104-
(intervals["y_hat"]["lower"] <= E_XZ_test) & (E_XZ_test <= intervals["y_hat"]["upper"])
100+
(intervals["y_hat"]["lower"] <= E_XZ_test)
101+
& (E_XZ_test <= intervals["y_hat"]["upper"])
105102
)
106103
print(f"Coverage of 95% posterior interval for E[Y|X,Z]: {mean_coverage:.3f}")
107104

108105
# Check coverage of tau(X)
109106
tau_coverage = np.mean(
110-
(intervals["tau_hat"]["lower"] <= tau_test) & (tau_test <= intervals["tau_hat"]["upper"])
107+
(intervals["tau_hat"]["lower"] <= tau_test)
108+
& (tau_test <= intervals["tau_hat"]["upper"])
111109
)
112110
print(f"Coverage of 95% posterior interval for tau(X): {tau_coverage:.3f}")
113111

114112
# Check coverage of mu(X)
115113
mu_coverage = np.mean(
116-
(intervals["mu_hat"]["lower"] <= mu_test) & (mu_test <= intervals["mu_hat"]["upper"])
114+
(intervals["mu_hat"]["lower"] <= mu_test)
115+
& (mu_test <= intervals["mu_hat"]["upper"])
117116
)
118117
print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}")
119118

120119
# Sample from the posterior predictive distribution
121120
bcf_ppd_samples = bcf_model.sample_posterior_predictive(
122-
covariates = X_test, treatment = Z_test, propensity = pi_test, num_draws_per_sample = 10
121+
covariates=X_test, treatment=Z_test, propensity=pi_test, num_draws_per_sample=10
123122
)
124123

125124
# Plot PPD mean vs actual
126125
ppd_mean = np.mean(bcf_ppd_samples, axis=(0, 2))
127126
plt.clf()
128127
plt.scatter(ppd_mean, y_test, color="blue")
129-
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
128+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
130129
plt.xlabel("Predicted")
131130
plt.ylabel("Actual")
132131
plt.title("Posterior Predictive Mean Comparison")

stochtree/bart.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
_expand_dims_1d,
2525
_expand_dims_2d,
2626
_expand_dims_2d_diag,
27-
_posterior_predictive_heuristic_multiplier,
28-
_summarize_interval
27+
_posterior_predictive_heuristic_multiplier,
28+
_summarize_interval,
2929
)
3030

3131

@@ -1858,7 +1858,16 @@ def predict(
18581858
result["variance_forest_predictions"] = None
18591859
return result
18601860

1861-
def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict:
1861+
def compute_posterior_interval(
1862+
self,
1863+
terms: Union[list[str], str] = "all",
1864+
scale: str = "linear",
1865+
level: float = 0.95,
1866+
covariates: np.array = None,
1867+
basis: np.array = None,
1868+
rfx_group_ids: np.array = None,
1869+
rfx_basis: np.array = None,
1870+
) -> dict:
18621871
"""
18631872
Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
18641873
@@ -1889,7 +1898,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
18891898
raise ValueError("Model has not yet been sampled")
18901899
for term in terms:
18911900
if not self.has_term(term):
1892-
warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.")
1901+
warnings.warn(
1902+
f"Term {term} was not sampled in this model and its intervals will not be returned."
1903+
)
18931904

18941905
# Handle mean function scale
18951906
if not isinstance(scale, str):
@@ -1903,8 +1914,14 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19031914
)
19041915

19051916
# Check that all the necessary inputs were provided for interval computation
1906-
needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.include_mean_forest
1907-
needs_covariates = ("mean_forest" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate
1917+
needs_covariates_intermediate = (
1918+
("y_hat" in terms) or ("all" in terms)
1919+
) and self.include_mean_forest
1920+
needs_covariates = (
1921+
("mean_forest" in terms)
1922+
or ("variance_forest" in terms)
1923+
or needs_covariates_intermediate
1924+
)
19081925
if needs_covariates:
19091926
if covariates is None:
19101927
raise ValueError(
@@ -1926,7 +1943,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19261943
raise ValueError(
19271944
"'basis' must have the same number of rows as 'covariates'"
19281945
)
1929-
needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx
1946+
needs_rfx_data_intermediate = (
1947+
("y_hat" in terms) or ("all" in terms)
1948+
) and self.has_rfx
19301949
needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate
19311950
if needs_rfx_data:
19321951
if rfx_group_ids is None:
@@ -1951,7 +1970,15 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19511970
)
19521971

19531972
# Compute posterior matrices for the requested model terms
1954-
predictions = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale)
1973+
predictions = self.predict(
1974+
covariates=covariates,
1975+
basis=basis,
1976+
rfx_group_ids=rfx_group_ids,
1977+
rfx_basis=rfx_basis,
1978+
type="posterior",
1979+
terms=terms,
1980+
scale=scale,
1981+
)
19551982
has_multiple_terms = True if isinstance(predictions, dict) else False
19561983

19571984
# Compute posterior intervals
@@ -1964,11 +1991,16 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19641991
)
19651992
return result
19661993
else:
1967-
return _summarize_interval(
1968-
predictions, 1, level=level
1969-
)
1970-
1971-
def sample_posterior_predictive(self, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None) -> np.array:
1994+
return _summarize_interval(predictions, 1, level=level)
1995+
1996+
def sample_posterior_predictive(
1997+
self,
1998+
covariates: np.array = None,
1999+
basis: np.array = None,
2000+
rfx_group_ids: np.array = None,
2001+
rfx_basis: np.array = None,
2002+
num_draws_per_sample: int = None,
2003+
) -> np.array:
19722004
"""
19732005
Sample from the posterior predictive distribution for outcomes modeled by BART
19742006
@@ -1984,7 +2016,7 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr
19842016
An array of basis function evaluations for random effects. Required if the BART model includes random effects.
19852017
num_draws_per_sample : int, optional
19862018
The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
1987-
2019+
19882020
Returns
19892021
-------
19902022
np.array
@@ -2044,58 +2076,61 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr
20442076
)
20452077

20462078
# Compute posterior predictive samples
2047-
bart_preds = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms="all")
2079+
bart_preds = self.predict(
2080+
covariates=covariates,
2081+
basis=basis,
2082+
rfx_group_ids=rfx_group_ids,
2083+
rfx_basis=rfx_basis,
2084+
type="posterior",
2085+
terms="all",
2086+
)
20482087

20492088
# Compute outcome mean and variance for posterior predictive distribution
2050-
has_mean_term = (self.include_mean_forest or self.has_rfx)
2089+
has_mean_term = self.include_mean_forest or self.has_rfx
20512090
has_variance_forest = self.include_variance_forest
20522091
samples_global_variance = self.sample_sigma2_global
20532092
num_posterior_draws = self.num_samples
20542093
num_observations = covariates.shape[0]
20552094
if has_mean_term:
20562095
ppd_mean = bart_preds["y_hat"]
20572096
else:
2058-
ppd_mean = 0.
2097+
ppd_mean = 0.0
20592098
if has_variance_forest:
20602099
ppd_variance = bart_preds["variance_forest_predictions"]
20612100
else:
20622101
if samples_global_variance:
2063-
ppd_variance = np.tile(
2064-
self.global_var_samples,
2065-
(num_observations, 1)
2066-
)
2102+
ppd_variance = np.tile(self.global_var_samples, (num_observations, 1))
20672103
else:
20682104
ppd_variance = self.sigma2_init
2069-
2105+
20702106
# Sample from the posterior predictive distribution
20712107
if num_draws_per_sample is None:
20722108
ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier(
2073-
num_posterior_draws,
2074-
num_observations
2109+
num_posterior_draws, num_observations
20752110
)
20762111
else:
20772112
ppd_draw_multiplier = num_draws_per_sample
20782113
if ppd_draw_multiplier > 1:
20792114
ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1))
20802115
ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1))
20812116
ppd_array = np.random.normal(
2082-
loc = ppd_mean,
2083-
scale = np.sqrt(ppd_variance),
2084-
size = (ppd_draw_multiplier, num_observations, num_posterior_draws)
2117+
loc=ppd_mean,
2118+
scale=np.sqrt(ppd_variance),
2119+
size=(ppd_draw_multiplier, num_observations, num_posterior_draws),
20852120
)
20862121
else:
20872122
ppd_array = np.random.normal(
2088-
loc = ppd_mean,
2089-
scale = np.sqrt(ppd_variance),
2090-
size = (num_observations, num_posterior_draws)
2123+
loc=ppd_mean,
2124+
scale=np.sqrt(ppd_variance),
2125+
size=(num_observations, num_posterior_draws),
20912126
)
2092-
2127+
20932128
# Binarize outcome for probit models
20942129
if is_probit:
20952130
ppd_array = (ppd_array > 0.0) * 1
2096-
2131+
20972132
return ppd_array
2098-
2133+
20992134
def to_json(self) -> str:
21002135
"""
21012136
Converts a sampled BART model to JSON string representation (which can then be saved to a file or
@@ -2381,7 +2416,7 @@ def is_sampled(self) -> bool:
23812416
`True` if a BART model has been sampled, `False` otherwise
23822417
"""
23832418
return self.sampled
2384-
2419+
23852420
def has_term(self, term: str) -> bool:
23862421
"""
23872422
Whether or not a model includes a term.
@@ -2390,7 +2425,7 @@ def has_term(self, term: str) -> bool:
23902425
----------
23912426
term : str
23922427
Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`.
2393-
2428+
23942429
Returns
23952430
-------
23962431
bool

0 commit comments

Comments
 (0)