Skip to content

Commit f572cb9

Browse files
committed
Reformat python code
1 parent 51d3da7 commit f572cb9

File tree

2 files changed

+70
-47
lines changed

2 files changed

+70
-47
lines changed

stochtree/bart.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,7 @@ def sample(
258258
"variance_prior_shape": 1.0,
259259
"variance_prior_scale": 1.0,
260260
}
261-
rfx_params_updated = _preprocess_params(
262-
rfx_params_default, rfx_params
263-
)
261+
rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params)
264262

265263
### Unpack all parameter values
266264
# 1. General parameters
@@ -323,9 +321,7 @@ def sample(
323321
rfx_working_parameter_prior_cov = rfx_params_updated[
324322
"working_parameter_prior_cov"
325323
]
326-
rfx_group_parameter_prior_cov = rfx_params_updated[
327-
"group_parameter_prior_cov"
328-
]
324+
rfx_group_parameter_prior_cov = rfx_params_updated["group_parameter_prior_cov"]
329325
rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"]
330326
rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"]
331327

@@ -1885,13 +1881,13 @@ def compute_contrast(
18851881
type: str = "posterior",
18861882
scale: str = "linear",
18871883
) -> Union[np.array, tuple]:
1888-
"""Compute a contrast using a BART model by making two sets of outcome predictions and taking their
1889-
difference. This function provides the flexibility to compute any contrast of interest by specifying
1890-
covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
1891-
For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend
1892-
of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
1893-
terminology of a classic two-treatment causal inference problem. We mirror the function calls and
1894-
terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote
1884+
"""Compute a contrast using a BART model by making two sets of outcome predictions and taking their
1885+
difference. This function provides the flexibility to compute any contrast of interest by specifying
1886+
covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
1887+
For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend
1888+
of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
1889+
terminology of a classic two-treatment causal inference problem. We mirror the function calls and
1890+
terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote
18951891
its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction.
18961892
18971893
Parameters
@@ -1905,12 +1901,12 @@ def compute_contrast(
19051901
basis_1 : np.array, optional
19061902
Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values).
19071903
rfx_group_ids_0 : np.array, optional
1908-
Test set group labels used for prediction from an additive random effects model in the "control" case.
1909-
We do not currently support (but plan to in the near future), test set evaluation for group labels that
1904+
Test set group labels used for prediction from an additive random effects model in the "control" case.
1905+
We do not currently support (but plan to in the near future), test set evaluation for group labels that
19101906
were not in the training set. Must be a numpy array.
19111907
rfx_group_ids_1 : np.array, optional
1912-
Test set group labels used for prediction from an additive random effects model in the "treatment" case.
1913-
We do not currently support (but plan to in the near future), test set evaluation for group labels that
1908+
Test set group labels used for prediction from an additive random effects model in the "treatment" case.
1909+
We do not currently support (but plan to in the near future), test set evaluation for group labels that
19141910
were not in the training set. Must be a numpy array.
19151911
rfx_basis_0 : np.array, optional
19161912
Test set basis for used for prediction from an additive random effects model in the "control" case.
@@ -1949,10 +1945,7 @@ def compute_contrast(
19491945
has_rfx = self.has_rfx
19501946

19511947
# Check that we have at least one term to predict on probability scale
1952-
if (
1953-
not has_mean_forest
1954-
and not has_rfx
1955-
):
1948+
if not has_mean_forest and not has_rfx:
19561949
raise ValueError(
19571950
"Contrast cannot be computed as the model does not have a mean forest or random effects term"
19581951
)
@@ -1988,12 +1981,28 @@ def compute_contrast(
19881981
raise ValueError(
19891982
"covariates_1 and basis_1 must have the same number of rows"
19901983
)
1991-
1984+
19921985
# Predict for the control arm
1993-
control_preds = self.predict(covariates=covariates_0, basis=basis_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", terms="y_hat", scale="linear")
1986+
control_preds = self.predict(
1987+
covariates=covariates_0,
1988+
basis=basis_0,
1989+
rfx_group_ids=rfx_group_ids_0,
1990+
rfx_basis=rfx_basis_0,
1991+
type="posterior",
1992+
terms="y_hat",
1993+
scale="linear",
1994+
)
19941995

19951996
# Predict for the treatment arm
1996-
treatment_preds = self.predict(covariates=covariates_1, basis=basis_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", terms="y_hat", scale="linear")
1997+
treatment_preds = self.predict(
1998+
covariates=covariates_1,
1999+
basis=basis_1,
2000+
rfx_group_ids=rfx_group_ids_1,
2001+
rfx_basis=rfx_basis_1,
2002+
type="posterior",
2003+
terms="y_hat",
2004+
scale="linear",
2005+
)
19972006

19982007
# Transform to probability scale if requested
19992008
if probability_scale:
@@ -2002,9 +2011,9 @@ def compute_contrast(
20022011

20032012
# Compute and return contrast
20042013
if predict_mean:
2005-
return(np.mean(treatment_preds - control_preds, axis=1))
2014+
return np.mean(treatment_preds - control_preds, axis=1)
20062015
else:
2007-
return(treatment_preds - control_preds)
2016+
return treatment_preds - control_preds
20082017

20092018
def compute_posterior_interval(
20102019
self,

stochtree/bcf.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ def sample(
315315
"variance_prior_shape": 1.0,
316316
"variance_prior_scale": 1.0,
317317
}
318-
rfx_params_updated = _preprocess_params(
319-
rfx_params_default, rfx_params
320-
)
318+
rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params)
321319

322320
### Unpack all parameter values
323321
# 1. General parameters
@@ -405,9 +403,7 @@ def sample(
405403
rfx_working_parameter_prior_cov = rfx_params_updated[
406404
"working_parameter_prior_cov"
407405
]
408-
rfx_group_parameter_prior_cov = rfx_params_updated[
409-
"group_parameter_prior_cov"
410-
]
406+
rfx_group_parameter_prior_cov = rfx_params_updated["group_parameter_prior_cov"]
411407
rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"]
412408
rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"]
413409

@@ -2556,13 +2552,13 @@ def compute_contrast(
25562552
type: str = "posterior",
25572553
scale: str = "linear",
25582554
) -> dict:
2559-
"""Compute a contrast using a BCF model by making two sets of outcome predictions and taking their
2560-
difference. This function provides the flexibility to compute any contrast of interest by specifying
2561-
covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
2562-
For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend
2563-
of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
2564-
terminology of a classic two-treatment causal inference problem. We mirror the function calls and
2565-
terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote
2555+
"""Compute a contrast using a BCF model by making two sets of outcome predictions and taking their
2556+
difference. This function provides the flexibility to compute any contrast of interest by specifying
2557+
covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast.
2558+
For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend
2559+
of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment"
2560+
terminology of a classic two-treatment causal inference problem. We mirror the function calls and
2561+
terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote
25662562
its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction.
25672563
25682564
Parameters
@@ -2580,12 +2576,12 @@ def compute_contrast(
25802576
propensity_1 : `np.array`, optional
25812577
Propensities used for prediction in the "treatment" case. Must be a numpy array or vector.
25822578
rfx_group_ids_0 : np.array, optional
2583-
Test set group labels used for prediction from an additive random effects model in the "control" case.
2584-
We do not currently support (but plan to in the near future), test set evaluation for group labels that
2579+
Test set group labels used for prediction from an additive random effects model in the "control" case.
2580+
We do not currently support (but plan to in the near future), test set evaluation for group labels that
25852581
were not in the training set. Must be a numpy array.
25862582
rfx_group_ids_1 : np.array, optional
2587-
Test set group labels used for prediction from an additive random effects model in the "control" case.
2588-
We do not currently support (but plan to in the near future), test set evaluation for group labels that
2583+
Test set group labels used for prediction from an additive random effects model in the "control" case.
2584+
We do not currently support (but plan to in the near future), test set evaluation for group labels that
25892585
were not in the training set. Must be a numpy array.
25902586
rfx_basis_0 : np.array, optional
25912587
Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a numpy array.
@@ -2634,10 +2630,28 @@ def compute_contrast(
26342630
raise ValueError("X_1 and Z_1 must have the same number of rows")
26352631

26362632
# Predict for the control arm
2637-
control_preds = self.predict(X=X_0, Z=Z_0, propensity=propensity_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", terms="y_hat", scale="linear")
2633+
control_preds = self.predict(
2634+
X=X_0,
2635+
Z=Z_0,
2636+
propensity=propensity_0,
2637+
rfx_group_ids=rfx_group_ids_0,
2638+
rfx_basis=rfx_basis_0,
2639+
type="posterior",
2640+
terms="y_hat",
2641+
scale="linear",
2642+
)
26382643

26392644
# Predict for the treatment arm
2640-
treatment_preds = self.predict(X=X_1, Z=Z_1, propensity=propensity_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", terms="y_hat", scale="linear")
2645+
treatment_preds = self.predict(
2646+
X=X_1,
2647+
Z=Z_1,
2648+
propensity=propensity_1,
2649+
rfx_group_ids=rfx_group_ids_1,
2650+
rfx_basis=rfx_basis_1,
2651+
type="posterior",
2652+
terms="y_hat",
2653+
scale="linear",
2654+
)
26412655

26422656
# Transform to probability scale if requested
26432657
if probability_scale:
@@ -2646,9 +2660,9 @@ def compute_contrast(
26462660

26472661
# Compute and return contrast
26482662
if predict_mean:
2649-
return(np.mean(treatment_preds - control_preds, axis=1))
2663+
return np.mean(treatment_preds - control_preds, axis=1)
26502664
else:
2651-
return(treatment_preds - control_preds)
2665+
return treatment_preds - control_preds
26522666

26532667
def compute_posterior_interval(
26542668
self,

0 commit comments

Comments
 (0)