@@ -423,37 +423,53 @@ bart <- function(
423423 floor(num_values / cutpoint_grid_size ),
424424 1
425425 )
426+ x_is_df <- is.data.frame(X_train )
426427 covs_warning_1 <- NULL
427428 covs_warning_2 <- NULL
428429 covs_warning_3 <- NULL
430+ covs_warning_4 <- NULL
429431 for (i in 1 : num_cov_orig ) {
430- # Determine the number of unique values
431- num_unique_values <- length(unique(X_train [, i ]))
432-
433- # Determine a "name" for the covariate
434- cov_name <- ifelse(
435- is.null(colnames(X_train )),
436- paste0(" X" , i ),
437- colnames(X_train )[i ]
438- )
439-
440- # Check for a small relative number of unique values
441- unique_full_ratio <- num_unique_values / num_values
442- if (unique_full_ratio < 0.2 ) {
443- covs_warning_1 <- c(covs_warning_1 , cov_name )
432+ # Skip check for variables that are treated as categorical
433+ x_numeric <- T
434+ if (x_is_df ) {
435+ if (is.factor(X_train [, i ])) {
436+ x_numeric <- F
437+ }
444438 }
439+ if (x_numeric ) {
440+ # Determine the number of unique values
441+ num_unique_values <- length(unique(X_train [, i ]))
442+
443+ # Determine a "name" for the covariate
444+ cov_name <- ifelse(
445+ is.null(colnames(X_train )),
446+ paste0(" X" , i ),
447+ colnames(X_train )[i ]
448+ )
445449
446- # Check for a small absolute number of unique values
447- if (num_values > 100 ) {
448- if (num_unique_values < 20 ) {
449- covs_warning_2 <- c(covs_warning_2 , cov_name )
450+ # Check for a small relative number of unique values
451+ unique_full_ratio <- num_unique_values / num_values
452+ if (unique_full_ratio < 0.2 ) {
453+ covs_warning_1 <- c(covs_warning_1 , cov_name )
454+ }
455+
456+ # Check for a small absolute number of unique values
457+ if (num_values > 100 ) {
458+ if (num_unique_values < 20 ) {
459+ covs_warning_2 <- c(covs_warning_2 , cov_name )
460+ }
461+ }
462+
463+ # Check for a large number of duplicates of any individual value
464+ x_j_hist <- table(X_train [, i ])
465+ if (any(x_j_hist > 2 * max_grid_size )) {
466+ covs_warning_3 <- c(covs_warning_3 , cov_name )
450467 }
451- }
452468
453- # Check for a large number of duplicates of any individual value
454- x_j_hist <- table( X_train [, i ])
455- if (any( x_j_hist > 2 * max_grid_size )) {
456- covs_warning_3 <- c( covs_warning_3 , cov_name )
469+ # Check for binary variables
470+ if ( num_unique_values == 2 ) {
471+ covs_warning_4 <- c( covs_warning_4 , cov_name )
472+ }
457473 }
458474 }
459475
@@ -494,6 +510,18 @@ bart <- function(
494510 )
495511 )
496512 }
513+
514+ if (! is.null(covs_warning_4 )) {
515+ warning(
516+ paste0(
517+ " Covariates " ,
518+ paste(covs_warning_4 , collapse = " , " ),
519+ " appear to be binary but are currently treated by stochtree as continuous. " ,
520+ " This might present some issues with the grow-from-root (GFR) algorithm. " ,
521+ " Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
522+ )
523+ )
524+ }
497525 }
498526
499527 # Standardize the keep variable lists to numeric indices
0 commit comments