Skip to content

Commit 58d221c

Browse files
committed
Added rfx model_spec to python BCF interface and included demo script to test it
1 parent 998ebc3 commit 58d221c

File tree

5 files changed

+277
-37
lines changed

5 files changed

+277
-37
lines changed

R/bcf.R

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,19 @@ bcf <- function(
358358
rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape
359359
rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale
360360

361+
# Handle random effects specification
362+
if (!is.character(rfx_model_spec)) {
363+
stop("rfx_model_spec must be a string or character vector")
364+
}
365+
if (
366+
!(rfx_model_spec %in%
367+
c("custom", "intercept_only", "intercept_plus_treatment"))
368+
) {
369+
stop(
370+
"rfx_model_spec must either be 'custom', 'intercept_only', or 'intercept_plus_treatment'"
371+
)
372+
}
373+
361374
# Set a function-scoped RNG if user provided a random seed
362375
custom_rng <- random_seed >= 0
363376
if (custom_rng) {
@@ -2760,9 +2773,8 @@ predict.bcfmodel <- function(
27602773
))
27612774
return(NULL)
27622775
}
2763-
predict_rfx_intermediate <- ((predict_y_hat && has_rfx))
2764-
predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept_only) ||
2765-
(predict_mu_forest && has_rfx && rfx_intercept_plus_treatment) ||
2776+
predict_rfx_intermediate <- (predict_y_hat && has_rfx)
2777+
predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept) ||
27662778
(predict_tau_forest && has_rfx && rfx_intercept_plus_treatment))
27672779
predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest)
27682780
predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest)
@@ -2946,12 +2958,12 @@ predict.bcfmodel <- function(
29462958

29472959
# Extract "raw" rfx coefficients for each rfx basis term if needed
29482960
if (predict_rfx_raw) {
2949-
# Extract the raw RFX samples and scale by train set outcome sd
2961+
# Extract the raw RFX samples and scale by train set outcome standard deviation
29502962
rfx_param_list <- object$rfx_samples$extract_parameter_samples()
29512963
rfx_beta_draws <- rfx_param_list$beta_samples *
29522964
object$model_params$outcome_scale
29532965

2954-
# Construct a matrix with the correct random effects
2966+
# Construct a matrix with the appropriate group random effects arranged for each observation
29552967
rfx_predictions_raw <- array(
29562968
NA,
29572969
dim = c(

demo/debug/bcf_contrast_debug.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,47 @@
153153
contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test
154154
np.allclose(contrast_diff, 0, atol=0.001)
155155

156+
# Now repeat the same process but via random effects model spec
157+
bcf_model = BCFModel()
158+
bcf_model.sample(
159+
X_train=X_train,
160+
Z_train=Z_train,
161+
y_train=y_train,
162+
rfx_group_ids_train=group_ids_train,
163+
num_gfr=10,
164+
num_burnin=0,
165+
num_mcmc=1000,
166+
random_effects_params={"model_spec": "intercept_plus_treatment"},
167+
)
168+
169+
# Compute CATE posterior
170+
tau_hat_posterior_test = bcf_model.compute_contrast(
171+
X_0=X_test,
172+
X_1=X_test,
173+
Z_0=np.zeros((n_test, 1)),
174+
Z_1=np.ones((n_test, 1)),
175+
rfx_group_ids_0=group_ids_test,
176+
rfx_group_ids_1=group_ids_test,
177+
rfx_basis_0=np.concatenate((np.ones((n_test, 1)), np.zeros((n_test, 1))), axis=1),
178+
rfx_basis_1=np.ones((n_test, 2)),
179+
type="posterior",
180+
scale="linear",
181+
)
182+
183+
# Compute the same quantity via predict
184+
tau_hat_posterior_test_comparison = bcf_model.predict(
185+
X=X_test,
186+
Z=Z_test,
187+
rfx_group_ids=group_ids_test,
188+
type="posterior",
189+
terms="cate",
190+
scale="linear",
191+
)
192+
193+
# Compare results
194+
contrast_diff = tau_hat_posterior_test_comparison - tau_hat_posterior_test
195+
np.allclose(contrast_diff, 0, atol=0.001)
196+
156197
# Generate data for a probit BCF model with random effects
157198
X = rng.uniform(low=0.0, high=1.0, size=(n, p))
158199
mu_x = X[:, 0]

src/py_stochtree.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,64 @@ class RandomEffectsContainerCpp {
14181418
int NumGroups() {
14191419
return rfx_container_->NumGroups();
14201420
}
1421+
py::array_t<double> GetBeta() {
1422+
int num_samples = rfx_container_->NumSamples();
1423+
int num_components = rfx_container_->NumComponents();
1424+
int num_groups = rfx_container_->NumGroups();
1425+
std::vector<double> beta_raw = rfx_container_->GetBeta();
1426+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_components, num_groups, num_samples}));
1427+
auto accessor = result.mutable_unchecked<3>();
1428+
for (int i = 0; i < num_components; i++) {
1429+
for (int j = 0; j < num_groups; j++) {
1430+
for (int k = 0; k < num_samples; k++) {
1431+
accessor(i,j,k) = beta_raw[k*num_groups*num_components + j*num_components + i];
1432+
}
1433+
}
1434+
}
1435+
return result;
1436+
}
1437+
py::array_t<double> GetXi() {
1438+
int num_samples = rfx_container_->NumSamples();
1439+
int num_components = rfx_container_->NumComponents();
1440+
int num_groups = rfx_container_->NumGroups();
1441+
std::vector<double> xi_raw = rfx_container_->GetXi();
1442+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_components, num_groups, num_samples}));
1443+
auto accessor = result.mutable_unchecked<3>();
1444+
for (int i = 0; i < num_components; i++) {
1445+
for (int j = 0; j < num_groups; j++) {
1446+
for (int k = 0; k < num_samples; k++) {
1447+
accessor(i,j,k) = xi_raw[k*num_groups*num_components + j*num_components + i];
1448+
}
1449+
}
1450+
}
1451+
return result;
1452+
}
1453+
py::array_t<double> GetAlpha() {
1454+
int num_samples = rfx_container_->NumSamples();
1455+
int num_components = rfx_container_->NumComponents();
1456+
std::vector<double> alpha_raw = rfx_container_->GetAlpha();
1457+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_components, num_samples}));
1458+
auto accessor = result.mutable_unchecked<2>();
1459+
for (int i = 0; i < num_components; i++) {
1460+
for (int j = 0; j < num_samples; j++) {
1461+
accessor(i,j) = alpha_raw[j*num_components + i];
1462+
}
1463+
}
1464+
return result;
1465+
}
1466+
py::array_t<double> GetSigma() {
1467+
int num_samples = rfx_container_->NumSamples();
1468+
int num_components = rfx_container_->NumComponents();
1469+
std::vector<double> sigma_raw = rfx_container_->GetSigma();
1470+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({num_components, num_samples}));
1471+
auto accessor = result.mutable_unchecked<2>();
1472+
for (int i = 0; i < num_components; i++) {
1473+
for (int j = 0; j < num_samples; j++) {
1474+
accessor(i,j) = sigma_raw[j*num_components + i];
1475+
}
1476+
}
1477+
return result;
1478+
}
14211479
void DeleteSample(int sample_num) {
14221480
rfx_container_->DeleteSample(sample_num);
14231481
}
@@ -2294,6 +2352,10 @@ PYBIND11_MODULE(stochtree_cpp, m) {
22942352
.def("NumSamples", &RandomEffectsContainerCpp::NumSamples)
22952353
.def("NumComponents", &RandomEffectsContainerCpp::NumComponents)
22962354
.def("NumGroups", &RandomEffectsContainerCpp::NumGroups)
2355+
.def("GetBeta", &RandomEffectsContainerCpp::GetBeta)
2356+
.def("GetXi", &RandomEffectsContainerCpp::GetXi)
2357+
.def("GetAlpha", &RandomEffectsContainerCpp::GetAlpha)
2358+
.def("GetSigma", &RandomEffectsContainerCpp::GetSigma)
22972359
.def("DeleteSample", &RandomEffectsContainerCpp::DeleteSample)
22982360
.def("Predict", &RandomEffectsContainerCpp::Predict)
22992361
.def("SaveToJsonFile", &RandomEffectsContainerCpp::SaveToJsonFile)

0 commit comments

Comments
 (0)