@@ -2386,13 +2386,16 @@ def to_json(self) -> str:
23862386 bart_json .add_boolean ("include_mean_forest" , self .include_mean_forest )
23872387 bart_json .add_boolean ("include_variance_forest" , self .include_variance_forest )
23882388 bart_json .add_boolean ("has_rfx" , self .has_rfx )
2389+ bart_json .add_boolean ("has_rfx_basis" , self .has_rfx_basis )
2390+ bart_json .add_scalar ("num_rfx_basis" , self .num_rfx_basis )
23892391 bart_json .add_integer ("num_gfr" , self .num_gfr )
23902392 bart_json .add_integer ("num_burnin" , self .num_burnin )
23912393 bart_json .add_integer ("num_mcmc" , self .num_mcmc )
23922394 bart_json .add_integer ("num_samples" , self .num_samples )
23932395 bart_json .add_integer ("num_basis" , self .num_basis )
23942396 bart_json .add_boolean ("requires_basis" , self .has_basis )
23952397 bart_json .add_boolean ("probit_outcome_model" , self .probit_outcome_model )
2398+ bart_json .add_string ("rfx_model_spec" , self .rfx_model_spec )
23962399
23972400 # Add parameter samples
23982401 if self .sample_sigma2_global :
@@ -2427,6 +2430,8 @@ def from_json(self, json_string: str) -> None:
24272430 self .include_mean_forest = bart_json .get_boolean ("include_mean_forest" )
24282431 self .include_variance_forest = bart_json .get_boolean ("include_variance_forest" )
24292432 self .has_rfx = bart_json .get_boolean ("has_rfx" )
2433+ self .has_rfx_basis = bart_json .get_boolean ("has_rfx_basis" )
2434+ self .num_rfx_basis = bart_json .get_scalar ("num_rfx_basis" )
24302435 if self .include_mean_forest :
24312436 # TODO: don't just make this a placeholder that we overwrite
24322437 self .forest_container_mean = ForestContainer (0 , 0 , False , False )
@@ -2465,6 +2470,7 @@ def from_json(self, json_string: str) -> None:
24652470 self .num_basis = bart_json .get_integer ("num_basis" )
24662471 self .has_basis = bart_json .get_boolean ("requires_basis" )
24672472 self .probit_outcome_model = bart_json .get_boolean ("probit_outcome_model" )
2473+ self .rfx_model_spec = bart_json .get_string ("rfx_model_spec" )
24682474
24692475 # Unpack parameter samples
24702476 if self .sample_sigma2_global :
@@ -2550,6 +2556,8 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
25502556
25512557 # Unpack random effects
25522558 self .has_rfx = json_object_default .get_boolean ("has_rfx" )
2559+ self .has_rfx_basis = json_object_default .get_boolean ("has_rfx_basis" )
2560+ self .num_rfx_basis = json_object_default .get_scalar ("num_rfx_basis" )
25532561 if self .has_rfx :
25542562 self .rfx_container = RandomEffectsContainer ()
25552563 for i in range (len (json_object_list )):
@@ -2575,6 +2583,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None:
25752583 self .probit_outcome_model = json_object_default .get_boolean (
25762584 "probit_outcome_model"
25772585 )
2586+ self .rfx_model_spec = json_object_default .get_string (
2587+ "rfx_model_spec"
2588+ )
25782589
25792590 # Unpack number of samples
25802591 for i in range (len (json_object_list )):
0 commit comments