diff --git a/ezyrb/database.py b/ezyrb/database.py index 53bc15b9..7f31ef52 100644 --- a/ezyrb/database.py +++ b/ezyrb/database.py @@ -15,10 +15,6 @@ class Database: :param array_like parameters: the input parameters :param array_like snapshots: the input snapshots - :param Scale scaler_parameters: the scaler for the parameters. Default - is None meaning no scaling. - :param Scale scaler_snapshots: the scaler for the snapshots. Default is - None meaning no scaling. :param array_like space: the input spatial data :Example: @@ -46,6 +42,7 @@ def __init__(self, parameters=None, snapshots=None, space=None): ) self._pairs = [] + if parameters is None and snapshots is None: logger.debug("Empty database created") return @@ -149,11 +146,11 @@ def add(self, parameter, snapshot): """ if not isinstance(parameter, Parameter): logger.error("Invalid parameter type: %s", type(parameter)) - raise ValueError + raise TypeError(f"Expected a Parameter object, got {type(parameter)}") if not isinstance(snapshot, Snapshot): logger.error("Invalid snapshot type: %s", type(snapshot)) - raise ValueError + raise TypeError(f"Expected a Snapshot object, got {type(snapshot)}") self._pairs.append((parameter, snapshot)) logger.debug( @@ -161,7 +158,7 @@ def add(self, parameter, snapshot): ) return self - + def split(self, chunks, seed=None): """ @@ -209,7 +206,7 @@ def split(self, chunks, seed=None): else: logger.error("Invalid chunk type") - ValueError + raise TypeError(f"Invalid chunk type. Expected a list of integers or floats, but got {type(chunks)}.") new_database = [Database() for _ in range(len(chunks))] for i, chunk in enumerate(chunks): @@ -235,4 +232,4 @@ def get_snapshot_space(self, index): """ if index < 0 or index >= len(self._pairs): raise IndexError("Snapshot index out of range.") - return self._pairs[index][1].space + return self._pairs[index][1].space \ No newline at end of file diff --git a/ezyrb/plugin/scaler.py b/ezyrb/plugin/scaler.py index 588f54e7..7b6f40a5 100644 --- a/ezyrb/plugin/scaler.py +++ b/ezyrb/plugin/scaler.py @@ -10,13 +10,12 @@ class DatabaseScaler(Plugin): """ The plugin to rescale the database of the reduced order model. It uses a user defined `scaler`, which has to have implemented the `fit`, `transform` - and `inverse_trasform` methods (i.e. `sklearn` interface), to rescale + and `inverse_transform` methods (i.e. `sklearn` interface), to rescale the parameters and/or the snapshots. It can be applied at the full order - (`mode='full'`), at the reduced one (`mode='reduced'`) or both of them - (`mode='both'`). + (`mode='full'`) or at the reduced one (`mode='reduced'`). :param obj scaler: a generic object which has to have implemented the - `fit`, `transform` and `inverse_trasform` methods (i.e. `sklearn` + `fit`, `transform` and `inverse_transform` methods (i.e. `sklearn` interface). :param {'full', 'reduced'} mode: define if the rescaling has to be applied at the full order ('full') or at the reduced one ('reduced'). @@ -62,11 +61,14 @@ def target(self): rtype: str """ return self._target + @target.setter def target(self, new_target): if new_target not in ["snapshots", "parameters"]: - raise ValueError + error_msg = f"Invalid target: '{new_target}' must be 'snapshots' or 'parameters'." + logger.error(error_msg) + raise ValueError(error_msg) self._target = new_target @@ -82,10 +84,13 @@ def mode(self): @mode.setter def mode(self, new_mode): if new_mode not in ["full", "reduced"]: - raise ValueError + error_msg = f"Invalid mode: '{new_mode}' must be 'full' or 'reduced'." + logger.error(error_msg) + raise ValueError(error_msg) self._mode = new_mode + def _select_matrix(self, db): """ Helper function to select the proper matrix to rescale.