@@ -50,7 +50,6 @@ class ModelBuilder:
5050
5151 def __init__ (
5252 self ,
53- data : Union [np .ndarray , pd .DataFrame , pd .Series ] = None ,
5453 model_config : Dict = None ,
5554 sampler_config : Dict = None ,
5655 ):
@@ -77,10 +76,8 @@ def __init__(
7776
7877 self .model_config = model_config # parameters for priors etc.
7978 self .model = None # Set by build_model
80- self .output_var = "" # Set by build_model
8179 self .idata : Optional [az .InferenceData ] = None # idata is generated during fitting
8280 self .is_fitted_ = False
83- self .data = data
8481
8582 def _validate_data (self , X , y = None ):
8683 if y is not None :
@@ -122,6 +119,19 @@ def _data_setter(
122119
123120 raise NotImplementedError
124121
122+ @property
123+ @abstractmethod
124+ def output_var (self ):
125+ """
126+ Returns the name of the output variable of the model.
127+
128+ Returns
129+ -------
130+ output_var : str
131+ Name of the output variable of the model.
132+ """
133+ raise NotImplementedError
134+
125135 @property
126136 @abstractmethod
127137 def default_model_config (self ) -> Dict :
@@ -176,39 +186,41 @@ def default_sampler_config(self) -> Dict:
176186 raise NotImplementedError
177187
178188 @abstractmethod
179- def generate_model_data (
180- self , data : Union [np . ndarray , pd .DataFrame , pd .Series ] = None
181- ) -> pd . DataFrame :
189+ def generate_and_preprocess_model_data (
190+ self , X : Union [pd .DataFrame , pd .Series ], y : pd . Series
191+ ) -> None :
182192 """
183- Returns a default dataset for a class, can be used as a hint to data formatting required for the class
184- If data is not None, dataset will be created from it's content.
193+ Applies preprocessing to the data before fitting the model.
194+ if validate is True, it will check if the data is valid for the model.
195+ sets self.model_coords based on provided dataset
185196
186197 Parameters:
187- data : Union[np.ndarray, pd.DataFrame, pd.Series], optional
188- dataset that will replace the default sample data
189-
198+ X : array, shape (n_obs, n_features)
199+ y : array, shape (n_obs,)
190200
191201 Examples
192202 --------
193203 >>> @classmethod
194- >>> def generate_model_data (self):
204+ >>> def generate_and_preprocess_model_data (self, X, y ):
195205 >>> x = np.linspace(start=1, stop=50, num=100)
196206 >>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
197- >>> data = pd.DataFrame({'input': x, 'output': y})
207+ >>> X = pd.DataFrame(x, columns=['x'])
208+ >>> y = pd.Series(y, name='y')
209+ >>> self.X = X
210+ >>> self.y = y
198211
199212 Returns
200213 -------
201- data : pd.DataFrame
202- The data we want to train the model on.
214+ None
203215
204216 """
205217 raise NotImplementedError
206218
207219 @abstractmethod
208220 def build_model (
209221 self ,
210- data : Union [ np . ndarray , pd .DataFrame , pd . Series ] = None ,
211- model_config : Dict = None ,
222+ X : pd .DataFrame ,
223+ y : pd . Series ,
212224 ** kwargs ,
213225 ) -> None :
214226 """
@@ -217,22 +229,31 @@ def build_model(
217229
218230 Parameters
219231 ----------
220- data : dict
221- Preformated data that is going to be used in the model. For efficiency reasons it should contain only the necesary data columns,
222- not entire available dataset since it's going to be encoded into data used to recreate the model.
223- If not provided uses data from self.data
224- model_config : dict
225- Dictionary where keys are strings representing names of parameters of the model, values are dictionaries of parameters
226- needed for creating model parameters. If not provided uses data from self.model_config
232+ X : pd.DataFrame
233+ The input data that is going to be used in the model. This should be a DataFrame
234+ containing the features (predictors) for the model. For efficiency reasons, it should
235+ only contain the necessary data columns, not the entire available dataset, as this
236+ will be encoded into the data used to recreate the model.
237+
238+ y : pd.Series
239+ The target data for the model. This should be a Series representing the output
240+ or dependent variable for the model.
241+
242+ kwargs : dict
243+ Additional keyword arguments that may be used for model configuration.
227244
228245 See Also
229246 --------
230247 default_model_config : returns default model config
231248
232- Returns:
233- ----------
249+ Returns
250+ -------
234251 None
235252
253+ Raises
254+ ------
255+ NotImplementedError
256+ This is an abstract method and must be implemented in a subclass.
236257 """
237258 raise NotImplementedError
238259
@@ -248,7 +269,7 @@ def sample_model(self, **kwargs):
248269 Returns
249270 -------
250271 xarray.Dataset
251- The PyMC3 samples dataset.
272+ The PyMC samples dataset.
252273
253274 Raises
254275 ------
@@ -383,12 +404,14 @@ def load(cls, fname: str):
383404 filepath = Path (str (fname ))
384405 idata = az .from_netcdf (filepath )
385406 model = cls (
386- data = idata .fit_data .to_dataframe (),
387407 model_config = json .loads (idata .attrs ["model_config" ]),
388408 sampler_config = json .loads (idata .attrs ["sampler_config" ]),
389409 )
390410 model .idata = idata
391- model .build_model ()
411+ dataset = idata .fit_data .to_dataframe ()
412+ X = dataset .drop (columns = [model .output_var ])
413+ y = dataset [model .output_var ]
414+ model .build_model (X , y )
392415 # All previously used data is in idata.
393416
394417 if model .id != idata .attrs ["id" ]:
@@ -400,8 +423,8 @@ def load(cls, fname: str):
400423
401424 def fit (
402425 self ,
403- X : Union [ np . ndarray , pd .DataFrame , pd . Series ] ,
404- y : Union [ np . ndarray , pd .Series ] ,
426+ X : pd .DataFrame ,
427+ y : pd .Series ,
405428 progressbar : bool = True ,
406429 predictor_names : List [str ] = None ,
407430 random_seed : RandomState = None ,
@@ -442,25 +465,19 @@ def fit(
442465 if predictor_names is None :
443466 predictor_names = []
444467
445- X , y = X , y
446-
447- self .build_model (data = self .data )
448- self ._data_setter (X , y )
468+ y = pd .DataFrame ({self .output_var : y })
469+ self .generate_and_preprocess_model_data (X , y .values .flatten ())
470+ self .build_model (self .X , self .y )
449471
450472 sampler_config = self .sampler_config .copy ()
451473 sampler_config ["progressbar" ] = progressbar
452474 sampler_config ["random_seed" ] = random_seed
453475 sampler_config .update (** kwargs )
454-
455476 self .idata = self .sample_model (** sampler_config )
456- if type (X ) is np .ndarray :
457- if len (predictor_names ) > 0 :
458- X = pd .DataFrame (X , columns = predictor_names )
459- else :
460- X = pd .DataFrame (X , columns = [f"predictor{ x } " for x in range (1 , X .shape [1 ] + 1 )])
461- if type (y ) is np .ndarray :
462- y = pd .Series (y , name = "target" )
463- combined_data = pd .concat ([X , y ], axis = 1 )
477+
478+ X_df = pd .DataFrame (X , columns = X .columns )
479+ combined_data = pd .concat ([X_df , y ], axis = 1 )
480+ assert all (combined_data .columns ), "All columns must have non-empty names"
464481 self .idata .add_groups (fit_data = combined_data .to_xarray ()) # type: ignore
465482 return self .idata # type: ignore
466483
@@ -513,6 +530,7 @@ def predict(
513530 def sample_prior_predictive (
514531 self ,
515532 X_pred ,
533+ y_pred = None ,
516534 samples : Optional [int ] = None ,
517535 extend_idata : bool = False ,
518536 combined : bool = True ,
@@ -539,13 +557,15 @@ def sample_prior_predictive(
539557 prior_predictive_samples : DataArray, shape (n_pred, samples)
540558 Prior predictive samples for each input X_pred
541559 """
560+ if y_pred is None :
561+ y_pred = np .zeros (len (X_pred ))
542562 if samples is None :
543563 samples = self .sampler_config .get ("draws" , 500 )
544564
545565 if self .model is None :
546- self .build_model ()
566+ self .build_model (X_pred , y_pred )
547567
548- self ._data_setter (X_pred )
568+ self ._data_setter (X_pred , y_pred )
549569 if self .model is not None :
550570 with self .model : # sample with new input data
551571 prior_pred : az .InferenceData = pm .sample_prior_predictive (samples , ** kwargs )
0 commit comments