You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
171
171
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
172
172
173
-
rfx_params : dict, optional
173
+
random_effects_params : dict, optional
174
174
Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional.
175
175
176
+
* `model_spec`: Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
176
177
* `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
177
178
* `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
178
179
* `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
Copy file name to clipboardExpand all lines: stochtree/bcf.py
+25-14Lines changed: 25 additions & 14 deletions
Original file line number
Diff line number
Diff line change
@@ -208,7 +208,7 @@ def sample(
208
208
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
209
209
* `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset.
210
210
211
-
rfx_params : dict, optional
211
+
random_effects_params : dict, optional
212
212
Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional.
213
213
214
214
* `model_spec`: Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored.
@@ -2504,13 +2504,13 @@ def predict(
2504
2504
raiseValueError(
2505
2505
"rfx_group_ids must be provided if rfx_basis is provided"
2506
2506
)
2507
-
ifrfx_basisisnotNone:
2508
-
ifrfx_basis.ndim==1:
2509
-
rfx_basis=np.expand_dims(rfx_basis, 1)
2510
-
ifrfx_basis.shape[0] !=X.shape[0]:
2511
-
raiseValueError("X and rfx_basis must have the same number of rows")
2512
-
ifrfx_basis.shape[1] !=self.num_rfx_basis:
2513
-
raiseValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
2507
+
ifrfx_basisisnotNone:
2508
+
ifrfx_basis.ndim==1:
2509
+
rfx_basis=np.expand_dims(rfx_basis, 1)
2510
+
ifrfx_basis.shape[0] !=X.shape[0]:
2511
+
raiseValueError("X and rfx_basis must have the same number of rows")
2512
+
ifrfx_basis.shape[1] !=self.num_rfx_basis:
2513
+
raiseValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model")
0 commit comments