From 1a8c44ae7da7ffbc5512f410c538540f5ae50c20 Mon Sep 17 00:00:00 2001 From: Anushree Bondia Date: Tue, 13 Jan 2026 21:31:04 +0530 Subject: [PATCH 1/7] fix: added save and read functionality in SSD issue #13328 --- mne/decoding/__init__.pyi | 2 +- mne/decoding/ssd.py | 115 +++++++++++++++++++++++++++++++++ mne/decoding/tests/test_ssd.py | 41 ++++++++++++ 3 files changed, 157 insertions(+), 1 deletion(-) diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 1131f1597c5..e242ee17708 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -36,7 +36,7 @@ from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator -from .ssd import SSD +from .ssd import SSD, read_ssd from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency from .transformer import ( diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 41d67ece8c6..1353c071622 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -18,6 +18,8 @@ from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer +from mne.utils import _check_fname, check_version +from mne import __version__ as mne_version @fill_doc @@ -236,6 +238,66 @@ def fit(self, X, y=None): logger.info("Done.") return self + + def save(self, fname, overwrite=False): + """Save the SSD object to disk. + + Parameters + ---------- + fname : path-like + Output filename. Must end with ``.h5``. + overwrite : bool + If True, overwrite the file. + """ + from ..utils.check import _import_h5io_funcs + + _validate_type(fname, "path-like", "fname") + _check_fname(fname, overwrite=overwrite) + check_version("h5py") + + if not hasattr(self, "filters_"): + raise RuntimeError( + "Cannot save an unfitted SSD object. " + "Call `fit` before saving." + ) + + state = dict( + class_name="SSD", + mne_version=mne_version, + + # init params + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + reg=self.reg, + n_components=self.n_components, + picks=self.picks, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, + return_filtered=self.return_filtered, + n_fft=self.n_fft, + cov_method_params=self.cov_method_params, + restr_type=self.restr_type, + rank=self.rank, + + # fitted attributes + filters=self.filters_, + patterns=self.patterns_, + eigenvalues=self.evals_, + picks_=self.picks_, + freqs_signal_=self.freqs_signal_, + freqs_noise_=self.freqs_noise_, + n_fft_=self.n_fft_, + sfreq_=self.sfreq_, + info=self.info, + ) + + _, write_hdf5 = _import_h5io_funcs() + write_hdf5( + fname, + state, + title="mne-python SSD", + overwrite=overwrite, + ) + def transform(self, X): """Estimate epochs sources given the SSD filters. @@ -350,3 +412,56 @@ def apply(self, X): pick_patterns = self.patterns_[: self.n_components].T X = pick_patterns @ X_ssd return X + +def read_ssd(fname): + """Read an SSD object from disk. + + Parameters + ---------- + fname : path-like + Path to ``.h5`` file. + + Returns + ------- + ssd : SSD + The loaded SSD object. + """ + from ..utils.check import _import_h5io_funcs + + _validate_type(fname, "path-like", "fname") + check_version("h5py") + + read_hdf5, _ = _import_h5io_funcs() + state = read_hdf5(fname, title="mne-python SSD") + + if state.get("class_name") != "SSD": + raise RuntimeError( + "The file does not contain a valid SSD object." + ) + + ssd = SSD( + info=state["info"], + filt_params_signal=state["filt_params_signal"], + filt_params_noise=state["filt_params_noise"], + reg=state["reg"], + n_components=state["n_components"], + picks=state["picks"], + sort_by_spectral_ratio=state["sort_by_spectral_ratio"], + return_filtered=state["return_filtered"], + n_fft=state["n_fft"], + cov_method_params=state["cov_method_params"], + restr_type=state["restr_type"], + rank=state["rank"], + ) + + # restore fitted state + ssd.filters_ = state["filters"] + ssd.patterns_ = state["patterns"] + ssd.eigenvalues_ = state["eigenvalues"] + ssd.picks_ = state["picks_"] + ssd.freqs_signal_ = state["freqs_signal_"] + ssd.freqs_noise_ = state["freqs_noise_"] + ssd.n_fft_ = state["n_fft_"] + ssd.sfreq_ = state["sfreq_"] + + return ssd diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 236e65b82fd..e5df51aef07 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -21,6 +21,7 @@ from mne.decoding.ssd import SSD from mne.filter import filter_data from mne.time_frequency import psd_array_welch +from mne.decoding import read_ssd freqs_sig = 9, 12 freqs_noise = 8, 13 @@ -360,6 +361,46 @@ def test_ssd_pipeline(): assert out.shape == (100, 2) assert pipe.get_params()["SSD__n_components"] == 5 +def test_ssd_save_load(tmp_path): + """Test saving and loading of SSD.""" + X, _, _ = simulate_data() + sf = 250 + n_channels = X.shape[0] + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=5, + sort_by_spectral_ratio=True, + ) + ssd.fit(X) + + fname = tmp_path / "ssd.h5" + ssd.save(fname) + + ssd_loaded = read_ssd(fname) + + # Check numerical equivalence + X_orig = ssd.transform(X) + X_loaded = ssd_loaded.transform(X) + + assert_array_almost_equal(X_orig, X_loaded) + def test_sorting(): """Test sorting learning during training.""" From 00a4e662d76edbc6e1681920b434b1864967a520 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:03:45 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/decoding/__init__.pyi | 2 +- mne/decoding/ssd.py | 18 +++++++----------- mne/decoding/tests/test_ssd.py | 4 ++-- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index e242ee17708..1131f1597c5 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -36,7 +36,7 @@ from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator -from .ssd import SSD, read_ssd +from .ssd import SSD from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency from .transformer import ( diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 1353c071622..af0ee5643fe 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -7,6 +7,9 @@ import numpy as np +from mne import __version__ as mne_version +from mne.utils import _check_fname, check_version + from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..filter import filter_data @@ -18,8 +21,6 @@ from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer -from mne.utils import _check_fname, check_version -from mne import __version__ as mne_version @fill_doc @@ -238,7 +239,7 @@ def fit(self, X, y=None): logger.info("Done.") return self - + def save(self, fname, overwrite=False): """Save the SSD object to disk. @@ -257,14 +258,12 @@ def save(self, fname, overwrite=False): if not hasattr(self, "filters_"): raise RuntimeError( - "Cannot save an unfitted SSD object. " - "Call `fit` before saving." + "Cannot save an unfitted SSD object. Call `fit` before saving." ) state = dict( class_name="SSD", mne_version=mne_version, - # init params filt_params_signal=self.filt_params_signal, filt_params_noise=self.filt_params_noise, @@ -277,7 +276,6 @@ def save(self, fname, overwrite=False): cov_method_params=self.cov_method_params, restr_type=self.restr_type, rank=self.rank, - # fitted attributes filters=self.filters_, patterns=self.patterns_, @@ -298,7 +296,6 @@ def save(self, fname, overwrite=False): overwrite=overwrite, ) - def transform(self, X): """Estimate epochs sources given the SSD filters. @@ -413,6 +410,7 @@ def apply(self, X): X = pick_patterns @ X_ssd return X + def read_ssd(fname): """Read an SSD object from disk. @@ -435,9 +433,7 @@ def read_ssd(fname): state = read_hdf5(fname, title="mne-python SSD") if state.get("class_name") != "SSD": - raise RuntimeError( - "The file does not contain a valid SSD object." - ) + raise RuntimeError("The file does not contain a valid SSD object.") ssd = SSD( info=state["info"], diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index e5df51aef07..ea577544fa2 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -16,12 +16,11 @@ from mne import Epochs, create_info, io, pick_types, read_events from mne._fiff.pick import _picks_to_idx -from mne.decoding import CSP +from mne.decoding import CSP, read_ssd from mne.decoding._mod_ged import _get_spectral_ratio from mne.decoding.ssd import SSD from mne.filter import filter_data from mne.time_frequency import psd_array_welch -from mne.decoding import read_ssd freqs_sig = 9, 12 freqs_noise = 8, 13 @@ -361,6 +360,7 @@ def test_ssd_pipeline(): assert out.shape == (100, 2) assert pipe.get_params()["SSD__n_components"] == 5 + def test_ssd_save_load(tmp_path): """Test saving and loading of SSD.""" X, _, _ = simulate_data() From c0e20ef69ce58ebc6f2e0da2848e08d4f5bd42e4 Mon Sep 17 00:00:00 2001 From: Anushree Bondia Date: Tue, 13 Jan 2026 21:45:22 +0530 Subject: [PATCH 3/7] fix: evals updated --- mne/decoding/ssd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 1353c071622..b3a48e5cbfc 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -457,7 +457,7 @@ def read_ssd(fname): # restore fitted state ssd.filters_ = state["filters"] ssd.patterns_ = state["patterns"] - ssd.eigenvalues_ = state["eigenvalues"] + ssd.evals_ = state["eigenvalues"] ssd.picks_ = state["picks_"] ssd.freqs_signal_ = state["freqs_signal_"] ssd.freqs_noise_ = state["freqs_noise_"] From 615ed8e122f9c4d113d503a098f28950976bf11c Mon Sep 17 00:00:00 2001 From: Anushree Bondia Date: Wed, 14 Jan 2026 09:59:28 +0530 Subject: [PATCH 4/7] fix:added changes --- mne/decoding/base.py | 25 +++++++++ mne/decoding/ssd.py | 126 ++++++++++++++++++++++--------------------- 2 files changed, 89 insertions(+), 62 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 3a51a04bed7..2e8c775be2f 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -135,6 +135,31 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._is_base_ged = False + def __getstate__(self): + """Get state for serialization. + + This explicitly drops callables and other runtime-only attributes. + Subclasses can extend this to add estimator-specific state. + """ + state = self.__dict__.copy() + + # Callables are not serializable and must be reconstructed + state.pop("cov_callable", None) + state.pop("mod_ged_callable", None) + + return state + + def __setstate__(self, state): + """Restore state from serialization. + + Subclasses are responsible for reconstructing dropped callables. + """ + self.__dict__.update(state) + + # Ensure attributes exist even before reconstruction + self.cov_callable = None + self.mod_ged_callable = None + def fit(self, X, y=None): """...""" # Let the inheriting transformers check data by themselves diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 96e571aee28..8b2561321c0 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -148,6 +148,63 @@ def __init__( mod_ged_callable=_ssd_mod, restr_type=restr_type, ) + + def __getstate__(self): + """Get state for serialization.""" + state = super().__getstate__() + + # init parameters + state.update( + info=self.info, + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + reg=self.reg, + n_components=self.n_components, + picks=self.picks, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, + return_filtered=self.return_filtered, + n_fft=self.n_fft, + cov_method_params=self.cov_method_params, + restr_type=self.restr_type, + rank=self.rank, + ) + + # fitted attributes (only if present) + for attr in ( + "filters_", + "patterns_", + "evals_", + "picks_", + "freqs_signal_", + "freqs_noise_", + "n_fft_", + "sfreq_", + ): + if hasattr(self, attr): + state[attr] = getattr(self, attr) + + return state + + def __setstate__(self, state): + """Restore state from serialization.""" + super().__setstate__(state) + + # Restore attributes + self.__dict__.update(state) + + # Rebuild covariance callable exactly as in __init__ + self.cov_callable = partial( + _ssd_estimate, + reg=self.reg, + cov_method_params=self.cov_method_params, + info=self.info, + picks=self.picks, + n_fft=self.n_fft, + filt_params_signal=self.filt_params_signal, + filt_params_noise=self.filt_params_noise, + rank=self.rank, + sort_by_spectral_ratio=self.sort_by_spectral_ratio, + ) def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing @@ -241,60 +298,12 @@ def fit(self, X, y=None): return self def save(self, fname, overwrite=False): - """Save the SSD object to disk. - - Parameters - ---------- - fname : path-like - Output filename. Must end with ``.h5``. - overwrite : bool - If True, overwrite the file. - """ - from ..utils.check import _import_h5io_funcs - - _validate_type(fname, "path-like", "fname") - _check_fname(fname, overwrite=overwrite) - check_version("h5py") - - if not hasattr(self, "filters_"): - raise RuntimeError( - "Cannot save an unfitted SSD object. Call `fit` before saving." - ) - - state = dict( + state = self.__getstate__() + state.update( class_name="SSD", mne_version=mne_version, - # init params - filt_params_signal=self.filt_params_signal, - filt_params_noise=self.filt_params_noise, - reg=self.reg, - n_components=self.n_components, - picks=self.picks, - sort_by_spectral_ratio=self.sort_by_spectral_ratio, - return_filtered=self.return_filtered, - n_fft=self.n_fft, - cov_method_params=self.cov_method_params, - restr_type=self.restr_type, - rank=self.rank, - # fitted attributes - filters=self.filters_, - patterns=self.patterns_, - eigenvalues=self.evals_, - picks_=self.picks_, - freqs_signal_=self.freqs_signal_, - freqs_noise_=self.freqs_noise_, - n_fft_=self.n_fft_, - sfreq_=self.sfreq_, - info=self.info, ) - _, write_hdf5 = _import_h5io_funcs() - write_hdf5( - fname, - state, - title="mne-python SSD", - overwrite=overwrite, - ) def transform(self, X): """Estimate epochs sources given the SSD filters. @@ -435,7 +444,7 @@ def read_ssd(fname): if state.get("class_name") != "SSD": raise RuntimeError("The file does not contain a valid SSD object.") - ssd = SSD( + ssd = SSD( info=state["info"], filt_params_signal=state["filt_params_signal"], filt_params_noise=state["filt_params_noise"], @@ -450,14 +459,7 @@ def read_ssd(fname): rank=state["rank"], ) - # restore fitted state - ssd.filters_ = state["filters"] - ssd.patterns_ = state["patterns"] - ssd.evals_ = state["eigenvalues"] - ssd.picks_ = state["picks_"] - ssd.freqs_signal_ = state["freqs_signal_"] - ssd.freqs_noise_ = state["freqs_noise_"] - ssd.n_fft_ = state["n_fft_"] - ssd.sfreq_ = state["sfreq_"] - - return ssd + # restore full state (fitted attributes + callables) + ssd.__setstate__(state) + + return ssd \ No newline at end of file From 71db9928c626196d91126e150ab5209709e28b0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 04:29:58 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/decoding/ssd.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 8b2561321c0..b1b0f433b07 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -8,7 +8,7 @@ import numpy as np from mne import __version__ as mne_version -from mne.utils import _check_fname, check_version +from mne.utils import check_version from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx @@ -148,7 +148,7 @@ def __init__( mod_ged_callable=_ssd_mod, restr_type=restr_type, ) - + def __getstate__(self): """Get state for serialization.""" state = super().__getstate__() @@ -184,7 +184,7 @@ def __getstate__(self): state[attr] = getattr(self, attr) return state - + def __setstate__(self, state): """Restore state from serialization.""" super().__setstate__(state) @@ -304,7 +304,6 @@ def save(self, fname, overwrite=False): mne_version=mne_version, ) - def transform(self, X): """Estimate epochs sources given the SSD filters. @@ -445,21 +444,21 @@ def read_ssd(fname): raise RuntimeError("The file does not contain a valid SSD object.") ssd = SSD( - info=state["info"], - filt_params_signal=state["filt_params_signal"], - filt_params_noise=state["filt_params_noise"], - reg=state["reg"], - n_components=state["n_components"], - picks=state["picks"], - sort_by_spectral_ratio=state["sort_by_spectral_ratio"], - return_filtered=state["return_filtered"], - n_fft=state["n_fft"], - cov_method_params=state["cov_method_params"], - restr_type=state["restr_type"], - rank=state["rank"], - ) + info=state["info"], + filt_params_signal=state["filt_params_signal"], + filt_params_noise=state["filt_params_noise"], + reg=state["reg"], + n_components=state["n_components"], + picks=state["picks"], + sort_by_spectral_ratio=state["sort_by_spectral_ratio"], + return_filtered=state["return_filtered"], + n_fft=state["n_fft"], + cov_method_params=state["cov_method_params"], + restr_type=state["restr_type"], + rank=state["rank"], + ) # restore full state (fitted attributes + callables) ssd.__setstate__(state) - return ssd \ No newline at end of file + return ssd From 38039097fa0b6e24e38489f2855cb8b7b9c062e5 Mon Sep 17 00:00:00 2001 From: Anushree Bondia Date: Sun, 18 Jan 2026 12:17:04 +0530 Subject: [PATCH 6/7] fix:added changes --- mne/decoding/__init__.pyi | 3 +- mne/decoding/ssd.py | 120 +++++++++++--------------------------- 2 files changed, 37 insertions(+), 86 deletions(-) diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 1131f1597c5..f08e6f35e9a 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -23,6 +23,7 @@ __all__ = [ "cross_val_multiscore", "get_coef", "get_spatial_filter_from_estimator", + "read_ssd", ] from .base import ( BaseEstimator, @@ -36,7 +37,7 @@ from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator -from .ssd import SSD +from .ssd import SSD, read_ssd from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency from .transformer import ( diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index b1b0f433b07..d29ad19cc63 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -21,7 +21,7 @@ from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer - +from ..utils.check import _check_fname, _import_h5io_funcs, check_fname @fill_doc class SSD(_GEDTransformer): @@ -149,62 +149,13 @@ def __init__( restr_type=restr_type, ) - def __getstate__(self): - """Get state for serialization.""" - state = super().__getstate__() - - # init parameters - state.update( - info=self.info, - filt_params_signal=self.filt_params_signal, - filt_params_noise=self.filt_params_noise, - reg=self.reg, - n_components=self.n_components, - picks=self.picks, - sort_by_spectral_ratio=self.sort_by_spectral_ratio, - return_filtered=self.return_filtered, - n_fft=self.n_fft, - cov_method_params=self.cov_method_params, - restr_type=self.restr_type, - rank=self.rank, - ) - - # fitted attributes (only if present) - for attr in ( - "filters_", - "patterns_", - "evals_", - "picks_", - "freqs_signal_", - "freqs_noise_", - "n_fft_", - "sfreq_", - ): - if hasattr(self, attr): - state[attr] = getattr(self, attr) - - return state - def __setstate__(self, state): """Restore state from serialization.""" - super().__setstate__(state) - - # Restore attributes + # Since read_ssd creates a new instance via __init__ first, + # callables are already set correctly. We just restore fitted attributes. + # Don't call super().__setstate__() as it would set callables to None. self.__dict__.update(state) - - # Rebuild covariance callable exactly as in __init__ - self.cov_callable = partial( - _ssd_estimate, - reg=self.reg, - cov_method_params=self.cov_method_params, - info=self.info, - picks=self.picks, - n_fft=self.n_fft, - filt_params_signal=self.filt_params_signal, - filt_params_noise=self.filt_params_noise, - rank=self.rank, - sort_by_spectral_ratio=self.sort_by_spectral_ratio, - ) + return self def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing @@ -297,12 +248,33 @@ def fit(self, X, y=None): logger.info("Done.") return self - def save(self, fname, overwrite=False): + @fill_doc + def save(self, fname, *, overwrite=False, verbose=None): + """Save the SSD decomposition to disk (in HDF5 format). + + Parameters + ---------- + fname : path-like + Path of file to save to. Should end with ``'.h5'`` or ``'.hdf5'``. + %(overwrite)s + %(verbose)s + + See Also + -------- + mne.decoding.read_ssd + """ + + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "SSD", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) state = self.__getstate__() state.update( class_name="SSD", mne_version=mne_version, ) + write_hdf5( + fname, state, overwrite=overwrite, title="mnepython", slash="replace" + ) def transform(self, X): """Estimate epochs sources given the SSD filters. @@ -425,40 +397,18 @@ def read_ssd(fname): Parameters ---------- fname : path-like - Path to ``.h5`` file. + Path to an SSD file in HDF5 format, which should end with ``.h5`` or + ``.hdf5``. Returns ------- ssd : SSD The loaded SSD object. """ - from ..utils.check import _import_h5io_funcs - - _validate_type(fname, "path-like", "fname") - check_version("h5py") - read_hdf5, _ = _import_h5io_funcs() - state = read_hdf5(fname, title="mne-python SSD") - - if state.get("class_name") != "SSD": - raise RuntimeError("The file does not contain a valid SSD object.") - - ssd = SSD( - info=state["info"], - filt_params_signal=state["filt_params_signal"], - filt_params_noise=state["filt_params_noise"], - reg=state["reg"], - n_components=state["n_components"], - picks=state["picks"], - sort_by_spectral_ratio=state["sort_by_spectral_ratio"], - return_filtered=state["return_filtered"], - n_fft=state["n_fft"], - cov_method_params=state["cov_method_params"], - restr_type=state["restr_type"], - rank=state["rank"], - ) - - # restore full state (fitted attributes + callables) - ssd.__setstate__(state) - - return ssd + _validate_type(fname, "path-like", "fname") + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + state = read_hdf5(fname, title="mnepython", slash="replace") + return SSD( + info=None, filt_params_signal=None, filt_params_noise=None + ).__setstate__(state) From 1bc0fd978da4133e4f1eefc6c848aed7c2c2df84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 13:53:04 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/decoding/ssd.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index d29ad19cc63..0606aa468e8 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -8,7 +8,6 @@ import numpy as np from mne import __version__ as mne_version -from mne.utils import check_version from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx @@ -18,10 +17,11 @@ fill_doc, logger, ) +from ..utils.check import _check_fname, _import_h5io_funcs, check_fname from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer -from ..utils.check import _check_fname, _import_h5io_funcs, check_fname + @fill_doc class SSD(_GEDTransformer): @@ -263,7 +263,6 @@ def save(self, fname, *, overwrite=False, verbose=None): -------- mne.decoding.read_ssd """ - _, write_hdf5 = _import_h5io_funcs() check_fname(fname, "SSD", (".h5", ".hdf5")) fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) @@ -409,6 +408,6 @@ def read_ssd(fname): _validate_type(fname, "path-like", "fname") fname = _check_fname(fname=fname, overwrite="read", must_exist=False) state = read_hdf5(fname, title="mnepython", slash="replace") - return SSD( - info=None, filt_params_signal=None, filt_params_noise=None - ).__setstate__(state) + return SSD(info=None, filt_params_signal=None, filt_params_noise=None).__setstate__( + state + )