Skip to content

Commit 19b8645

Browse files
committed
Initial implementation of flexible random effects parameters through the BART interface
1 parent 78cc4c6 commit 19b8645

File tree

7 files changed

+400
-28
lines changed

7 files changed

+400
-28
lines changed

R/bart.R

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@
4545
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
4646
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
4747
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
48+
#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
49+
#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
50+
#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
51+
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
52+
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
53+
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
4854
#'
4955
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
5056
#'
@@ -118,7 +124,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
118124
variable_weights = NULL, random_seed = -1,
119125
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
120126
num_chains = 1, verbose = FALSE,
121-
probit_outcome_model = FALSE
127+
probit_outcome_model = FALSE,
128+
rfx_working_parameter_prior_mean = NULL,
129+
rfx_group_parameter_prior_mean = NULL,
130+
rfx_working_parameter_prior_cov = NULL,
131+
rfx_group_parameter_prior_cov = NULL,
132+
rfx_variance_prior_shape = 1,
133+
rfx_variance_prior_scale = 1
122134
)
123135
general_params_updated <- preprocessParams(
124136
general_params_default, general_params
@@ -168,6 +180,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
168180
num_chains <- general_params_updated$num_chains
169181
verbose <- general_params_updated$verbose
170182
probit_outcome_model <- general_params_updated$probit_outcome_model
183+
rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean
184+
rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean
185+
rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov
186+
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
187+
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
188+
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
171189

172190
# 2. Mean forest parameters
173191
num_trees_mean <- mean_forest_params_updated$num_trees
@@ -672,19 +690,43 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
672690

673691
# Random effects initialization
674692
if (has_rfx) {
675-
# Prior parameters
676-
if (num_rfx_components == 1) {
677-
alpha_init <- c(1)
678-
} else if (num_rfx_components > 1) {
679-
alpha_init <- c(1,rep(0,num_rfx_components-1))
693+
# Prior parameters - use user-provided values or defaults
694+
if (is.null(rfx_working_parameter_prior_mean)) {
695+
if (num_rfx_components == 1) {
696+
alpha_init <- c(1)
697+
} else if (num_rfx_components > 1) {
698+
alpha_init <- c(1,rep(0,num_rfx_components-1))
699+
} else {
700+
stop("There must be at least 1 random effect component")
701+
}
680702
} else {
681-
stop("There must be at least 1 random effect component")
703+
alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
682704
}
683-
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
684-
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
685-
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
686-
sigma_xi_shape <- 1
687-
sigma_xi_scale <- 1
705+
706+
if (is.null(rfx_group_parameter_prior_mean)) {
707+
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
708+
} else {
709+
xi_init <- expand_dims_1d(rfx_group_parameter_prior_mean, num_rfx_components)
710+
# If it's a vector, expand to matrix
711+
if (is.vector(xi_init)) {
712+
xi_init <- matrix(rep(xi_init, num_rfx_groups), num_rfx_components, num_rfx_groups)
713+
}
714+
}
715+
716+
if (is.null(rfx_working_parameter_prior_cov)) {
717+
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
718+
} else {
719+
sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
720+
}
721+
722+
if (is.null(rfx_group_parameter_prior_cov)) {
723+
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
724+
} else {
725+
sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
726+
}
727+
728+
sigma_xi_shape <- rfx_variance_prior_shape
729+
sigma_xi_scale <- rfx_variance_prior_scale
688730

689731
# Random effects data structure and storage container
690732
rfx_dataset_train <- createRandomEffectsDataset(rfx_group_ids_train, rfx_basis_train)

R/utils.R

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,3 +855,86 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) {
855855
}
856856
return(x_preprocessed)
857857
}
858+
859+
#' Convert scalar input to vector of dimension `output_size`,
860+
#' or check that input array is equivalent to a vector of dimension `output_size`.
861+
#'
862+
#' @param input Input to be converted to a vector (or passed through as-is)
863+
#' @param output_size Intended size of the output vector
864+
#' @return A vector of length `output_size`
865+
#' @export
866+
expand_dims_1d <- function(input, output_size) {
867+
if (length(input) == 1) {
868+
output <- rep(input, output_size)
869+
} else if (is.numeric(input)) {
870+
if (length(input) != output_size) {
871+
stop("`input` must be a 1D numpy array with `output_size` elements")
872+
}
873+
output <- input
874+
} else {
875+
stop("`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times")
876+
}
877+
return(output)
878+
}
879+
880+
#' Ensures that input is propagated appropriately to a matrix of dimension `output_rows` x `output_cols`.
881+
#' Handles the following cases:
882+
#' 1. `input` is a scalar: output is simply a (`output_rows`, `output_cols`) matrix with `input` repeated for each element
883+
#' 2. `input` is a vector of length `output_rows`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_cols` columns
884+
#' 3. `input` is a vector of length `output_cols`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_rows` rows
885+
#' 4. `input` is a matrix of dimension (`output_rows`, `output_cols`): input is passed through as-is
886+
#' All other cases throw an error.
887+
#'
888+
#' @param input Input to be converted to a matrix (or passed through as-is)
889+
#' @param output_rows Intended number of rows in the output array
890+
#' @param output_cols Intended number of columns in the output array
891+
#' @return A matrix of dimension `output_rows` x `output_cols`
892+
#' @export
893+
expand_dims_2d <- function(input, output_rows, output_cols) {
894+
if (length(input) == 1) {
895+
output <- as.matrix(rep(input, output_rows * output_cols), ncol = output_cols)
896+
} else if (is.numeric(input)) {
897+
if (length(input) == output_cols) {
898+
output <- matrix(rep(x, output_rows), nrow=output_rows, byrow = T)
899+
} else if (length(input) == output_rows) {
900+
output <- matrix(rep(x, output_cols), ncol=output_cols, byrow = F)
901+
} else {
902+
stop("If `input` is a vector, it must either contain `output_rows` or `output_cols` elements")
903+
}
904+
} else if (is.matrix(input)) {
905+
if (nrow(input) != output_rows) {
906+
stop("`input` must be a matrix with `output_rows` rows")
907+
}
908+
if (ncol(input) != output_cols) {
909+
stop("`input` must be a matrix with `output_cols` columns")
910+
}
911+
output <- input
912+
} else {
913+
stop("`input` must be either a matrix, vector or a scalar")
914+
}
915+
return(output)
916+
}
917+
918+
#' Convert scalar input to square matrix of dimension `output_size` x `output_size` with `input` along the diagonal,
919+
#' or check that input array is equivalent to a square matrix of dimension `output_size` x `output_size`.
920+
#'
921+
#' @param input Input to be converted to a square matrix (or passed through as-is)
922+
#' @param output_size Intended row and column dimension of the square output matrix
923+
#' @return A square matrix of dimension `output_size` x `output_size`
924+
#' @export
925+
expand_dims_2d_diag <- function(input, output_size) {
926+
if (length(input) == 1) {
927+
output <- as.matrix(diag(input, output_size))
928+
} else if (is.matrix(input)) {
929+
if (nrow(input) != ncol(input)) {
930+
stop("`input` must be a square matrix")
931+
}
932+
if (nrow(input) != output_size) {
933+
stop("`input` must be a square matrix with `output_size` rows and columns")
934+
}
935+
output <- input
936+
} else {
937+
stop("`input` must be either a square matrix or a scalar")
938+
}
939+
return(output)
940+
}

stochtree/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
_check_matrix_square,
3232
_standardize_array_to_list,
3333
_standardize_array_to_np,
34+
_expand_dims_1d,
35+
_expand_dims_2d,
36+
_expand_dims_2d_diag
3437
)
3538

