@@ -258,9 +258,7 @@ def sample(
258258 "variance_prior_shape" : 1.0 ,
259259 "variance_prior_scale" : 1.0 ,
260260 }
261- rfx_params_updated = _preprocess_params (
262- rfx_params_default , rfx_params
263- )
261+ rfx_params_updated = _preprocess_params (rfx_params_default , rfx_params )
264262
265263 ### Unpack all parameter values
266264 # 1. General parameters
@@ -323,9 +321,7 @@ def sample(
323321 rfx_working_parameter_prior_cov = rfx_params_updated [
324322 "working_parameter_prior_cov"
325323 ]
326- rfx_group_parameter_prior_cov = rfx_params_updated [
327- "group_parameter_prior_cov"
328- ]
324+ rfx_group_parameter_prior_cov = rfx_params_updated ["group_parameter_prior_cov" ]
329325 rfx_variance_prior_shape = rfx_params_updated ["variance_prior_shape" ]
330326 rfx_variance_prior_scale = rfx_params_updated ["variance_prior_scale" ]
331327
@@ -1885,13 +1881,13 @@ def compute_contrast(
18851881 type : str = "posterior" ,
18861882 scale : str = "linear" ,
18871883 ) -> Union [np .array , tuple ]:
1888- """Compute a contrast using a BART model by making two sets of outcome predictions and taking their
1889- difference. This function provides the flexibility to compute any contrast of interest by specifying
1890- covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
1891- For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend
1892- of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
1893- terminology of a classic two-treatment causal inference problem. We mirror the function calls and
1894- terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote
1884+ """Compute a contrast using a BART model by making two sets of outcome predictions and taking their
1885+ difference. This function provides the flexibility to compute any contrast of interest by specifying
1886+ covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
1887+ For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend
1888+ of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
1889+ terminology of a classic two-treatment causal inference problem. We mirror the function calls and
1890+ terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote
18951891 its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction.
18961892
18971893 Parameters
@@ -1905,12 +1901,12 @@ def compute_contrast(
19051901 basis_1 : np.array, optional
19061902 Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values).
19071903 rfx_group_ids_0 : np.array, optional
1908- Test set group labels used for prediction from an additive random effects model in the "control" case.
1909- We do not currently support (but plan to in the near future), test set evaluation for group labels that
1904+ Test set group labels used for prediction from an additive random effects model in the "control" case.
1905+ We do not currently support (but plan to in the near future), test set evaluation for group labels that
19101906 were not in the training set. Must be a numpy array.
19111907 rfx_group_ids_1 : np.array, optional
1912- Test set group labels used for prediction from an additive random effects model in the "treatment" case.
1913- We do not currently support (but plan to in the near future), test set evaluation for group labels that
1908+ Test set group labels used for prediction from an additive random effects model in the "treatment" case.
1909+ We do not currently support (but plan to in the near future), test set evaluation for group labels that
19141910 were not in the training set. Must be a numpy array.
19151911 rfx_basis_0 : np.array, optional
19161912 Test set basis for used for prediction from an additive random effects model in the "control" case.
@@ -1949,10 +1945,7 @@ def compute_contrast(
19491945 has_rfx = self .has_rfx
19501946
19511947 # Check that we have at least one term to predict on probability scale
1952- if (
1953- not has_mean_forest
1954- and not has_rfx
1955- ):
1948+ if not has_mean_forest and not has_rfx :
19561949 raise ValueError (
19571950 "Contrast cannot be computed as the model does not have a mean forest or random effects term"
19581951 )
@@ -1988,12 +1981,28 @@ def compute_contrast(
19881981 raise ValueError (
19891982 "covariates_1 and basis_1 must have the same number of rows"
19901983 )
1991-
1984+
19921985 # Predict for the control arm
1993- control_preds = self .predict (covariates = covariates_0 , basis = basis_0 , rfx_group_ids = rfx_group_ids_0 , rfx_basis = rfx_basis_0 , type = "posterior" , terms = "y_hat" , scale = "linear" )
1986+ control_preds = self .predict (
1987+ covariates = covariates_0 ,
1988+ basis = basis_0 ,
1989+ rfx_group_ids = rfx_group_ids_0 ,
1990+ rfx_basis = rfx_basis_0 ,
1991+ type = "posterior" ,
1992+ terms = "y_hat" ,
1993+ scale = "linear" ,
1994+ )
19941995
19951996 # Predict for the treatment arm
1996- treatment_preds = self .predict (covariates = covariates_1 , basis = basis_1 , rfx_group_ids = rfx_group_ids_1 , rfx_basis = rfx_basis_1 , type = "posterior" , terms = "y_hat" , scale = "linear" )
1997+ treatment_preds = self .predict (
1998+ covariates = covariates_1 ,
1999+ basis = basis_1 ,
2000+ rfx_group_ids = rfx_group_ids_1 ,
2001+ rfx_basis = rfx_basis_1 ,
2002+ type = "posterior" ,
2003+ terms = "y_hat" ,
2004+ scale = "linear" ,
2005+ )
19972006
19982007 # Transform to probability scale if requested
19992008 if probability_scale :
@@ -2002,9 +2011,9 @@ def compute_contrast(
20022011
20032012 # Compute and return contrast
20042013 if predict_mean :
2005- return ( np .mean (treatment_preds - control_preds , axis = 1 ) )
2014+ return np .mean (treatment_preds - control_preds , axis = 1 )
20062015 else :
2007- return ( treatment_preds - control_preds )
2016+ return treatment_preds - control_preds
20082017
20092018 def compute_posterior_interval (
20102019 self ,
0 commit comments