Skip to content

Commit 4f1beed

Browse files
committed
Added posterior predictive sampling methods to BART and BCF in python
1 parent df746dc commit 4f1beed

File tree

5 files changed

+357
-25
lines changed

5 files changed

+357
-25
lines changed

R/posterior_transformation.R

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#' Sample from the posterior predictive distribution for outcomes modeled by BCF
22
#'
33
#' @param model_object A fitted BCF model object of class `bcfmodel`.
4-
#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
5-
#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
6-
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities.
4+
#' @param covariates A matrix or data frame of covariates.
5+
#' @param treatment A vector or matrix of treatment assignments.
6+
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
77
#' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.
88
#' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.
9-
#' @param num_draws (Optional) The number of samples to draw from the likelihood, for each draw of the posterior, in computing intervals. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).
9+
#' @param num_draws_per_sample (Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).
1010
#'
11-
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples).
11+
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
1212
#'
1313
#' @export
1414
#' @examples
@@ -30,9 +30,9 @@ sample_bcf_posterior_predictive <- function(
3030
propensity = NULL,
3131
rfx_group_ids = NULL,
3232
rfx_basis = NULL,
33-
num_draws = NULL
33+
num_draws_per_sample = NULL
3434
) {
35-
# Check the provided model object and requested term
35+
# Check the provided model object
3636
check_model_is_valid(model_object)
3737

3838
# Determine whether the outcome is continuous (Gaussian) or binary (probit-link)
@@ -123,7 +123,7 @@ sample_bcf_posterior_predictive <- function(
123123
}
124124
}
125125

126-
# Compute posterior predictive samples
126+
# Compute posterior samples
127127
bcf_preds <- predict(
128128
model_object,
129129
X = covariates,
@@ -132,8 +132,11 @@ sample_bcf_posterior_predictive <- function(
132132
rfx_group_ids = rfx_group_ids,
133133
rfx_basis = rfx_basis,
134134
type = "posterior",
135-
terms = c("all")
135+
terms = c("all"),
136+
scale = "linear"
136137
)
138+
139+
# Compute outcome mean and variance for every posterior draw
137140
has_rfx <- model_object$model_params$has_rfx
138141
has_variance_forest <- model_object$model_params$include_variance_forest
139142
samples_global_variance <- model_object$model_params$sample_sigma2_global
@@ -155,16 +158,20 @@ sample_bcf_posterior_predictive <- function(
155158
ppd_variance <- model_object$model_params$initial_sigma2
156159
}
157160
}
158-
if (is.null(num_draws)) {
161+
162+
# Sample from the posterior predictive distribution
163+
if (is.null(num_draws_per_sample)) {
159164
ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier(
160165
num_posterior_draws,
161166
num_observations
162167
)
163168
} else {
164-
ppd_draw_multiplier <- num_draws
169+
ppd_draw_multiplier <- num_draws_per_sample
165170
}
166171
num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations
167172
ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance))
173+
174+
# Reshape data
168175
if (ppd_draw_multiplier > 1) {
169176
ppd_array <- array(
170177
ppd_vector,
@@ -177,6 +184,7 @@ sample_bcf_posterior_predictive <- function(
177184
)
178185
}
179186

187+
# Binarize outcomes for probit models
180188
if (is_probit) {
181189
ppd_array <- (ppd_array > 0.0) * 1
182190
}
@@ -187,13 +195,13 @@ sample_bcf_posterior_predictive <- function(
187195
#' Sample from the posterior predictive distribution for outcomes modeled by BART
188196
#'
189197
#' @param model_object A fitted BART model object of class `bartmodel`.
190-
#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
198+
#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
191199
#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
192200
#' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects.
193201
#' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects.
194-
#' @param num_draws The number of posterior predictive samples to draw in computing intervals. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
202+
#' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
195203
#'
196-
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples).
204+
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
197205
#'
198206
#' @export
199207
#' @examples
@@ -211,9 +219,9 @@ sample_bart_posterior_predictive <- function(
211219
basis = NULL,
212220
rfx_group_ids = NULL,
213221
rfx_basis = NULL,
214-
num_draws = NULL
222+
num_draws_per_sample = NULL
215223
) {
216-
# Check the provided model object and requested term
224+
# Check the provided model object
217225
check_model_is_valid(model_object)
218226

219227
# Determine whether the outcome is continuous (Gaussian) or binary (probit-link)
@@ -276,16 +284,19 @@ sample_bart_posterior_predictive <- function(
276284
}
277285
}
278286

279-
# Compute posterior predictive samples
287+
# Compute posterior samples
280288
bart_preds <- predict(
281289
model_object,
282290
covariates = covariates,
283291
leaf_basis = basis,
284292
rfx_group_ids = rfx_group_ids,
285293
rfx_basis = rfx_basis,
286294
type = "posterior",
287-
terms = c("all")
295+
terms = c("all"),
296+
scale = "linear"
288297
)
298+
299+
# Compute outcome mean and variance for every posterior draw
289300
has_mean_term <- (model_object$model_params$include_mean_forest ||
290301
model_object$model_params$has_rfx)
291302
has_variance_forest <- model_object$model_params$include_variance_forest
@@ -312,16 +323,20 @@ sample_bart_posterior_predictive <- function(
312323
ppd_variance <- model_object$model_params$sigma2_init
313324
}
314325
}
315-
if (is.null(num_draws)) {
326+
327+
# Sample from the posterior predictive distribution
328+
if (is.null(num_draws_per_sample)) {
316329
ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier(
317330
num_posterior_draws,
318331
num_observations
319332
)
320333
} else {
321-
ppd_draw_multiplier <- num_draws
334+
ppd_draw_multiplier <- num_draws_per_sample
322335
}
323336
num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations
324337
ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance))
338+
339+
# Reshape data
325340
if (ppd_draw_multiplier > 1) {
326341
ppd_array <- array(
327342
ppd_vector,
@@ -334,6 +349,7 @@ sample_bart_posterior_predictive <- function(
334349
)
335350
}
336351

352+
# Binarize outcomes for probit models
337353
if (is_probit) {
338354
ppd_array <- (ppd_array > 0.0) * 1
339355
}

demo/debug/bart_predict_debug.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,25 @@
8080
(intervals["y_hat"]["lower"] <= f_X_test) & (f_X_test <= intervals["y_hat"]["upper"])
8181
)
8282
print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}")
83+
84+
# Sample from the posterior predictive distribution
85+
bart_ppd_samples = bart_model.sample_posterior_predictive(
86+
covariates = X_test, num_draws_per_sample = 10
87+
)
88+
89+
# Plot PPD mean vs actual
90+
ppd_mean = np.mean(bart_ppd_samples, axis=(0, 2))
91+
plt.clf()
92+
plt.scatter(ppd_mean, y_test, color="blue")
93+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
94+
plt.xlabel("Predicted")
95+
plt.ylabel("Actual")
96+
plt.title("Posterior Predictive Mean Comparison")
97+
plt.show()
98+
99+
# Check coverage of posterior predictive distribution
100+
ppd_intervals = np.percentile(bart_ppd_samples, [2.5, 97.5], axis=(0, 2))
101+
ppd_coverage = np.mean(
102+
(ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :])
103+
)
104+
print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}")

