Skip to content

Commit 774af3d

Browse files
committed
Handle propensity covariates for BCF models in the leaf index prediction R function
1 parent 1ff234b commit 774af3d

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

R/kernel.R

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#'
3333
#' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this
3434
#'
35+
#' @param propensity (Optional) Propensities used for prediction (BCF-only).
3536
#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided,
3637
#' this function will return leaf indices for every sample of a forest.
3738
#' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
@@ -46,7 +47,7 @@
4647
#' computeForestLeafIndices(bart_model, X, "mean")
4748
#' computeForestLeafIndices(bart_model, X, "mean", 0)
4849
#' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9))
49-
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
50+
computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, propensity=NULL, forest_inds=NULL) {
5051
# Extract relevant forest container
5152
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
5253
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
@@ -93,6 +94,21 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
9394
covariates_processed <- covariates
9495
}
9596

97+
# Handle BCF propensity covariate
98+
if (model_type == "bcf") {
99+
# Add propensities to covariate set if necessary
100+
if (model_object$model_params$propensity_covariate != "none") {
101+
if (is.null(propensity)) {
102+
if (!model_object$model_params$internal_propensity_model) {
103+
stop("propensity must be provided for this model")
104+
}
105+
# Compute propensity score using the internal bart model
106+
propensity <- rowMeans(predict(model_object$bart_propensity_model, X)$y_hat)
107+
}
108+
covariates_processed <- cbind(covariates_processed, propensity)
109+
}
110+
}
111+
96112
# Preprocess forest indices
97113
num_forests <- forest_container$num_samples()
98114
if (is.null(forest_inds)) {

0 commit comments

Comments
 (0)