Skip to content

Commit 005f700

Browse files
committed
Update Python and R train set prediction caching logic
1 parent cb164de commit 005f700

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

R/bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
994994
# Variance forest predictions
995995
if (include_variance_forest) {
996996
# sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
997-
sigma2_x_hat_train <- variance_forest_pred_train
997+
sigma2_x_hat_train <- exp(variance_forest_pred_train)
998998
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
999999
}
10001000

src/py_stochtree.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,16 @@ class ForestSamplerCpp {
11661166
}
11671167
}
11681168

1169+
py::array_t<double> GetCachedForestPredictions() {
1170+
int n_train = tracker_->GetNumObservations();
1171+
auto output = py::array_t<double>(py::detail::any_container<py::ssize_t>({n_train}));
1172+
auto accessor = output.mutable_unchecked<1>();
1173+
for (size_t i = 0; i < n_train; i++) {
1174+
accessor(i) = tracker_->GetSamplePrediction(i);
1175+
}
1176+
return output;
1177+
}
1178+
11691179
void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest) {
11701180
// Perform the update operation
11711181
StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest.GetEnsemble());

stochtree/bart.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,10 @@ def sample(
10051005
self.global_var_samples = np.empty(self.num_samples, dtype=np.float64)
10061006
if sample_sigma2_leaf:
10071007
self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64)
1008+
if self.include_mean_forest:
1009+
yhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
1010+
if self.include_variance_forest:
1011+
sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
10081012
sample_counter = -1
10091013

10101014
# Forest Dataset (covariates and optional basis)
@@ -1187,6 +1191,9 @@ def sample(
11871191
True,
11881192
)
11891193

1194+
if keep_sample:
1195+
yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
1196+
11901197
# Sample the variance forest
11911198
if self.include_variance_forest:
11921199
forest_sampler_variance.sample_one_iteration(
@@ -1201,6 +1208,9 @@ def sample(
12011208
True,
12021209
)
12031210

1211+
if keep_sample:
1212+
sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
1213+
12041214
# Sample variance parameters (if requested)
12051215
if self.sample_sigma2_global:
12061216
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1441,6 +1451,10 @@ def sample(
14411451
self.global_var_samples = self.global_var_samples[num_gfr:]
14421452
if self.sample_sigma2_leaf:
14431453
self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:]
1454+
if self.include_mean_forest:
1455+
yhat_train_raw = yhat_train_raw[:,num_gfr:]
1456+
if self.include_variance_forest:
1457+
sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:]
14441458
self.num_samples -= num_gfr
14451459

14461460
# Store predictions
@@ -1451,9 +1465,6 @@ def sample(
14511465
self.leaf_scale_samples = self.leaf_scale_samples
14521466

14531467
if self.include_mean_forest:
1454-
yhat_train_raw = self.forest_container_mean.forest_container_cpp.Predict(
1455-
forest_dataset_train.dataset_cpp
1456-
)
14571468
self.y_hat_train = yhat_train_raw * self.y_std + self.y_bar
14581469
if self.has_test:
14591470
yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict(
@@ -1482,11 +1493,7 @@ def sample(
14821493
self.y_hat_test = rfx_preds_test
14831494

14841495
if self.include_variance_forest:
1485-
sigma2_x_train_raw = (
1486-
self.forest_container_variance.forest_container_cpp.Predict(
1487-
forest_dataset_train.dataset_cpp
1488-
)
1489-
)
1496+
sigma2_x_train_raw = np.exp(sigma2_x_train_raw)
14901497
if self.sample_sigma2_global:
14911498
self.sigma2_x_train = sigma2_x_train_raw
14921499
for i in range(self.num_samples):

stochtree/sampler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,17 @@ def propagate_basis_update(
266266
self.forest_sampler_cpp.PropagateBasisUpdate(
267267
dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp
268268
)
269+
270+
def get_cached_forest_predictions(self) -> np.array:
271+
"""
272+
Extract an internally-cached prediction of a forest on the training dataset in a sampler.
273+
274+
Returns
275+
----------
276+
np.array
277+
Numpy 1D array with as many elements as observations in the training dataset
278+
"""
279+
return self.forest_sampler_cpp.GetCachedForestPredictions()
269280

270281
def update_alpha(self, alpha: float) -> None:
271282
"""

0 commit comments

Comments
 (0)