From 8a3593194de34b3ac8d130671381019f0ba88ced Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Mon, 24 Nov 2025 20:00:49 +0530 Subject: [PATCH 1/2] ENH: Standardise y validation and conversion via _convert_y --- aeon/anomaly_detection/collection/base.py | 13 +++++++++---- aeon/anomaly_detection/series/base.py | 15 ++++++++++++--- aeon/regression/base.py | 15 ++++++++++----- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/aeon/anomaly_detection/collection/base.py b/aeon/anomaly_detection/collection/base.py index f7a046d5bf..4a17a96ee0 100644 --- a/aeon/anomaly_detection/collection/base.py +++ b/aeon/anomaly_detection/collection/base.py @@ -109,7 +109,8 @@ def fit(self, X, y=None): X = self._preprocess_collection(X) if y is not None: - y = self._check_y(y, self.metadata_["n_cases"]) + self._check_y(y, self.metadata_["n_cases"]) + y = self._convert_y(y) self._fit(X, y) @@ -208,7 +209,8 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray: return self._predict(X) if y is not None: - y = self._check_y(y) + self._check_y(y, self.metadata_["n_cases"]) + y = self._convert_y(y) pred = self._fit_predict(X, y) @@ -251,7 +253,10 @@ def _check_y(self, y, n_cases): f"Mismatch in number of cases. Found X = {n_cases} and y = {n_labels}" ) + return y + + def _convert_y(self, y): + """Convert y to correct anomaly detection format after validation.""" if isinstance(y, pd.Series): y = pd.Series.to_numpy(y) - - return y + return y.astype(bool) diff --git a/aeon/anomaly_detection/series/base.py b/aeon/anomaly_detection/series/base.py index 4b68ee2a34..78d85b807b 100644 --- a/aeon/anomaly_detection/series/base.py +++ b/aeon/anomaly_detection/series/base.py @@ -123,7 +123,8 @@ def fit(self, X, y=None, axis=1): X = self._preprocess_series(X, axis, True) if y is not None: - y = self._check_y(y) + self._check_y(y) + y = self._convert_y(y) self._fit(X=X, y=y) @@ -203,7 +204,8 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray: return self._predict(X) if y is not None: - y = self._check_y(y) + self._check_y(y) + y = self._convert_y(y) pred = self._fit_predict(X, y) @@ -284,5 +286,12 @@ def _check_y(self, y: VALID_SERIES_INPUT_TYPES) -> np.ndarray: f"{VALID_SERIES_INPUT_TYPES}, saw {type(y)}" ) - new_y = new_y.astype(bool) return new_y + + def _convert_y(self, y): + """Convert y to correct anomaly detection format after validation.""" + if isinstance(y, pd.Series): + y = y.values + elif isinstance(y, pd.DataFrame): + y = y.squeeze().values + return y.astype(bool) diff --git a/aeon/regression/base.py b/aeon/regression/base.py index 84025decc5..8045a1162e 100644 --- a/aeon/regression/base.py +++ b/aeon/regression/base.py @@ -231,7 +231,8 @@ def score(self, X, y, metric="r2", metric_params=None) -> float: MSE score of predict(X) vs y """ self._check_is_fitted() - y = self._check_y(y, len(X)) + self._check_y(y, len(X)) + y = self._convert_y(y) _metric_params = metric_params if metric_params is None: _metric_params = {} @@ -351,7 +352,8 @@ def _fit_setup(self, X, y): self.reset() X = self._preprocess_collection(X) - y = self._check_y(y, self.metadata_["n_cases"]) + self._check_y(y, self.metadata_["n_cases"]) + y = self._convert_y(y) # return processed X and y return X, y @@ -380,13 +382,16 @@ def _check_y(self, y, n_cases): f"sklearn.utils.multiclass.type_of_target" ) - if isinstance(y, pd.Series): - y = pd.Series.to_numpy(y) - if any([isinstance(label, str) for label in y]): raise ValueError( "y contains strings, cannot fit a regressor. If suitable, convert " "to floats or consider classification." ) + return y + + def _convert_y(self, y): + """Convert y to the correct regression format after validation.""" + if isinstance(y, pd.Series): + y = pd.Series.to_numpy(y) return y.astype(float) From 156869a81306a9999aac9f9a1c3e3f72bbda4860 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Mon, 24 Nov 2025 20:45:25 +0530 Subject: [PATCH 2/2] ENH: update test__check_y to reflect standardised conversion logic --- aeon/regression/tests/test_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aeon/regression/tests/test_base.py b/aeon/regression/tests/test_base.py index d04469db8b..758efbba18 100644 --- a/aeon/regression/tests/test_base.py +++ b/aeon/regression/tests/test_base.py @@ -98,7 +98,8 @@ def test__check_y(): reg._check_y(y, 100) assert isinstance(y, np.ndarray) y = pd.Series(y) - y = reg._check_y(y, 100) + reg._check_y(y, 100) + y = reg._convert_y(y) assert isinstance(y, np.ndarray) # Test error raising # y wrong length