2424 _expand_dims_1d ,
2525 _expand_dims_2d ,
2626 _expand_dims_2d_diag ,
27- _posterior_predictive_heuristic_multiplier ,
28- _summarize_interval
27+ _posterior_predictive_heuristic_multiplier ,
28+ _summarize_interval ,
2929)
3030
3131
@@ -1858,7 +1858,16 @@ def predict(
18581858 result ["variance_forest_predictions" ] = None
18591859 return result
18601860
1861- def compute_posterior_interval (self , terms : Union [list [str ], str ] = "all" , scale : str = "linear" , level : float = 0.95 , covariates : np .array = None , basis : np .array = None , rfx_group_ids : np .array = None , rfx_basis : np .array = None ) -> dict :
1861+ def compute_posterior_interval (
1862+ self ,
1863+ terms : Union [list [str ], str ] = "all" ,
1864+ scale : str = "linear" ,
1865+ level : float = 0.95 ,
1866+ covariates : np .array = None ,
1867+ basis : np .array = None ,
1868+ rfx_group_ids : np .array = None ,
1869+ rfx_basis : np .array = None ,
1870+ ) -> dict :
18621871 """
18631872 Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
18641873
@@ -1889,7 +1898,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
18891898 raise ValueError ("Model has not yet been sampled" )
18901899 for term in terms :
18911900 if not self .has_term (term ):
1892- warnings .warn (f"Term { term } was not sampled in this model and its intervals will not be returned." )
1901+ warnings .warn (
1902+ f"Term { term } was not sampled in this model and its intervals will not be returned."
1903+ )
18931904
18941905 # Handle mean function scale
18951906 if not isinstance (scale , str ):
@@ -1903,8 +1914,14 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19031914 )
19041915
19051916 # Check that all the necessary inputs were provided for interval computation
1906- needs_covariates_intermediate = (("y_hat" in terms ) or ("all" in terms )) and self .include_mean_forest
1907- needs_covariates = ("mean_forest" in terms ) or ("variance_forest" in terms ) or needs_covariates_intermediate
1917+ needs_covariates_intermediate = (
1918+ ("y_hat" in terms ) or ("all" in terms )
1919+ ) and self .include_mean_forest
1920+ needs_covariates = (
1921+ ("mean_forest" in terms )
1922+ or ("variance_forest" in terms )
1923+ or needs_covariates_intermediate
1924+ )
19081925 if needs_covariates :
19091926 if covariates is None :
19101927 raise ValueError (
@@ -1926,7 +1943,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19261943 raise ValueError (
19271944 "'basis' must have the same number of rows as 'covariates'"
19281945 )
1929- needs_rfx_data_intermediate = (("y_hat" in terms ) or ("all" in terms )) and self .has_rfx
1946+ needs_rfx_data_intermediate = (
1947+ ("y_hat" in terms ) or ("all" in terms )
1948+ ) and self .has_rfx
19301949 needs_rfx_data = ("rfx" in terms ) or needs_rfx_data_intermediate
19311950 if needs_rfx_data :
19321951 if rfx_group_ids is None :
@@ -1951,7 +1970,15 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19511970 )
19521971
19531972 # Compute posterior matrices for the requested model terms
1954- predictions = self .predict (covariates = covariates , basis = basis , rfx_group_ids = rfx_group_ids , rfx_basis = rfx_basis , type = "posterior" , terms = terms , scale = scale )
1973+ predictions = self .predict (
1974+ covariates = covariates ,
1975+ basis = basis ,
1976+ rfx_group_ids = rfx_group_ids ,
1977+ rfx_basis = rfx_basis ,
1978+ type = "posterior" ,
1979+ terms = terms ,
1980+ scale = scale ,
1981+ )
19551982 has_multiple_terms = True if isinstance (predictions , dict ) else False
19561983
19571984 # Compute posterior intervals
@@ -1964,11 +1991,16 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale
19641991 )
19651992 return result
19661993 else :
1967- return _summarize_interval (
1968- predictions , 1 , level = level
1969- )
1970-
1971- def sample_posterior_predictive (self , covariates : np .array = None , basis : np .array = None , rfx_group_ids : np .array = None , rfx_basis : np .array = None , num_draws_per_sample : int = None ) -> np .array :
1994+ return _summarize_interval (predictions , 1 , level = level )
1995+
1996+ def sample_posterior_predictive (
1997+ self ,
1998+ covariates : np .array = None ,
1999+ basis : np .array = None ,
2000+ rfx_group_ids : np .array = None ,
2001+ rfx_basis : np .array = None ,
2002+ num_draws_per_sample : int = None ,
2003+ ) -> np .array :
19722004 """
19732005 Sample from the posterior predictive distribution for outcomes modeled by BART
19742006
@@ -1984,7 +2016,7 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr
19842016 An array of basis function evaluations for random effects. Required if the BART model includes random effects.
19852017 num_draws_per_sample : int, optional
19862018 The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
1987-
2019+
19882020 Returns
19892021 -------
19902022 np.array
@@ -2044,58 +2076,61 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr
20442076 )
20452077
20462078 # Compute posterior predictive samples
2047- bart_preds = self .predict (covariates = covariates , basis = basis , rfx_group_ids = rfx_group_ids , rfx_basis = rfx_basis , type = "posterior" , terms = "all" )
2079+ bart_preds = self .predict (
2080+ covariates = covariates ,
2081+ basis = basis ,
2082+ rfx_group_ids = rfx_group_ids ,
2083+ rfx_basis = rfx_basis ,
2084+ type = "posterior" ,
2085+ terms = "all" ,
2086+ )
20482087
20492088 # Compute outcome mean and variance for posterior predictive distribution
2050- has_mean_term = ( self .include_mean_forest or self .has_rfx )
2089+ has_mean_term = self .include_mean_forest or self .has_rfx
20512090 has_variance_forest = self .include_variance_forest
20522091 samples_global_variance = self .sample_sigma2_global
20532092 num_posterior_draws = self .num_samples
20542093 num_observations = covariates .shape [0 ]
20552094 if has_mean_term :
20562095 ppd_mean = bart_preds ["y_hat" ]
20572096 else :
2058- ppd_mean = 0.
2097+ ppd_mean = 0.0
20592098 if has_variance_forest :
20602099 ppd_variance = bart_preds ["variance_forest_predictions" ]
20612100 else :
20622101 if samples_global_variance :
2063- ppd_variance = np .tile (
2064- self .global_var_samples ,
2065- (num_observations , 1 )
2066- )
2102+ ppd_variance = np .tile (self .global_var_samples , (num_observations , 1 ))
20672103 else :
20682104 ppd_variance = self .sigma2_init
2069-
2105+
20702106 # Sample from the posterior predictive distribution
20712107 if num_draws_per_sample is None :
20722108 ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier (
2073- num_posterior_draws ,
2074- num_observations
2109+ num_posterior_draws , num_observations
20752110 )
20762111 else :
20772112 ppd_draw_multiplier = num_draws_per_sample
20782113 if ppd_draw_multiplier > 1 :
20792114 ppd_mean = np .tile (ppd_mean , (ppd_draw_multiplier , 1 , 1 ))
20802115 ppd_variance = np .tile (ppd_variance , (ppd_draw_multiplier , 1 , 1 ))
20812116 ppd_array = np .random .normal (
2082- loc = ppd_mean ,
2083- scale = np .sqrt (ppd_variance ),
2084- size = (ppd_draw_multiplier , num_observations , num_posterior_draws )
2117+ loc = ppd_mean ,
2118+ scale = np .sqrt (ppd_variance ),
2119+ size = (ppd_draw_multiplier , num_observations , num_posterior_draws ),
20852120 )
20862121 else :
20872122 ppd_array = np .random .normal (
2088- loc = ppd_mean ,
2089- scale = np .sqrt (ppd_variance ),
2090- size = (num_observations , num_posterior_draws )
2123+ loc = ppd_mean ,
2124+ scale = np .sqrt (ppd_variance ),
2125+ size = (num_observations , num_posterior_draws ),
20912126 )
2092-
2127+
20932128 # Binarize outcome for probit models
20942129 if is_probit :
20952130 ppd_array = (ppd_array > 0.0 ) * 1
2096-
2131+
20972132 return ppd_array
2098-
2133+
20992134 def to_json (self ) -> str :
21002135 """
21012136 Converts a sampled BART model to JSON string representation (which can then be saved to a file or
@@ -2381,7 +2416,7 @@ def is_sampled(self) -> bool:
23812416 `True` if a BART model has been sampled, `False` otherwise
23822417 """
23832418 return self .sampled
2384-
2419+
23852420 def has_term (self , term : str ) -> bool :
23862421 """
23872422 Whether or not a model includes a term.
@@ -2390,7 +2425,7 @@ def has_term(self, term: str) -> bool:
23902425 ----------
23912426 term : str
23922427 Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`.
2393-
2428+
23942429 Returns
23952430 -------
23962431 bool
0 commit comments