Skip to content

Commit f11a0f4

Browse files
committed
Fixed several prediction bugs in R for BART / BCF
1 parent d1ab690 commit f11a0f4

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

R/bart.R

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,7 +1696,7 @@ bart <- function(
16961696
"is_leaf_constant" = is_leaf_constant,
16971697
"leaf_regression" = leaf_regression,
16981698
"requires_basis" = requires_basis,
1699-
"num_covariates" = ncol(X_train),
1699+
"num_covariates" = num_cov_orig,
17001700
"num_basis" = ifelse(
17011701
is.null(leaf_basis_train),
17021702
0,
@@ -1896,12 +1896,10 @@ predict.bartmodel <- function(
18961896
)
18971897
}
18981898

1899-
# Preprocess covariates
1899+
# Check that covariates are matrix or data frame
19001900
if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) {
19011901
stop("covariates must be a matrix or dataframe")
19021902
}
1903-
train_set_metadata <- object$train_set_metadata
1904-
X <- preprocessPredictionData(covariates, train_set_metadata)
19051903

19061904
# Convert all input data to matrices if not already converted
19071905
if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) {
@@ -1915,11 +1913,13 @@ predict.bartmodel <- function(
19151913
if ((object$model_params$requires_basis) && (is.null(leaf_basis))) {
19161914
stop("Basis (leaf_basis) must be provided for this model")
19171915
}
1918-
if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) {
1919-
stop("X and leaf_basis must have the same number of rows")
1916+
if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) {
1917+
stop("covariates and leaf_basis must have the same number of rows")
19201918
}
1921-
if (object$model_params$num_covariates != ncol(X)) {
1922-
stop("X and leaf_basis must have the same number of rows")
1919+
if (object$model_params$num_covariates != ncol(covariates)) {
1920+
stop(
1921+
"covariates must contain the same number of columns as the BART model's training dataset"
1922+
)
19231923
}
19241924
if ((predict_rfx) && (is.null(rfx_group_ids))) {
19251925
stop(
@@ -1938,6 +1938,10 @@ predict.bartmodel <- function(
19381938
)
19391939
}
19401940

1941+
# Preprocess covariates
1942+
train_set_metadata <- object$train_set_metadata
1943+
covariates <- preprocessPredictionData(covariates, train_set_metadata)
1944+
19411945
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
19421946
has_rfx <- FALSE
19431947
if (predict_rfx) {
@@ -1956,14 +1960,14 @@ predict.bartmodel <- function(
19561960

19571961
# Produce basis for the "intercept-only" random effects case
19581962
if ((predict_rfx) && (is.null(rfx_basis))) {
1959-
rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1)
1963+
rfx_basis <- matrix(rep(1, nrow(covariates)), ncol = 1)
19601964
}
19611965

19621966
# Create prediction dataset
19631967
if (!is.null(leaf_basis)) {
1964-
prediction_dataset <- createForestDataset(X, leaf_basis)
1968+
prediction_dataset <- createForestDataset(covariates, leaf_basis)
19651969
} else {
1966-
prediction_dataset <- createForestDataset(X)
1970+
prediction_dataset <- createForestDataset(covariates)
19671971
}
19681972

19691973
# Compute variance forest predictions

R/bcf.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,12 +2708,10 @@ predict.bcfmodel <- function(
27082708
predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest)
27092709
predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest)
27102710

2711-
# Preprocess covariates
2711+
# Make sure covariates are matrix or data frame
27122712
if ((!is.data.frame(X)) && (!is.matrix(X))) {
27132713
stop("X must be a matrix or dataframe")
27142714
}
2715-
train_set_metadata <- object$train_set_metadata
2716-
X <- preprocessPredictionData(X, train_set_metadata)
27172715

27182716
# Convert all input data to matrices if not already converted
27192717
if ((is.null(dim(Z))) && (!is.null(Z))) {
@@ -2762,6 +2760,10 @@ predict.bcfmodel <- function(
27622760
)
27632761
}
27642762

2763+
# Preprocess covariates
2764+
train_set_metadata <- object$train_set_metadata
2765+
X <- preprocessPredictionData(X, train_set_metadata)
2766+
27652767
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
27662768
has_rfx <- FALSE
27672769
if (!is.null(rfx_group_ids)) {
@@ -2846,7 +2848,7 @@ predict.bcfmodel <- function(
28462848
}
28472849

28482850
# Compute rfx predictions
2849-
if (predict_rfx) {
2851+
if (predict_rfx || predict_rfx_intermediate) {
28502852
rfx_predictions <- object$rfx_samples$predict(
28512853
rfx_group_ids,
28522854
rfx_basis

0 commit comments

Comments
 (0)