demo/debug/bcf_predict_debug.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,25 @@
116116
(intervals["mu_hat"]["lower"] <= mu_test) & (mu_test <= intervals["mu_hat"]["upper"])
117117
)
118118
print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}")
119+
120+
# Sample from the posterior predictive distribution
121+
bcf_ppd_samples = bcf_model.sample_posterior_predictive(
122+
covariates = X_test, treatment = Z_test, propensity = pi_test, num_draws_per_sample = 10
123+
)
124+
125+
# Plot PPD mean vs actual
126+
ppd_mean = np.mean(bcf_ppd_samples, axis=(0, 2))
127+
plt.clf()
128+
plt.scatter(ppd_mean, y_test, color="blue")
129+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
130+
plt.xlabel("Predicted")
131+
plt.ylabel("Actual")
132+
plt.title("Posterior Predictive Mean Comparison")
133+
plt.show()
134+
135+
# Check coverage of posterior predictive distribution
136+
ppd_intervals = np.percentile(bcf_ppd_samples, [2.5, 97.5], axis=(0, 2))
137+
ppd_coverage = np.mean(
138+
(ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :])
139+
)
140+
print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}")

stochtree/bart.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,10 +1884,12 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
18841884
dict
18851885
A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned.
18861886
"""
1887-
# Check the provided model object and requested term
1888-
self.is_sampled()
1887+
# Check the provided model object and requested terms
1888+
if not self.is_sampled():
1889+
raise ValueError("Model has not yet been sampled")
18891890
for term in terms:
1890-
self.has_term(term)
1891+
if not self.has_term(term):
1892+
warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.")
18911893

18921894
# Handle mean function scale
18931895
if not isinstance(scale, str):
@@ -1966,6 +1968,134 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19661968
predictions, 1, level=level
19671969
)
19681970

1971+
def sample_posterior_predictive(self, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None) -> np.array:
1972+
"""
1973+
Sample from the posterior predictive distribution for outcomes modeled by BART
1974+
1975+
Parameters
1976+
----------
1977+
covariates : np.array, optional
1978+
An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
1979+
basis : np.array, optional
1980+
An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
1981+
rfx_group_ids : np.array, optional
1982+
An array of group IDs for random effects. Required if the BART model includes random effects.
1983+
rfx_basis : np.array, optional
1984+
An array of basis function evaluations for random effects. Required if the BART model includes random effects.
1985+
num_draws_per_sample : int, optional
1986+
The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
1987+
1988+
Returns
1989+
-------
1990+
np.array
1991+
A matrix of posterior predictive samples. If `num_draws = 1`.
1992+
"""
1993+
# Check the provided model object
1994+
if not self.is_sampled():
1995+
raise ValueError("Model has not yet been sampled")
1996+
1997+
# Determine whether the outcome is continuous (Gaussian) or binary (probit-link)
1998+
is_probit = self.probit_outcome_model
1999+
2000+
# Check that all the necessary inputs were provided for interval computation
2001+
needs_covariates = self.include_mean_forest
2002+
if needs_covariates:
2003+
if covariates is None:
2004+
raise ValueError(
2005+
"'covariates' must be provided in order to compute the requested intervals"
2006+
)
2007+
if not isinstance(covariates, np.ndarray) and not isinstance(
2008+
covariates, pd.DataFrame
2009+
):
2010+
raise ValueError("'covariates' must be a matrix or data frame")
2011+
needs_basis = needs_covariates and self.has_basis
2012+
if needs_basis:
2013+
if basis is None:
2014+
raise ValueError(
2015+
"'basis' must be provided in order to compute the requested intervals"
2016+
)
2017+
if not isinstance(basis, np.ndarray):
2018+
raise ValueError("'basis' must be a numpy array")
2019+
if basis.shape[0] != covariates.shape[0]:
2020+
raise ValueError(
2021+
"'basis' must have the same number of rows as 'covariates'"
2022+
)
2023+
needs_rfx_data = self.has_rfx
2024+
if needs_rfx_data:
2025+
if rfx_group_ids is None:
2026+
raise ValueError(
2027+
"'rfx_group_ids' must be provided in order to compute the requested intervals"
2028+
)
2029+
if not isinstance(rfx_group_ids, np.ndarray):
2030+
raise ValueError("'rfx_group_ids' must be a numpy array")
2031+
if rfx_group_ids.shape[0] != covariates.shape[0]:
2032+
raise ValueError(
2033+
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
2034+
)
2035+
if rfx_basis is None:
2036+
raise ValueError(
2037+
"'rfx_basis' must be provided in order to compute the requested intervals"
2038+
)
2039+
if not isinstance(rfx_basis, np.ndarray):
2040+
raise ValueError("'rfx_basis' must be a numpy array")
2041+
if rfx_basis.shape[0] != covariates.shape[0]:
2042+
raise ValueError(
2043+
"'rfx_basis' must have the same number of rows as 'covariates'"
2044+
)
2045+
2046+
# Compute posterior predictive samples
2047+
bart_preds = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms="all")
2048+
2049+
# Compute outcome mean and variance for posterior predictive distribution
2050+
has_mean_term = (self.include_mean_forest or self.has_rfx)
2051+
has_variance_forest = self.include_variance_forest
2052+
samples_global_variance = self.sample_sigma2_global
2053+
num_posterior_draws = self.num_samples
2054+
num_observations = covariates.shape[0]
2055+
if has_mean_term:
2056+
ppd_mean = bart_preds["y_hat"]
2057+
else:
2058+
ppd_mean = 0.
2059+
if has_variance_forest:
2060+
ppd_variance = bart_preds["variance_forest_predictions"]
2061+
else:
2062+
if samples_global_variance:
2063+
ppd_variance = np.tile(
2064+
self.global_var_samples,
2065+
(num_observations, 1)
2066+
)
2067+
else:
2068+
ppd_variance = self.sigma2_init
2069+
2070+
# Sample from the posterior predictive distribution
2071+
if num_draws_per_sample is None:
2072+
ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier(
2073+
num_posterior_draws,
2074+
num_observations
2075+
)
2076+
else:
2077+
ppd_draw_multiplier = num_draws_per_sample
2078+
if ppd_draw_multiplier > 1:
2079+
ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1))
2080+
ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1))
2081+
ppd_array = np.random.normal(
2082+
loc = ppd_mean,
2083+
scale = np.sqrt(ppd_variance),
2084+
size = (ppd_draw_multiplier, num_observations, num_posterior_draws)
2085+
)
2086+
else:
2087+
ppd_array = np.random.normal(
2088+
loc = ppd_mean,
2089+
scale = np.sqrt(ppd_variance),
2090+
size = (num_observations, num_posterior_draws)
2091+
)
2092+
2093+
# Binarize outcome for probit models
2094+
if is_probit:
2095+
ppd_array = (ppd_array > 0.0) * 1
2096+
2097+
return ppd_array
2098+
19692099
def to_json(self) -> str:
19702100
"""
19712101
Converts a sampled BART model to JSON string representation (which can then be saved to a file or

0 commit comments

Comments
 (0)