Skip to content

Commit 2ff2f49

Browse files
committed
Updated BART and BCF caching logic, added unit tests, and fixed bugs
1 parent 005f700 commit 2ff2f49

File tree

7 files changed

+323
-47
lines changed

7 files changed

+323
-47
lines changed

R/bart.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
760760
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
761761
)
762762

763-
# Cache predictions
763+
# Cache train set predictions since they are already computed during sampling
764764
if (keep_sample) {
765765
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
766766
}
@@ -772,7 +772,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
772772
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
773773
)
774774

775-
# Cache predictions
775+
# Cache train set predictions since they are already computed during sampling
776776
if (keep_sample) {
777777
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
778778
}
@@ -923,6 +923,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
923923
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
924924
)
925925

926+
# Cache train set predictions since they are already computed during sampling
926927
if (keep_sample) {
927928
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
928929
}
@@ -934,6 +935,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
934935
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
935936
)
936937

938+
# Cache train set predictions since they are already computed during sampling
937939
if (keep_sample) {
938940
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
939941
}

R/bcf.R

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
885885
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
886886
if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples)
887887
if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples)
888+
muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
889+
if (include_variance_forest) sigma2_x_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
888890
sample_counter <- 0
889891

890892
# Prepare adaptive coding structure
@@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
997999
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
9981000
)
9991001

1002+
# Cache train set predictions since they are already computed during sampling
1003+
if (keep_sample) {
1004+
muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
1005+
}
1006+
10001007
# Sample variance parameters (if requested)
10011008
if (sample_sigma2_global) {
10021009
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10161023
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
10171024
)
10181025

1026+
# Cannot cache train set predictions for tau because the cached predictions in the
1027+
# tracking data structures are pre-multiplied by the basis (treatment)
1028+
# ...
1029+
10191030
# Sample coding parameters (if requested)
10201031
if (adaptive_coding) {
10211032
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10601071
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
10611072
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
10621073
)
1074+
1075+
# Cache train set predictions since they are already computed during sampling
1076+
if (keep_sample) {
1077+
sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
1078+
}
10631079
}
10641080
if (sample_sigma2_global) {
10651081
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12631279
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
12641280
)
12651281

1282+
# Cache train set predictions since they are already computed during sampling
1283+
if (keep_sample) {
1284+
muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
1285+
}
1286+
12661287
# Sample variance parameters (if requested)
12671288
if (sample_sigma2_global) {
12681289
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12821303
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
12831304
)
12841305

1306+
# Cannot cache train set predictions for tau because the cached predictions in the
1307+
# tracking data structures are pre-multiplied by the basis (treatment)
1308+
# ...
1309+
12851310
# Sample coding parameters (if requested)
12861311
if (adaptive_coding) {
12871312
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13261351
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
13271352
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
13281353
)
1354+
1355+
# Cache train set predictions since they are already computed during sampling
1356+
if (keep_sample) {
1357+
sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
1358+
}
13291359
}
13301360
if (sample_sigma2_global) {
13311361
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13721402
b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)]
13731403
b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)]
13741404
}
1405+
muhat_train_raw <- muhat_train_raw[,(num_gfr+1):ncol(muhat_train_raw)]
1406+
if (include_variance_forest) {
1407+
sigma2_x_train_raw <- sigma2_x_train_raw[,(num_gfr+1):ncol(sigma2_x_train_raw)]
1408+
}
13751409
num_retained_samples <- num_retained_samples - num_gfr
13761410
}
13771411

13781412
# Forest predictions
1379-
mu_hat_train <- forest_samples_mu$predict(forest_dataset_train)*y_std_train + y_bar_train
1413+
mu_hat_train <- muhat_train_raw*y_std_train + y_bar_train
13801414
if (adaptive_coding) {
13811415
tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_train)
13821416
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train
@@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13951429
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test)
13961430
}
13971431
if (include_variance_forest) {
1398-
sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
1432+
sigma2_x_hat_train <- exp(sigma2_x_train_raw)
13991433
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
14001434
}
14011435

src/container.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
206206
CHECK_EQ(this->num_trees_, forest_container_json.at("num_trees"));
207207
CHECK_EQ(this->output_dimension_, forest_container_json.at("output_dimension"));
208208
CHECK_EQ(this->is_leaf_constant_, forest_container_json.at("is_leaf_constant"));
209+
CHECK_EQ(this->is_exponentiated_, forest_container_json.at("is_exponentiated"));
209210
CHECK_EQ(this->initialized_, forest_container_json.at("initialized"));
210211
int new_num_samples = forest_container_json.at("num_samples");
211212

@@ -215,8 +216,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
215216
for (int i = 0; i < forest_container_json.at("num_samples"); i++) {
216217
forest_ind = this->num_samples_ + i;
217218
forest_label = "forest_" + std::to_string(i);
218-
// forests_[forest_ind] = std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_);
219-
forests_.push_back(std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_));
219+
forests_.push_back(std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_, this->is_exponentiated_));
220220
forests_[forest_ind]->from_json(forest_container_json.at(forest_label));
221221
}
222222
this->num_samples_ += new_num_samples;

