|
9 | 9 | from .forest import ForestContainer |
10 | 10 |
|
11 | 11 |
|
12 | | -def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None): |
| 12 | +def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None): |
13 | 13 | """ |
14 | 14 | Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. |
15 | 15 |
|
@@ -37,6 +37,8 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC |
37 | 37 | * **ForestContainer** |
38 | 38 | * `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this |
39 | 39 | |
| 40 | + propensity : `np.array`, optional |
| 41 | + Optional test set propensities. Must be provided if propensities were provided when the model was sampled. |
40 | 42 | forest_inds : int or np.ndarray |
41 | 43 | Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. |
42 | 44 | This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. |
@@ -88,6 +90,19 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC |
88 | 90 | else: |
89 | 91 | covariates_processed = covariates |
90 | 92 | covariates_processed = np.asfortranarray(covariates_processed) |
| 93 | + |
| 94 | + # Handle BCF propensity covariate |
| 95 | + if model_type == "bcf": |
| 96 | + if model_object.propensity_covariate != "none": |
| 97 | + if propensity is None: |
| 98 | + if not model_object.internal_propensity_model: |
| 99 | + raise ValueError( |
| 100 | + "Propensity scores not provided, but no propensity model was trained during sampling" |
| 101 | + ) |
| 102 | + propensity = np.mean( |
| 103 | + model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True |
| 104 | + ) |
| 105 | + covariates_processed = np.c_[covariates_processed, propensity] |
91 | 106 |
|
92 | 107 | # Preprocess forest indices |
93 | 108 | num_forests = forest_container.num_samples() |
|
0 commit comments