@@ -1005,6 +1005,10 @@ def sample(
10051005 self .global_var_samples = np .empty (self .num_samples , dtype = np .float64 )
10061006 if sample_sigma2_leaf :
10071007 self .leaf_scale_samples = np .empty (self .num_samples , dtype = np .float64 )
1008+ if self .include_mean_forest :
1009+ yhat_train_raw = np .empty ((self .n_train , self .num_samples ), dtype = np .float64 )
1010+ if self .include_variance_forest :
1011+ sigma2_x_train_raw = np .empty ((self .n_train , self .num_samples ), dtype = np .float64 )
10081012 sample_counter = - 1
10091013
10101014 # Forest Dataset (covariates and optional basis)
@@ -1187,6 +1191,9 @@ def sample(
11871191 True ,
11881192 )
11891193
1194+ if keep_sample :
1195+ yhat_train_raw [:,sample_counter ] = forest_sampler_mean .get_cached_forest_predictions ()
1196+
11901197 # Sample the variance forest
11911198 if self .include_variance_forest :
11921199 forest_sampler_variance .sample_one_iteration (
@@ -1201,6 +1208,9 @@ def sample(
12011208 True ,
12021209 )
12031210
1211+ if keep_sample :
1212+ sigma2_x_train_raw [:,sample_counter ] = forest_sampler_variance .get_cached_forest_predictions ()
1213+
12041214 # Sample variance parameters (if requested)
12051215 if self .sample_sigma2_global :
12061216 current_sigma2 = global_var_model .sample_one_iteration (
@@ -1441,6 +1451,10 @@ def sample(
14411451 self .global_var_samples = self .global_var_samples [num_gfr :]
14421452 if self .sample_sigma2_leaf :
14431453 self .leaf_scale_samples = self .leaf_scale_samples [num_gfr :]
1454+ if self .include_mean_forest :
1455+ yhat_train_raw = yhat_train_raw [:,num_gfr :]
1456+ if self .include_variance_forest :
1457+ sigma2_x_train_raw = sigma2_x_train_raw [:,num_gfr :]
14441458 self .num_samples -= num_gfr
14451459
14461460 # Store predictions
@@ -1451,9 +1465,6 @@ def sample(
14511465 self .leaf_scale_samples = self .leaf_scale_samples
14521466
14531467 if self .include_mean_forest :
1454- yhat_train_raw = self .forest_container_mean .forest_container_cpp .Predict (
1455- forest_dataset_train .dataset_cpp
1456- )
14571468 self .y_hat_train = yhat_train_raw * self .y_std + self .y_bar
14581469 if self .has_test :
14591470 yhat_test_raw = self .forest_container_mean .forest_container_cpp .Predict (
@@ -1482,11 +1493,7 @@ def sample(
14821493 self .y_hat_test = rfx_preds_test
14831494
14841495 if self .include_variance_forest :
1485- sigma2_x_train_raw = (
1486- self .forest_container_variance .forest_container_cpp .Predict (
1487- forest_dataset_train .dataset_cpp
1488- )
1489- )
1496+ sigma2_x_train_raw = np .exp (sigma2_x_train_raw )
14901497 if self .sample_sigma2_global :
14911498 self .sigma2_x_train = sigma2_x_train_raw
14921499 for i in range (self .num_samples ):
0 commit comments