You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: stochtree/bart.py
+70-17Lines changed: 70 additions & 17 deletions
Original file line number
Diff line number
Diff line change
@@ -1579,7 +1579,8 @@ def predict(
1579
1579
rfx_group_ids: np.array=None,
1580
1580
rfx_basis: np.array=None,
1581
1581
type: str="posterior",
1582
-
terms: Union[list[str], str] ="all"
1582
+
terms: Union[list[str], str] ="all",
1583
+
scale: str="linear"
1583
1584
) ->Union[np.array, tuple]:
1584
1585
"""Return predictions from every forest sampled (either / both of mean and variance).
1585
1586
Return type is either a single array of predictions, if a BART model only includes a
@@ -1599,11 +1600,25 @@ def predict(
1599
1600
Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
1600
1601
terms : str, optional
1601
1602
Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
1603
+
scale : str, optional
1604
+
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
1602
1605
1603
1606
Returns
1604
1607
-------
1605
1608
Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested.
1606
1609
"""
1610
+
# Handle mean function scale
1611
+
ifnotisinstance(scale, str):
1612
+
raiseValueError("scale must be a string")
1613
+
ifscalenotin ["linear", "probability"]:
1614
+
raiseValueError("scale must either be 'linear' or 'probability'")
1615
+
is_probit=self.probit_outcome_model
1616
+
if (scale=="probability") and (notis_probit):
1617
+
raiseValueError(
1618
+
"scale cannot be 'probability' for models not fit with a probit outcome model"
0 commit comments