Skip to content

Commit 574fb90

Browse files
committed
Initial BART refactor that avoids double predicting forests on the train dataset
1 parent 33174a6 commit 574fb90

File tree

5 files changed

+52
-2
lines changed

5 files changed

+52
-2
lines changed

R/bart.R

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
707707
num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains
708708
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
709709
if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples)
710+
if (include_mean_forest) mean_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
711+
if (include_variance_forest) variance_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
710712
sample_counter <- 0
711713

712714
# Initialize the leaves of each tree in the mean forest
@@ -757,13 +759,23 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
757759
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
758760
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
759761
)
762+
763+
# Cache predictions
764+
if (keep_sample) {
765+
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
766+
}
760767
}
761768
if (include_variance_forest) {
762769
forest_model_variance$sample_one_iteration(
763770
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
764771
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
765772
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
766773
)
774+
775+
# Cache predictions
776+
if (keep_sample) {
777+
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
778+
}
767779
}
768780
if (sample_sigma2_global) {
769781
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -910,13 +922,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
910922
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
911923
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
912924
)
925+
926+
if (keep_sample) {
927+
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
928+
}
913929
}
914930
if (include_variance_forest) {
915931
forest_model_variance$sample_one_iteration(
916932
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
917933
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
918934
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
919935
)
936+
937+
if (keep_sample) {
938+
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
939+
}
920940
}
921941
if (sample_sigma2_global) {
922942
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -949,6 +969,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
949969
rfx_samples$delete_sample(0)
950970
}
951971
}
972+
if (include_mean_forest) {
973+
mean_forest_pred_train <- mean_forest_pred_train[,(num_gfr+1):ncol(mean_forest_pred_train)]
974+
}
975+
if (include_variance_forest) {
976+
variance_forest_pred_train <- variance_forest_pred_train[,(num_gfr+1):ncol(variance_forest_pred_train)]
977+
}
952978
if (sample_sigma2_global) {
953979
global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)]
954980
}
@@ -960,13 +986,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
960986

961987
# Mean forest predictions
962988
if (include_mean_forest) {
963-
y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
989+
# y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
990+
y_hat_train <- mean_forest_pred_train*y_std_train + y_bar_train
964991
if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train
965992
}
966993

967994
# Variance forest predictions
968995
if (include_variance_forest) {
969-
sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
996+
# sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
997+
sigma2_x_hat_train <- variance_forest_pred_train
970998
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
971999
}
9721000

R/model.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ ForestModel <- R6::R6Class(
126126
}
127127
},
128128

129+
#' @description
130+
#' Extract an internally-cached prediction of a forest on the training dataset in a sampler.
131+
#' @return Vector with as many elements as observations in the training dataset
132+
get_cached_forest_predictions = function() {
133+
get_cached_forest_predictions_cpp(self$tracker_ptr)
134+
},
135+
129136
#' @description
130137
#' Propagates basis update through to the (full/partial) residual by iteratively
131138
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions

include/stochtree/partition_tracker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class ForestTracker {
9191
SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();}
9292
UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();}
9393
SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();}
94+
int GetNumObservations() {return num_observations_;}
95+
int GetNumTrees() {return num_trees_;}
96+
int GetNumFeatures() {return num_features_;}
97+
bool Initialized() {return initialized_;}
9498

9599
private:
96100
/*! \brief Mapper from observations to predicted values summed over every tree in a forest */

src/py_stochtree.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
21472147
.def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest)
21482148
.def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration)
21492149
.def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel)
2150+
.def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions)
21502151
.def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate)
21512152
.def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate)
21522153
.def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha)

src/sampler.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,16 @@ cpp11::external_pointer<StochTree::ForestTracker> forest_tracker_cpp(cpp11::exte
284284
return cpp11::external_pointer<StochTree::ForestTracker>(tracker_ptr_.release());
285285
}
286286

287+
[[cpp11::register]]
288+
cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker_ptr) {
289+
int n_train = tracker_ptr->GetNumObservations();
290+
cpp11::writable::doubles output(n_train);
291+
for (int i = 0; i < n_train; i++) {
292+
output[i] = tracker_ptr->GetSamplePrediction(i);
293+
}
294+
return output;
295+
}
296+
287297
[[cpp11::register]]
288298
cpp11::writable::integers sample_without_replacement_integer_cpp(
289299
cpp11::integers population_vector,

0 commit comments

Comments
 (0)