@@ -1867,189 +1867,6 @@ def predict(
18671867 result ["variance_forest_predictions" ] = None
18681868 return result
18691869
1870- def predict_mean (
1871- self ,
1872- covariates : np .array ,
1873- basis : np .array = None ,
1874- rfx_group_ids : np .array = None ,
1875- rfx_basis : np .array = None ,
1876- ) -> np .array :
1877- """Predict expected conditional outcome from a BART model.
1878-
1879- Parameters
1880- ----------
1881- covariates : np.array
1882- Test set covariates.
1883- basis : np.array, optional
1884- Optional test set basis vector, must be provided if the model was trained with a leaf regression basis.
1885-
1886- Returns
1887- -------
1888- np.array
1889- Mean forest predictions.
1890- """
1891- if not self .is_sampled ():
1892- msg = (
1893- "This BARTModel instance is not fitted yet. Call 'fit' with "
1894- "appropriate arguments before using this model."
1895- )
1896- raise NotSampledError (msg )
1897-
1898- has_mean_predictions = self .include_mean_forest or self .has_rfx
1899- if not has_mean_predictions :
1900- msg = (
1901- "This BARTModel instance was not sampled with a mean forest or random effects. "
1902- "Call 'fit' with appropriate arguments before using this model."
1903- )
1904- raise NotSampledError (msg )
1905-
1906- # Data checks
1907- if not isinstance (covariates , pd .DataFrame ) and not isinstance (
1908- covariates , np .ndarray
1909- ):
1910- raise ValueError ("covariates must be a pandas dataframe or numpy array" )
1911- if basis is not None :
1912- if not isinstance (basis , np .ndarray ):
1913- raise ValueError ("basis must be a numpy array" )
1914- if basis .shape [0 ] != covariates .shape [0 ]:
1915- raise ValueError (
1916- "covariates and basis must have the same number of rows"
1917- )
1918-
1919- # Convert everything to standard shape (2-dimensional)
1920- if isinstance (covariates , np .ndarray ):
1921- if covariates .ndim == 1 :
1922- covariates = np .expand_dims (covariates , 1 )
1923- if basis is not None :
1924- if basis .ndim == 1 :
1925- basis = np .expand_dims (basis , 1 )
1926-
1927- # Covariate preprocessing
1928- if not self ._covariate_preprocessor ._check_is_fitted ():
1929- if not isinstance (covariates , np .ndarray ):
1930- raise ValueError (
1931- "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
1932- )
1933- else :
1934- warnings .warn (
1935- "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe." ,
1936- RuntimeWarning ,
1937- )
1938- if not np .issubdtype (
1939- covariates .dtype , np .floating
1940- ) and not np .issubdtype (covariates .dtype , np .integer ):
1941- raise ValueError (
1942- "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
1943- )
1944- covariates_processed = covariates
1945- else :
1946- covariates_processed = self ._covariate_preprocessor .transform (covariates )
1947-
1948- # Dataset construction
1949- pred_dataset = Dataset ()
1950- pred_dataset .add_covariates (covariates_processed )
1951- if basis is not None :
1952- pred_dataset .add_basis (basis )
1953-
1954- # Mean forest predictions
1955- if self .include_mean_forest :
1956- mean_pred_raw = self .forest_container_mean .forest_container_cpp .Predict (
1957- pred_dataset .dataset_cpp
1958- )
1959- mean_pred = mean_pred_raw * self .y_std + self .y_bar
1960-
1961- # RFX predictions
1962- if self .has_rfx :
1963- rfx_preds = (
1964- self .rfx_container .predict (rfx_group_ids , rfx_basis ) * self .y_std
1965- )
1966- if self .include_mean_forest :
1967- mean_pred = mean_pred + rfx_preds
1968- else :
1969- mean_pred = rfx_preds + self .y_bar
1970-
1971- return mean_pred
1972-
1973- def predict_variance (self , covariates : np .array ) -> np .array :
1974- """Predict expected conditional variance from a BART model.
1975-
1976- Parameters
1977- ----------
1978- covariates : np.array
1979- Test set covariates.
1980-
1981- Returns
1982- -------
1983- np.array
1984- Variance forest predictions.
1985- """
1986- if not self .is_sampled ():
1987- msg = (
1988- "This BARTModel instance is not fitted yet. Call 'fit' with "
1989- "appropriate arguments before using this model."
1990- )
1991- raise NotSampledError (msg )
1992-
1993- if not self .include_variance_forest :
1994- msg = (
1995- "This BARTModel instance was not sampled with a variance forest. "
1996- "Call 'fit' with appropriate arguments before using this model."
1997- )
1998- raise NotSampledError (msg )
1999-
2000- # Data checks
2001- if not isinstance (covariates , pd .DataFrame ) and not isinstance (
2002- covariates , np .ndarray
2003- ):
2004- raise ValueError ("covariates must be a pandas dataframe or numpy array" )
2005-
2006- # Convert everything to standard shape (2-dimensional)
2007- if isinstance (covariates , np .ndarray ):
2008- if covariates .ndim == 1 :
2009- covariates = np .expand_dims (covariates , 1 )
2010-
2011- # Covariate preprocessing
2012- if not self ._covariate_preprocessor ._check_is_fitted ():
2013- if not isinstance (covariates , np .ndarray ):
2014- raise ValueError (
2015- "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
2016- )
2017- else :
2018- warnings .warn (
2019- "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe." ,
2020- RuntimeWarning ,
2021- )
2022- if not np .issubdtype (
2023- covariates .dtype , np .floating
2024- ) and not np .issubdtype (covariates .dtype , np .integer ):
2025- raise ValueError (
2026- "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
2027- )
2028- covariates_processed = covariates
2029- else :
2030- covariates_processed = self ._covariate_preprocessor .transform (covariates )
2031-
2032- # Dataset construction
2033- pred_dataset = Dataset ()
2034- pred_dataset .add_covariates (covariates_processed )
2035-
2036- # Variance forest predictions
2037- variance_pred_raw = self .forest_container_variance .forest_container_cpp .Predict (
2038- pred_dataset .dataset_cpp
2039- )
2040- if self .sample_sigma2_global :
2041- variance_pred = np .empty_like (variance_pred_raw )
2042- for i in range (self .num_samples ):
2043- variance_pred [:, i ] = (
2044- variance_pred_raw [:, i ] * self .global_var_samples [i ]
2045- )
2046- else :
2047- variance_pred = (
2048- variance_pred_raw * self .sigma2_init * self .y_std * self .y_std
2049- )
2050-
2051- return variance_pred
2052-
20531870 def to_json (self ) -> str :
20541871 """
20551872 Converts a sampled BART model to JSON string representation (which can then be saved to a file or
0 commit comments