Skip to content

Commit 5814f41

Browse files
committed
Updated python BART interface for user-provided RFX parameters and unit tests
1 parent f9e6206 commit 5814f41

File tree

2 files changed

+163
-13
lines changed

2 files changed

+163
-13
lines changed

stochtree/bart.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel
2525
from .serialization import JSONSerializer
26-
from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d_diag
26+
from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag
2727

2828

2929
class BARTModel:
@@ -260,6 +260,12 @@ def sample(
260260
keep_every = general_params_updated["keep_every"]
261261
num_chains = general_params_updated["num_chains"]
262262
self.probit_outcome_model = general_params_updated["probit_outcome_model"]
263+
rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"]
264+
rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"]
265+
rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"]
266+
rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"]
267+
rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"]
268+
rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"]
263269

264270
# 2. Mean forest parameters
265271
num_trees_mean = mean_forest_params_updated["num_trees"]
@@ -291,14 +297,6 @@ def sample(
291297
drop_vars_variance = variance_forest_params_updated["drop_vars"]
292298
num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"]
293299

294-
# 4. Random effects parameters
295-
rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"]
296-
rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"]
297-
rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"]
298-
rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"]
299-
rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"]
300-
rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"]
301-
302300
# Override keep_gfr if there are no MCMC samples
303301
if num_mcmc == 0:
304302
keep_gfr = True
@@ -993,10 +991,7 @@ def sample(
993991
if rfx_group_parameter_prior_mean is None:
994992
xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups))
995993
else:
996-
xi_init = _expand_dims_1d(rfx_group_parameter_prior_mean, num_rfx_components)
997-
# If it's a vector, expand to matrix
998-
if xi_init.ndim == 1:
999-
xi_init = np.tile(np.expand_dims(xi_init, 1), (1, num_rfx_groups))
994+
xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
1000995

1001996
if rfx_working_parameter_prior_cov is None:
1002997
sigma_alpha_init = np.identity(num_rfx_components)

test/python/test_bart.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
from sklearn.model_selection import train_test_split
34

45
from stochtree import BARTModel
@@ -1009,3 +1010,157 @@ def conditional_stddev(X):
10091010
np.testing.assert_allclose(
10101011
rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2
10111012
)
1013+
1014+
def test_bart_rfx_parameters(self):
1015+
# RNG
1016+
random_seed = 101
1017+
rng = np.random.default_rng(random_seed)
1018+
1019+
# Generate covariates and basis
1020+
n = 100
1021+
p_X = 10
1022+
p_W = 1
1023+
X = rng.uniform(0, 1, (n, p_X))
1024+
W = rng.uniform(0, 1, (n, p_W))
1025+
1026+
# Generate RFX group labels and basis term
1027+
num_rfx_basis = 2
1028+
num_rfx_groups = 4
1029+
group_labels = rng.choice(num_rfx_groups, size=n)
1030+
rfx_basis = np.empty((n, num_rfx_basis))
1031+
rfx_basis[:, 0] = 1.0
1032+
if num_rfx_basis > 1:
1033+
rfx_basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1))
1034+
1035+
# Define the outcome mean function
1036+
def outcome_mean(X, W):
1037+
return np.where(
1038+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
1039+
-7.5 * W[:, 0],
1040+
np.where(
1041+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
1042+
-2.5 * W[:, 0],
1043+
np.where(
1044+
(X[:, 0] >= 0.5) & (X[:, 0] < 0.75),
1045+
2.5 * W[:, 0],
1046+
7.5 * W[:, 0],
1047+
),
1048+
),
1049+
)
1050+
1051+
# Define the group rfx function
1052+
def rfx_term(group_labels, basis):
1053+
return np.where(
1054+
group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1]
1055+
)
1056+
1057+
# Define the conditional standard deviation function
1058+
def conditional_stddev(X):
1059+
return np.where(
1060+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
1061+
0.25,
1062+
np.where(
1063+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
1064+
0.5,
1065+
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 1, 2),
1066+
),
1067+
)
1068+
1069+
# Generate outcome
1070+
epsilon = rng.normal(0, 1, n)
1071+
y = (
1072+
outcome_mean(X, W)
1073+
+ rfx_term(group_labels, rfx_basis)
1074+
+ epsilon * conditional_stddev(X)
1075+
)
1076+
1077+
# Test-train split
1078+
sample_inds = np.arange(n)
1079+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
1080+
X_train = X[train_inds, :]
1081+
X_test = X[test_inds, :]
1082+
basis_train = W[train_inds, :]
1083+
basis_test = W[test_inds, :]
1084+
group_labels_train = group_labels[train_inds]
1085+
group_labels_test = group_labels[test_inds]
1086+
rfx_basis_train = rfx_basis[train_inds, :]
1087+
rfx_basis_test = rfx_basis[test_inds, :]
1088+
y_train = y[train_inds]
1089+
n_train = X_train.shape[0]
1090+
n_test = X_test.shape[0]
1091+
1092+
# BART settings
1093+
num_gfr = 10
1094+
num_burnin = 0
1095+
num_mcmc = 10
1096+
1097+
# Specify no rfx parameters
1098+
general_params = {}
1099+
bart_model = BARTModel()
1100+
bart_model.sample(
1101+
X_train=X_train,
1102+
y_train=y_train,
1103+
leaf_basis_train=basis_train,
1104+
rfx_group_ids_train=group_labels_train,
1105+
rfx_basis_train=rfx_basis_train,
1106+
X_test=X_test,
1107+
leaf_basis_test=basis_test,
1108+
rfx_group_ids_test=group_labels_test,
1109+
rfx_basis_test=rfx_basis_test,
1110+
num_gfr=num_gfr,
1111+
num_burnin=num_burnin,
1112+
num_mcmc=num_mcmc,
1113+
general_params=general_params,
1114+
)
1115+
1116+
# Specify scalar rfx parameters
1117+
general_params = {
1118+
"rfx_working_parameter_prior_mean": 1.,
1119+
"rfx_group_parameter_prior_mean": 1.,
1120+
"rfx_working_parameter_prior_cov": 1.,
1121+
"rfx_group_parameter_prior_cov": 1.,
1122+
"rfx_variance_prior_shape": 1,
1123+
"rfx_variance_prior_scale": 1
1124+
}
1125+
bart_model_2 = BARTModel()
1126+
bart_model_2.sample(
1127+
X_train=X_train,
1128+
y_train=y_train,
1129+
leaf_basis_train=basis_train,
1130+
rfx_group_ids_train=group_labels_train,
1131+
rfx_basis_train=rfx_basis_train,
1132+
X_test=X_test,
1133+
leaf_basis_test=basis_test,
1134+
rfx_group_ids_test=group_labels_test,
1135+
rfx_basis_test=rfx_basis_test,
1136+
num_gfr=num_gfr,
1137+
num_burnin=num_burnin,
1138+
num_mcmc=num_mcmc,
1139+
general_params=general_params,
1140+
)
1141+
1142+
# Specify all relevant rfx parameters as vectors
1143+
general_params = {
1144+
"rfx_working_parameter_prior_mean": np.repeat(1., num_rfx_basis),
1145+
"rfx_group_parameter_prior_mean": np.repeat(1., num_rfx_basis),
1146+
"rfx_working_parameter_prior_cov": np.identity(num_rfx_basis),
1147+
"rfx_group_parameter_prior_cov": np.identity(num_rfx_basis),
1148+
"rfx_variance_prior_shape": 1,
1149+
"rfx_variance_prior_scale": 1
1150+
}
1151+
bart_model_2 = BARTModel()
1152+
bart_model_2.sample(
1153+
X_train=X_train,
1154+
y_train=y_train,
1155+
leaf_basis_train=basis_train,
1156+
rfx_group_ids_train=group_labels_train,
1157+
rfx_basis_train=rfx_basis_train,
1158+
X_test=X_test,
1159+
leaf_basis_test=basis_test,
1160+
rfx_group_ids_test=group_labels_test,
1161+
rfx_basis_test=rfx_basis_test,
1162+
num_gfr=num_gfr,
1163+
num_burnin=num_burnin,
1164+
num_mcmc=num_mcmc,
1165+
general_params=general_params,
1166+
)

0 commit comments

Comments
 (0)