Skip to content

Commit 69e6d07

Browse files
authored
Merge pull request #176 from StochasticTree/gfr-feature-subsampling
Support probability-weighted feature subsampling in the GFR algorithm
2 parents 778c9f9 + ee95a2f commit 69e6d07

32 files changed

+882
-84
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ export(rootResetRandomEffectsModel)
5757
export(rootResetRandomEffectsTracker)
5858
export(sampleGlobalErrorVarianceOneIteration)
5959
export(sampleLeafVarianceOneIteration)
60+
export(sample_without_replacement)
6061
export(saveBARTModelToJson)
6162
export(saveBARTModelToJsonFile)
6263
export(saveBARTModelToJsonString)

R/bart.R

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
4545
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
4646
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
47+
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
4748
#'
4849
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
4950
#'
@@ -58,7 +59,7 @@
5859
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
5960
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
6061
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
61-
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
62+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
6263
#'
6364
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
6465
#'
@@ -73,6 +74,7 @@
7374
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set.
7475
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
7576
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
77+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
7678
#'
7779
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
7880
#' @export
@@ -98,6 +100,7 @@
98100
#' X_train <- X[train_inds,]
99101
#' y_test <- y[test_inds]
100102
#' y_train <- y[train_inds]
103+
#'
101104
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
102105
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
103106
bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL,
@@ -114,7 +117,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
114117
sigma2_global_shape = 0, sigma2_global_scale = 0,
115118
variable_weights = NULL, random_seed = -1,
116119
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
117-
num_chains = 1, verbose = FALSE
120+
num_chains = 1, verbose = FALSE,
121+
probit_outcome_model = FALSE
118122
)
119123
general_params_updated <- preprocessParams(
120124
general_params_default, general_params
@@ -127,7 +131,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
127131
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
128132
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
129133
keep_vars = NULL, drop_vars = NULL,
130-
probit_outcome_model = FALSE
134+
num_features_subsample = NULL
131135
)
132136
mean_forest_params_updated <- preprocessParams(
133137
mean_forest_params_default, mean_forest_params
@@ -141,7 +145,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
141145
var_forest_leaf_init = NULL,
142146
var_forest_prior_shape = NULL,
143147
var_forest_prior_scale = NULL,
144-
keep_vars = NULL, drop_vars = NULL
148+
keep_vars = NULL, drop_vars = NULL,
149+
num_features_subsample = NULL
145150
)
146151
variance_forest_params_updated <- preprocessParams(
147152
variance_forest_params_default, variance_forest_params
@@ -162,6 +167,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
162167
keep_every <- general_params_updated$keep_every
163168
num_chains <- general_params_updated$num_chains
164169
verbose <- general_params_updated$verbose
170+
probit_outcome_model <- general_params_updated$probit_outcome_model
165171

166172
# 2. Mean forest parameters
167173
num_trees_mean <- mean_forest_params_updated$num_trees
@@ -175,7 +181,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
175181
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
176182
keep_vars_mean <- mean_forest_params_updated$keep_vars
177183
drop_vars_mean <- mean_forest_params_updated$drop_vars
178-
probit_outcome_model <- mean_forest_params_updated$probit_outcome_model
184+
num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample
179185

180186
# 3. Variance forest parameters
181187
num_trees_variance <- variance_forest_params_updated$num_trees
@@ -189,6 +195,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
189195
b_forest <- variance_forest_params_updated$var_forest_prior_scale
190196
keep_vars_variance <- variance_forest_params_updated$keep_vars
191197
drop_vars_variance <- variance_forest_params_updated$drop_vars
198+
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample
192199

193200
# Check if there are enough GFR samples to seed num_chains samplers
194201
if (num_gfr > 0) {
@@ -373,6 +380,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
373380
variable_weights_variance <- variable_weights_variance[original_var_indices]*variable_weights_adj
374381
variable_weights_variance[!(original_var_indices %in% variable_subset_variance)] <- 0
375382
}
383+
384+
# Set num_features_subsample to default, ncol(X_train), if not already set
385+
if (is.null(num_features_subsample_mean)) {
386+
num_features_subsample_mean <- ncol(X_train)
387+
}
388+
if (is.null(num_features_subsample_variance)) {
389+
num_features_subsample_variance <- ncol(X_train)
390+
}
376391

377392
# Convert all input data to matrices if not already converted
378393
if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) {
@@ -633,15 +648,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
633648
num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension,
634649
alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean,
635650
leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale,
636-
cutpoint_grid_size=cutpoint_grid_size)
651+
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample=num_features_subsample_mean)
637652
forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config)
638653
}
639654
if (include_variance_forest) {
640655
forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train),
641656
num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1,
642657
alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance,
643658
max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest,
644-
cutpoint_grid_size=cutpoint_grid_size)
659+
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample=num_features_subsample_variance)
645660
forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config)
646661
}
647662

