Skip to content

Commit 3c2d4ce

Browse files
committed
Updated serialization methods to reflect python BART updates
1 parent 8b2d62e commit 3c2d4ce

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

stochtree/bart.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,13 +2386,16 @@ def to_json(self) -> str:
23862386
bart_json.add_boolean("include_mean_forest", self.include_mean_forest)
23872387
bart_json.add_boolean("include_variance_forest", self.include_variance_forest)
23882388
bart_json.add_boolean("has_rfx", self.has_rfx)
2389+
bart_json.add_boolean("has_rfx_basis", self.has_rfx_basis)
2390+
bart_json.add_scalar("num_rfx_basis", self.num_rfx_basis)
23892391
bart_json.add_integer("num_gfr", self.num_gfr)
23902392
bart_json.add_integer("num_burnin", self.num_burnin)
23912393
bart_json.add_integer("num_mcmc", self.num_mcmc)
23922394
bart_json.add_integer("num_samples", self.num_samples)
23932395
bart_json.add_integer("num_basis", self.num_basis)
23942396
bart_json.add_boolean("requires_basis", self.has_basis)
23952397
bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model)
2398+
bart_json.add_string("rfx_model_spec", self.rfx_model_spec)
23962399

23972400
# Add parameter samples
23982401
if self.sample_sigma2_global:
@@ -2427,6 +2430,8 @@ def from_json(self, json_string: str) -> None:
24272430
self.include_mean_forest = bart_json.get_boolean("include_mean_forest")
24282431
self.include_variance_forest = bart_json.get_boolean("include_variance_forest")
24292432
self.has_rfx = bart_json.get_boolean("has_rfx")
2433+
self.has_rfx_basis = bart_json.get_boolean("has_rfx_basis")
2434+
self.num_rfx_basis = bart_json.get_scalar("num_rfx_basis")
24302435
if self.include_mean_forest:
24312436
# TODO: don't just make this a placeholder that we overwrite
24322437
self.forest_container_mean = ForestContainer(0, 0, False, False)
@@ -2465,6 +2470,7 @@ def from_json(self, json_string: str) -> None:
24652470
self.num_basis = bart_json.get_integer("num_basis")
24662471
self.has_basis = bart_json.get_boolean("requires_basis")
24672472
self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model")
2473+
self.rfx_model_spec = bart_json.get_string("rfx_model_spec")
24682474

24692475
# Unpack parameter samples
24702476
if self.sample_sigma2_global:
@@ -2550,6 +2556,8 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
25502556

25512557
# Unpack random effects
25522558
self.has_rfx = json_object_default.get_boolean("has_rfx")
2559+
self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis")
2560+
self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis")
25532561
if self.has_rfx:
25542562
self.rfx_container = RandomEffectsContainer()
25552563
for i in range(len(json_object_list)):
@@ -2575,6 +2583,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
25752583
self.probit_outcome_model = json_object_default.get_boolean(
25762584
"probit_outcome_model"
25772585
)
2586+
self.rfx_model_spec = json_object_default.get_string(
2587+
"rfx_model_spec"
2588+
)
25782589

25792590
# Unpack number of samples
25802591
for i in range(len(json_object_list)):

test/python/test_bart.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,7 @@ def conditional_stddev(X):
11231123

11241124
# Specify scalar rfx parameters
11251125
rfx_params = {
1126+
"model_spec": "custom",
11261127
"working_parameter_prior_mean": 1.,
11271128
"group_parameter_prior_mean": 1.,
11281129
"working_parameter_prior_cov": 1.,
@@ -1144,11 +1145,12 @@ def conditional_stddev(X):
11441145
num_gfr=num_gfr,
11451146
num_burnin=num_burnin,
11461147
num_mcmc=num_mcmc,
1147-
rfx_params=rfx_params,
1148+
random_effects_params=rfx_params,
11481149
)
11491150

11501151
# Specify all relevant rfx parameters as vectors
11511152
rfx_params = {
1153+
"model_spec": "custom",
11521154
"working_parameter_prior_mean": np.repeat(1., num_rfx_basis),
11531155
"group_parameter_prior_mean": np.repeat(1., num_rfx_basis),
11541156
"working_parameter_prior_cov": np.identity(num_rfx_basis),
@@ -1170,5 +1172,32 @@ def conditional_stddev(X):
11701172
num_gfr=num_gfr,
11711173
num_burnin=num_burnin,
11721174
num_mcmc=num_mcmc,
1173-
rfx_params=rfx_params,
1175+
random_effects_params=rfx_params,
11741176
)
1177+
1178+
# Fit a simpler intercept-only RFX model
1179+
rfx_params = {
1180+
"model_spec": "intercept_only"
1181+
}
1182+
bart_model_4 = BARTModel()
1183+
bart_model_4.sample(
1184+
X_train=X_train,
1185+
y_train=y_train,
1186+
leaf_basis_train=basis_train,
1187+
rfx_group_ids_train=group_labels_train,
1188+
X_test=X_test,
1189+
leaf_basis_test=basis_test,
1190+
rfx_group_ids_test=group_labels_test,
1191+
num_gfr=num_gfr,
1192+
num_burnin=num_burnin,
1193+
num_mcmc=num_mcmc,
1194+
random_effects_params=rfx_params,
1195+
)
1196+
preds = bart_model_4.predict(
1197+
covariates=X_test,
1198+
basis=basis_test,
1199+
rfx_group_ids=group_labels_test,
1200+
type="posterior",
1201+
terms="rfx"
1202+
)
1203+
assert preds.shape == (n_test, num_mcmc)

0 commit comments

Comments
 (0)