@@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
885885 if (sample_sigma2_global ) global_var_samples <- rep(NA , num_retained_samples )
886886 if (sample_sigma2_leaf_mu ) leaf_scale_mu_samples <- rep(NA , num_retained_samples )
887887 if (sample_sigma2_leaf_tau ) leaf_scale_tau_samples <- rep(NA , num_retained_samples )
888+ muhat_train_raw <- matrix (NA_real_ , nrow(X_train ), num_retained_samples )
889+ if (include_variance_forest ) sigma2_x_train_raw <- matrix (NA_real_ , nrow(X_train ), num_retained_samples )
888890 sample_counter <- 0
889891
890892 # Prepare adaptive coding structure
@@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
997999 global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
9981000 )
9991001
1002+ # Cache train set predictions since they are already computed during sampling
1003+ if (keep_sample ) {
1004+ muhat_train_raw [,sample_counter ] <- forest_model_mu $ get_cached_forest_predictions()
1005+ }
1006+
10001007 # Sample variance parameters (if requested)
10011008 if (sample_sigma2_global ) {
10021009 current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10161023 global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
10171024 )
10181025
1026+ # Cannot cache train set predictions for tau because the cached predictions in the
1027+ # tracking data structures are pre-multiplied by the basis (treatment)
1028+ # ...
1029+
10191030 # Sample coding parameters (if requested)
10201031 if (adaptive_coding ) {
10211032 # Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10601071 active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
10611072 global_model_config = global_model_config , keep_forest = keep_sample , gfr = TRUE
10621073 )
1074+
1075+ # Cache train set predictions since they are already computed during sampling
1076+ if (keep_sample ) {
1077+ sigma2_x_train_raw [,sample_counter ] <- forest_model_variance $ get_cached_forest_predictions()
1078+ }
10631079 }
10641080 if (sample_sigma2_global ) {
10651081 current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12631279 global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
12641280 )
12651281
1282+ # Cache train set predictions since they are already computed during sampling
1283+ if (keep_sample ) {
1284+ muhat_train_raw [,sample_counter ] <- forest_model_mu $ get_cached_forest_predictions()
1285+ }
1286+
12661287 # Sample variance parameters (if requested)
12671288 if (sample_sigma2_global ) {
12681289 current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12821303 global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
12831304 )
12841305
1306+ # Cannot cache train set predictions for tau because the cached predictions in the
1307+ # tracking data structures are pre-multiplied by the basis (treatment)
1308+ # ...
1309+
12851310 # Sample coding parameters (if requested)
12861311 if (adaptive_coding ) {
12871312 # Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13261351 active_forest = active_forest_variance , rng = rng , forest_model_config = forest_model_config_variance ,
13271352 global_model_config = global_model_config , keep_forest = keep_sample , gfr = FALSE
13281353 )
1354+
1355+ # Cache train set predictions since they are already computed during sampling
1356+ if (keep_sample ) {
1357+ sigma2_x_train_raw [,sample_counter ] <- forest_model_variance $ get_cached_forest_predictions()
1358+ }
13291359 }
13301360 if (sample_sigma2_global ) {
13311361 current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train , forest_dataset_train , rng , a_global , b_global )
@@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13721402 b_1_samples <- b_1_samples [(num_gfr + 1 ): length(b_1_samples )]
13731403 b_0_samples <- b_0_samples [(num_gfr + 1 ): length(b_0_samples )]
13741404 }
1405+ muhat_train_raw <- muhat_train_raw [,(num_gfr + 1 ): ncol(muhat_train_raw )]
1406+ if (include_variance_forest ) {
1407+ sigma2_x_train_raw <- sigma2_x_train_raw [,(num_gfr + 1 ): ncol(sigma2_x_train_raw )]
1408+ }
13751409 num_retained_samples <- num_retained_samples - num_gfr
13761410 }
13771411
13781412 # Forest predictions
1379- mu_hat_train <- forest_samples_mu $ predict( forest_dataset_train ) * y_std_train + y_bar_train
1413+ mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train
13801414 if (adaptive_coding ) {
13811415 tau_hat_train_raw <- forest_samples_tau $ predict_raw(forest_dataset_train )
13821416 tau_hat_train <- t(t(tau_hat_train_raw ) * (b_1_samples - b_0_samples ))* y_std_train
@@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13951429 y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test )
13961430 }
13971431 if (include_variance_forest ) {
1398- sigma2_x_hat_train <- forest_samples_variance $ predict( forest_dataset_train )
1432+ sigma2_x_hat_train <- exp( sigma2_x_train_raw )
13991433 if (has_test ) sigma2_x_hat_test <- forest_samples_variance $ predict(forest_dataset_test )
14001434 }
14011435
0 commit comments