Skip to content

Commit 1ff234b

Browse files
authored
Merge pull request #173 from StochasticTree/restricted-sweep-sampling
Enable updating only a user-provided subset of trees in a forest sweep through the custom sampler interface
2 parents 795b1e8 + 4e63606 commit 1ff234b

File tree

15 files changed

+549
-56
lines changed

15 files changed

+549
-56
lines changed

R/config.R

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ ForestModelConfig <- R6::R6Class(
1717
#' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
1818
feature_types = NULL,
1919

20+
#' @field sweep_update_indices Vector of trees to update in a sweep
21+
sweep_update_indices = NULL,
22+
2023
#' @field num_trees Number of trees in the forest being sampled
2124
num_trees = NULL,
2225

@@ -62,6 +65,7 @@ ForestModelConfig <- R6::R6Class(
6265
#' Create a new ForestModelConfig object.
6366
#'
6467
#' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
68+
#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep
6569
#' @param num_trees Number of trees in the forest being sampled
6670
#' @param num_features Number of features in training dataset
6771
#' @param num_observations Number of observations in training dataset
@@ -78,7 +82,7 @@ ForestModelConfig <- R6::R6Class(
7882
#' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`)
7983
#'
8084
#' @return A new ForestModelConfig object.
81-
initialize = function(feature_types = NULL, num_trees = NULL, num_features = NULL,
85+
initialize = function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL,
8286
num_observations = NULL, variable_weights = NULL, leaf_dimension = 1,
8387
alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1,
8488
leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0,
@@ -101,6 +105,10 @@ ForestModelConfig <- R6::R6Class(
101105
if (is.null(num_trees)) {
102106
stop("num_trees must be provided")
103107
}
108+
if (!is.null(sweep_update_indices)) {
109+
stopifnot(min(sweep_update_indices) >= 0)
110+
stopifnot(max(sweep_update_indices) < num_trees)
111+
}
104112
if (is.null(num_observations)) {
105113
stop("num_observations must be provided")
106114
}
@@ -111,6 +119,7 @@ ForestModelConfig <- R6::R6Class(
111119
stop("`variable_weights` must have `num_features` total elements")
112120
}
113121
self$feature_types <- feature_types
122+
self$sweep_update_indices <- sweep_update_indices
114123
self$variable_weights <- variable_weights
115124
self$num_trees <- num_trees
116125
self$num_features <- num_features
@@ -158,6 +167,17 @@ ForestModelConfig <- R6::R6Class(
158167
self$feature_types <- feature_types
159168
},
160169

170+
#' @description
171+
#' Update sweep update indices
172+
#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep
173+
update_sweep_indices = function(sweep_update_indices) {
174+
if (!is.null(sweep_update_indices)) {
175+
stopifnot(min(sweep_update_indices) >= 0)
176+
stopifnot(max(sweep_update_indices) < self$num_trees)
177+
}
178+
self$sweep_update_indices <- sweep_update_indices
179+
},
180+
161181
#' @description
162182
#' Update variable weights
163183
#' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset
@@ -242,6 +262,13 @@ ForestModelConfig <- R6::R6Class(
242262
return(self$feature_types)
243263
},
244264

265+
#' @description
266+
#' Query sweep update indices for this ForestModelConfig object
267+
#' @returns Vector of (0-indexed) indices of trees to update in a sweep
268+
get_sweep_indices = function() {
269+
return(self$sweep_update_indices)
270+
},
271+
245272
#' @description
246273
#' Query variable weights for this ForestModelConfig object
247274
#' @returns Vector specifying sampling probability for all p covariates in ForestDataset
@@ -382,6 +409,7 @@ GlobalModelConfig <- R6::R6Class(
382409
#' Create a forest model config object
383410
#'
384411
#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
412+
#' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep
385413
#' @param num_trees Number of trees in the forest being sampled
386414
#' @param num_features Number of features in training dataset
387415
#' @param num_observations Number of observations in training dataset
@@ -401,13 +429,13 @@ GlobalModelConfig <- R6::R6Class(
401429
#'
402430
#' @examples
403431
#' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100)
404-
createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_features = NULL,
432+
createForestModelConfig <- function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL,
405433
num_observations = NULL, variable_weights = NULL, leaf_dimension = 1,
406434
alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1,
407435
leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0,
408436
variance_forest_scale = 1.0, cutpoint_grid_size = 100){
409437
return(invisible((
410-
ForestModelConfig$new(feature_types, num_trees, num_features, num_observations,
438+
ForestModelConfig$new(feature_types, sweep_update_indices, num_trees, num_features, num_observations,
411439
variable_weights, leaf_dimension, alpha, beta, min_samples_leaf,
412440
max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape,
413441
variance_forest_scale, cutpoint_grid_size)

R/cpp11.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,12 +580,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
580580
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
581581
}
582582

583-
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
584-
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
583+
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
584+
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
585585
}
586586

587-
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
588-
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
587+
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
588+
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
589589
}
590590

591591
sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {

R/model.R

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,15 @@ ForestModel <- R6::R6Class(
7070
#' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`.
7171
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`.
7272
sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest,
73-
rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE) {
73+
rng, forest_model_config, global_model_config,
74+
keep_forest = TRUE, gfr = TRUE) {
7475
if (active_forest$is_empty()) {
7576
stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.")
7677
}
7778

7879
# Unpack parameters from model config object
7980
feature_types <- forest_model_config$feature_types
81+
sweep_update_indices <- forest_model_config$sweep_update_indices
8082
leaf_model_int <- forest_model_config$leaf_model_type
8183
leaf_model_scale <- forest_model_config$leaf_model_scale
8284
variable_weights <- forest_model_config$variable_weights
@@ -85,6 +87,12 @@ ForestModel <- R6::R6Class(
8587
global_scale <- global_model_config$global_error_variance
8688
cutpoint_grid_size <- forest_model_config$cutpoint_grid_size
8789

90+
# Default to empty integer vector if sweep_update_indices is NULL
91+
if (is.null(sweep_update_indices)) {
92+
# sweep_update_indices <- integer(0)
93+
sweep_update_indices <- 0:(forest_model_config$num_trees - 1)
94+
}
95+
8896
# Detect changes to tree prior
8997
if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) {
9098
update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha)
@@ -104,14 +112,14 @@ ForestModel <- R6::R6Class(
104112
sample_gfr_one_iteration_cpp(
105113
forest_dataset$data_ptr, residual$data_ptr,
106114
forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr,
107-
self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
115+
self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale,
108116
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest
109117
)
110118
} else {
111119
sample_mcmc_one_iteration_cpp(
112120
forest_dataset$data_ptr, residual$data_ptr,
113121
forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr,
114-
self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
122+
self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale,
115123
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest
116124
)
117125
}

debug/api_debug.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
669669
// Prepare the samplers
670670
LeafModelVariant leaf_model = leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest);
671671

672+
// Initialize vector of sweep update indices
673+
std::vector<int> sweep_indices(num_trees);
674+
std::iota(sweep_indices.begin(), sweep_indices.end(), 0);
675+
672676
// Run the GFR sampler
673677
if (num_gfr > 0) {
674678
for (int i = 0; i < num_gfr; i++) {
@@ -683,13 +687,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
683687

684688
// Sample tree ensemble
685689
if (model_type == ModelType::kConstantLeafGaussian) {
686-
GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true);
690+
GFRSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true);
687691
} else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) {
688-
GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true);
692+
GFRSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true);
689693
} else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) {
690-
GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols);
694+
GFRSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, true, omega_cols);
691695
} else if (model_type == ModelType::kLogLinearVariance) {
692-
GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, feature_types, cutpoint_grid_size, true, true, false);
696+
GFRSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, feature_types, cutpoint_grid_size, true, true, false);
693697
}
694698

695699
if (rfx_included) {
@@ -720,13 +724,13 @@ void RunDebug(int dgp_num = 0, const ModelType model_type = kConstantLeafGaussia
720724

721725
// Sample tree ensemble
722726
if (model_type == ModelType::kConstantLeafGaussian) {
723-
MCMCSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true);
727+
MCMCSampleOneIter<GaussianConstantLeafModel, GaussianConstantSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianConstantLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true);
724728
} else if (model_type == ModelType::kUnivariateRegressionLeafGaussian) {
725-
MCMCSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true);
729+
MCMCSampleOneIter<GaussianUnivariateRegressionLeafModel, GaussianUnivariateRegressionSuffStat>(active_forest, tracker, forest_samples, std::get<GaussianUnivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true);
726730
} else if (model_type == ModelType::kMultivariateRegressionLeafGaussian) {
727-
MCMCSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, true, omega_cols);
731+
MCMCSampleOneIter<GaussianMultivariateRegressionLeafModel, GaussianMultivariateRegressionSuffStat, int>(active_forest, tracker, forest_samples, std::get<GaussianMultivariateRegressionLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, true, omega_cols);
728732
} else if (model_type == ModelType::kLogLinearVariance) {
729-
MCMCSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, global_variance, true, true, false);
733+
MCMCSampleOneIter<LogLinearVarianceLeafModel, LogLinearVarianceSuffStat>(active_forest, tracker, forest_samples, std::get<LogLinearVarianceLeafModel>(leaf_model), dataset, residual, tree_prior, gen, variable_weights, sweep_indices, global_variance, true, true, false);
730734
}
731735

732736
if (rfx_included) {

0 commit comments

Comments
 (0)