Skip to content

Commit 8f51e2b

Browse files
committed
Updated python prediction and interval methods
1 parent 169e3fa commit 8f51e2b

File tree

7 files changed

+461
-20
lines changed

7 files changed

+461
-20
lines changed

R/posterior_transformation.R

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ posterior_predictive_heuristic_multiplier <- function(
361361
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
362362
#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
363363
#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
364-
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities.
364+
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
365365
#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects.
366366
#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
367367
#'
@@ -417,7 +417,6 @@ compute_bcf_posterior_interval <- function(
417417
"scale cannot be 'probability' for models not fit with a probit outcome model"
418418
)
419419
}
420-
probability_scale <- scale == "probability"
421420

422421
# Check that all the necessary inputs were provided for interval computation
423422
needs_covariates_intermediate <- ((("y_hat" %in% terms) ||
@@ -547,9 +546,7 @@ compute_bcf_posterior_interval <- function(
547546
}
548547
}
549548

550-
#' Compute posterior credible intervals for BART model terms
551-
#'
552-
#' This function computes posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
549+
#' Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
553550
#' @param model_object A fitted BART or BCF model object of class `bartmodel`.
554551
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`.
555552
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
@@ -604,7 +601,6 @@ compute_bart_posterior_interval <- function(
604601
"scale cannot be 'probability' for models not fit with a probit outcome model"
605602
)
606603
}
607-
probability_scale <- scale == "probability"
608604

609605
# Check that all the necessary inputs were provided for interval computation
610606
needs_covariates_intermediate <- ((("y_hat" %in% terms) ||

demo/debug/bart_predict_debug.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# Generate data
1010
rng = np.random.default_rng()
11-
n = 100
11+
n = 500
1212
p = 5
1313
X = rng.uniform(low=0.0, high=1.0, size=(n, p))
1414
f_X = np.where(
@@ -42,7 +42,7 @@
4242
X_test=X_test,
4343
num_gfr=10,
4444
num_burnin=0,
45-
num_mcmc=10,
45+
num_mcmc=1000,
4646
)
4747

4848
# # Check several predict approaches
@@ -66,3 +66,17 @@
6666
plt.ylabel("Actual")
6767
plt.title("Y hat")
6868
plt.show()
69+
70+
# Compute posterior interval
71+
intervals = bart_model.compute_posterior_interval(
72+
terms = "all",
73+
scale = "linear",
74+
level = 0.95,
75+
covariates = X_test
76+
)
77+
78+
# Check coverage
79+
mean_coverage = np.mean(
80+
(intervals["y_hat"]["lower"] <= f_X_test) & (f_X_test <= intervals["y_hat"]["upper"])
81+
)
82+
print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}")

demo/debug/bcf_predict_debug.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Demo of updated predict method for BART
2+
3+
# Load library
4+
from stochtree import BCFModel
5+
import numpy as np
6+
from sklearn.model_selection import train_test_split
7+
from scipy.stats import norm
8+
import matplotlib.pyplot as plt
9+
10+
# Generate data
11+
rng = np.random.default_rng()
12+
n = 1000
13+
p = 5
14+
X = rng.normal(loc=0.0, scale=1.0, size=(n, p))
15+
mu_X = X[:,0]
16+
tau_X = 0.25 * X[:,1]
17+
pi_X = norm.cdf(0.5 * X[:,1])
18+
Z = rng.binomial(n=1, p=pi_X, size=(n,))
19+
E_XZ = mu_X + tau_X * Z
20+
snr = 2.0
21+
noise_sd = np.std(E_XZ) / snr
22+
y = E_XZ + rng.normal(loc=0.0, scale=noise_sd, size=(n,))
23+
24+
# Train-test split
25+
sample_inds = np.arange(n)
26+
test_set_pct = 0.2
27+
train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct)
28+
X_train = X[train_inds, :]
29+
X_test = X[test_inds, :]
30+
Z_train = Z[train_inds]
31+
Z_test = Z[test_inds]
32+
pi_train = pi_X[train_inds]
33+
pi_test = pi_X[test_inds]
34+
tau_train = tau_X[train_inds]
35+
tau_test = tau_X[test_inds]
36+
mu_train = mu_X[train_inds]
37+
mu_test = mu_X[test_inds]
38+
y_train = y[train_inds]
39+
y_test = y[test_inds]
40+
E_XZ_train = E_XZ[train_inds]
41+
E_XZ_test = E_XZ[test_inds]
42+
43+
# Fit simple BCF model
44+
bcf_model = BCFModel()
45+
bcf_model.sample(
46+
X_train=X_train,
47+
Z_train=Z_train,
48+
pi_train=pi_train,
49+
y_train=y_train,
50+
num_gfr=10,
51+
num_burnin=0,
52+
num_mcmc=1000,
53+
)
54+
55+
# Check several predict approaches
56+
bcf_preds = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)
57+
y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)['y_hat']
58+
y_hat_mean_test = bcf_model.predict(
59+
X=X_test, Z=Z_test, propensity=pi_test,
60+
type = "mean",
61+
terms = ["y_hat"]
62+
)
63+
tau_hat_mean_test = bcf_model.predict(
64+
X=X_test, Z=Z_test, propensity=pi_test,
65+
type = "mean",
66+
terms = ["cate"]
67+
)
68+
# Check that this raises a warning
69+
y_hat_test = bcf_model.predict(
70+
X=X_test, Z=Z_test, propensity=pi_test,
71+
type = "mean",
72+
terms = ["rfx", "variance"]
73+
)
74+
75+
# Plot predicted versus actual
76+
plt.scatter(y_hat_mean_test, y_test, color="black")
77+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
78+
plt.xlabel("Predicted")
79+
plt.ylabel("Actual")
80+
plt.title("Y hat")
81+
plt.show()
82+
83+
# Plot predicted versus actual
84+
plt.clf()
85+
plt.scatter(tau_hat_mean_test, tau_test, color="black")
86+
plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3)))
87+
plt.xlabel("Predicted")
88+
plt.ylabel("Actual")
89+
plt.title("CATE function")
90+
plt.show()
91+
92+
# Compute posterior interval
93+
intervals = bcf_model.compute_posterior_interval(
94+
terms = "all",
95+
scale = "linear",
96+
level = 0.95,
97+
covariates = X_test,
98+
treatment = Z_test,
99+
propensity = pi_test
100+
)
101+
102+
# Check coverage of E[Y | X, Z]
103+
mean_coverage = np.mean(
104+
(intervals["y_hat"]["lower"] <= E_XZ_test) & (E_XZ_test <= intervals["y_hat"]["upper"])
105+
)
106+
print(f"Coverage of 95% posterior interval for E[Y|X,Z]: {mean_coverage:.3f}")
107+
108+
# Check coverage of tau(X)
109+
tau_coverage = np.mean(
110+
(intervals["tau_hat"]["lower"] <= tau_test) & (tau_test <= intervals["tau_hat"]["upper"])
111+
)
112+
print(f"Coverage of 95% posterior interval for tau(X): {tau_coverage:.3f}")
113+
114+
# Check coverage of mu(X)
115+
mu_coverage = np.mean(
116+
(intervals["mu_hat"]["lower"] <= mu_test) & (mu_test <= intervals["mu_hat"]["upper"])
117+
)
118+
print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}")

stochtree/bart.py

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
"""
2-
Bayesian Additive Regression Trees (BART) module
3-
"""
4-
51
import warnings
62
from math import log
73
from numbers import Integral
@@ -28,6 +24,8 @@
2824
_expand_dims_1d,
2925
_expand_dims_2d,
3026
_expand_dims_2d_diag,
27+
_posterior_predictive_heuristic_multiplier,
28+
_summarize_interval
3129
)
3230

3331

@@ -1860,6 +1858,114 @@ def predict(
18601858
result["variance_forest_predictions"] = None
18611859
return result
18621860

1861+
def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict:
1862+
"""
1863+
Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
1864+
1865+
Parameters
1866+
----------
1867+
terms : str, optional
1868+
Character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Defaults to `"all"`.
1869+
scale : str, optional
1870+
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`.
1871+
level : float, optional
1872+
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval.
1873+
covariates : np.array, optional
1874+
Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).
1875+
basis : np.array, optional
1876+
Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
1877+
rfx_group_ids : np.array, optional
1878+
Optional vector of group IDs for random effects. Required if the requested term includes random effects.
1879+
rfx_basis : np.array, optional
1880+
Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
1881+
1882+
Returns
1883+
-------
1884+
dict
1885+
A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned.
1886+
"""
1887+
# Check the provided model object and requested term
1888+
self.is_sampled()
1889+
for term in terms:
1890+
self.has_term(term)
1891+
1892+
# Handle mean function scale
1893+
if not isinstance(scale, str):
1894+
raise ValueError("scale must be a string")
1895+
if scale not in ["linear", "probability"]:
1896+
raise ValueError("scale must either be 'linear' or 'probability'")
1897+
is_probit = self.probit_outcome_model
1898+
if (scale == "probability") and (not is_probit):
1899+
raise ValueError(
1900+
"scale cannot be 'probability' for models not fit with a probit outcome model"
1901+
)
1902+
1903+
# Check that all the necessary inputs were provided for interval computation
1904+
needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.include_mean_forest
1905+
needs_covariates = ("mean_forest" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate
1906+
if needs_covariates:
1907+
if covariates is None:
1908+
raise ValueError(
1909+
"'covariates' must be provided in order to compute the requested intervals"
1910+
)
1911+
if not isinstance(covariates, np.ndarray) and not isinstance(
1912+
covariates, pd.DataFrame
1913+
):
1914+
raise ValueError("'covariates' must be a matrix or data frame")
1915+
needs_basis = needs_covariates and self.has_basis
1916+
if needs_basis:
1917+
if basis is None:
1918+
raise ValueError(
1919+
"'basis' must be provided in order to compute the requested intervals"
1920+
)
1921+
if not isinstance(basis, np.ndarray):
1922+
raise ValueError("'basis' must be a numpy array")
1923+
if basis.shape[0] != covariates.shape[0]:
1924+
raise ValueError(
1925+
"'basis' must have the same number of rows as 'covariates'"
1926+
)
1927+
needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx
1928+
needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate
1929+
if needs_rfx_data:
1930+
if rfx_group_ids is None:
1931+
raise ValueError(
1932+
"'rfx_group_ids' must be provided in order to compute the requested intervals"
1933+
)
1934+
if not isinstance(rfx_group_ids, np.ndarray):
1935+
raise ValueError("'rfx_group_ids' must be a numpy array")
1936+
if rfx_group_ids.shape[0] != covariates.shape[0]:
1937+
raise ValueError(
1938+
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
1939+
)
1940+
if rfx_basis is None:
1941+
raise ValueError(
1942+
"'rfx_basis' must be provided in order to compute the requested intervals"
1943+
)
1944+
if not isinstance(rfx_basis, np.ndarray):
1945+
raise ValueError("'rfx_basis' must be a numpy array")
1946+
if rfx_basis.shape[0] != covariates.shape[0]:
1947+
raise ValueError(
1948+
"'rfx_basis' must have the same number of rows as 'covariates'"
1949+
)
1950+
1951+
# Compute posterior matrices for the requested model terms
1952+
predictions = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale)
1953+
has_multiple_terms = True if isinstance(predictions, dict) else False
1954+
1955+
# Compute posterior intervals
1956+
if has_multiple_terms:
1957+
result = dict()
1958+
for term in predictions.keys():
1959+
if predictions[term] is not None:
1960+
result[term] = _summarize_interval(
1961+
predictions[term], 1, level=level
1962+
)
1963+
return result
1964+
else:
1965+
return _summarize_interval(
1966+
predictions, 1, level=level
1967+
)
1968+
18631969
def to_json(self) -> str:
18641970
"""
18651971
Converts a sampled BART model to JSON string representation (which can then be saved to a file or
@@ -2145,3 +2251,30 @@ def is_sampled(self) -> bool:
21452251
`True` if a BART model has been sampled, `False` otherwise
21462252
"""
21472253
return self.sampled
2254+
2255+
def has_term(self, term: str) -> bool:
2256+
"""
2257+
Whether or not a model includes a term.
2258+
2259+
Parameters
2260+
----------
2261+
term : str
2262+
Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`.
2263+
2264+
Returns
2265+
-------
2266+
bool
2267+
`True` if the model includes the specified term, `False` otherwise
2268+
"""
2269+
if term == "mean_forest":
2270+
return self.include_mean_forest
2271+
elif term == "variance_forest":
2272+
return self.include_variance_forest
2273+
elif term == "rfx":
2274+
return self.has_rfx
2275+
elif term == "y_hat":
2276+
return self.include_mean_forest or self.has_rfx
2277+
elif term == "all":
2278+
return True
2279+
else:
2280+
return False

0 commit comments

Comments
 (0)