Skip to content

Commit 78ceee2

Browse files
committed
Reformat python code
1 parent c0cbdeb commit 78ceee2

File tree

1 file changed

+105
-52
lines changed

1 file changed

+105
-52
lines changed

stochtree/bart.py

Lines changed: 105 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
)
2424
from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel
2525
from .serialization import JSONSerializer
26-
from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag
26+
from .utils import (
27+
NotSampledError,
28+
_expand_dims_1d,
29+
_expand_dims_2d,
30+
_expand_dims_2d_diag,
31+
)
2732

2833

2934
class BARTModel:
@@ -262,10 +267,18 @@ def sample(
262267
keep_every = general_params_updated["keep_every"]
263268
num_chains = general_params_updated["num_chains"]
264269
self.probit_outcome_model = general_params_updated["probit_outcome_model"]
265-
rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"]
266-
rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"]
267-
rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"]
268-
rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"]
270+
rfx_working_parameter_prior_mean = general_params_updated[
271+
"rfx_working_parameter_prior_mean"
272+
]
273+
rfx_group_parameter_prior_mean = general_params_updated[
274+
"rfx_group_parameter_prior_mean"
275+
]
276+
rfx_working_parameter_prior_cov = general_params_updated[
277+
"rfx_working_parameter_prior_cov"
278+
]
279+
rfx_group_parameter_prior_cov = general_params_updated[
280+
"rfx_group_parameter_prior_cov"
281+
]
269282
rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"]
270283
rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"]
271284
num_threads = general_params_updated["num_threads"]
@@ -282,7 +295,9 @@ def sample(
282295
b_leaf = mean_forest_params_updated["sigma2_leaf_scale"]
283296
keep_vars_mean = mean_forest_params_updated["keep_vars"]
284297
drop_vars_mean = mean_forest_params_updated["drop_vars"]
285-
num_features_subsample_mean = mean_forest_params_updated["num_features_subsample"]
298+
num_features_subsample_mean = mean_forest_params_updated[
299+
"num_features_subsample"
300+
]
286301

287302
# 3. Variance forest parameters
288303
num_trees_variance = variance_forest_params_updated["num_trees"]
@@ -298,7 +313,9 @@ def sample(
298313
b_forest = variance_forest_params_updated["var_forest_prior_scale"]
299314
keep_vars_variance = variance_forest_params_updated["keep_vars"]
300315
drop_vars_variance = variance_forest_params_updated["drop_vars"]
301-
num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"]
316+
num_features_subsample_variance = variance_forest_params_updated[
317+
"num_features_subsample"
318+
]
302319

303320
# Override keep_gfr if there are no MCMC samples
304321
if num_mcmc == 0:
@@ -989,26 +1006,34 @@ def sample(
9891006
else:
9901007
raise ValueError("There must be at least 1 random effect component")
9911008
else:
992-
alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
993-
1009+
alpha_init = _expand_dims_1d(
1010+
rfx_working_parameter_prior_mean, num_rfx_components
1011+
)
1012+
9941013
if rfx_group_parameter_prior_mean is None:
9951014
xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups))
9961015
else:
997-
xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
998-
1016+
xi_init = _expand_dims_2d(
1017+
rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups
1018+
)
1019+
9991020
if rfx_working_parameter_prior_cov is None:
10001021
sigma_alpha_init = np.identity(num_rfx_components)
10011022
else:
1002-
sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
1003-
1023+
sigma_alpha_init = _expand_dims_2d_diag(
1024+
rfx_working_parameter_prior_cov, num_rfx_components
1025+
)
1026+
10041027
if rfx_group_parameter_prior_cov is None:
10051028
sigma_xi_init = np.identity(num_rfx_components)
10061029
else:
1007-
sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
1008-
1030+
sigma_xi_init = _expand_dims_2d_diag(
1031+
rfx_group_parameter_prior_cov, num_rfx_components
1032+
)
1033+
10091034
sigma_xi_shape = rfx_variance_prior_shape
10101035
sigma_xi_scale = rfx_variance_prior_scale
1011-
1036+
10121037
# Random effects sampling data structures
10131038
rfx_dataset_train = RandomEffectsDataset()
10141039
rfx_dataset_train.add_group_labels(rfx_group_ids_train)
@@ -1046,9 +1071,13 @@ def sample(
10461071
if sample_sigma2_leaf:
10471072
self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64)
10481073
if self.include_mean_forest:
1049-
yhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
1074+
yhat_train_raw = np.empty(
1075+
(self.n_train, self.num_samples), dtype=np.float64
1076+
)
10501077
if self.include_variance_forest:
1051-
sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
1078+
sigma2_x_train_raw = np.empty(
1079+
(self.n_train, self.num_samples), dtype=np.float64
1080+
)
10521081
sample_counter = -1
10531082

10541083
# Forest Dataset (covariates and optional basis)
@@ -1104,8 +1133,8 @@ def sample(
11041133
max_depth=max_depth_mean,
11051134
leaf_model_type=leaf_model_mean_forest,
11061135
leaf_model_scale=current_leaf_scale,
1107-
cutpoint_grid_size=cutpoint_grid_size,
1108-
num_features_subsample=num_features_subsample_mean
1136+
cutpoint_grid_size=cutpoint_grid_size,
1137+
num_features_subsample=num_features_subsample_mean,
11091138
)
11101139
forest_sampler_mean = ForestSampler(
11111140
forest_dataset_train,
@@ -1128,7 +1157,7 @@ def sample(
11281157
cutpoint_grid_size=cutpoint_grid_size,
11291158
variance_forest_shape=a_forest,
11301159
variance_forest_scale=b_forest,
1131-
num_features_subsample=num_features_subsample_variance
1160+
num_features_subsample=num_features_subsample_variance,
11321161
)
11331162
forest_sampler_variance = ForestSampler(
11341163
forest_dataset_train,
@@ -1234,7 +1263,9 @@ def sample(
12341263

12351264
# Cache train set predictions since they are already computed during sampling
12361265
if keep_sample:
1237-
yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
1266+
yhat_train_raw[:, sample_counter] = (
1267+
forest_sampler_mean.get_cached_forest_predictions()
1268+
)
12381269

12391270
# Sample the variance forest
12401271
if self.include_variance_forest:
@@ -1253,7 +1284,9 @@ def sample(
12531284

12541285
# Cache train set predictions since they are already computed during sampling
12551286
if keep_sample:
1256-
sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
1287+
sigma2_x_train_raw[:, sample_counter] = (
1288+
forest_sampler_variance.get_cached_forest_predictions()
1289+
)
12571290

12581291
# Sample variance parameters (if requested)
12591292
if self.sample_sigma2_global:
@@ -1435,7 +1468,9 @@ def sample(
14351468
)
14361469

14371470
if keep_sample:
1438-
yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
1471+
yhat_train_raw[:, sample_counter] = (
1472+
forest_sampler_mean.get_cached_forest_predictions()
1473+
)
14391474

14401475
# Sample the variance forest
14411476
if self.include_variance_forest:
@@ -1453,7 +1488,9 @@ def sample(
14531488
)
14541489

14551490
if keep_sample:
1456-
sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
1491+
sigma2_x_train_raw[:, sample_counter] = (
1492+
forest_sampler_variance.get_cached_forest_predictions()
1493+
)
14571494

14581495
# Sample variance parameters (if requested)
14591496
if self.sample_sigma2_global:
@@ -1504,9 +1541,9 @@ def sample(
15041541
if self.sample_sigma2_leaf:
15051542
self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:]
15061543
if self.include_mean_forest:
1507-
yhat_train_raw = yhat_train_raw[:,num_gfr:]
1544+
yhat_train_raw = yhat_train_raw[:, num_gfr:]
15081545
if self.include_variance_forest:
1509-
sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:]
1546+
sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:]
15101547
self.num_samples -= num_gfr
15111548

15121549
# Store predictions
@@ -1553,7 +1590,10 @@ def sample(
15531590
)
15541591
else:
15551592
self.sigma2_x_train = (
1556-
np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std
1593+
np.exp(sigma2_x_train_raw)
1594+
* self.sigma2_init
1595+
* self.y_std
1596+
* self.y_std
15571597
)
15581598
if self.has_test:
15591599
sigma2_x_test_raw = (
@@ -1577,10 +1617,10 @@ def predict(
15771617
covariates: Union[np.array, pd.DataFrame],
15781618
basis: np.array = None,
15791619
rfx_group_ids: np.array = None,
1580-
rfx_basis: np.array = None,
1581-
type: str = "posterior",
1582-
terms: Union[list[str], str] = "all",
1583-
scale: str = "linear"
1620+
rfx_basis: np.array = None,
1621+
type: str = "posterior",
1622+
terms: Union[list[str], str] = "all",
1623+
scale: str = "linear",
15841624
) -> Union[np.array, tuple]:
15851625
"""Return predictions from every forest sampled (either / both of mean and variance).
15861626
Return type is either a single array of predictions, if a BART model only includes a
@@ -1634,28 +1674,39 @@ def predict(
16341674
has_variance_forest = self.include_variance_forest
16351675
has_rfx = self.has_rfx
16361676
has_y_hat = has_mean_forest or has_rfx
1637-
predict_y_hat = ((has_y_hat and ("y_hat" in terms)) or
1638-
(has_y_hat and ("all" in terms)))
1639-
predict_mean_forest = ((has_mean_forest and ("mean_forest" in terms)) or
1640-
(has_mean_forest and ("all" in terms)))
1641-
predict_rfx = ((has_rfx and ("rfx" in terms)) or
1642-
(has_rfx and ("all" in terms)))
1643-
predict_variance_forest = ((has_variance_forest and ("variance_forest" in terms)) or
1644-
(has_variance_forest and ("all" in terms)))
1645-
predict_count = (predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest)
1677+
predict_y_hat = (has_y_hat and ("y_hat" in terms)) or (
1678+
has_y_hat and ("all" in terms)
1679+
)
1680+
predict_mean_forest = (has_mean_forest and ("mean_forest" in terms)) or (
1681+
has_mean_forest and ("all" in terms)
1682+
)
1683+
predict_rfx = (has_rfx and ("rfx" in terms)) or (has_rfx and ("all" in terms))
1684+
predict_variance_forest = (
1685+
has_variance_forest and ("variance_forest" in terms)
1686+
) or (has_variance_forest and ("all" in terms))
1687+
predict_count = (
1688+
predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest
1689+
)
16461690
if predict_count == 0:
16471691
term_list = ", ".join(terms)
1648-
warnings.warn(f"None of the requested model terms, {term_list}, were fit in this model")
1692+
warnings.warn(
1693+
f"None of the requested model terms, {term_list}, were fit in this model"
1694+
)
16491695
return None
16501696
predict_rfx_intermediate = predict_y_hat and has_rfx
16511697
predict_mean_forest_intermediate = predict_y_hat and has_mean_forest
16521698

16531699
# Check that we have at least one term to predict on probability scale
1654-
if (probability_scale and not predict_y_hat and not predict_mean_forest and not predict_rfx):
1700+
if (
1701+
probability_scale
1702+
and not predict_y_hat
1703+
and not predict_mean_forest
1704+
and not predict_rfx
1705+
):
16551706
raise ValueError(
16561707
"scale can only be 'probability' if at least one mean term is requested"
16571708
)
1658-
1709+
16591710
# Check the model is valid
16601711
if not self.is_sampled():
16611712
msg = (
@@ -1730,7 +1781,9 @@ def predict(
17301781
variance_pred_raw * self.sigma2_init * self.y_std * self.y_std
17311782
)
17321783
if predict_mean:
1733-
variance_forest_predictions = np.mean(variance_forest_predictions, axis = 1)
1784+
variance_forest_predictions = np.mean(
1785+
variance_forest_predictions, axis=1
1786+
)
17341787

17351788
# Forest predictions
17361789
if predict_mean_forest or predict_mean_forest_intermediate:
@@ -1756,7 +1809,7 @@ def predict(
17561809
y_hat = mean_forest_predictions
17571810
elif predict_y_hat and has_rfx:
17581811
y_hat = rfx_predictions
1759-
1812+
17601813
if probability_scale:
17611814
if predict_y_hat and has_mean_forest and has_rfx:
17621815
y_hat = norm.ppf(mean_forest_predictions + rfx_predictions)
@@ -1775,16 +1828,16 @@ def predict(
17751828
y_hat = mean_forest_predictions
17761829
elif predict_y_hat and has_rfx:
17771830
y_hat = rfx_predictions
1778-
1831+
17791832
# Collapse to posterior mean predictions if requested
17801833
if predict_mean:
17811834
if predict_mean_forest:
1782-
mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1)
1835+
mean_forest_predictions = np.mean(mean_forest_predictions, axis=1)
17831836
if predict_rfx:
1784-
rfx_predictions = np.mean(rfx_predictions, axis = 1)
1837+
rfx_predictions = np.mean(rfx_predictions, axis=1)
17851838
if predict_y_hat:
1786-
y_hat = np.mean(y_hat, axis = 1)
1787-
1839+
y_hat = np.mean(y_hat, axis=1)
1840+
17881841
if predict_count == 1:
17891842
if predict_y_hat:
17901843
return y_hat
@@ -1813,7 +1866,7 @@ def predict(
18131866
else:
18141867
result["variance_forest_predictions"] = None
18151868
return result
1816-
1869+
18171870
def predict_mean(
18181871
self,
18191872
covariates: np.array,

0 commit comments

Comments
 (0)