Skip to content

Commit 5c2dbfc

Browse files
committed
Handle propensity covariates for BCF models in the leaf index prediction Python function
1 parent 774af3d commit 5c2dbfc

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

stochtree/kernel.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .forest import ForestContainer
1010

1111

12-
def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None):
12+
def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None):
1313
"""
1414
Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset.
1515
@@ -37,6 +37,8 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
3737
* **ForestContainer**
3838
* `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this
3939
40+
propensity : `np.array`, optional
41+
Optional test set propensities. Must be provided if propensities were provided when the model was sampled.
4042
forest_inds : int or np.ndarray
4143
Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest.
4244
This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
@@ -88,6 +90,19 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
8890
else:
8991
covariates_processed = covariates
9092
covariates_processed = np.asfortranarray(covariates_processed)
93+
94+
# Handle BCF propensity covariate
95+
if model_type == "bcf":
96+
if model_object.propensity_covariate != "none":
97+
if propensity is None:
98+
if not model_object.internal_propensity_model:
99+
raise ValueError(
100+
"Propensity scores not provided, but no propensity model was trained during sampling"
101+
)
102+
propensity = np.mean(
103+
model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True
104+
)
105+
covariates_processed = np.c_[covariates_processed, propensity]
91106

92107
# Preprocess forest indices
93108
num_forests = forest_container.num_samples()

0 commit comments

Comments
 (0)