stochtree/bart.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ def sample(
11911191
True,
11921192
)
11931193

1194+
# Cache train set predictions since they are already computed during sampling
11941195
if keep_sample:
11951196
yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
11961197

@@ -1208,6 +1209,7 @@ def sample(
12081209
True,
12091210
)
12101211

1212+
# Cache train set predictions since they are already computed during sampling
12111213
if keep_sample:
12121214
sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
12131215

@@ -1389,6 +1391,9 @@ def sample(
13891391
False,
13901392
)
13911393

1394+
if keep_sample:
1395+
yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions()
1396+
13921397
# Sample the variance forest
13931398
if self.include_variance_forest:
13941399
forest_sampler_variance.sample_one_iteration(
@@ -1403,6 +1408,9 @@ def sample(
14031408
False,
14041409
)
14051410

1411+
if keep_sample:
1412+
sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions()
1413+
14061414
# Sample variance parameters (if requested)
14071415
if self.sample_sigma2_global:
14081416
current_sigma2 = global_var_model.sample_one_iteration(
@@ -1493,16 +1501,15 @@ def sample(
14931501
self.y_hat_test = rfx_preds_test
14941502

14951503
if self.include_variance_forest:
1496-
sigma2_x_train_raw = np.exp(sigma2_x_train_raw)
14971504
if self.sample_sigma2_global:
1498-
self.sigma2_x_train = sigma2_x_train_raw
1505+
self.sigma2_x_train = np.empty_like(sigma2_x_train_raw)
14991506
for i in range(self.num_samples):
15001507
self.sigma2_x_train[:, i] = (
1501-
sigma2_x_train_raw[:, i] * self.global_var_samples[i]
1508+
np.exp(sigma2_x_train_raw[:, i]) * self.global_var_samples[i]
15021509
)
15031510
else:
15041511
self.sigma2_x_train = (
1505-
sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std
1512+
np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std
15061513
)
15071514
if self.has_test:
15081515
sigma2_x_test_raw = (
@@ -1628,14 +1635,14 @@ def predict(
16281635
)
16291636
)
16301637
if self.sample_sigma2_global:
1631-
variance_pred = variance_pred_raw
1638+
variance_pred = np.empty_like(variance_pred_raw)
16321639
for i in range(self.num_samples):
1633-
variance_pred[:, i] = np.sqrt(
1640+
variance_pred[:, i] = (
16341641
variance_pred_raw[:, i] * self.global_var_samples[i]
16351642
)
16361643
else:
16371644
variance_pred = (
1638-
np.sqrt(variance_pred_raw * self.sigma2_init) * self.y_std
1645+
variance_pred_raw * self.sigma2_init * self.y_std * self.y_std
16391646
)
16401647

16411648
has_mean_predictions = self.include_mean_forest or self.has_rfx
@@ -1817,7 +1824,7 @@ def predict_variance(self, covariates: np.array) -> np.array:
18171824
pred_dataset.dataset_cpp
18181825
)
18191826
if self.sample_sigma2_global:
1820-
variance_pred = variance_pred_raw
1827+
variance_pred = np.empty_like(variance_pred_raw)
18211828
for i in range(self.num_samples):
18221829
variance_pred[:, i] = (
18231830
variance_pred_raw[:, i] * self.global_var_samples[i]
@@ -2024,11 +2031,11 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
20242031
for i in range(len(json_object_list)):
20252032
if i == 0:
20262033
self.forest_container_variance.forest_container_cpp.LoadFromJson(
2027-
json_object_list[i].json_cpp, "forest_1"
2034+
json_object_list[i].json_cpp, "forest_0"
20282035
)
20292036
else:
20302037
self.forest_container_variance.forest_container_cpp.AppendFromJson(
2031-
json_object_list[i].json_cpp, "forest_1"
2038+
json_object_list[i].json_cpp, "forest_0"
20322039
)
20332040

20342041
# Unpack random effects
@@ -2053,13 +2060,19 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
20532060
self.num_gfr = json_object_default.get_integer("num_gfr")
20542061
self.num_burnin = json_object_default.get_integer("num_burnin")
20552062
self.num_mcmc = json_object_default.get_integer("num_mcmc")
2056-
self.num_samples = json_object_default.get_integer("num_samples")
20572063
self.num_basis = json_object_default.get_integer("num_basis")
20582064
self.has_basis = json_object_default.get_boolean("requires_basis")
20592065
self.probit_outcome_model = json_object_default.get_boolean(
20602066
"probit_outcome_model"
20612067
)
20622068

2069+
# Unpack number of samples
2070+
for i in range(len(json_object_list)):
2071+
if i == 0:
2072+
self.num_samples = json_object_list[i].get_integer("num_samples")
2073+
else:
2074+
self.num_samples += json_object_list[i].get_integer("num_samples")
2075+
20632076
# Unpack parameter samples
20642077
if self.sample_sigma2_global:
20652078
for i in range(len(json_object_list)):

0 commit comments

Comments
 (0)