R/bcf.R

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
6363
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
6464
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
65+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
6566
#'
6667
#' @param treatment_effect_forest_params (Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional.
6768
#'
@@ -78,6 +79,7 @@
7879
#' - `delta_max` Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Default: `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`.
7980
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
8081
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
82+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
8183
#'
8284
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
8385
#'
@@ -92,6 +94,7 @@
9294
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set.
9395
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
9496
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
97+
#' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
9598
#'
9699
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
97100
#' @export
@@ -171,7 +174,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
171174
min_samples_leaf = 5, max_depth = 10,
172175
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
173176
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
174-
keep_vars = NULL, drop_vars = NULL
177+
keep_vars = NULL, drop_vars = NULL,
178+
num_features_subsample = NULL
175179
)
176180
prognostic_forest_params_updated <- preprocessParams(
177181
prognostic_forest_params_default, prognostic_forest_params
@@ -183,8 +187,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
183187
min_samples_leaf = 5, max_depth = 5,
184188
sample_sigma2_leaf = FALSE, sigma2_leaf_init = NULL,
185189
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
186-
keep_vars = NULL, drop_vars = NULL,
187-
delta_max = 0.9
190+
keep_vars = NULL, drop_vars = NULL, delta_max = 0.9,
191+
num_features_subsample = NULL
188192
)
189193
treatment_effect_forest_params_updated <- preprocessParams(
190194
treatment_effect_forest_params_default, treatment_effect_forest_params
@@ -198,7 +202,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
198202
variance_forest_init = NULL,
199203
var_forest_prior_shape = NULL,
200204
var_forest_prior_scale = NULL,
201-
keep_vars = NULL, drop_vars = NULL
205+
keep_vars = NULL, drop_vars = NULL,
206+
num_features_subsample = NULL
202207
)
203208
variance_forest_params_updated <- preprocessParams(
204209
variance_forest_params_default, variance_forest_params
@@ -238,6 +243,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
238243
b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale
239244
keep_vars_mu <- prognostic_forest_params_updated$keep_vars
240245
drop_vars_mu <- prognostic_forest_params_updated$drop_vars
246+
num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample
241247

242248
# 3. Tau forest parameters
243249
num_trees_tau <- treatment_effect_forest_params_updated$num_trees
@@ -252,6 +258,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
252258
keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars
253259
drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars
254260
delta_max <- treatment_effect_forest_params_updated$delta_max
261+
num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample
255262

256263
# 4. Variance forest parameters
257264
num_trees_variance <- variance_forest_params_updated$num_trees
@@ -265,6 +272,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
265272
b_forest <- variance_forest_params_updated$var_forest_prior_scale
266273
keep_vars_variance <- variance_forest_params_updated$keep_vars
267274
drop_vars_variance <- variance_forest_params_updated$drop_vars
275+
num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample
268276

269277
# Check if there are enough GFR samples to seed num_chains samplers
270278
if (num_gfr > 0) {
@@ -477,6 +485,17 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
477485
X_test_raw <- X_test
478486
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
479487

488+
# Set num_features_subsample to default, ncol(X_train), if not already set
489+
if (is.null(num_features_subsample_mu)) {
490+
num_features_subsample_mu <- ncol(X_train)
491+
}
492+
if (is.null(num_features_subsample_tau)) {
493+
num_features_subsample_tau <- ncol(X_train)
494+
}
495+
if (is.null(num_features_subsample_variance)) {
496+
num_features_subsample_variance <- ncol(X_train)
497+
}
498+
480499
# Convert all input data to matrices if not already converted
481500
if ((is.null(dim(Z_train))) && (!is.null(Z_train))) {
482501
Z_train <- as.matrix(as.numeric(Z_train))
@@ -899,20 +918,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
899918
num_observations=nrow(X_train), variable_weights=variable_weights_mu, leaf_dimension=leaf_dimension_mu_forest,
900919
alpha=alpha_mu, beta=beta_mu, min_samples_leaf=min_samples_leaf_mu, max_depth=max_depth_mu,
901920
leaf_model_type=leaf_model_mu_forest, leaf_model_scale=current_leaf_scale_mu,
902-
cutpoint_grid_size=cutpoint_grid_size)
921+
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample = num_features_subsample_mu)
903922
forest_model_config_tau <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_tau, num_features=ncol(X_train),
904923
num_observations=nrow(X_train), variable_weights=variable_weights_tau, leaf_dimension=leaf_dimension_tau_forest,
905924
alpha=alpha_tau, beta=beta_tau, min_samples_leaf=min_samples_leaf_tau, max_depth=max_depth_tau,
906925
leaf_model_type=leaf_model_tau_forest, leaf_model_scale=current_leaf_scale_tau,
907-
cutpoint_grid_size=cutpoint_grid_size)
926+
cutpoint_grid_size=cutpoint_grid_size, num_features_subsample = num_features_subsample_tau)
908927
forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu, global_model_config)
909928
forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau, global_model_config)
910929
if (include_variance_forest) {
911930
forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train),
912931
num_observations=nrow(X_train), variable_weights=variable_weights_variance,
913932
leaf_dimension=leaf_dimension_variance_forest, alpha=alpha_variance, beta=beta_variance,
914933
min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance,
915-
leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size)
934+
leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size,
935+
num_features_subsample=num_features_subsample_variance)
916936
forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config)
917937
}
918938

0 commit comments

Comments
 (0)