@@ -1696,7 +1696,7 @@ bart <- function(
16961696 " is_leaf_constant" = is_leaf_constant ,
16971697 " leaf_regression" = leaf_regression ,
16981698 " requires_basis" = requires_basis ,
1699- " num_covariates" = ncol( X_train ) ,
1699+ " num_covariates" = num_cov_orig ,
17001700 " num_basis" = ifelse(
17011701 is.null(leaf_basis_train ),
17021702 0 ,
@@ -1896,12 +1896,10 @@ predict.bartmodel <- function(
18961896 )
18971897 }
18981898
1899- # Preprocess covariates
1899+ # Check that covariates are matrix or data frame
19001900 if ((! is.data.frame(covariates )) && (! is.matrix(covariates ))) {
19011901 stop(" covariates must be a matrix or dataframe" )
19021902 }
1903- train_set_metadata <- object $ train_set_metadata
1904- X <- preprocessPredictionData(covariates , train_set_metadata )
19051903
19061904 # Convert all input data to matrices if not already converted
19071905 if ((is.null(dim(leaf_basis ))) && (! is.null(leaf_basis ))) {
@@ -1915,11 +1913,13 @@ predict.bartmodel <- function(
19151913 if ((object $ model_params $ requires_basis ) && (is.null(leaf_basis ))) {
19161914 stop(" Basis (leaf_basis) must be provided for this model" )
19171915 }
1918- if ((! is.null(leaf_basis )) && (nrow(X ) != nrow(leaf_basis ))) {
1919- stop(" X and leaf_basis must have the same number of rows" )
1916+ if ((! is.null(leaf_basis )) && (nrow(covariates ) != nrow(leaf_basis ))) {
1917+ stop(" covariates and leaf_basis must have the same number of rows" )
19201918 }
1921- if (object $ model_params $ num_covariates != ncol(X )) {
1922- stop(" X and leaf_basis must have the same number of rows" )
1919+ if (object $ model_params $ num_covariates != ncol(covariates )) {
1920+ stop(
1921+ " covariates must contain the same number of columns as the BART model's training dataset"
1922+ )
19231923 }
19241924 if ((predict_rfx ) && (is.null(rfx_group_ids ))) {
19251925 stop(
@@ -1938,6 +1938,10 @@ predict.bartmodel <- function(
19381938 )
19391939 }
19401940
1941+ # Preprocess covariates
1942+ train_set_metadata <- object $ train_set_metadata
1943+ covariates <- preprocessPredictionData(covariates , train_set_metadata )
1944+
19411945 # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
19421946 has_rfx <- FALSE
19431947 if (predict_rfx ) {
@@ -1956,14 +1960,14 @@ predict.bartmodel <- function(
19561960
19571961 # Produce basis for the "intercept-only" random effects case
19581962 if ((predict_rfx ) && (is.null(rfx_basis ))) {
1959- rfx_basis <- matrix (rep(1 , nrow(X )), ncol = 1 )
1963+ rfx_basis <- matrix (rep(1 , nrow(covariates )), ncol = 1 )
19601964 }
19611965
19621966 # Create prediction dataset
19631967 if (! is.null(leaf_basis )) {
1964- prediction_dataset <- createForestDataset(X , leaf_basis )
1968+ prediction_dataset <- createForestDataset(covariates , leaf_basis )
19651969 } else {
1966- prediction_dataset <- createForestDataset(X )
1970+ prediction_dataset <- createForestDataset(covariates )
19671971 }
19681972
19691973 # Compute variance forest predictions
0 commit comments