Skip to content

Commit c0cbdeb

Browse files
committed
Updated BART predict method in python
1 parent 4601dc3 commit c0cbdeb

File tree

1 file changed

+70
-17
lines changed

1 file changed

+70
-17
lines changed

stochtree/bart.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,8 @@ def predict(
15791579
rfx_group_ids: np.array = None,
15801580
rfx_basis: np.array = None,
15811581
type: str = "posterior",
1582-
terms: Union[list[str], str] = "all"
1582+
terms: Union[list[str], str] = "all",
1583+
scale: str = "linear"
15831584
) -> Union[np.array, tuple]:
15841585
"""Return predictions from every forest sampled (either / both of mean and variance).
15851586
Return type is either a single array of predictions, if a BART model only includes a
@@ -1599,11 +1600,25 @@ def predict(
15991600
Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
16001601
terms : str, optional
16011602
Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
1603+
scale : str, optional
1604+
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
16021605
16031606
Returns
16041607
-------
16051608
Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested.
16061609
"""
1610+
# Handle mean function scale
1611+
if not isinstance(scale, str):
1612+
raise ValueError("scale must be a string")
1613+
if scale not in ["linear", "probability"]:
1614+
raise ValueError("scale must either be 'linear' or 'probability'")
1615+
is_probit = self.probit_outcome_model
1616+
if (scale == "probability") and (not is_probit):
1617+
raise ValueError(
1618+
"scale cannot be 'probability' for models not fit with a probit outcome model"
1619+
)
1620+
probability_scale = scale == "probability"
1621+
16071622
# Handle prediction type
16081623
if not isinstance(type, str):
16091624
raise ValueError("type must be a string")
@@ -1635,6 +1650,13 @@ def predict(
16351650
predict_rfx_intermediate = predict_y_hat and has_rfx
16361651
predict_mean_forest_intermediate = predict_y_hat and has_mean_forest
16371652

1653+
# Check that we have at least one term to predict on probability scale
1654+
if (probability_scale and not predict_y_hat and not predict_mean_forest and not predict_rfx):
1655+
raise ValueError(
1656+
"scale can only be 'probability' if at least one mean term is requested"
1657+
)
1658+
1659+
# Check the model is valid
16381660
if not self.is_sampled():
16391661
msg = (
16401662
"This BARTModel instance is not fitted yet. Call 'fit' with "
@@ -1690,22 +1712,7 @@ def predict(
16901712
if basis is not None:
16911713
pred_dataset.add_basis(basis)
16921714

1693-
# Forest predictions
1694-
if predict_mean_forest or predict_mean_forest_intermediate:
1695-
mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(
1696-
pred_dataset.dataset_cpp
1697-
)
1698-
mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar
1699-
if predict_mean:
1700-
mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1)
1701-
1702-
if predict_rfx or predict_rfx_intermediate:
1703-
rfx_predictions = (
1704-
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
1705-
)
1706-
if predict_mean:
1707-
rfx_predictions = np.mean(rfx_predictions, axis = 1)
1708-
1715+
# Variance forest predictions
17091716
if predict_variance_forest:
17101717
variance_pred_raw = (
17111718
self.forest_container_variance.forest_container_cpp.Predict(
@@ -1725,13 +1732,59 @@ def predict(
17251732
if predict_mean:
17261733
variance_forest_predictions = np.mean(variance_forest_predictions, axis = 1)
17271734

1735+
# Forest predictions
1736+
if predict_mean_forest or predict_mean_forest_intermediate:
1737+
mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(
1738+
pred_dataset.dataset_cpp
1739+
)
1740+
mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar
1741+
# if predict_mean:
1742+
# mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1)
1743+
1744+
# Random effects predictions
1745+
if predict_rfx or predict_rfx_intermediate:
1746+
rfx_predictions = (
1747+
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
1748+
)
1749+
# if predict_mean:
1750+
# rfx_predictions = np.mean(rfx_predictions, axis = 1)
1751+
1752+
# Combine into y hat predictions
17281753
if predict_y_hat and has_mean_forest and has_rfx:
17291754
y_hat = mean_forest_predictions + rfx_predictions
17301755
elif predict_y_hat and has_mean_forest:
17311756
y_hat = mean_forest_predictions
17321757
elif predict_y_hat and has_rfx:
17331758
y_hat = rfx_predictions
17341759

1760+
if probability_scale:
1761+
if predict_y_hat and has_mean_forest and has_rfx:
1762+
y_hat = norm.ppf(mean_forest_predictions + rfx_predictions)
1763+
mean_forest_predictions = norm.ppf(mean_forest_predictions)
1764+
rfx_predictions = norm.ppf(rfx_predictions)
1765+
elif predict_y_hat and has_mean_forest:
1766+
y_hat = norm.ppf(mean_forest_predictions)
1767+
mean_forest_predictions = norm.ppf(mean_forest_predictions)
1768+
elif predict_y_hat and has_rfx:
1769+
y_hat = norm.ppf(rfx_predictions)
1770+
rfx_predictions = norm.ppf(rfx_predictions)
1771+
else:
1772+
if predict_y_hat and has_mean_forest and has_rfx:
1773+
y_hat = mean_forest_predictions + rfx_predictions
1774+
elif predict_y_hat and has_mean_forest:
1775+
y_hat = mean_forest_predictions
1776+
elif predict_y_hat and has_rfx:
1777+
y_hat = rfx_predictions
1778+
1779+
# Collapse to posterior mean predictions if requested
1780+
if predict_mean:
1781+
if predict_mean_forest:
1782+
mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1)
1783+
if predict_rfx:
1784+
rfx_predictions = np.mean(rfx_predictions, axis = 1)
1785+
if predict_y_hat:
1786+
y_hat = np.mean(y_hat, axis = 1)
1787+
17351788
if predict_count == 1:
17361789
if predict_y_hat:
17371790
return y_hat

0 commit comments

Comments
 (0)