3639
__all__ = [

stochtree/bart.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel
2525
from .serialization import JSONSerializer
26-
from .utils import NotSampledError
26+
from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d_diag
2727

2828

2929
class BARTModel:
@@ -132,6 +132,12 @@ def sample(
132132
* `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
133133
* `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`.
134134
* `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`.
135+
* `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
136+
* `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
137+
* `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
138+
* `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
139+
* `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
140+
* `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
135141
136142
mean_forest_params : dict, optional
137143
Dictionary of mean forest model parameters, each of which has a default value processed internally, so this argument is optional.
@@ -190,6 +196,12 @@ def sample(
190196
"keep_every": 1,
191197
"num_chains": 1,
192198
"probit_outcome_model": False,
199+
"rfx_working_parameter_prior_mean": None,
200+
"rfx_group_parameter_prior_mean": None,
201+
"rfx_working_parameter_prior_cov": None,
202+
"rfx_group_parameter_prior_cov": None,
203+
"rfx_variance_prior_shape": 1.0,
204+
"rfx_variance_prior_scale": 1.0,
193205
}
194206
general_params_updated = _preprocess_params(
195207
general_params_default, general_params
@@ -279,6 +291,14 @@ def sample(
279291
drop_vars_variance = variance_forest_params_updated["drop_vars"]
280292
num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"]
281293

294+
# 4. Random effects parameters
295+
rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"]
296+
rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"]
297+
rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"]
298+
rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"]
299+
rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"]
300+
rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"]
301+
282302
# Override keep_gfr if there are no MCMC samples
283303
if num_mcmc == 0:
284304
keep_gfr = True
@@ -954,22 +974,43 @@ def sample(
954974

955975
# Set up random effects structures
956976
if self.has_rfx:
957-
if num_rfx_components == 1:
958-
alpha_init = np.array([1])
959-
elif num_rfx_components > 1:
960-
alpha_init = np.concatenate(
961-
(
962-
np.ones(1, dtype=float),
963-
np.zeros(num_rfx_components - 1, dtype=float),
977+
# Use user-provided values or defaults
978+
if rfx_working_parameter_prior_mean is None:
979+
if num_rfx_components == 1:
980+
alpha_init = np.array([1])
981+
elif num_rfx_components > 1:
982+
alpha_init = np.concatenate(
983+
(
984+
np.ones(1, dtype=float),
985+
np.zeros(num_rfx_components - 1, dtype=float),
986+
)
964987
)
965-
)
988+
else:
989+
raise ValueError("There must be at least 1 random effect component")
990+
else:
991+
alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
992+
993+
if rfx_group_parameter_prior_mean is None:
994+
xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups))
995+
else:
996+
xi_init = _expand_dims_1d(rfx_group_parameter_prior_mean, num_rfx_components)
997+
# If it's a vector, expand to matrix
998+
if xi_init.ndim == 1:
999+
xi_init = np.tile(np.expand_dims(xi_init, 1), (1, num_rfx_groups))
1000+
1001+
if rfx_working_parameter_prior_cov is None:
1002+
sigma_alpha_init = np.identity(num_rfx_components)
1003+
else:
1004+
sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
1005+
1006+
if rfx_group_parameter_prior_cov is None:
1007+
sigma_xi_init = np.identity(num_rfx_components)
9661008
else:
967-
raise ValueError("There must be at least 1 random effect component")
968-
xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups))
969-
sigma_alpha_init = np.identity(num_rfx_components)
970-
sigma_xi_init = np.identity(num_rfx_components)
971-
sigma_xi_shape = 1.0
972-
sigma_xi_scale = 1.0
1009+
sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
1010+
1011+
sigma_xi_shape = rfx_variance_prior_shape
1012+
sigma_xi_scale = rfx_variance_prior_scale
1013+
9731014
rfx_dataset_train = RandomEffectsDataset()
9741015
rfx_dataset_train.add_group_labels(rfx_group_ids_train)
9751016
rfx_dataset_train.add_basis(rfx_basis_train)

0 commit comments

Comments
 (0)