You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: R/posterior_transformation.R
+36-20Lines changed: 36 additions & 20 deletions
Original file line number
Diff line number
Diff line change
@@ -1,14 +1,14 @@
1
1
#' Sample from the posterior predictive distribution for outcomes modeled by BCF
2
2
#'
3
3
#' @param model_object A fitted BCF model object of class `bcfmodel`.
4
-
#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
5
-
#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
6
-
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities.
4
+
#' @param covariates A matrix or data frame of covariates.
5
+
#' @param treatment A vector or matrix of treatment assignments.
6
+
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
7
7
#' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.
8
8
#' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.
9
-
#' @param num_draws (Optional) The number of samples to draw from the likelihood, for each draw of the posterior, in computing intervals. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).
9
+
#' @param num_draws_per_sample (Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).
10
10
#'
11
-
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples).
11
+
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
#' Sample from the posterior predictive distribution for outcomes modeled by BART
188
196
#'
189
197
#' @param model_object A fitted BART model object of class `bartmodel`.
190
-
#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
198
+
#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
191
199
#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
192
200
#' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects.
193
201
#' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects.
194
-
#' @param num_draws The number of posterior predictive samples to draw in computing intervals. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
202
+
#' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
195
203
#'
196
-
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples).
204
+
#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned.
1886
1886
"""
1887
-
# Check the provided model object and requested term
1888
-
self.is_sampled()
1887
+
# Check the provided model object and requested terms
1888
+
ifnotself.is_sampled():
1889
+
raiseValueError("Model has not yet been sampled")
1889
1890
forterminterms:
1890
-
self.has_term(term)
1891
+
ifnotself.has_term(term):
1892
+
warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.")
Sample from the posterior predictive distribution for outcomes modeled by BART
1974
+
1975
+
Parameters
1976
+
----------
1977
+
covariates : np.array, optional
1978
+
An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
1979
+
basis : np.array, optional
1980
+
An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
1981
+
rfx_group_ids : np.array, optional
1982
+
An array of group IDs for random effects. Required if the BART model includes random effects.
1983
+
rfx_basis : np.array, optional
1984
+
An array of basis function evaluations for random effects. Required if the BART model includes random effects.
1985
+
num_draws_per_sample : int, optional
1986
+
The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
1987
+
1988
+
Returns
1989
+
-------
1990
+
np.array
1991
+
A matrix of posterior predictive samples. If `num_draws = 1`.
1992
+
"""
1993
+
# Check the provided model object
1994
+
ifnotself.is_sampled():
1995
+
raiseValueError("Model has not yet been sampled")
1996
+
1997
+
# Determine whether the outcome is continuous (Gaussian) or binary (probit-link)
1998
+
is_probit=self.probit_outcome_model
1999
+
2000
+
# Check that all the necessary inputs were provided for interval computation
2001
+
needs_covariates=self.include_mean_forest
2002
+
ifneeds_covariates:
2003
+
ifcovariatesisNone:
2004
+
raiseValueError(
2005
+
"'covariates' must be provided in order to compute the requested intervals"
0 commit comments