Skip to content

Commit 6998a6f

Browse files
authored
Merge pull request #174 from StochasticTree/kernel-hotfix
Handle BCF models with propensity covariates in the kernel prediction functions
2 parents 1ff234b + 5fd4ae1 commit 6998a6f

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

R/forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ Forest <- R6::R6Class(
785785
#' Return constant leaf status of trees in a `Forest` object
786786
#' @return `TRUE` if leaves are constant, `FALSE` otherwise
787787
is_constant_leaf = function() {
788-
return(is_leaf_constant_forest_container_cpp(self$forest_ptr))
788+
return(is_leaf_constant_active_forest_cpp(self$forest_ptr))
789789
},
790790

791791
#' @description

R/kernel.R

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#'
3333
#' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this
3434
#'
35+
#' @param propensity (Optional) Propensities used for prediction (BCF-only).
3536
#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided,
3637
#' this function will return leaf indices for every sample of a forest.
3738
#' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
@@ -46,7 +47,7 @@
4647
#' computeForestLeafIndices(bart_model, X, "mean")
4748
#' computeForestLeafIndices(bart_model, X, "mean", 0)
4849
#' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
49-
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
50+
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, propensity=NULL, forest_inds=NULL) {
5051
# Extract relevant forest container
5152
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
5253
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
@@ -93,6 +94,21 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
9394
covariates_processed <- covariates
9495
}
9596

97+
# Handle BCF propensity covariate
98+
if (model_type == "bcf") {
99+
# Add propensities to covariate set if necessary
100+
if (model_object$model_params$propensity_covariate != "none") {
101+
if (is.null(propensity)) {
102+
if (!model_object$model_params$internal_propensity_model) {
103+
stop("propensity must be provided for this model")
104+
}
105+
# Compute propensity score using the internal bart model
106+
propensity <- rowMeans(predict(model_object$bart_propensity_model, covariates)$y_hat)
107+
}
108+
covariates_processed <- cbind(covariates_processed, propensity)
109+
}
110+
}
111+
96112
# Preprocess forest indices
97113
num_forests <- forest_container$num_samples()
98114
if (is.null(forest_inds)) {

man/computeForestLeafIndices.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

stochtree/kernel.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .forest import ForestContainer
1010

1111

12-
def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None):
12+
def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None):
1313
"""
1414
Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset.
1515
@@ -37,6 +37,8 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
3737
* **ForestContainer**
3838
* `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this
3939
40+
propensity : `np.array`, optional
41+
Optional test set propensities. Must be provided if propensities were provided when the model was sampled.
4042
forest_inds : int or np.ndarray
4143
Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest.
4244
This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
@@ -88,6 +90,19 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
8890
else:
8991
covariates_processed = covariates
9092
covariates_processed = np.asfortranarray(covariates_processed)
93+
94+
# Handle BCF propensity covariate
95+
if model_type == "bcf":
96+
if model_object.propensity_covariate != "none":
97+
if propensity is None:
98+
if not model_object.internal_propensity_model:
99+
raise ValueError(
100+
"Propensity scores not provided, but no propensity model was trained during sampling"
101+
)
102+
propensity = np.mean(
103+
model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True
104+
)
105+
covariates_processed = np.c_[covariates_processed, propensity]
91106

92107
# Preprocess forest indices
93108
num_forests = forest_container.num_samples()

0 commit comments

Comments
 (0)