|
32 | 32 | #' |
33 | 33 | #' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this |
34 | 34 | #' |
| 35 | +#' @param propensity (Optional) Propensities used for prediction (BCF-only). |
35 | 36 | #' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, |
36 | 37 | #' this function will return leaf indices for every sample of a forest. |
37 | 38 | #' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. |
|
46 | 47 | #' computeForestLeafIndices(bart_model, X, "mean") |
47 | 48 | #' computeForestLeafIndices(bart_model, X, "mean", 0) |
48 | 49 | #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) |
49 | | -computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { |
| 50 | +computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, propensity=NULL, forest_inds=NULL) { |
50 | 51 | # Extract relevant forest container |
51 | 52 | stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) |
52 | 53 | model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) |
@@ -93,6 +94,21 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, |
93 | 94 | covariates_processed <- covariates |
94 | 95 | } |
95 | 96 |
|
| 97 | + # Handle BCF propensity covariate |
| 98 | + if (model_type == "bcf") { |
| 99 | + # Add propensities to covariate set if necessary |
| 100 | + if (model_object$model_params$propensity_covariate != "none") { |
| 101 | + if (is.null(propensity)) { |
| 102 | + if (!model_object$model_params$internal_propensity_model) { |
| 103 | + stop("propensity must be provided for this model") |
| 104 | + } |
| 105 | + # Compute propensity score using the internal bart model |
| 106 | + propensity <- rowMeans(predict(model_object$bart_propensity_model, covariates)$y_hat) |
| 107 | + } |
| 108 | + covariates_processed <- cbind(covariates_processed, propensity) |
| 109 | + } |
| 110 | + } |
| 111 | + |
96 | 112 | # Preprocess forest indices |
97 | 113 | num_forests <- forest_container$num_samples() |
98 | 114 | if (is.null(forest_inds)) { |
|
0 commit comments