@@ -707,6 +707,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
707707 num_retained_samples <- num_gfr + ifelse(keep_burnin , num_burnin , 0 ) + num_mcmc * num_chains
708708 if (sample_sigma2_global ) global_var_samples <- rep(NA , num_retained_samples )
709709 if (sample_sigma2_leaf ) leaf_scale_samples <- rep(NA , num_retained_samples )
710+ if (include_mean_forest ) mean_forest_pred_train <- matrix (NA_real_ , nrow(X_train ), num_retained_samples )
711+ if (include_variance_forest ) variance_forest_pred_train <- matrix (NA_real_ , nrow(X_train ), num_retained_samples )
710712 sample_counter <- 0
711713
712714 # Initialize the leaves of each tree in the mean forest
@@ -757,13 +759,23 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
757759 active_forest = active_forest_mean , rng = rng , forest_model_config = forest_model_config_mean ,
758760 global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
759761 )
762+
763+ # Cache predictions
764+ if (keep_sample ) {
765+ mean_forest_pred_train [,sample_counter ] <- forest_model_mean $ get_cached_forest_predictions()
766+ }
760767 }
761768 if (include_variance_forest ) {
762769 forest_model_variance $ sample_one_iteration(
763770 forest_dataset = forest_dataset_train , residual = outcome_train , forest_samples = forest_samples_variance ,
764771 active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
765772 global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
766773 )
774+
775+ # Cache predictions
776+ if (keep_sample ) {
777+ variance_forest_pred_train [,sample_counter ] <- forest_model_variance $ get_cached_forest_predictions()
778+ }
767779 }
768780 if (sample_sigma2_global ) {
769781 current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -910,13 +922,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
910922 active_forest = active_forest_mean , rng = rng , forest_model_config = forest_model_config_mean ,
911923 global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
912924 )
925+
926+ if (keep_sample ) {
927+ mean_forest_pred_train [,sample_counter ] <- forest_model_mean $ get_cached_forest_predictions()
928+ }
913929 }
914930 if (include_variance_forest ) {
915931 forest_model_variance $ sample_one_iteration(
916932 forest_dataset = forest_dataset_train , residual = outcome_train , forest_samples = forest_samples_variance ,
917933 active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
918934 global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
919935 )
936+
937+ if (keep_sample ) {
938+ variance_forest_pred_train [,sample_counter ] <- forest_model_variance $ get_cached_forest_predictions()
939+ }
920940 }
921941 if (sample_sigma2_global ) {
922942 current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -949,6 +969,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
949969 rfx_samples $ delete_sample(0 )
950970 }
951971 }
972+ if (include_mean_forest ) {
973+ mean_forest_pred_train <- mean_forest_pred_train [,(num_gfr + 1 ): ncol(mean_forest_pred_train )]
974+ }
975+ if (include_variance_forest ) {
976+ variance_forest_pred_train <- variance_forest_pred_train [,(num_gfr + 1 ): ncol(variance_forest_pred_train )]
977+ }
952978 if (sample_sigma2_global ) {
953979 global_var_samples <- global_var_samples [(num_gfr + 1 ): length(global_var_samples )]
954980 }
@@ -960,13 +986,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
960986
961987 # Mean forest predictions
962988 if (include_mean_forest ) {
963- y_hat_train <- forest_samples_mean $ predict(forest_dataset_train )* y_std_train + y_bar_train
989+ # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
990+ y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train
964991 if (has_test ) y_hat_test <- forest_samples_mean $ predict(forest_dataset_test )* y_std_train + y_bar_train
965992 }
966993
967994 # Variance forest predictions
968995 if (include_variance_forest ) {
969- sigma2_x_hat_train <- forest_samples_variance $ predict(forest_dataset_train )
996+ # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
997+ sigma2_x_hat_train <- variance_forest_pred_train
970998 if (has_test ) sigma2_x_hat_test <- forest_samples_variance $ predict(forest_dataset_test )
971999 }
9721000
0 commit comments