Skip to content

Commit 0726617

Browse files
committed
Updated BCF interface to allow user-provided RFX parameters
1 parent 5814f41 commit 0726617

File tree

8 files changed

+357
-34
lines changed

8 files changed

+357
-34
lines changed

R/bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
690690

691691
# Random effects initialization
692692
if (has_rfx) {
693-
# Prior parameters - use user-provided values or defaults
693+
# Prior parameters
694694
if (is.null(rfx_working_parameter_prior_mean)) {
695695
if (num_rfx_components == 1) {
696696
alpha_init <- c(1)

R/bcf.R

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
4848
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
4949
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
50+
#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
51+
#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
52+
#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
53+
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
54+
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
55+
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
5056
#'
5157
#' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional.
5258
#'
@@ -162,7 +168,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
162168
treated_coding_init = 0.5, rfx_prior_var = NULL,
163169
random_seed = -1, keep_burnin = FALSE, keep_gfr = FALSE,
164170
keep_every = 1, num_chains = 1, verbose = FALSE,
165-
probit_outcome_model = FALSE
171+
probit_outcome_model = FALSE,
172+
rfx_working_parameter_prior_mean = NULL,
173+
rfx_group_parameter_prior_mean = NULL,
174+
rfx_working_parameter_prior_cov = NULL,
175+
rfx_group_parameter_prior_cov = NULL,
176+
rfx_variance_prior_shape = 1,
177+
rfx_variance_prior_scale = 1
166178
)
167179
general_params_updated <- preprocessParams(
168180
general_params_default, general_params
@@ -230,6 +242,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
230242
num_chains <- general_params_updated$num_chains
231243
verbose <- general_params_updated$verbose
232244
probit_outcome_model <- general_params_updated$probit_outcome_model
245+
rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean
246+
rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean
247+
rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov
248+
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
249+
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
250+
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
233251

234252
# 2. Mu forest parameters
235253
num_trees_mu <- prognostic_forest_params_updated$num_trees
@@ -842,24 +860,39 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
842860

843861
# Random effects prior parameters
844862
if (has_rfx) {
845-
# Initialize the working parameter to 1
846-
if (num_rfx_components < 1) {
847-
stop("There must be at least 1 random effect component")
863+
# Prior parameters
864+
if (is.null(rfx_working_parameter_prior_mean)) {
865+
if (num_rfx_components == 1) {
866+
alpha_init <- c(1)
867+
} else if (num_rfx_components > 1) {
868+
alpha_init <- c(1,rep(0,num_rfx_components-1))
869+
} else {
870+
stop("There must be at least 1 random effect component")
871+
}
872+
} else {
873+
alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
874+
}
875+
876+
if (is.null(rfx_group_parameter_prior_mean)) {
877+
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
878+
} else {
879+
xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
848880
}
849-
alpha_init <- rep(1,num_rfx_components)
850-
# Initialize each group parameter based on a regression of outcome on basis in that grou
851-
xi_init <- matrix(0,num_rfx_components,num_rfx_groups)
852-
for (i in 1:num_rfx_groups) {
853-
group_subset_indices <- rfx_group_ids_train == i
854-
basis_group <- rfx_basis_train[group_subset_indices,]
855-
resid_group <- resid_train[group_subset_indices]
856-
rfx_group_model <- lm(resid_group ~ 0+basis_group)
857-
xi_init[,i] <- unname(coef(rfx_group_model))
881+
882+
if (is.null(rfx_working_parameter_prior_cov)) {
883+
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
884+
} else {
885+
sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
886+
}
887+
888+
if (is.null(rfx_group_parameter_prior_cov)) {
889+
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
890+
} else {
891+
sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
858892
}
859-
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
860-
sigma_xi_init <- diag(rfx_prior_var)
861-
sigma_xi_shape <- 1
862-
sigma_xi_scale <- 1
893+
894+
sigma_xi_shape <- rfx_variance_prior_shape
895+
sigma_xi_scale <- rfx_variance_prior_scale
863896
}
864897

865898
# Random effects data structure and storage container

man/bcf.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

stochtree/bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ def sample(
972972

973973
# Set up random effects structures
974974
if self.has_rfx:
975-
# Use user-provided values or defaults
975+
# Prior parameters
976976
if rfx_working_parameter_prior_mean is None:
977977
if num_rfx_components == 1:
978978
alpha_init = np.array([1])
@@ -1006,6 +1006,7 @@ def sample(
10061006
sigma_xi_shape = rfx_variance_prior_shape
10071007
sigma_xi_scale = rfx_variance_prior_scale
10081008

1009+
# Random effects sampling data structures
10091010
rfx_dataset_train = RandomEffectsDataset()
10101011
rfx_dataset_train.add_group_labels(rfx_group_ids_train)
10111012
rfx_dataset_train.add_basis(rfx_basis_train)

stochtree/bcf.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel
2525
from .serialization import JSONSerializer
26-
from .utils import NotSampledError
26+
from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag
2727

2828

2929
class BCFModel:
@@ -152,6 +152,13 @@ def sample(
152152
* `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
153153
* `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`.
154154
* `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`.
155+
* `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
156+
* `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
157+
* `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
158+
* `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
159+
* `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
160+
* `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
161+
155162
156163
prognostic_forest_params : dict, optional
157164
Dictionary of prognostic forest model parameters, each of which has a default value processed internally, so this argument is optional.
@@ -227,6 +234,12 @@ def sample(
227234
"keep_every": 1,
228235
"num_chains": 1,
229236
"probit_outcome_model": False,
237+
"rfx_working_parameter_prior_mean": None,
238+
"rfx_group_parameter_prior_mean": None,
239+
"rfx_working_parameter_prior_cov": None,
240+
"rfx_group_parameter_prior_cov": None,
241+
"rfx_variance_prior_shape": 1.0,
242+
"rfx_variance_prior_scale": 1.0,
230243
}
231244
general_params_updated = _preprocess_params(
232245
general_params_default, general_params
@@ -309,6 +322,12 @@ def sample(
309322
keep_every = general_params_updated["keep_every"]
310323
num_chains = general_params_updated["num_chains"]
311324
self.probit_outcome_model = general_params_updated["probit_outcome_model"]
325+
rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"]
326+
rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"]
327+
rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"]
328+
rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"]
329+
rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"]
330+
rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"]
312331

313332
# 2. Mu forest parameters
314333
num_trees_mu = prognostic_forest_params_updated["num_trees"]
@@ -1354,19 +1373,41 @@ def sample(
13541373

13551374
# Set up random effects structures
13561375
if self.has_rfx:
1357-
if num_rfx_components == 1:
1358-
alpha_init = np.array([1])
1359-
elif num_rfx_components > 1:
1360-
alpha_init = np.concatenate(
1361-
(np.ones(1, dtype=float), np.zeros(num_rfx_components - 1, dtype=float))
1362-
)
1376+
# Prior parameters
1377+
if rfx_working_parameter_prior_mean is None:
1378+
if num_rfx_components == 1:
1379+
alpha_init = np.array([1])
1380+
elif num_rfx_components > 1:
1381+
alpha_init = np.concatenate(
1382+
(
1383+
np.ones(1, dtype=float),
1384+
np.zeros(num_rfx_components - 1, dtype=float),
1385+
)
1386+
)
1387+
else:
1388+
raise ValueError("There must be at least 1 random effect component")
13631389
else:
1364-
raise ValueError("There must be at least 1 random effect component")
1365-
xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups))
1366-
sigma_alpha_init = np.identity(num_rfx_components)
1367-
sigma_xi_init = np.identity(num_rfx_components)
1368-
sigma_xi_shape = 1.0
1369-
sigma_xi_scale = 1.0
1390+
alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
1391+
1392+
if rfx_group_parameter_prior_mean is None:
1393+
xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups))
1394+
else:
1395+
xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
1396+
1397+
if rfx_working_parameter_prior_cov is None:
1398+
sigma_alpha_init = np.identity(num_rfx_components)
1399+
else:
1400+
sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
1401+
1402+
if rfx_group_parameter_prior_cov is None:
1403+
sigma_xi_init = np.identity(num_rfx_components)
1404+
else:
1405+
sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
1406+
1407+
sigma_xi_shape = rfx_variance_prior_shape
1408+
sigma_xi_scale = rfx_variance_prior_scale
1409+
1410+
# Random effects sampling data structures
13701411
rfx_dataset_train = RandomEffectsDataset()
13711412
rfx_dataset_train.add_group_labels(rfx_group_ids_train)
13721413
rfx_dataset_train.add_basis(rfx_basis_train)

0 commit comments

Comments
 (0)