@@ -381,14 +381,21 @@ def save(self, fname: str) -> None:
381381 raise RuntimeError ("The model hasn't been fit yet, call .fit() first" )
382382
383383 @classmethod
384- def _convert_dims_to_tuple (cls , model_config : Dict ) -> Dict :
384+ def _model_config_formatting (cls , model_config : Dict ) -> Dict :
385+ """
386+ Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists.
387+ This function converts them back to tuples and numpy arrays to ensure correct id encoding.
388+ """
385389 for key in model_config :
386- if (
387- isinstance (model_config [key ], dict )
388- and "dims" in model_config [key ]
389- and isinstance (model_config [key ]["dims" ], list )
390- ):
391- model_config [key ]["dims" ] = tuple (model_config [key ]["dims" ])
390+ if isinstance (model_config [key ], dict ):
391+ for sub_key in model_config [key ]:
392+ if isinstance (model_config [key ][sub_key ], list ):
393+ # Check if "dims" key to convert it to tuple
394+ if sub_key == "dims" :
395+ model_config [key ][sub_key ] = tuple (model_config [key ][sub_key ])
396+ # Convert all other lists to numpy arrays
397+ else :
398+ model_config [key ][sub_key ] = np .array (model_config [key ][sub_key ])
392399 return model_config
393400
394401 @classmethod
@@ -420,7 +427,7 @@ def load(cls, fname: str):
420427 filepath = Path (str (fname ))
421428 idata = az .from_netcdf (filepath )
422429 # needs to be converted, because json.loads was changing tuple to list
423- model_config = cls ._convert_dims_to_tuple (json .loads (idata .attrs ["model_config" ]))
430+ model_config = cls ._model_config_formatting (json .loads (idata .attrs ["model_config" ]))
424431 model = cls (
425432 model_config = model_config ,
426433 sampler_config = json .loads (idata .attrs ["sampler_config" ]),
0 commit comments