2323)
2424from .sampler import RNG , ForestSampler , GlobalVarianceModel , LeafVarianceModel
2525from .serialization import JSONSerializer
26- from .utils import NotSampledError , _expand_dims_1d , _expand_dims_2d , _expand_dims_2d_diag
26+ from .utils import (
27+ NotSampledError ,
28+ _expand_dims_1d ,
29+ _expand_dims_2d ,
30+ _expand_dims_2d_diag ,
31+ )
2732
2833
2934class BARTModel :
@@ -262,10 +267,18 @@ def sample(
262267 keep_every = general_params_updated ["keep_every" ]
263268 num_chains = general_params_updated ["num_chains" ]
264269 self .probit_outcome_model = general_params_updated ["probit_outcome_model" ]
265- rfx_working_parameter_prior_mean = general_params_updated ["rfx_working_parameter_prior_mean" ]
266- rfx_group_parameter_prior_mean = general_params_updated ["rfx_group_parameter_prior_mean" ]
267- rfx_working_parameter_prior_cov = general_params_updated ["rfx_working_parameter_prior_cov" ]
268- rfx_group_parameter_prior_cov = general_params_updated ["rfx_group_parameter_prior_cov" ]
270+ rfx_working_parameter_prior_mean = general_params_updated [
271+ "rfx_working_parameter_prior_mean"
272+ ]
273+ rfx_group_parameter_prior_mean = general_params_updated [
274+ "rfx_group_parameter_prior_mean"
275+ ]
276+ rfx_working_parameter_prior_cov = general_params_updated [
277+ "rfx_working_parameter_prior_cov"
278+ ]
279+ rfx_group_parameter_prior_cov = general_params_updated [
280+ "rfx_group_parameter_prior_cov"
281+ ]
269282 rfx_variance_prior_shape = general_params_updated ["rfx_variance_prior_shape" ]
270283 rfx_variance_prior_scale = general_params_updated ["rfx_variance_prior_scale" ]
271284 num_threads = general_params_updated ["num_threads" ]
@@ -282,7 +295,9 @@ def sample(
282295 b_leaf = mean_forest_params_updated ["sigma2_leaf_scale" ]
283296 keep_vars_mean = mean_forest_params_updated ["keep_vars" ]
284297 drop_vars_mean = mean_forest_params_updated ["drop_vars" ]
285- num_features_subsample_mean = mean_forest_params_updated ["num_features_subsample" ]
298+ num_features_subsample_mean = mean_forest_params_updated [
299+ "num_features_subsample"
300+ ]
286301
287302 # 3. Variance forest parameters
288303 num_trees_variance = variance_forest_params_updated ["num_trees" ]
@@ -298,7 +313,9 @@ def sample(
298313 b_forest = variance_forest_params_updated ["var_forest_prior_scale" ]
299314 keep_vars_variance = variance_forest_params_updated ["keep_vars" ]
300315 drop_vars_variance = variance_forest_params_updated ["drop_vars" ]
301- num_features_subsample_variance = variance_forest_params_updated ["num_features_subsample" ]
316+ num_features_subsample_variance = variance_forest_params_updated [
317+ "num_features_subsample"
318+ ]
302319
303320 # Override keep_gfr if there are no MCMC samples
304321 if num_mcmc == 0 :
@@ -989,26 +1006,34 @@ def sample(
9891006 else :
9901007 raise ValueError ("There must be at least 1 random effect component" )
9911008 else :
992- alpha_init = _expand_dims_1d (rfx_working_parameter_prior_mean , num_rfx_components )
993-
1009+ alpha_init = _expand_dims_1d (
1010+ rfx_working_parameter_prior_mean , num_rfx_components
1011+ )
1012+
9941013 if rfx_group_parameter_prior_mean is None :
9951014 xi_init = np .tile (np .expand_dims (alpha_init , 1 ), (1 , num_rfx_groups ))
9961015 else :
997- xi_init = _expand_dims_2d (rfx_group_parameter_prior_mean , num_rfx_components , num_rfx_groups )
998-
1016+ xi_init = _expand_dims_2d (
1017+ rfx_group_parameter_prior_mean , num_rfx_components , num_rfx_groups
1018+ )
1019+
9991020 if rfx_working_parameter_prior_cov is None :
10001021 sigma_alpha_init = np .identity (num_rfx_components )
10011022 else :
1002- sigma_alpha_init = _expand_dims_2d_diag (rfx_working_parameter_prior_cov , num_rfx_components )
1003-
1023+ sigma_alpha_init = _expand_dims_2d_diag (
1024+ rfx_working_parameter_prior_cov , num_rfx_components
1025+ )
1026+
10041027 if rfx_group_parameter_prior_cov is None :
10051028 sigma_xi_init = np .identity (num_rfx_components )
10061029 else :
1007- sigma_xi_init = _expand_dims_2d_diag (rfx_group_parameter_prior_cov , num_rfx_components )
1008-
1030+ sigma_xi_init = _expand_dims_2d_diag (
1031+ rfx_group_parameter_prior_cov , num_rfx_components
1032+ )
1033+
10091034 sigma_xi_shape = rfx_variance_prior_shape
10101035 sigma_xi_scale = rfx_variance_prior_scale
1011-
1036+
10121037 # Random effects sampling data structures
10131038 rfx_dataset_train = RandomEffectsDataset ()
10141039 rfx_dataset_train .add_group_labels (rfx_group_ids_train )
@@ -1046,9 +1071,13 @@ def sample(
10461071 if sample_sigma2_leaf :
10471072 self .leaf_scale_samples = np .empty (self .num_samples , dtype = np .float64 )
10481073 if self .include_mean_forest :
1049- yhat_train_raw = np .empty ((self .n_train , self .num_samples ), dtype = np .float64 )
1074+ yhat_train_raw = np .empty (
1075+ (self .n_train , self .num_samples ), dtype = np .float64
1076+ )
10501077 if self .include_variance_forest :
1051- sigma2_x_train_raw = np .empty ((self .n_train , self .num_samples ), dtype = np .float64 )
1078+ sigma2_x_train_raw = np .empty (
1079+ (self .n_train , self .num_samples ), dtype = np .float64
1080+ )
10521081 sample_counter = - 1
10531082
10541083 # Forest Dataset (covariates and optional basis)
@@ -1104,8 +1133,8 @@ def sample(
11041133 max_depth = max_depth_mean ,
11051134 leaf_model_type = leaf_model_mean_forest ,
11061135 leaf_model_scale = current_leaf_scale ,
1107- cutpoint_grid_size = cutpoint_grid_size ,
1108- num_features_subsample = num_features_subsample_mean
1136+ cutpoint_grid_size = cutpoint_grid_size ,
1137+ num_features_subsample = num_features_subsample_mean ,
11091138 )
11101139 forest_sampler_mean = ForestSampler (
11111140 forest_dataset_train ,
@@ -1128,7 +1157,7 @@ def sample(
11281157 cutpoint_grid_size = cutpoint_grid_size ,
11291158 variance_forest_shape = a_forest ,
11301159 variance_forest_scale = b_forest ,
1131- num_features_subsample = num_features_subsample_variance
1160+ num_features_subsample = num_features_subsample_variance ,
11321161 )
11331162 forest_sampler_variance = ForestSampler (
11341163 forest_dataset_train ,
@@ -1234,7 +1263,9 @@ def sample(
12341263
12351264 # Cache train set predictions since they are already computed during sampling
12361265 if keep_sample :
1237- yhat_train_raw [:,sample_counter ] = forest_sampler_mean .get_cached_forest_predictions ()
1266+ yhat_train_raw [:, sample_counter ] = (
1267+ forest_sampler_mean .get_cached_forest_predictions ()
1268+ )
12381269
12391270 # Sample the variance forest
12401271 if self .include_variance_forest :
@@ -1253,7 +1284,9 @@ def sample(
12531284
12541285 # Cache train set predictions since they are already computed during sampling
12551286 if keep_sample :
1256- sigma2_x_train_raw [:,sample_counter ] = forest_sampler_variance .get_cached_forest_predictions ()
1287+ sigma2_x_train_raw [:, sample_counter ] = (
1288+ forest_sampler_variance .get_cached_forest_predictions ()
1289+ )
12571290
12581291 # Sample variance parameters (if requested)
12591292 if self .sample_sigma2_global :
@@ -1435,7 +1468,9 @@ def sample(
14351468 )
14361469
14371470 if keep_sample :
1438- yhat_train_raw [:,sample_counter ] = forest_sampler_mean .get_cached_forest_predictions ()
1471+ yhat_train_raw [:, sample_counter ] = (
1472+ forest_sampler_mean .get_cached_forest_predictions ()
1473+ )
14391474
14401475 # Sample the variance forest
14411476 if self .include_variance_forest :
@@ -1453,7 +1488,9 @@ def sample(
14531488 )
14541489
14551490 if keep_sample :
1456- sigma2_x_train_raw [:,sample_counter ] = forest_sampler_variance .get_cached_forest_predictions ()
1491+ sigma2_x_train_raw [:, sample_counter ] = (
1492+ forest_sampler_variance .get_cached_forest_predictions ()
1493+ )
14571494
14581495 # Sample variance parameters (if requested)
14591496 if self .sample_sigma2_global :
@@ -1504,9 +1541,9 @@ def sample(
15041541 if self .sample_sigma2_leaf :
15051542 self .leaf_scale_samples = self .leaf_scale_samples [num_gfr :]
15061543 if self .include_mean_forest :
1507- yhat_train_raw = yhat_train_raw [:,num_gfr :]
1544+ yhat_train_raw = yhat_train_raw [:, num_gfr :]
15081545 if self .include_variance_forest :
1509- sigma2_x_train_raw = sigma2_x_train_raw [:,num_gfr :]
1546+ sigma2_x_train_raw = sigma2_x_train_raw [:, num_gfr :]
15101547 self .num_samples -= num_gfr
15111548
15121549 # Store predictions
@@ -1553,7 +1590,10 @@ def sample(
15531590 )
15541591 else :
15551592 self .sigma2_x_train = (
1556- np .exp (sigma2_x_train_raw ) * self .sigma2_init * self .y_std * self .y_std
1593+ np .exp (sigma2_x_train_raw )
1594+ * self .sigma2_init
1595+ * self .y_std
1596+ * self .y_std
15571597 )
15581598 if self .has_test :
15591599 sigma2_x_test_raw = (
@@ -1577,10 +1617,10 @@ def predict(
15771617 covariates : Union [np .array , pd .DataFrame ],
15781618 basis : np .array = None ,
15791619 rfx_group_ids : np .array = None ,
1580- rfx_basis : np .array = None ,
1581- type : str = "posterior" ,
1582- terms : Union [list [str ], str ] = "all" ,
1583- scale : str = "linear"
1620+ rfx_basis : np .array = None ,
1621+ type : str = "posterior" ,
1622+ terms : Union [list [str ], str ] = "all" ,
1623+ scale : str = "linear" ,
15841624 ) -> Union [np .array , tuple ]:
15851625 """Return predictions from every forest sampled (either / both of mean and variance).
15861626 Return type is either a single array of predictions, if a BART model only includes a
@@ -1634,28 +1674,39 @@ def predict(
16341674 has_variance_forest = self .include_variance_forest
16351675 has_rfx = self .has_rfx
16361676 has_y_hat = has_mean_forest or has_rfx
1637- predict_y_hat = ((has_y_hat and ("y_hat" in terms )) or
1638- (has_y_hat and ("all" in terms )))
1639- predict_mean_forest = ((has_mean_forest and ("mean_forest" in terms )) or
1640- (has_mean_forest and ("all" in terms )))
1641- predict_rfx = ((has_rfx and ("rfx" in terms )) or
1642- (has_rfx and ("all" in terms )))
1643- predict_variance_forest = ((has_variance_forest and ("variance_forest" in terms )) or
1644- (has_variance_forest and ("all" in terms )))
1645- predict_count = (predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest )
1677+ predict_y_hat = (has_y_hat and ("y_hat" in terms )) or (
1678+ has_y_hat and ("all" in terms )
1679+ )
1680+ predict_mean_forest = (has_mean_forest and ("mean_forest" in terms )) or (
1681+ has_mean_forest and ("all" in terms )
1682+ )
1683+ predict_rfx = (has_rfx and ("rfx" in terms )) or (has_rfx and ("all" in terms ))
1684+ predict_variance_forest = (
1685+ has_variance_forest and ("variance_forest" in terms )
1686+ ) or (has_variance_forest and ("all" in terms ))
1687+ predict_count = (
1688+ predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest
1689+ )
16461690 if predict_count == 0 :
16471691 term_list = ", " .join (terms )
1648- warnings .warn (f"None of the requested model terms, { term_list } , were fit in this model" )
1692+ warnings .warn (
1693+ f"None of the requested model terms, { term_list } , were fit in this model"
1694+ )
16491695 return None
16501696 predict_rfx_intermediate = predict_y_hat and has_rfx
16511697 predict_mean_forest_intermediate = predict_y_hat and has_mean_forest
16521698
16531699 # Check that we have at least one term to predict on probability scale
1654- if (probability_scale and not predict_y_hat and not predict_mean_forest and not predict_rfx ):
1700+ if (
1701+ probability_scale
1702+ and not predict_y_hat
1703+ and not predict_mean_forest
1704+ and not predict_rfx
1705+ ):
16551706 raise ValueError (
16561707 "scale can only be 'probability' if at least one mean term is requested"
16571708 )
1658-
1709+
16591710 # Check the model is valid
16601711 if not self .is_sampled ():
16611712 msg = (
@@ -1730,7 +1781,9 @@ def predict(
17301781 variance_pred_raw * self .sigma2_init * self .y_std * self .y_std
17311782 )
17321783 if predict_mean :
1733- variance_forest_predictions = np .mean (variance_forest_predictions , axis = 1 )
1784+ variance_forest_predictions = np .mean (
1785+ variance_forest_predictions , axis = 1
1786+ )
17341787
17351788 # Forest predictions
17361789 if predict_mean_forest or predict_mean_forest_intermediate :
@@ -1756,7 +1809,7 @@ def predict(
17561809 y_hat = mean_forest_predictions
17571810 elif predict_y_hat and has_rfx :
17581811 y_hat = rfx_predictions
1759-
1812+
17601813 if probability_scale :
17611814 if predict_y_hat and has_mean_forest and has_rfx :
17621815 y_hat = norm .ppf (mean_forest_predictions + rfx_predictions )
@@ -1775,16 +1828,16 @@ def predict(
17751828 y_hat = mean_forest_predictions
17761829 elif predict_y_hat and has_rfx :
17771830 y_hat = rfx_predictions
1778-
1831+
17791832 # Collapse to posterior mean predictions if requested
17801833 if predict_mean :
17811834 if predict_mean_forest :
1782- mean_forest_predictions = np .mean (mean_forest_predictions , axis = 1 )
1835+ mean_forest_predictions = np .mean (mean_forest_predictions , axis = 1 )
17831836 if predict_rfx :
1784- rfx_predictions = np .mean (rfx_predictions , axis = 1 )
1837+ rfx_predictions = np .mean (rfx_predictions , axis = 1 )
17851838 if predict_y_hat :
1786- y_hat = np .mean (y_hat , axis = 1 )
1787-
1839+ y_hat = np .mean (y_hat , axis = 1 )
1840+
17881841 if predict_count == 1 :
17891842 if predict_y_hat :
17901843 return y_hat
@@ -1813,7 +1866,7 @@ def predict(
18131866 else :
18141867 result ["variance_forest_predictions" ] = None
18151868 return result
1816-
1869+
18171870 def predict_mean (
18181871 self ,
18191872 covariates : np .array ,
0 commit comments