Skip to content

Commit 76d8b06

Browse files
committed
Deploying to r-dev from @ 96f78f0 🚀
1 parent 022320a commit 76d8b06

12 files changed

+186
-208
lines changed

R/bart.R

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,11 @@ bart <- function(
418418
# Raise a warning if the data have ties and only GFR is being run
419419
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
420420
num_values <- nrow(X_train)
421-
max_grid_size <- floor(num_values / cutpoint_grid_size)
421+
max_grid_size <- ifelse(
422+
num_values > cutpoint_grid_size,
423+
floor(num_values / cutpoint_grid_size),
424+
1
425+
)
422426
covs_warning_1 <- NULL
423427
covs_warning_2 <- NULL
424428
covs_warning_3 <- NULL
@@ -1924,7 +1928,7 @@ bart <- function(
19241928
#' Predict from a sampled BART model on new data
19251929
#'
19261930
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
1927-
#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
1931+
#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
19281932
#' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`.
19291933
#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
19301934
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
@@ -1961,10 +1965,10 @@ bart <- function(
19611965
#' y_train <- y[train_inds]
19621966
#' bart_model <- bart(X_train = X_train, y_train = y_train,
19631967
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
1964-
#' y_hat_test <- predict(bart_model, X_test)$y_hat
1968+
#' y_hat_test <- predict(bart_model, X=X_test)$y_hat
19651969
predict.bartmodel <- function(
19661970
object,
1967-
covariates,
1971+
X,
19681972
leaf_basis = NULL,
19691973
rfx_group_ids = NULL,
19701974
rfx_basis = NULL,
@@ -2047,8 +2051,8 @@ predict.bartmodel <- function(
20472051
}
20482052

20492053
# Check that covariates are matrix or data frame
2050-
if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) {
2051-
stop("covariates must be a matrix or dataframe")
2054+
if ((!is.data.frame(X)) && (!is.matrix(X))) {
2055+
stop("X must be a matrix or dataframe")
20522056
}
20532057

20542058
# Convert all input data to matrices if not already converted
@@ -2063,12 +2067,12 @@ predict.bartmodel <- function(
20632067
if ((object$model_params$requires_basis) && (is.null(leaf_basis))) {
20642068
stop("Basis (leaf_basis) must be provided for this model")
20652069
}
2066-
if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) {
2067-
stop("covariates and leaf_basis must have the same number of rows")
2070+
if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) {
2071+
stop("X and leaf_basis must have the same number of rows")
20682072
}
2069-
if (object$model_params$num_covariates != ncol(covariates)) {
2073+
if (object$model_params$num_covariates != ncol(X)) {
20702074
stop(
2071-
"covariates must contain the same number of columns as the BART model's training dataset"
2075+
"X must contain the same number of columns as the BART model's training dataset"
20722076
)
20732077
}
20742078
if ((predict_rfx) && (is.null(rfx_group_ids))) {
@@ -2089,7 +2093,7 @@ predict.bartmodel <- function(
20892093

20902094
# Preprocess covariates
20912095
train_set_metadata <- object$train_set_metadata
2092-
covariates <- preprocessPredictionData(covariates, train_set_metadata)
2096+
X <- preprocessPredictionData(X, train_set_metadata)
20932097

20942098
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
20952099
has_rfx <- FALSE
@@ -2119,8 +2123,8 @@ predict.bartmodel <- function(
21192123
# Only construct a basis if user-provided basis missing
21202124
if (is.null(rfx_basis)) {
21212125
rfx_basis <- matrix(
2122-
rep(1, nrow(covariates)),
2123-
nrow = nrow(covariates),
2126+
rep(1, nrow(X)),
2127+
nrow = nrow(X),
21242128
ncol = 1
21252129
)
21262130
}
@@ -2129,9 +2133,9 @@ predict.bartmodel <- function(
21292133

21302134
# Create prediction dataset
21312135
if (!is.null(leaf_basis)) {
2132-
prediction_dataset <- createForestDataset(covariates, leaf_basis)
2136+
prediction_dataset <- createForestDataset(X, leaf_basis)
21332137
} else {
2134-
prediction_dataset <- createForestDataset(covariates)
2138+
prediction_dataset <- createForestDataset(X)
21352139
}
21362140

21372141
# Compute variance forest predictions
@@ -2843,7 +2847,7 @@ createBARTModelFromJsonFile <- function(json_filename) {
28432847
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
28442848
#' bart_json <- saveBARTModelToJsonString(bart_model)
28452849
#' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
2846-
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
2850+
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat)
28472851
createBARTModelFromJsonString <- function(json_string) {
28482852
# Load a `CppJson` object from string
28492853
bart_json <- createCppJsonString(json_string)

R/bcf.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,11 @@ bcf <- function(
522522
# Raise a warning if the data have ties and only GFR is being run
523523
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
524524
num_values <- nrow(X_train)
525-
max_grid_size <- floor(num_values / cutpoint_grid_size)
525+
max_grid_size <- ifelse(
526+
num_values > cutpoint_grid_size,
527+
floor(num_values / cutpoint_grid_size),
528+
1
529+
)
526530
covs_warning_1 <- NULL
527531
covs_warning_2 <- NULL
528532
covs_warning_3 <- NULL

R/kernel.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ computeForestLeafIndices <- function(
129129
propensity <- rowMeans(
130130
predict(
131131
model_object$bart_propensity_model,
132-
covariates
132+
X = covariates
133133
)$y_hat
134134
)
135135
}

0 commit comments

Comments
 (0)