Skip to content

Commit 7dc71ce

Browse files
committed
Deprecate predict_mean and predict_variance methods in python BART
1 parent 32f13a6 commit 7dc71ce

File tree

2 files changed

+1
-184
lines changed

2 files changed

+1
-184
lines changed

stochtree/bart.py

Lines changed: 0 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,189 +1867,6 @@ def predict(
18671867
result["variance_forest_predictions"] = None
18681868
return result
18691869

1870-
def predict_mean(
1871-
self,
1872-
covariates: np.array,
1873-
basis: np.array = None,
1874-
rfx_group_ids: np.array = None,
1875-
rfx_basis: np.array = None,
1876-
) -> np.array:
1877-
"""Predict expected conditional outcome from a BART model.
1878-
1879-
Parameters
1880-
----------
1881-
covariates : np.array
1882-
Test set covariates.
1883-
basis : np.array, optional
1884-
Optional test set basis vector, must be provided if the model was trained with a leaf regression basis.
1885-
1886-
Returns
1887-
-------
1888-
np.array
1889-
Mean forest predictions.
1890-
"""
1891-
if not self.is_sampled():
1892-
msg = (
1893-
"This BARTModel instance is not fitted yet. Call 'fit' with "
1894-
"appropriate arguments before using this model."
1895-
)
1896-
raise NotSampledError(msg)
1897-
1898-
has_mean_predictions = self.include_mean_forest or self.has_rfx
1899-
if not has_mean_predictions:
1900-
msg = (
1901-
"This BARTModel instance was not sampled with a mean forest or random effects. "
1902-
"Call 'fit' with appropriate arguments before using this model."
1903-
)
1904-
raise NotSampledError(msg)
1905-
1906-
# Data checks
1907-
if not isinstance(covariates, pd.DataFrame) and not isinstance(
1908-
covariates, np.ndarray
1909-
):
1910-
raise ValueError("covariates must be a pandas dataframe or numpy array")
1911-
if basis is not None:
1912-
if not isinstance(basis, np.ndarray):
1913-
raise ValueError("basis must be a numpy array")
1914-
if basis.shape[0] != covariates.shape[0]:
1915-
raise ValueError(
1916-
"covariates and basis must have the same number of rows"
1917-
)
1918-
1919-
# Convert everything to standard shape (2-dimensional)
1920-
if isinstance(covariates, np.ndarray):
1921-
if covariates.ndim == 1:
1922-
covariates = np.expand_dims(covariates, 1)
1923-
if basis is not None:
1924-
if basis.ndim == 1:
1925-
basis = np.expand_dims(basis, 1)
1926-
1927-
# Covariate preprocessing
1928-
if not self._covariate_preprocessor._check_is_fitted():
1929-
if not isinstance(covariates, np.ndarray):
1930-
raise ValueError(
1931-
"Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
1932-
)
1933-
else:
1934-
warnings.warn(
1935-
"This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
1936-
RuntimeWarning,
1937-
)
1938-
if not np.issubdtype(
1939-
covariates.dtype, np.floating
1940-
) and not np.issubdtype(covariates.dtype, np.integer):
1941-
raise ValueError(
1942-
"Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
1943-
)
1944-
covariates_processed = covariates
1945-
else:
1946-
covariates_processed = self._covariate_preprocessor.transform(covariates)
1947-
1948-
# Dataset construction
1949-
pred_dataset = Dataset()
1950-
pred_dataset.add_covariates(covariates_processed)
1951-
if basis is not None:
1952-
pred_dataset.add_basis(basis)
1953-
1954-
# Mean forest predictions
1955-
if self.include_mean_forest:
1956-
mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(
1957-
pred_dataset.dataset_cpp
1958-
)
1959-
mean_pred = mean_pred_raw * self.y_std + self.y_bar
1960-
1961-
# RFX predictions
1962-
if self.has_rfx:
1963-
rfx_preds = (
1964-
self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std
1965-
)
1966-
if self.include_mean_forest:
1967-
mean_pred = mean_pred + rfx_preds
1968-
else:
1969-
mean_pred = rfx_preds + self.y_bar
1970-
1971-
return mean_pred
1972-
1973-
def predict_variance(self, covariates: np.array) -> np.array:
1974-
"""Predict expected conditional variance from a BART model.
1975-
1976-
Parameters
1977-
----------
1978-
covariates : np.array
1979-
Test set covariates.
1980-
1981-
Returns
1982-
-------
1983-
np.array
1984-
Variance forest predictions.
1985-
"""
1986-
if not self.is_sampled():
1987-
msg = (
1988-
"This BARTModel instance is not fitted yet. Call 'fit' with "
1989-
"appropriate arguments before using this model."
1990-
)
1991-
raise NotSampledError(msg)
1992-
1993-
if not self.include_variance_forest:
1994-
msg = (
1995-
"This BARTModel instance was not sampled with a variance forest. "
1996-
"Call 'fit' with appropriate arguments before using this model."
1997-
)
1998-
raise NotSampledError(msg)
1999-
2000-
# Data checks
2001-
if not isinstance(covariates, pd.DataFrame) and not isinstance(
2002-
covariates, np.ndarray
2003-
):
2004-
raise ValueError("covariates must be a pandas dataframe or numpy array")
2005-
2006-
# Convert everything to standard shape (2-dimensional)
2007-
if isinstance(covariates, np.ndarray):
2008-
if covariates.ndim == 1:
2009-
covariates = np.expand_dims(covariates, 1)
2010-
2011-
# Covariate preprocessing
2012-
if not self._covariate_preprocessor._check_is_fitted():
2013-
if not isinstance(covariates, np.ndarray):
2014-
raise ValueError(
2015-
"Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
2016-
)
2017-
else:
2018-
warnings.warn(
2019-
"This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
2020-
RuntimeWarning,
2021-
)
2022-
if not np.issubdtype(
2023-
covariates.dtype, np.floating
2024-
) and not np.issubdtype(covariates.dtype, np.integer):
2025-
raise ValueError(
2026-
"Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
2027-
)
2028-
covariates_processed = covariates
2029-
else:
2030-
covariates_processed = self._covariate_preprocessor.transform(covariates)
2031-
2032-
# Dataset construction
2033-
pred_dataset = Dataset()
2034-
pred_dataset.add_covariates(covariates_processed)
2035-
2036-
# Variance forest predictions
2037-
variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(
2038-
pred_dataset.dataset_cpp
2039-
)
2040-
if self.sample_sigma2_global:
2041-
variance_pred = np.empty_like(variance_pred_raw)
2042-
for i in range(self.num_samples):
2043-
variance_pred[:, i] = (
2044-
variance_pred_raw[:, i] * self.global_var_samples[i]
2045-
)
2046-
else:
2047-
variance_pred = (
2048-
variance_pred_raw * self.sigma2_init * self.y_std * self.y_std
2049-
)
2050-
2051-
return variance_pred
2052-
20531870
def to_json(self) -> str:
20541871
"""
20551872
Converts a sampled BART model to JSON string representation (which can then be saved to a file or

stochtree/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union, Optional
1+
from typing import Union
22

33
import numpy as np
44

0 commit comments

Comments
 (0)