|
1 | | -""" |
2 | | -Bayesian Additive Regression Trees (BART) module |
3 | | -""" |
4 | | - |
5 | 1 | import warnings |
6 | 2 | from math import log |
7 | 3 | from numbers import Integral |
|
28 | 24 | _expand_dims_1d, |
29 | 25 | _expand_dims_2d, |
30 | 26 | _expand_dims_2d_diag, |
| 27 | + _posterior_predictive_heuristic_multiplier, |
| 28 | + _summarize_interval |
31 | 29 | ) |
32 | 30 |
|
33 | 31 |
|
@@ -1860,6 +1858,114 @@ def predict( |
1860 | 1858 | result["variance_forest_predictions"] = None |
1861 | 1859 | return result |
1862 | 1860 |
|
| 1861 | + def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict: |
| 1862 | + """ |
| 1863 | + Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. |
| 1864 | +
|
| 1865 | + Parameters |
| 1866 | + ---------- |
| 1867 | + terms : str, optional |
| 1868 | + Character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Defaults to `"all"`. |
| 1869 | + scale : str, optional |
| 1870 | + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. |
| 1871 | + level : float, optional |
| 1872 | + A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. |
| 1873 | + covariates : np.array, optional |
| 1874 | + Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). |
| 1875 | + basis : np.array, optional |
| 1876 | + Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. |
| 1877 | + rfx_group_ids : np.array, optional |
| 1878 | + Optional vector of group IDs for random effects. Required if the requested term includes random effects. |
| 1879 | + rfx_basis : np.array, optional |
| 1880 | + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. |
| 1881 | +
|
| 1882 | + Returns |
| 1883 | + ------- |
| 1884 | + dict |
| 1885 | + A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. |
| 1886 | + """ |
| 1887 | + # Check the provided model object and requested term |
| 1888 | + self.is_sampled() |
| 1889 | + for term in terms: |
| 1890 | + self.has_term(term) |
| 1891 | + |
| 1892 | + # Handle mean function scale |
| 1893 | + if not isinstance(scale, str): |
| 1894 | + raise ValueError("scale must be a string") |
| 1895 | + if scale not in ["linear", "probability"]: |
| 1896 | + raise ValueError("scale must either be 'linear' or 'probability'") |
| 1897 | + is_probit = self.probit_outcome_model |
| 1898 | + if (scale == "probability") and (not is_probit): |
| 1899 | + raise ValueError( |
| 1900 | + "scale cannot be 'probability' for models not fit with a probit outcome model" |
| 1901 | + ) |
| 1902 | + |
| 1903 | + # Check that all the necessary inputs were provided for interval computation |
| 1904 | + needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.include_mean_forest |
| 1905 | + needs_covariates = ("mean_forest" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate |
| 1906 | + if needs_covariates: |
| 1907 | + if covariates is None: |
| 1908 | + raise ValueError( |
| 1909 | + "'covariates' must be provided in order to compute the requested intervals" |
| 1910 | + ) |
| 1911 | + if not isinstance(covariates, np.ndarray) and not isinstance( |
| 1912 | + covariates, pd.DataFrame |
| 1913 | + ): |
| 1914 | + raise ValueError("'covariates' must be a matrix or data frame") |
| 1915 | + needs_basis = needs_covariates and self.has_basis |
| 1916 | + if needs_basis: |
| 1917 | + if basis is None: |
| 1918 | + raise ValueError( |
| 1919 | + "'basis' must be provided in order to compute the requested intervals" |
| 1920 | + ) |
| 1921 | + if not isinstance(basis, np.ndarray): |
| 1922 | + raise ValueError("'basis' must be a numpy array") |
| 1923 | + if basis.shape[0] != covariates.shape[0]: |
| 1924 | + raise ValueError( |
| 1925 | + "'basis' must have the same number of rows as 'covariates'" |
| 1926 | + ) |
| 1927 | + needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx |
| 1928 | + needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate |
| 1929 | + if needs_rfx_data: |
| 1930 | + if rfx_group_ids is None: |
| 1931 | + raise ValueError( |
| 1932 | + "'rfx_group_ids' must be provided in order to compute the requested intervals" |
| 1933 | + ) |
| 1934 | + if not isinstance(rfx_group_ids, np.ndarray): |
| 1935 | + raise ValueError("'rfx_group_ids' must be a numpy array") |
| 1936 | + if rfx_group_ids.shape[0] != covariates.shape[0]: |
| 1937 | + raise ValueError( |
| 1938 | + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" |
| 1939 | + ) |
| 1940 | + if rfx_basis is None: |
| 1941 | + raise ValueError( |
| 1942 | + "'rfx_basis' must be provided in order to compute the requested intervals" |
| 1943 | + ) |
| 1944 | + if not isinstance(rfx_basis, np.ndarray): |
| 1945 | + raise ValueError("'rfx_basis' must be a numpy array") |
| 1946 | + if rfx_basis.shape[0] != covariates.shape[0]: |
| 1947 | + raise ValueError( |
| 1948 | + "'rfx_basis' must have the same number of rows as 'covariates'" |
| 1949 | + ) |
| 1950 | + |
| 1951 | + # Compute posterior matrices for the requested model terms |
| 1952 | + predictions = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale) |
| 1953 | + has_multiple_terms = True if isinstance(predictions, dict) else False |
| 1954 | + |
| 1955 | + # Compute posterior intervals |
| 1956 | + if has_multiple_terms: |
| 1957 | + result = dict() |
| 1958 | + for term in predictions.keys(): |
| 1959 | + if predictions[term] is not None: |
| 1960 | + result[term] = _summarize_interval( |
| 1961 | + predictions[term], 1, level=level |
| 1962 | + ) |
| 1963 | + return result |
| 1964 | + else: |
| 1965 | + return _summarize_interval( |
| 1966 | + predictions, 1, level=level |
| 1967 | + ) |
| 1968 | + |
1863 | 1969 | def to_json(self) -> str: |
1864 | 1970 | """ |
1865 | 1971 | Converts a sampled BART model to JSON string representation (which can then be saved to a file or |
@@ -2145,3 +2251,30 @@ def is_sampled(self) -> bool: |
2145 | 2251 | `True` if a BART model has been sampled, `False` otherwise |
2146 | 2252 | """ |
2147 | 2253 | return self.sampled |
| 2254 | + |
| 2255 | + def has_term(self, term: str) -> bool: |
| 2256 | + """ |
| 2257 | + Whether or not a model includes a term. |
| 2258 | +
|
| 2259 | + Parameters |
| 2260 | + ---------- |
| 2261 | + term : str |
| 2262 | + Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. |
| 2263 | + |
| 2264 | + Returns |
| 2265 | + ------- |
| 2266 | + bool |
| 2267 | + `True` if the model includes the specified term, `False` otherwise |
| 2268 | + """ |
| 2269 | + if term == "mean_forest": |
| 2270 | + return self.include_mean_forest |
| 2271 | + elif term == "variance_forest": |
| 2272 | + return self.include_variance_forest |
| 2273 | + elif term == "rfx": |
| 2274 | + return self.has_rfx |
| 2275 | + elif term == "y_hat": |
| 2276 | + return self.include_mean_forest or self.has_rfx |
| 2277 | + elif term == "all": |
| 2278 | + return True |
| 2279 | + else: |
| 2280 | + return False |
0 commit comments