|
23 | 23 | ) |
24 | 24 | from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel |
25 | 25 | from .serialization import JSONSerializer |
26 | | -from .utils import NotSampledError |
| 26 | +from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d_diag |
27 | 27 |
|
28 | 28 |
|
29 | 29 | class BARTModel: |
@@ -132,6 +132,12 @@ def sample( |
132 | 132 | * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. |
133 | 133 | * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. |
134 | 134 | * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. |
| 135 | + * `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. |
| 136 | + * `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. |
| 137 | + * `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. |
| 138 | + * `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. |
| 139 | + * `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. |
| 140 | + * `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. |
135 | 141 |
|
136 | 142 | mean_forest_params : dict, optional |
137 | 143 | Dictionary of mean forest model parameters, each of which has a default value processed internally, so this argument is optional. |
@@ -190,6 +196,12 @@ def sample( |
190 | 196 | "keep_every": 1, |
191 | 197 | "num_chains": 1, |
192 | 198 | "probit_outcome_model": False, |
| 199 | + "rfx_working_parameter_prior_mean": None, |
| 200 | + "rfx_group_parameter_prior_mean": None, |
| 201 | + "rfx_working_parameter_prior_cov": None, |
| 202 | + "rfx_group_parameter_prior_cov": None, |
| 203 | + "rfx_variance_prior_shape": 1.0, |
| 204 | + "rfx_variance_prior_scale": 1.0, |
193 | 205 | } |
194 | 206 | general_params_updated = _preprocess_params( |
195 | 207 | general_params_default, general_params |
@@ -279,6 +291,14 @@ def sample( |
279 | 291 | drop_vars_variance = variance_forest_params_updated["drop_vars"] |
280 | 292 | num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"] |
281 | 293 |
|
| 294 | + # 4. Random effects parameters |
| 295 | + rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"] |
| 296 | + rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"] |
| 297 | + rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"] |
| 298 | + rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"] |
| 299 | + rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] |
| 300 | + rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] |
| 301 | + |
282 | 302 | # Override keep_gfr if there are no MCMC samples |
283 | 303 | if num_mcmc == 0: |
284 | 304 | keep_gfr = True |
@@ -954,22 +974,43 @@ def sample( |
954 | 974 |
|
955 | 975 | # Set up random effects structures |
956 | 976 | if self.has_rfx: |
957 | | - if num_rfx_components == 1: |
958 | | - alpha_init = np.array([1]) |
959 | | - elif num_rfx_components > 1: |
960 | | - alpha_init = np.concatenate( |
961 | | - ( |
962 | | - np.ones(1, dtype=float), |
963 | | - np.zeros(num_rfx_components - 1, dtype=float), |
| 977 | + # Use user-provided values or defaults |
| 978 | + if rfx_working_parameter_prior_mean is None: |
| 979 | + if num_rfx_components == 1: |
| 980 | + alpha_init = np.array([1]) |
| 981 | + elif num_rfx_components > 1: |
| 982 | + alpha_init = np.concatenate( |
| 983 | + ( |
| 984 | + np.ones(1, dtype=float), |
| 985 | + np.zeros(num_rfx_components - 1, dtype=float), |
| 986 | + ) |
964 | 987 | ) |
965 | | - ) |
| 988 | + else: |
| 989 | + raise ValueError("There must be at least 1 random effect component") |
| 990 | + else: |
| 991 | + alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) |
| 992 | + |
| 993 | + if rfx_group_parameter_prior_mean is None: |
| 994 | + xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) |
| 995 | + else: |
| 996 | + xi_init = _expand_dims_1d(rfx_group_parameter_prior_mean, num_rfx_components) |
| 997 | + # If it's a vector, expand to matrix |
| 998 | + if xi_init.ndim == 1: |
| 999 | + xi_init = np.tile(np.expand_dims(xi_init, 1), (1, num_rfx_groups)) |
| 1000 | + |
| 1001 | + if rfx_working_parameter_prior_cov is None: |
| 1002 | + sigma_alpha_init = np.identity(num_rfx_components) |
| 1003 | + else: |
| 1004 | + sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) |
| 1005 | + |
| 1006 | + if rfx_group_parameter_prior_cov is None: |
| 1007 | + sigma_xi_init = np.identity(num_rfx_components) |
966 | 1008 | else: |
967 | | - raise ValueError("There must be at least 1 random effect component") |
968 | | - xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) |
969 | | - sigma_alpha_init = np.identity(num_rfx_components) |
970 | | - sigma_xi_init = np.identity(num_rfx_components) |
971 | | - sigma_xi_shape = 1.0 |
972 | | - sigma_xi_scale = 1.0 |
| 1009 | + sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) |
| 1010 | + |
| 1011 | + sigma_xi_shape = rfx_variance_prior_shape |
| 1012 | + sigma_xi_scale = rfx_variance_prior_scale |
| 1013 | + |
973 | 1014 | rfx_dataset_train = RandomEffectsDataset() |
974 | 1015 | rfx_dataset_train.add_group_labels(rfx_group_ids_train) |
975 | 1016 | rfx_dataset_train.add_basis(rfx_basis_train) |
|
0 commit comments