diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 168494caf7..66b2e0ef10 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -95,7 +95,7 @@ get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, - ChunkRecordingExecutor, + ChunkExecutor, split_job_kwargs, fix_job_kwargs, ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f23b524271..068d2a047c 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -5,14 +5,13 @@ import numpy as np from probeinterface import read_probeinterface, write_probeinterface -from .base import BaseSegment +from .chunkable import ChunkableSegment, ChunkableMixin from .baserecordingsnippets import BaseRecordingSnippets from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs -from .recording_tools import write_binary_recording -class BaseRecording(BaseRecordingSnippets): +class BaseRecording(BaseRecordingSnippets, ChunkableMixin): """ Abstract class representing several a multichannel timeseries (or block of raw ephys traces). Internally handle list of RecordingSegment @@ -183,110 +182,40 @@ def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> No """ super().add_segment(recording_segment) - def get_num_samples(self, segment_index: int | None = None) -> int: + def get_sample_size_in_bytes(self): """ - Returns the number of samples for a segment. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index to retrieve the number of samples for. - For multi-segment objects, it is required, default: None - With single segment recording returns the number of samples in the segment + Returns the size of a single sample across all channels in bytes. Returns ------- int - The number of samples - """ - segment_index = self._check_segment_index(segment_index) - return int(self.segments[segment_index].get_num_samples()) - - get_num_frames = get_num_samples - - def get_total_samples(self) -> int: - """ - Returns the sum of the number of samples in each segment. - - Returns - ------- - int - The total number of samples - """ - num_segments = self.get_num_segments() - samples_per_segment = (self.get_num_samples(segment_index) for segment_index in range(num_segments)) - - return sum(samples_per_segment) - - def get_duration(self, segment_index=None) -> float: + The size of a single sample in bytes """ - Returns the duration in seconds. - - Parameters - ---------- - segment_index : int or None, default: None - The sample index to retrieve the duration for. - For multi-segment objects, it is required, default: None - With single segment recording returns the duration of the single segment - - Returns - ------- - float - The duration in seconds - """ - segment_duration = ( - self.get_end_time(segment_index) - self.get_start_time(segment_index) + (1 / self.get_sampling_frequency()) - ) - return segment_duration - - def get_total_duration(self) -> float: - """ - Returns the total duration in seconds - - Returns - ------- - float - The duration in seconds - """ - duration = sum([self.get_duration(segment_index) for segment_index in range(self.get_num_segments())]) - return duration + num_channels = self.get_num_channels() + dtype_size_bytes = self.get_dtype().itemsize + sample_size = num_channels * dtype_size_bytes + return sample_size - def get_memory_size(self, segment_index=None) -> int: + def get_num_samples(self, segment_index: int | None = None) -> int: """ - Returns the memory size of segment_index in bytes. + Returns the number of samples for a segment. Parameters ---------- segment_index : int or None, default: None - The index of the segment for which the memory size should be calculated. + The segment index to retrieve the number of samples for. For multi-segment objects, it is required, default: None - With single segment recording returns the memory size of the single segment + With single segment recording returns the number of samples in the segment Returns ------- int - The memory size of the specified segment in bytes. + The number of samples """ segment_index = self._check_segment_index(segment_index) - num_samples = self.get_num_samples(segment_index=segment_index) - num_channels = self.get_num_channels() - dtype_size_bytes = self.get_dtype().itemsize - - memory_bytes = num_samples * num_channels * dtype_size_bytes - - return memory_bytes - - def get_total_memory_size(self) -> int: - """ - Returns the sum in bytes of all the memory sizes of the segments. + return int(self.segments[segment_index].get_num_samples()) - Returns - ------- - int - The total memory size in bytes for all segments. - """ - memory_per_segment = (self.get_memory_size(segment_index) for segment_index in range(self.get_num_segments())) - return sum(memory_per_segment) + get_num_frames = get_num_samples def get_traces( self, @@ -369,228 +298,33 @@ def get_traces( traces = traces.astype("float32", copy=False) * gains + offsets return traces - def get_time_info(self, segment_index=None) -> dict: - """ - Retrieves the timing attributes for a given segment index. As with - other recorders this method only needs a segment index in the case - of multi-segment recordings. - - Returns - ------- - dict - A dictionary containing the following key-value pairs: - - - "sampling_frequency" : The sampling frequency of the RecordingSegment. - - "t_start" : The start time of the RecordingSegment. - - "time_vector" : The time vector of the RecordingSegment. - - Notes - ----- - The keys are always present, but the values may be None. - """ - - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - time_kwargs = rs.get_times_kwargs() - - return time_kwargs - - def get_times(self, segment_index=None) -> np.ndarray: - """Get time vector for a recording segment. - - If the segment has a time_vector, then it is returned. Otherwise - a time_vector is constructed on the fly with sampling frequency. - If t_start is defined and the time vector is constructed on the fly, - the first time will be t_start. Otherwise it will start from 0. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - np.array - The 1d times array - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - times = rs.get_times() - return times - - def get_start_time(self, segment_index=None) -> float: - """Get the start time of the recording segment. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - float - The start time in seconds - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.get_start_time() - - def get_end_time(self, segment_index=None) -> float: - """Get the stop time of the recording segment. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - float - The stop time in seconds - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.get_end_time() - - def has_time_vector(self, segment_index: int | None = None): - """Check if the segment of the recording has a time vector. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - bool - True if the recording has time vectors, False otherwise - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - d = rs.get_times_kwargs() - return d["time_vector"] is not None - - def set_times(self, times, segment_index=None, with_warning=True): - """Set times for a recording segment. - - Parameters - ---------- - times : 1d np.array - The time vector - segment_index : int or None, default: None - The segment index (required for multi-segment) - with_warning : bool, default: True - If True, a warning is printed - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - - assert times.ndim == 1, "Time must have ndim=1" - assert rs.get_num_samples() == times.shape[0], "times have wrong shape" - - rs.t_start = None - rs.time_vector = times.astype("float64", copy=False) - - if with_warning: - warnings.warn( - "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated across preprocessing" - "Use this carefully!" - ) - - def reset_times(self): - """ - Reset time information in-memory for all segments that have a time vector. - If the timestamps come from a file, the files won't be modified. but only the in-memory - attributes of the recording objects are deleted. Also `t_start` is set to None and the - segment's sampling frequency is set to the recording's sampling frequency. + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: """ - for segment_index in range(self.get_num_segments()): - rs = self.segments[segment_index] - if self.has_time_vector(segment_index): - rs.time_vector = None - rs.t_start = None - rs.sampling_frequency = self.sampling_frequency - - def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + General retrieval function for chunkable objects """ - Shift all times by a scalar value. + return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs) - Parameters - ---------- - shift : int | float - The shift to apply. If positive, times will be increased by `shift`. - e.g. shifting by 1 will be like the recording started 1 second later. - If negative, the start time will be decreased i.e. as if the recording - started earlier. - - segment_index : int | None - The segment on which to shift the times. - If `None`, all segments will be shifted. - """ - if segment_index is None: - segments_to_shift = range(self.get_num_segments()) - else: - segments_to_shift = (segment_index,) - - for segment_index in segments_to_shift: - rs = self.segments[segment_index] - - if self.has_time_vector(segment_index=segment_index): - rs.time_vector += shift - else: - new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift - rs.t_start = new_start_time - - def sample_index_to_time(self, sample_ind, segment_index=None): - """ - Transform sample index into time in seconds - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.sample_index_to_time(sample_ind) - - def time_to_sample_index(self, time_s, segment_index=None): - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.time_to_sample_index(time_s) - - def _get_t_starts(self): - # handle t_starts - t_starts = [] - has_time_vectors = [] - for rs in self.segments: - d = rs.get_times_kwargs() - t_starts.append(d["t_start"]) - - if all(t_start is None for t_start in t_starts): - t_starts = None - return t_starts - - def _get_time_vectors(self): - time_vectors = [] - for rs in self.segments: - d = rs.get_times_kwargs() - time_vectors.append(d["time_vector"]) - if all(time_vector is None for time_vector in time_vectors): - time_vectors = None - return time_vectors + def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]: + return (self.get_num_samples(segment_index=segment_index), self.get_num_channels()) def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": + from .chunkable_tools import write_binary + folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() t_starts = self._get_t_starts() - write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) + write_binary(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) from .binaryrecordingextractor import BinaryRecordingExtractor # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading # See the __init__ of `BinaryFolderRecording` + binary_rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=self.get_sampling_frequency(), @@ -610,6 +344,13 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) + # timestamps are not saved in binary, so we have to set them explicitly + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + # the use of get_times is preferred since timestamps are converted to array + time_vector = self.get_times(segment_index=segment_index) + cached.set_times(time_vector, segment_index=segment_index) + elif format == "memory": if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording @@ -620,6 +361,13 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): cached = NumpyRecording.from_recording(self, **job_kwargs) + # timestamps are not saved in memory, so we have to set them explicitly + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + # the use of get_times is preferred since timestamps are converted to array + time_vector = self.get_times(segment_index=segment_index) + cached.set_times(time_vector, segment_index=segment_index) + elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor @@ -630,6 +378,8 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): ) cached = ZarrRecordingExtractor(zarr_path, storage_options) + # timestamps are saved and restored in zarr, so no need to set them explicitly + elif format == "nwb": # TODO implement a format based on zarr raise NotImplementedError @@ -641,12 +391,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index in range(self.get_num_segments()): - if self.has_time_vector(segment_index): - # the use of get_times is preferred since timestamps are converted to array - time_vector = self.get_times(segment_index=segment_index) - cached.set_times(time_vector, segment_index=segment_index) - return cached def _extra_metadata_from_folder(self, folder): @@ -893,110 +637,11 @@ def astype(self, dtype, round: bool | None = None): return astype(self, dtype=dtype, round=round) -class BaseRecordingSegment(BaseSegment): +class BaseRecordingSegment(ChunkableSegment): """ Abstract class representing a multichannel timeseries, or block of raw ephys traces """ - def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): - # sampling_frequency and time_vector are exclusive - if sampling_frequency is None: - assert time_vector is not None, "Pass either 'sampling_frequency' or 'time_vector'" - assert time_vector.ndim == 1, "time_vector should be a 1D array" - - if time_vector is None: - assert sampling_frequency is not None, "Pass either 'sampling_frequency' or 'time_vector'" - - self.sampling_frequency = sampling_frequency - self.t_start = t_start - self.time_vector = time_vector - - BaseSegment.__init__(self) - - def get_times(self) -> np.ndarray: - if self.time_vector is not None: - self.time_vector = np.asarray(self.time_vector) - return self.time_vector - else: - time_vector = np.arange(self.get_num_samples(), dtype="float64") - time_vector /= self.sampling_frequency - if self.t_start is not None: - time_vector += self.t_start - return time_vector - - def get_start_time(self) -> float: - if self.time_vector is not None: - return self.time_vector[0] - else: - return self.t_start if self.t_start is not None else 0.0 - - def get_end_time(self) -> float: - if self.time_vector is not None: - return self.time_vector[-1] - else: - t_stop = (self.get_num_samples() - 1) / self.sampling_frequency - if self.t_start is not None: - t_stop += self.t_start - return t_stop - - def get_times_kwargs(self) -> dict: - """ - Retrieves the timing attributes characterizing a RecordingSegment - - Returns - ------- - dict - A dictionary containing the following key-value pairs: - - - "sampling_frequency" : The sampling frequency of the RecordingSegment. - - "t_start" : The start time of the RecordingSegment. - - "time_vector" : The time vector of the RecordingSegment. - - Notes - ----- - The keys are always present, but the values may be None. - """ - time_kwargs = dict( - sampling_frequency=self.sampling_frequency, t_start=self.t_start, time_vector=self.time_vector - ) - return time_kwargs - - def sample_index_to_time(self, sample_ind): - """ - Transform sample index into time in seconds - """ - if self.time_vector is None: - time_s = sample_ind / self.sampling_frequency - if self.t_start is not None: - time_s += self.t_start - else: - time_s = self.time_vector[sample_ind] - return time_s - - def time_to_sample_index(self, time_s): - """ - Transform time in seconds into sample index - """ - if self.time_vector is None: - if self.t_start is None: - sample_index = time_s * self.sampling_frequency - else: - sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = np.round(sample_index).astype(np.int64) - else: - sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - - return sample_index - - def get_num_samples(self) -> int: - """Returns the number of samples in this signal segment - - Returns: - SampleIndex : Number of samples in the signal segment - """ - # must be implemented in subclass - raise NotImplementedError - def get_traces( self, start_frame: int | None = None, @@ -1022,3 +667,11 @@ def get_traces( """ # must be implemented in subclass raise NotImplementedError + + def get_data( + self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None + ) -> np.ndarray: + """ + General retrieval function for chunkable objects + """ + return self.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=indices) diff --git a/src/spikeinterface/core/chunkable.py b/src/spikeinterface/core/chunkable.py new file mode 100644 index 0000000000..e7cddbc65e --- /dev/null +++ b/src/spikeinterface/core/chunkable.py @@ -0,0 +1,469 @@ +from abc import ABC, abstractmethod +from typing import Optional +import warnings + +import numpy as np + +from spikeinterface.core.base import BaseExtractor, BaseSegment + + +class ChunkableMixin(ABC): + """ + Abstract mixin class for chunkable objects. Note that the mixin can only be used + for classes that inherit from BaseExtractor. + Provides methods to handle chunked data access, that can be used for parallelization. + In addition, since chunkable objects are continuous data, time handling methods are provided. + + The Mixin is abstract since all methods need to be implemented in the child class in order + for it to function properly. + """ + + _preferred_mp_context = None + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not issubclass(cls, BaseExtractor): + raise TypeError(f"{cls.__name__} must inherit from BaseExtractor to use Chunkable mixin.") + + @abstractmethod + def get_sampling_frequency(self) -> float: + raise NotImplementedError + + @abstractmethod + def get_num_samples(self, segment_index: int | None = None) -> int: + raise NotImplementedError + + @abstractmethod + def get_sample_size_in_bytes(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]: + raise NotImplementedError + + @abstractmethod + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: + raise NotImplementedError + + def _extra_copy_metadata(self, other: "ChunkableMixin", **kwargs) -> None: + """ + Copy metadata from another Chunkable object. + + Parameters + ---------- + other : ChunkableMixin + The object from which to copy metadata. + """ + # inherit preferred mp context if any + if self.__class__._preferred_mp_context is not None: + other.__class__._preferred_mp_context = self.__class__._preferred_mp_context + + def get_preferred_mp_context(self): + """ + Get the preferred context for multiprocessing. + If None, the context is set by the multiprocessing package. + """ + return self.__class__._preferred_mp_context + + def get_memory_size(self, segment_index=None) -> int: + """ + Returns the memory size of segment_index in bytes. + + Parameters + ---------- + segment_index : int or None, default: None + The index of the segment for which the memory size should be calculated. + For multi-segment objects, it is required, default: None + With single segment recording returns the memory size of the single segment + + Returns + ------- + int + The memory size of the specified segment in bytes. + """ + segment_index = self._check_segment_index(segment_index) + num_samples = self.get_num_samples(segment_index=segment_index) + sample_size_in_bytes = self.get_sample_size_in_bytes() + + memory_bytes = num_samples * sample_size_in_bytes + + return memory_bytes + + def get_total_memory_size(self) -> int: + """ + Returns the sum in bytes of all the memory sizes of the segments. + + Returns + ------- + int + The total memory size in bytes for all segments. + """ + memory_per_segment = (self.get_memory_size(segment_index) for segment_index in range(self.get_num_segments())) + return sum(memory_per_segment) + + # Add time handling + def get_time_info(self, segment_index=None) -> dict: + """ + Retrieves the timing attributes for a given segment index. As with + other recorders this method only needs a segment index in the case + of multi-segment recordings. + + Returns + ------- + dict + A dictionary containing the following key-value pairs: + + - "sampling_frequency" : The sampling frequency of the RecordingSegment. + - "t_start" : The start time of the RecordingSegment. + - "time_vector" : The time vector of the RecordingSegment. + + Notes + ----- + The keys are always present, but the values may be None. + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + time_kwargs = rs.get_times_kwargs() + + return time_kwargs + + def get_times(self, segment_index=None, start_frame=None, end_frame=None) -> np.ndarray: + """Get time vector for a recording segment. + + If the segment has a time_vector, then it is returned. Otherwise + a time_vector is constructed on the fly with sampling frequency. + If t_start is defined and the time vector is constructed on the fly, + the first time will be t_start. Otherwise it will start from 0. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + start_frame : int or None, default: None + The start frame for the time vector + end_frame : int or None, default: None + The end frame for the time vector + + Returns + ------- + np.array + The 1d times array + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + times = rs.get_times(start_frame=start_frame, end_frame=end_frame) + return times + + def get_start_time(self, segment_index=None) -> float: + """Get the start time of the recording segment. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The start time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.get_start_time() + + def get_end_time(self, segment_index=None) -> float: + """Get the stop time of the recording segment. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The stop time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.get_end_time() + + def has_time_vector(self, segment_index: Optional[int] = None): + """Check if the segment of the recording has a time vector. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + bool + True if the recording has time vectors, False otherwise + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + d = rs.get_times_kwargs() + return d["time_vector"] is not None + + def set_times(self, times, segment_index=None, with_warning=True): + """Set times for a recording segment. + + Parameters + ---------- + times : 1d np.array + The time vector + segment_index : int or None, default: None + The segment index (required for multi-segment) + with_warning : bool, default: True + If True, a warning is printed + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + + assert times.ndim == 1, "Time must have ndim=1" + assert rs.get_num_samples() == times.shape[0], "times have wrong shape" + + rs.t_start = None + rs.time_vector = times.astype("float64", copy=False) + + if with_warning: + warnings.warn( + "Setting times with Recording.set_times() is not recommended because " + "times are not always propagated across preprocessing" + "Use this carefully!" + ) + + def reset_times(self): + """ + Reset time information in-memory for all segments that have a time vector. + If the timestamps come from a file, the files won't be modified. but only the in-memory + attributes of the recording objects are deleted. Also `t_start` is set to None and the + segment's sampling frequency is set to the recording's sampling frequency. + """ + for segment_index in range(self.get_num_segments()): + rs = self.segments[segment_index] + if self.has_time_vector(segment_index): + rs.time_vector = None + rs.t_start = None + rs.sampling_frequency = self.sampling_frequency + + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. + + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for segment_index in segments_to_shift: + rs = self.segments[segment_index] + + if self.has_time_vector(segment_index=segment_index): + rs.time_vector += shift + else: + new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift + rs.t_start = new_start_time + + def sample_index_to_time(self, sample_ind, segment_index=None): + """ + Transform sample index into time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.sample_index_to_time(sample_ind) + + def time_to_sample_index(self, time_s, segment_index=None): + """ + Transform time in seconds into sample index + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.time_to_sample_index(time_s) + + def get_total_samples(self) -> int: + """ + Returns the sum of the number of samples in each segment. + + Returns + ------- + int + The total number of samples + """ + num_segments = self.get_num_segments() + samples_per_segment = (self.get_num_samples(segment_index) for segment_index in range(num_segments)) + + return sum(samples_per_segment) + + def get_duration(self, segment_index=None) -> float: + """ + Returns the duration in seconds. + + Parameters + ---------- + segment_index : int or None, default: None + The sample index to retrieve the duration for. + For multi-segment objects, it is required, default: None + With single segment recording returns the duration of the single segment + + Returns + ------- + float + The duration in seconds + """ + segment_duration = ( + self.get_end_time(segment_index) - self.get_start_time(segment_index) + (1 / self.get_sampling_frequency()) + ) + return segment_duration + + def get_total_duration(self) -> float: + """ + Returns the total duration in seconds + + Returns + ------- + float + The duration in seconds + """ + duration = sum([self.get_duration(segment_index) for segment_index in range(self.get_num_segments())]) + return duration + + def _get_t_starts(self): + # handle t_starts + t_starts = [] + for rs in self.segments: + d = rs.get_times_kwargs() + t_starts.append(d["t_start"]) + + if all(t_start is None for t_start in t_starts): + t_starts = None + return t_starts + + def _get_time_vectors(self): + time_vectors = [] + for rs in self.segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + + +class ChunkableSegment(BaseSegment): + """Class for chunkable segments, which provide methods to handle time kwargs.""" + + def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): + # sampling_frequency and time_vector are exclusive + if sampling_frequency is None: + assert time_vector is not None, "Pass either 'sampling_frequency' or 'time_vector'" + assert time_vector.ndim == 1, "time_vector should be a 1D array" + + if time_vector is None: + assert sampling_frequency is not None, "Pass either 'sampling_frequency' or 'time_vector'" + + self.sampling_frequency = sampling_frequency + self.t_start = t_start + self.time_vector = time_vector + + BaseSegment.__init__(self) + + def get_times(self, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + if self.time_vector is not None: + self.time_vector = np.asarray(self.time_vector) + return self.time_vector[start_frame:end_frame] + else: + time_vector = np.arange(start_frame, end_frame, dtype="float64") + time_vector /= self.sampling_frequency + if self.t_start is not None: + time_vector += self.t_start + return time_vector + + def get_start_time(self) -> float: + if self.time_vector is not None: + return self.time_vector[0] + else: + return self.t_start if self.t_start is not None else 0.0 + + def get_end_time(self) -> float: + if self.time_vector is not None: + return self.time_vector[-1] + else: + t_stop = (self.get_num_samples() - 1) / self.sampling_frequency + if self.t_start is not None: + t_stop += self.t_start + return t_stop + + def get_times_kwargs(self) -> dict: + """ + Retrieves the timing attributes characterizing a RecordingSegment + + Returns + ------- + dict + A dictionary containing the following key-value pairs: + + - "sampling_frequency" : The sampling frequency of the RecordingSegment. + - "t_start" : The start time of the RecordingSegment. + - "time_vector" : The time vector of the RecordingSegment. + + Notes + ----- + The keys are always present, but the values may be None. + """ + time_kwargs = dict( + sampling_frequency=self.sampling_frequency, t_start=self.t_start, time_vector=self.time_vector + ) + return time_kwargs + + def sample_index_to_time(self, sample_ind): + """ + Transform sample index into time in seconds + """ + if self.time_vector is None: + time_s = sample_ind / self.sampling_frequency + if self.t_start is not None: + time_s += self.t_start + else: + time_s = self.time_vector[sample_ind] + return time_s + + def time_to_sample_index(self, time_s): + """ + Transform time in seconds into sample index + """ + if self.time_vector is None: + if self.t_start is None: + sample_index = time_s * self.sampling_frequency + else: + sample_index = (time_s - self.t_start) * self.sampling_frequency + sample_index = np.round(sample_index).astype(np.int64) + else: + sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 + + return sample_index + + def get_num_samples(self) -> int: + """Returns the number of samples in this signal segment + + Returns: + SampleIndex : Number of samples in the signal segment + """ + # must be implemented in subclass + raise NotImplementedError diff --git a/src/spikeinterface/core/chunkable_tools.py b/src/spikeinterface/core/chunkable_tools.py new file mode 100644 index 0000000000..160bdd74fc --- /dev/null +++ b/src/spikeinterface/core/chunkable_tools.py @@ -0,0 +1,688 @@ +from pathlib import Path +import warnings + + +import numpy as np + +from .core_tools import add_suffix, make_shared_array +from .job_tools import ( + chunk_duration_to_chunk_size, + ensure_n_jobs, + fix_job_kwargs, + ChunkExecutor, + _shared_job_kwargs_doc, +) +from .chunkable import ChunkableMixin, ChunkableSegment + + +def write_binary( + chunkable: ChunkableMixin, + file_paths: list[Path | str] | Path | str, + file_timestamps_paths: list[Path | str] | Path | str | None = None, + dtype: np.typing.DTypeLike = None, + add_file_extension: bool = True, + byte_offset: int = 0, + verbose: bool = False, + **job_kwargs, +): + """ + Save the data of a chunkable object to binary format. + + Note : + time_axis is always 0 (contrary to previous version. + to get time_axis=1 (which is a bad idea) use `write_binary_file_handle()` + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved to binary file + file_paths : list[Path | str] | Path | str + The path to the files to save data for each segment. + file_timestamps_paths : list[Path | str] | Path | str | None, default: None + The path to the timestamps file. If None, timestamps are not saved. + dtype : dtype or None, default: None + Type of the saved data + add_file_extension, bool, default: True + If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension. + byte_offset : int, default: 0 + Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data + to an existing file where you wrote a header or other data before. + verbose : bool + This is the verbosity of the ChunkExecutor + {} + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths + num_segments = chunkable.get_num_segments() + if len(file_path_list) != num_segments: + raise ValueError("'file_paths' must be a list of the same size as the number of segments in the chunkable") + + file_path_list = [Path(file_path) for file_path in file_path_list] + if add_file_extension: + file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] + + dtype = dtype if dtype is not None else chunkable.get_dtype() + + sample_size_bytes = chunkable.get_sample_size_in_bytes() + + file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} + if file_timestamps_paths is not None: + file_timestamps_path_dict = { + segment_index: file_path for segment_index, file_path in enumerate(file_timestamps_paths) + } + else: + file_timestamps_path_dict = None + for segment_index, file_path in file_path_dict.items(): + num_samples = chunkable.get_num_samples(segment_index=segment_index) + data_size_bytes = sample_size_bytes * num_samples + file_size_bytes = data_size_bytes + byte_offset + + # Create an empty file with file_size_bytes + with open(file_path, "wb+") as file: + # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) + file.seek(file_size_bytes - 1) + file.write(b"\0") + + if file_timestamps_path_dict is not None: + file_timestamps_path = file_timestamps_path_dict[segment_index] + with open(file_timestamps_path, "wb+") as file: + file.seek(num_samples * 8 - 1) # 8 bytes for float64 timestamps + file.write(b"\0") + + assert Path(file_path).is_file() + + # use executor (loop or workers) + func = _write_binary_chunk + init_func = _init_binary_worker + init_args = (chunkable, file_path_dict, dtype, byte_offset, file_timestamps_path_dict) + executor = ChunkExecutor( + chunkable, func, init_func, init_args, job_name="write_binary", verbose=verbose, **job_kwargs + ) + executor.run() + + +# used by write_binary + ChunkExecutor +def _init_binary_worker(chunkable, file_path_dict, dtype, byte_offset, file_timestamps_path_dict=None): + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["byte_offset"] = byte_offset + worker_ctx["dtype"] = np.dtype(dtype) + + file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} + worker_ctx["file_dict"] = file_dict + worker_ctx["file_timestamps_dict"] = file_timestamps_path_dict + + return worker_ctx + + +# used by write_binary + ChunkExecutor +def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + byte_offset = worker_ctx["byte_offset"] + file = worker_ctx["file_dict"][segment_index] + file_timestamps_dict = worker_ctx["file_timestamps_dict"] + sample_size_bytes = chunkable.get_sample_size_in_bytes() + + # Calculate byte offsets for the start frames relative to the entire recording + start_byte = byte_offset + start_frame * sample_size_bytes + + data = chunkable.get_data(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + data = data.astype(dtype, order="c", copy=False) + + file.seek(start_byte) + file.write(data.data) + # flush is important!! + file.flush() + + if file_timestamps_dict is not None: + file_timestamps = file_timestamps_dict[segment_index] + timestamps = chunkable.get_times(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + timestamps = timestamps.astype("float64", order="c", copy=False) + timestamp_byte_offset = start_frame * 8 # 8 bytes for float64 + file.seek(timestamp_byte_offset) + file.write(timestamps.data) + file.flush() + + +write_binary.__doc__ = write_binary.__doc__.format(_shared_job_kwargs_doc) + + +# used by write_memory +def _init_memory_worker(chunkable, arrays, shm_names, shapes, dtype): + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["dtype"] = np.dtype(dtype) + + if arrays is None: + # create it from share memory name + from multiprocessing.shared_memory import SharedMemory + + arrays = [] + # keep shm alive + worker_ctx["shms"] = [] + for i in range(len(shm_names)): + shm = SharedMemory(shm_names[i]) + worker_ctx["shms"].append(shm) + arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) + arrays.append(arr) + + worker_ctx["arrays"] = arrays + + return worker_ctx + + +# used by write_memory +def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + arr = worker_ctx["arrays"][segment_index] + + # apply function + traces = chunkable.get_data(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + traces = traces.astype(dtype, copy=False) + arr[start_frame:end_frame, :] = traces + + +def write_memory( + chunkable: ChunkableMixin, dtype=None, verbose=False, buffer_type="auto", job_name="write_memory", **job_kwargs +): + """ + Save the traces into numpy arrays (memory). + try to use the SharedMemory introduce in py3.8 if n_jobs > 1 + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved to memory + dtype : dtype, default: None + Type of the saved data + verbose : bool, default: False + If True, output is verbose (when chunks are used) + buffer_type : "auto" | "numpy" | "sharedmem", + The type of buffer to use for storing the data. + job_name : str, default: "write_memory" + Name of the job + {} + + Returns + --------- + arrays : one array per segment + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + if dtype is None: + dtype = chunkable.get_dtype() + + # create sharedmmep + arrays = [] + shm_names = [] + shms = [] + shapes = [] + + n_jobs = ensure_n_jobs(chunkable, n_jobs=job_kwargs.get("n_jobs", 1)) + if buffer_type == "auto": + if n_jobs > 1: + buffer_type = "sharedmem" + else: + buffer_type = "numpy" + + for segment_index in range(chunkable.get_num_segments()): + shape = chunkable.get_shape(segment_index=segment_index) + shapes.append(shape) + if buffer_type == "sharedmem": + arr, shm = make_shared_array(shape, dtype) + shm_names.append(shm.name) + shms.append(shm) + else: + arr = np.zeros(shape, dtype=dtype) + shms.append(None) + arrays.append(arr) + + # use executor (loop or workers) + func = _write_memory_chunk + init_func = _init_memory_worker + if n_jobs > 1: + init_args = (chunkable, None, shm_names, shapes, dtype) + else: + init_args = (chunkable, arrays, None, None, dtype) + + executor = ChunkExecutor(chunkable, func, init_func, init_args, verbose=verbose, job_name=job_name, **job_kwargs) + executor.run() + + return arrays, shms + + +write_memory.__doc__ = write_memory.__doc__.format(_shared_job_kwargs_doc) + + +def write_chunkable_to_zarr( + chunkable: ChunkableMixin, + zarr_group, + dataset_paths, + dataset_timestamps_paths=None, + extra_chunks=None, + dtype=None, + compressor_data=None, + filters_data=None, + compressor_times=None, + filters_times=None, + verbose=False, + **job_kwargs, +): + """ + Save the trace of a chunkable object in several zarr format. + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved in .dat format + zarr_group : zarr.Group + The zarr group to add traces to + dataset_paths : list + List of paths to traces datasets in the zarr group + dataset_timestamps_paths : list or None, default: None + List of paths to timestamps datasets in the zarr group. If None, timestamps are not saved. + extra_chunks : tuple or None, default: None + Extra chunking dimensions to use for the zarr dataset. + The first dimension is always time and controlled by the job_kwargs. + This is for example useful to chunk by channel, with `extra_chunks=(channel_chunk_size,)`. + dtype : dtype, default: None + Type of the saved data + compressor_data : zarr compressor or None, default: None + Zarr compressor for data + filters_data : list, default: None + List of zarr filters for data + compressor_times : zarr compressor or None, default: None + Zarr compressor for timestamps + filters_times : list, default: None + List of zarr filters for timestamps + verbose : bool, default: False + If True, output is verbose (when chunks are used) + {} + """ + from .job_tools import ( + ensure_chunk_size, + fix_job_kwargs, + ChunkExecutor, + ) + + assert dataset_paths is not None, "Provide 'dataset_paths' to save data in zarr format" + if dataset_timestamps_paths is not None: + assert ( + len(dataset_timestamps_paths) == chunkable.get_num_segments() + ), "dataset_timestamps_paths should have the same length as the number of segments in the chunkable" + else: + dataset_timestamps_paths = [None] * chunkable.get_num_segments() + + if not isinstance(dataset_paths, list): + dataset_paths = [dataset_paths] + assert len(dataset_paths) == chunkable.get_num_segments() + + if dtype is None: + dtype = chunkable.get_dtype() + + job_kwargs = fix_job_kwargs(job_kwargs) + chunk_size = ensure_chunk_size(chunkable, **job_kwargs) + + if extra_chunks is not None: + assert len(extra_chunks) == len(chunkable.get_shape(0)[1:]), ( + "extra_chunks should have the same length as the number of dimensions " + "of the chunkable minus one (time axis)" + ) + + # create zarr datasets files + zarr_datasets = [] + zarr_timestamps_datasets = [] + + for segment_index in range(chunkable.get_num_segments()): + num_samples = chunkable.get_num_samples(segment_index) + dset_name = dataset_paths[segment_index] + shape = chunkable.get_shape(segment_index) + dset = zarr_group.create_dataset( + name=dset_name, + shape=shape, + chunks=(chunk_size,) + extra_chunks if extra_chunks is not None else (chunk_size,), + dtype=dtype, + filters=filters_data, + compressor=compressor_data, + ) + zarr_datasets.append(dset) + if dataset_timestamps_paths[segment_index] is not None: + tset_name = dataset_timestamps_paths[segment_index] + zarr_timestamps_datasets.append( + zarr_group.create_dataset( + name=tset_name, + shape=(num_samples,), + chunks=(chunk_size,), + dtype="float64", + filters=filters_times, + compressor=compressor_times, + ) + ) + else: + zarr_timestamps_datasets.append(None) + + # use executor (loop or workers) + func = _write_zarr_chunk + init_func = _init_zarr_worker + init_args = (chunkable, zarr_datasets, dtype, zarr_timestamps_datasets) + executor = ChunkExecutor( + chunkable, func, init_func, init_args, verbose=verbose, job_name="write_zarr", **job_kwargs + ) + executor.run() + + # save t_starts + t_starts = np.zeros(chunkable.get_num_segments(), dtype="float64") * np.nan + for segment_index in range(chunkable.get_num_segments()): + time_info = chunkable.get_time_info(segment_index) + if time_info["t_start"] is not None: + t_starts[segment_index] = time_info["t_start"] + + if np.any(~np.isnan(t_starts)): + zarr_group.create_dataset(name="t_starts", data=t_starts, compressor=None) + + +# used by write_zarr_recording + ChunkExecutor +def _init_zarr_worker(chunkable, zarr_datasets, dtype, zarr_timestamps_datasets=None): + import zarr + + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["zarr_datasets"] = zarr_datasets + if zarr_timestamps_datasets is not None and len(zarr_timestamps_datasets) > 0: + worker_ctx["zarr_timestamps_datasets"] = zarr_timestamps_datasets + else: + worker_ctx["zarr_timestamps_datasets"] = None + worker_ctx["dtype"] = np.dtype(dtype) + + return worker_ctx + + +# used by write_zarr_recording + ChunkExecutor +def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): + import gc + + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + zarr_dataset = worker_ctx["zarr_datasets"][segment_index] + if worker_ctx["zarr_timestamps_datasets"] is not None: + zarr_timestamps_dataset = worker_ctx["zarr_timestamps_datasets"][segment_index] + else: + zarr_timestamps_dataset = None + + # apply function + data = chunkable.get_data( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + ) + data = data.astype(dtype) + zarr_dataset[start_frame:end_frame, :] = data + + if zarr_timestamps_dataset is not None: + timestamps = chunkable.get_times(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + zarr_timestamps_dataset[start_frame:end_frame] = timestamps + + # fix memory leak by forcing garbage collection + del data + gc.collect() + + +def get_random_sample_slices( + chunkable: ChunkableMixin, + method="full_random", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None, +): + """ + Get random slice of a chunkable object across segments. + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to get random chunks from + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is used only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. + concatenated : bool, default: True + If True chunk are concatenated along time axis + seed : int, default: None + Random seed + margin_frames : int, default: 0 + Margin in number of frames to avoid edge effects + + Returns + ------- + chunk_list : np.array + Array of concatenate chunks per segment + + + """ + # TODO: if segment have differents length make another sampling that dependant on the length of the segment + # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY + # And randomize the number of chunk per segment weighted by segment duration + + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable) + else: + raise ValueError("get_random_sample_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = chunkable.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = chunkable.get_num_samples(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = chunkable.get_num_samples(segment_index) + high = num_frames - chunk_size - margin_frames + # here we set endpoint to True, because the this represents the start of the + # chunk, and should be inclusive + random_starts = rng.integers(low=low, high=high, size=size, endpoint=True) + random_starts = np.sort(random_starts) + slices += [(segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts] + else: + raise ValueError(f"get_random_sample_slices : wrong method {method}") + + return slices + + +def get_chunks(chunkable: ChunkableMixin, concatenated=True, get_data_kwargs=None, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_sample_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_sample_slices()` for more details on parameters. + + # TODO: handle this in recording tools: + return * will be get_data_kwargs + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to get random chunks from + return_scaled : bool | None, default: None + DEPRECATED. Use return_in_uV instead. + return_in_uV : bool, default: False + If True and the chunkable has scaling (gain_to_uV and offset_to_uV properties), + traces are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_sample_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.ndarray | list of np.array + Array of concatenate chunks per segment + """ + slices = get_random_sample_slices(chunkable, **random_slices_kwargs) + + chunk_list = [] + get_data_kwargs = get_data_kwargs if get_data_kwargs is not None else {} + for segment_index, start_frame, end_frame in slices: + traces_chunk = chunkable.get_data( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, **get_data_kwargs + ) + chunk_list.append(traces_chunk) + + if concatenated: + return np.concatenate(chunk_list, axis=0) + else: + return chunk_list + + +def get_chunk_with_margin( + chunkable_segment: ChunkableSegment, + start_frame, + end_frame, + last_dimension_indices, + margin, + add_zeros=False, + add_reflect_padding=False, + window_on_margin=False, + dtype=None, +): + """ + Helper to get chunk with margin + + The margin is extracted from the recording when possible. If + at the edge of the recording, no margin is used unless one + of `add_zeros` or `add_reflect_padding` is True. In the first + case zero padding is used, in the second case np.pad is called + with mod="reflect". + """ + length = int(chunkable_segment.get_num_samples()) + + if last_dimension_indices is None: + last_dimension_indices = slice(None) + + if not (add_zeros or add_reflect_padding): + if window_on_margin and not add_zeros: + raise ValueError("window_on_margin requires add_zeros=True") + + if start_frame is None: + left_margin = 0 + start_frame = 0 + elif start_frame < margin: + left_margin = start_frame + else: + left_margin = margin + + if end_frame is None: + right_margin = 0 + end_frame = length + elif end_frame > (length - margin): + right_margin = length - end_frame + else: + right_margin = margin + + data_chunk = chunkable_segment.get_data( + start_frame - left_margin, + end_frame + right_margin, + last_dimension_indices, + ) + + else: + # either add_zeros or reflect_padding + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = length + + chunk_size = end_frame - start_frame + full_size = chunk_size + 2 * margin + + if start_frame < margin: + start_frame2 = 0 + left_pad = margin - start_frame + else: + start_frame2 = start_frame - margin + left_pad = 0 + + if end_frame > (length - margin): + end_frame2 = length + right_pad = end_frame + margin - length + else: + end_frame2 = end_frame + margin + right_pad = 0 + + data_chunk = chunkable_segment.get_data(start_frame2, end_frame2, last_dimension_indices) + + if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0: + need_copy = True + else: + need_copy = False + + left_margin = margin + right_margin = margin + + if need_copy: + if dtype is None: + dtype = data_chunk.dtype + + left_margin = margin + if end_frame < (length + margin): + right_margin = margin + else: + right_margin = end_frame + margin - length + + if add_zeros: + data_chunk2 = np.zeros((full_size, data_chunk.shape[1]), dtype=dtype) + i0 = left_pad + i1 = left_pad + data_chunk.shape[0] + data_chunk2[i0:i1, :] = data_chunk + if window_on_margin: + # apply inplace taper on border + taper = (1 - np.cos(np.arange(margin) / margin * np.pi)) / 2 + taper = taper[:, np.newaxis] + data_chunk2[:margin] *= taper + data_chunk2[-margin:] *= taper[::-1] + data_chunk = data_chunk2 + elif add_reflect_padding: + # in this case, we don't want to taper + data_chunk = np.pad( + data_chunk.astype(dtype, copy=False), + [(left_pad, right_pad)] + [(0, 0)] * (data_chunk.ndim - 1), + mode="reflect", + ) + else: + # we need a copy to change the dtype + data_chunk = np.asarray(data_chunk, dtype=dtype) + + return data_chunk, left_margin, right_margin diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 4bb1356769..1fab85e673 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -6,7 +6,6 @@ import platform import os import warnings -from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str import sys from tqdm.auto import tqdm @@ -16,6 +15,8 @@ import threading from threadpoolctl import threadpool_limits +from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str + _shared_job_kwargs_doc = """**job_kwargs : keyword arguments for parallel processing: * chunk_duration or chunk_size or chunk_memory or total_memory - chunk_size : int @@ -204,16 +205,16 @@ def divide_segment_into_chunks(num_frames, chunk_size): return chunks -def divide_recording_into_chunks(recording, chunk_size): - recording_slices = [] +def divide_chunkable_into_chunks(recording, chunk_size): + slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return recording_slices + slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return slices -def ensure_n_jobs(recording, n_jobs=1): +def ensure_n_jobs(extractor, n_jobs=1): if n_jobs == -1: n_jobs = os.cpu_count() elif n_jobs == 0: @@ -231,19 +232,19 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_memory_serializable(): + if not extractor.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not serializable to memory and can't be processed in parallel. " + "Extractor is not serializable to memory and can't be processed in parallel. " "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs -def chunk_duration_to_chunk_size(chunk_duration, recording): +def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"): if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + chunk_size = int(chunk_duration * chunkable.get_sampling_frequency()) elif isinstance(chunk_duration, str): if chunk_duration.endswith("ms"): chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 @@ -251,17 +252,23 @@ def chunk_duration_to_chunk_size(chunk_duration, recording): chunk_duration = float(chunk_duration.replace("s", "")) else: raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + chunk_size = int(chunk_duration * chunkable.get_sampling_frequency()) else: raise ValueError("chunk_duration must be str or float") return chunk_size def ensure_chunk_size( - recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs + chunkable: "ChunkableMixin", + total_memory=None, + chunk_size=None, + chunk_memory=None, + chunk_duration=None, + n_jobs=1, + **other_kwargs, ): """ - "chunk_size" is the traces.shape[0] for each worker. + "chunk_size" is the number of samples for each worker. Flexible chunk_size setter with 3 ways: * "chunk_size" : is the length in sample for each chunk independently of channel count and dtype. @@ -292,24 +299,20 @@ def ensure_chunk_size( assert total_memory is None # set by memory per worker size chunk_memory = convert_string_to_bytes(chunk_memory) - n_bytes = np.dtype(recording.get_dtype()).itemsize - num_channels = recording.get_num_channels() - chunk_size = int(chunk_memory / (num_channels * n_bytes)) + chunk_size = int(chunk_memory / chunkable.get_sample_size_in_bytes()) elif total_memory is not None: # clip by total memory size - n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs) + n_jobs = ensure_n_jobs(chunkable, n_jobs=n_jobs) total_memory = convert_string_to_bytes(total_memory) - n_bytes = np.dtype(recording.get_dtype()).itemsize - num_channels = recording.get_num_channels() - chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) + chunk_size = int(total_memory / (chunkable.get_sample_size_in_bytes() * n_jobs)) elif chunk_duration is not None: - chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment if n_jobs == 1: - num_segments = recording.get_num_segments() - samples_in_larger_segment = max([recording.get_num_samples(segment) for segment in range(num_segments)]) + num_segments = chunkable.get_num_segments() + samples_in_larger_segment = max([chunkable.get_num_samples(segment) for segment in range(num_segments)]) chunk_size = samples_in_larger_segment else: raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory") @@ -317,9 +320,9 @@ def ensure_chunk_size( return chunk_size -class ChunkRecordingExecutor: +class ChunkExecutor: """ - Core class for parallel processing to run a "function" over chunks on a recording. + Core class for parallel processing to run a "function" over chunks on a chunkable extractor. It supports running a function: * in loop with chunk processing (low RAM usage) @@ -331,8 +334,8 @@ class ChunkRecordingExecutor: Parameters ---------- - recording : RecordingExtractor - The recording to be processed + chunkable : ChunkableMixin + The chunkable object to be processed. func : function Function that runs on each chunk init_func : function @@ -380,7 +383,7 @@ class ChunkRecordingExecutor: def __init__( self, - recording, + chunkable: "ChunkableMixin", func, init_func, init_args, @@ -399,7 +402,7 @@ def __init__( max_threads_per_worker=1, need_worker_index=False, ): - self.recording = recording + self.chunkable = chunkable self.func = func self.init_func = init_func self.init_args = init_args @@ -418,7 +421,7 @@ def __init__( else: mp_context = "spawn" - preferred_mp_context = recording.get_preferred_mp_context() + preferred_mp_context = chunkable.get_preferred_mp_context() if preferred_mp_context is not None and preferred_mp_context != mp_context: warnings.warn( f"Your processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible." @@ -434,9 +437,8 @@ def __init__( self.handle_returns = handle_returns self.gather_func = gather_func - self.n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs) - self.chunk_size = ensure_chunk_size( - recording, + self.n_jobs = ensure_n_jobs(self.chunkable, n_jobs=n_jobs) + self.chunk_size = self.ensure_chunk_size( total_memory=total_memory, chunk_size=chunk_size, chunk_memory=chunk_memory, @@ -451,9 +453,9 @@ def __init__( self.need_worker_index = need_worker_index if verbose: - chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize + chunk_memory = self.get_chunk_memory() total_memory = chunk_memory * self.n_jobs - chunk_duration = self.chunk_size / recording.get_sampling_frequency() + chunk_duration = self.chunk_size / chunkable.sampling_frequency chunk_memory_str = convert_bytes_to_str(chunk_memory) total_memory_str = convert_bytes_to_str(total_memory) chunk_duration_str = convert_seconds_to_str(chunk_duration) @@ -468,13 +470,24 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self, recording_slices=None): + def get_chunk_memory(self): + return self.chunk_size * self.chunkable.get_sample_size_in_bytes() + + def ensure_chunk_size( + self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs + ): + return ensure_chunk_size( + self.chunkable, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs + ) + + def run(self, slices=None): """ Runs the defined jobs. """ - if recording_slices is None: - recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) + if slices is None: + # TODO: rename + slices = divide_chunkable_into_chunks(self.chunkable, self.chunk_size) if self.handle_returns: returns = [] @@ -483,9 +496,7 @@ def run(self, recording_slices=None): if self.n_jobs == 1: if self.progress_bar: - recording_slices = tqdm( - recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices) - ) + slices = tqdm(slices, desc=f"{self.job_name} (no parallelization)", total=len(slices)) init_args = self.init_args if self.need_worker_index: @@ -496,7 +507,7 @@ def run(self, recording_slices=None): if self.need_worker_index: worker_dict["worker_index"] = worker_index - for segment_index, frame_start, frame_stop in recording_slices: + for segment_index, frame_start, frame_stop in slices: res = self.func(segment_index, frame_start, frame_stop, worker_dict) if self.handle_returns: returns.append(res) @@ -504,7 +515,7 @@ def run(self, recording_slices=None): self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(recording_slices)) + n_jobs = min(self.n_jobs, len(slices)) if self.pool_engine == "process": @@ -534,13 +545,13 @@ def run(self, recording_slices=None): array_pid, ), ) as executor: - results = executor.map(process_function_wrapper, recording_slices) + results = executor.map(process_function_wrapper, slices) if self.progress_bar: results = tqdm( results, desc=f"{self.job_name} (workers: {n_jobs} processes {self.mp_context})", - total=len(recording_slices), + total=len(slices), ) for res in results: @@ -559,7 +570,7 @@ def run(self, recording_slices=None): if self.progress_bar: # here the tqdm threading do not work (maybe collision) so we need to create a pbar # before thread spawning - pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(recording_slices)) + pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(slices)) if self.need_worker_index: lock = threading.Lock() @@ -580,8 +591,8 @@ def run(self, recording_slices=None): ), ) as executor: - recording_slices2 = [(thread_local_data,) + tuple(args) for args in recording_slices] - results = executor.map(thread_function_wrapper, recording_slices2) + slices2 = [(thread_local_data,) + tuple(args) for args in slices] + results = executor.map(thread_function_wrapper, slices2) for res in results: if self.progress_bar: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 43cdd30c87..a91a4909b0 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -11,8 +11,9 @@ import numpy as np from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype +from spikeinterface.core.chunkable import ChunkableMixin from spikeinterface.core import BaseRecording, get_chunk_with_margin -from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc +from spikeinterface.core.job_tools import ChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc from spikeinterface.core import get_channel_distances @@ -24,7 +25,7 @@ class PipelineNode: def __init__( self, - recording: BaseRecording, + chunkable: ChunkableMixin, return_output: bool | tuple[bool] = True, parents: list[Type["PipelineNode"]] | None = None, ): @@ -36,8 +37,8 @@ def __init__( Parameters ---------- - recording : BaseRecording - The recording object. + chunkable : ChunkableMixin + The chunkable object. return_output : bool or tuple[bool], default: True Whether or not the output of the node is returned by the pipeline. When a Node have several toutputs then this can be a tuple of bool @@ -45,7 +46,7 @@ def __init__( Pass parents nodes to perform a previous computation. """ - self.recording = recording + self.chunkable = chunkable self.return_output = return_output if isinstance(parents, str): # only one parents is allowed @@ -54,14 +55,14 @@ def __init__( self._kwargs = dict() - def get_trace_margin(self): + def get_data_margin(self): # can optionaly be overwritten return 0 def get_dtype(self): raise NotImplementedError - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *args): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, *args): raise NotImplementedError @@ -76,7 +77,7 @@ class PeakSource(PipelineNode): # between processes or threads need_first_call_before_pipeline = False - def get_trace_margin(self): + def get_data_margin(self): raise NotImplementedError def get_dtype(self): @@ -93,7 +94,7 @@ def get_peak_slice( def _first_call_before_pipeline(self): # see need_first_call_before_pipeline = True - margin = self.get_trace_margin() + margin = self.get_data_margin() traces = self.recording.get_traces(start_frame=0, end_frame=margin * 2 + 1, segment_index=0) self.compute(traces, 0, margin * 2 + 1, 0, margin) @@ -116,7 +117,7 @@ def __init__(self, recording, peaks): i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) - def get_trace_margin(self): + def get_data_margin(self): return 0 def get_dtype(self): @@ -128,7 +129,7 @@ def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) return i0, i1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, peak_slice): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] @@ -154,6 +155,9 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + Parameters + ---------- + sorting : BaseSorting The sorting object. recording : BaseRecording @@ -208,7 +212,7 @@ def __init__( i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) - def get_trace_margin(self): + def get_data_margin(self): return 0 def get_dtype(self): @@ -225,7 +229,7 @@ def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) return i0, i1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, peak_slice): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] @@ -242,14 +246,14 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea local_peaks["in_margin"][:] = False mask = local_peaks["sample_index"] < max_margin local_peaks["in_margin"][mask] = True - mask = local_peaks["sample_index"] >= traces.shape[0] - max_margin + mask = local_peaks["sample_index"] >= chunk.shape[0] - max_margin local_peaks["in_margin"][mask] = True if not self.channel_from_template: # handle channel spike per spike for i, peak in enumerate(local_peaks): chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) - sparse_wfs = traces[peak["sample_index"], chans] + sparse_wfs = chunk[peak["sample_index"], chans] if self.peak_sign == "neg": local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] elif self.peak_sign == "pos": @@ -259,7 +263,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea # handle amplitude for i, peak in enumerate(local_peaks): - local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]] + local_peaks["amplitude"][i] = chunk[peak["sample_index"], peak["channel_index"]] return (local_peaks,) @@ -311,7 +315,8 @@ def __init__( Whether or not the output of the node is returned by the pipeline """ - PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output) + PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) + self.recording = recording self.ms_before = ms_before self.ms_after = ms_after self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0) @@ -350,18 +355,18 @@ def __init__( WaveformsNode.__init__( self, - recording=recording, + recording, parents=parents, ms_before=ms_before, ms_after=ms_after, return_output=return_output, ) - def get_trace_margin(self): + def get_data_margin(self): return max(self.nbefore, self.nafter) - def compute(self, traces, peaks): - waveforms = traces[peaks["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] + def compute(self, chunk, peaks): + waveforms = chunk[peaks["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] return waveforms @@ -407,7 +412,7 @@ def __init__( """ WaveformsNode.__init__( self, - recording=recording, + recording, parents=parents, ms_before=ms_before, ms_after=ms_after, @@ -425,15 +430,15 @@ def __init__( self.neighbours_mask = self.channel_distance <= radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) - def get_trace_margin(self): + def get_data_margin(self): return max(self.nbefore, self.nafter) - def compute(self, traces, peaks): - sparse_wfs = np.zeros((peaks.shape[0], self.nbefore + self.nafter, self.max_num_chans), dtype=traces.dtype) + def compute(self, chunk, peaks): + sparse_wfs = np.zeros((peaks.shape[0], self.nbefore + self.nafter, self.max_num_chans), dtype=chunk.dtype) for i, peak in enumerate(peaks): (chans,) = np.nonzero(self.neighbours_mask[peak["channel_index"]]) - sparse_wfs[i, :, : len(chans)] = traces[ + sparse_wfs[i, :, : len(chans)] = chunk[ peak["sample_index"] - self.nbefore : peak["sample_index"] + self.nafter, : ][:, chans] @@ -500,12 +505,11 @@ def check_graph(nodes, check_for_peak_source=True): Check that node list is orderd in a good (parents are before children) """ - if check_for_peak_source: - node0 = nodes[0] - if not isinstance(node0, PeakSource): - raise ValueError( - "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" - ) + node0 = nodes[0] + if not isinstance(node0, PeakSource) and check_for_peak_source: + raise ValueError( + "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" + ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" @@ -521,19 +525,19 @@ def check_graph(nodes, check_for_peak_source=True): def run_node_pipeline( - recording, - nodes, - job_kwargs, - job_name="pipeline", - gather_mode="memory", - gather_kwargs={}, - squeeze_output=True, - folder=None, - names=None, - verbose=False, - skip_after_n_peaks=None, - recording_slices=None, - check_for_peak_source=True, + chunkable: ChunkableMixin, + nodes: list[PipelineNode], + job_kwargs: dict, + job_name: str = "pipeline", + gather_mode: str = "memory", + gather_kwargs: dict = {}, + squeeze_output: bool = True, + folder: str | None = None, + names: list[str] | None = None, + verbose: bool = False, + skip_after_n_peaks: int | None = None, + slices: list[tuple] | None = None, + check_for_peak_source: bool = False, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -561,11 +565,12 @@ def run_node_pipeline( Parameters ---------- - - recording: Recording - + chunkable: ChunkableMixin + The chunkable object to run the pipeline on. This is typically a recording but it can be anything that have the + same interface for getting chunks with margin. nodes: a list of PipelineNode - + The list of nodes to run in the pipeline. The order of the nodes is important as it defines + the order of computation. job_kwargs: dict The classical job_kwargs job_name : str @@ -585,12 +590,12 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. - recording_slices : None | list[tuple] + slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. - check_for_peak_source : bool, default True - Whether to check that the first node is a PeakSource (PeakDetector or PeakRetriever or + check_for_peak_source : bool, default False + Whether to check the graph of PeakSource nodes. Returns ------- @@ -598,7 +603,6 @@ def run_node_pipeline( a tuple of vector for the output of nodes having return_output=True. If squeeze_output=True and only one output then directly np.array. """ - check_graph(nodes, check_for_peak_source=check_for_peak_source) job_kwargs = fix_job_kwargs(job_kwargs) @@ -621,10 +625,10 @@ def run_node_pipeline( # See need_first_call_before_pipeline : this trigger numba compilation before the run node0._first_call_before_pipeline() - init_args = (recording, nodes, skip_after_n_peaks_per_worker) + init_args = (chunkable, nodes, skip_after_n_peaks_per_worker) - processor = ChunkRecordingExecutor( - recording, + processor = ChunkExecutor( + chunkable, _compute_peak_pipeline_chunk, _init_peak_pipeline, init_args, @@ -634,30 +638,30 @@ def run_node_pipeline( **job_kwargs, ) - processor.run(recording_slices=recording_slices) + processor.run(slices=slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs -def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): +def _init_peak_pipeline(chunkable, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} - worker_ctx["recording"] = recording + worker_ctx["chunkable"] = chunkable worker_ctx["nodes"] = nodes - worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["max_margin"] = max(node.get_data_margin() for node in nodes) worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker worker_ctx["num_peaks"] = 0 return worker_ctx def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): - recording = worker_ctx["recording"] + chunkable = worker_ctx["chunkable"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] - recording_segment = recording.segments[segment_index] + chunkable_segment = chunkable.segments[segment_index] retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) # get peak slices once for all retrievers peak_slice_by_retriever = {} @@ -678,7 +682,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c if load_trace_and_compute: traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + chunkable_segment, start_frame, end_frame, None, max_margin, add_zeros=True ) # compute the graph pipeline_outputs = {} @@ -693,7 +697,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c # to handle compatibility peak detector is a special case # with specific margin # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() + extra_margin = max_margin - node.get_data_margin() if extra_margin: trace_detection = traces_chunk[extra_margin:-extra_margin] else: diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 48eb2d7fd4..084cac6aba 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -3,7 +3,6 @@ import warnings from pathlib import Path import os -import mmap import tqdm import numpy.typing as npt @@ -12,15 +11,19 @@ from .core_tools import add_suffix, make_shared_array from .job_tools import ( ensure_chunk_size, - ensure_n_jobs, divide_segment_into_chunks, fix_job_kwargs, - ChunkRecordingExecutor, + ChunkExecutor, _shared_job_kwargs_doc, - chunk_duration_to_chunk_size, split_job_kwargs, ) +from .chunkable_tools import get_random_sample_slices, get_chunks, get_chunk_with_margin + +# for back-compatibility imports +from .chunkable_tools import write_binary as write_binary_recording +from .chunkable_tools import write_memory as write_memory_recording + def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): """ @@ -52,124 +55,11 @@ def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): return samples -# used by write_binary_recording + ChunkRecordingExecutor -def _init_binary_worker(recording, file_path_dict, dtype, byte_offest): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["byte_offset"] = byte_offest - worker_ctx["dtype"] = np.dtype(dtype) - - file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} - worker_ctx["file_dict"] = file_dict - - return worker_ctx - - -def write_binary_recording( - recording: "BaseRecording", - file_paths: list[Path | str] | Path | str, - dtype: npt.DTypeLike | None = None, - add_file_extension: bool = True, - byte_offset: int = 0, - verbose: bool = False, - **job_kwargs, -): - """ - Save the trace of a recording extractor in several binary .dat format. - - Note : - time_axis is always 0 (contrary to previous version. - to get time_axis=1 (which is a bad idea) use `write_binary_recording_file_handle()` - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object to be saved in .dat format - file_path : str or list[str] - The path to the file. - dtype : dtype or None, default: None - Type of the saved data - add_file_extension, bool, default: True - If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension. - byte_offset : int, default: 0 - Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data - to an existing file where you wrote a header or other data before. - verbose : bool - This is the verbosity of the ChunkRecordingExecutor - {} - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths - num_segments = recording.get_num_segments() - if len(file_path_list) != num_segments: - raise ValueError("'file_paths' must be a list of the same size as the number of segments in the recording") - - file_path_list = [Path(file_path) for file_path in file_path_list] - if add_file_extension: - file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] - - dtype = dtype if dtype is not None else recording.get_dtype() - - dtype_size_bytes = np.dtype(dtype).itemsize - num_channels = recording.get_num_channels() - - file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} - for segment_index, file_path in file_path_dict.items(): - num_frames = recording.get_num_frames(segment_index=segment_index) - data_size_bytes = dtype_size_bytes * num_frames * num_channels - file_size_bytes = data_size_bytes + byte_offset - - # Create an empty file with file_size_bytes - with open(file_path, "wb+") as file: - # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) - file.seek(file_size_bytes - 1) - file.write(b"\0") - - assert Path(file_path).is_file() - - # use executor (loop or workers) - func = _write_binary_chunk - init_func = _init_binary_worker - init_args = (recording, file_path_dict, dtype, byte_offset) - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs - ) - executor.run() - - -# used by write_binary_recording + ChunkRecordingExecutor -def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - byte_offset = worker_ctx["byte_offset"] - file = worker_ctx["file_dict"][segment_index] - - num_channels = recording.get_num_channels() - dtype_size_bytes = np.dtype(dtype).itemsize - - # Calculate byte offsets for the start frames relative to the entire recording - start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes - - traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, order="c", copy=False) - - file.seek(start_byte) - file.write(traces.data) - # flush is important!! - file.flush() - - -write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) - - -def write_binary_recording_file_handle( +def write_binary_file_handle( recording, file_handle=None, time_axis=0, dtype=None, byte_offset=0, verbose=False, **job_kwargs ): """ - Old variant version of write_binary_recording with one file handle. + Old variant version of write_binary with one file handle. Can be useful in some case ??? Not used anymore at the moment. @@ -209,115 +99,6 @@ def write_binary_recording_file_handle( file_handle.write(traces.tobytes()) -# used by write_memory_recording -def _init_memory_worker(recording, arrays, shm_names, shapes, dtype): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["dtype"] = np.dtype(dtype) - - if arrays is None: - # create it from share memory name - from multiprocessing.shared_memory import SharedMemory - - arrays = [] - # keep shm alive - worker_ctx["shms"] = [] - for i in range(len(shm_names)): - shm = SharedMemory(shm_names[i]) - worker_ctx["shms"].append(shm) - arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) - arrays.append(arr) - - worker_ctx["arrays"] = arrays - - return worker_ctx - - -# used by write_memory_recording -def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - arr = worker_ctx["arrays"][segment_index] - - # apply function - traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, copy=False) - arr[start_frame:end_frame, :] = traces - - -def write_memory_recording(recording, dtype=None, verbose=False, buffer_type="auto", **job_kwargs): - """ - Save the traces into numpy arrays (memory). - try to use the SharedMemory introduce in py3.8 if n_jobs > 1 - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object to be saved in .dat format - dtype : dtype, default: None - Type of the saved data - verbose : bool, default: False - If True, output is verbose (when chunks are used) - buffer_type : "auto" | "numpy" | "sharedmem" - {} - - Returns - --------- - arrays : one array per segment - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - if dtype is None: - dtype = recording.get_dtype() - - # create sharedmmep - arrays = [] - shm_names = [] - shms = [] - shapes = [] - - n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) - if buffer_type == "auto": - if n_jobs > 1: - buffer_type = "sharedmem" - else: - buffer_type = "numpy" - - for segment_index in range(recording.get_num_segments()): - num_frames = recording.get_num_samples(segment_index) - num_channels = recording.get_num_channels() - shape = (num_frames, num_channels) - shapes.append(shape) - if buffer_type == "sharedmem": - arr, shm = make_shared_array(shape, dtype) - shm_names.append(shm.name) - shms.append(shm) - else: - arr = np.zeros(shape, dtype=dtype) - shms.append(None) - arrays.append(arr) - - # use executor (loop or workers) - func = _write_memory_chunk - init_func = _init_memory_worker - if n_jobs > 1: - init_args = (recording, None, shm_names, shapes, dtype) - else: - init_args = (recording, arrays, None, None, dtype) - - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs - ) - executor.run() - - return arrays, shms - - -write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) - - def write_to_h5_dataset_format( recording, dataset_path, @@ -458,101 +239,16 @@ def write_to_h5_dataset_format( return save_path -def get_random_recording_slices( - recording, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - chunk_size=None, - margin_frames=0, - seed=None, -): - """ - Get random slice of a recording across segments. - - This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. - - Parameters - ---------- - recording : BaseRecording - The recording to get random chunks from - method : "full_random" - The method used to get random slices. - * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices - and they can overlap. - num_chunks_per_segment : int, default: 20 - Number of chunks per segment - chunk_duration : str | float | None, default "500ms" - The duration of each chunk in 's' or 'ms' - chunk_size : int | None - Size of a chunk in number of frames. This is used only if chunk_duration is None. - This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. - concatenated : bool, default: True - If True chunk are concatenated along time axis - seed : int, default: None - Random seed - margin_frames : int, default: 0 - Margin in number of frames to avoid edge effects - - Returns - ------- - chunk_list : np.array - Array of concatenate chunks per segment - - - """ - # TODO: if segment have differents length make another sampling that dependant on the length of the segment - # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY - # And randomize the number of chunk per segment weighted by segment duration - - if method == "full_random": - if chunk_size is None: - if chunk_duration is not None: - chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) - else: - raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") - - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) - rng = np.random.default_rng(seed) - recording_slices = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - # here we set endpoint to True, because the this represents the start of the - # chunk, and should be inclusive - random_starts = rng.integers(low=low, high=high, size=size, endpoint=True) - random_starts = np.sort(random_starts) - recording_slices += [ - (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts - ] - else: - raise ValueError(f"get_random_recording_slices : wrong method {method}") - - return recording_slices - - def get_random_data_chunks( recording, return_scaled=None, return_in_uV=False, concatenated=True, **random_slices_kwargs ): """ Extract random chunks across segments. - Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + Internally, it uses `get_random_sample_slices()` and retrieves the traces chunk as a list or a concatenated unique array. - Please read `get_random_recording_slices()` for more details on parameters. + Please read `get_random_sample_slices()` for more details on parameters. Parameters @@ -569,7 +265,7 @@ def get_random_data_chunks( concatenated : bool, default: True If True chunk are concatenated along time axis **random_slices_kwargs : dict - Options transmited to get_random_recording_slices(), please read documentation from this + Options transmited to get_random_sample_slices(), please read documentation from this function for more details. Returns @@ -586,22 +282,12 @@ def get_random_data_chunks( ) return_in_uV = return_scaled - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) - - chunk_list = [] - for segment_index, start_frame, end_frame in recording_slices: - traces_chunk = recording.get_traces( - start_frame=start_frame, - end_frame=end_frame, - segment_index=segment_index, - return_in_uV=return_in_uV, - ) - chunk_list.append(traces_chunk) - - if concatenated: - return np.concatenate(chunk_list, axis=0) - else: - return chunk_list + return get_chunks( + recording, + concatenated=concatenated, + get_data_kwargs=dict(return_in_uV=return_in_uV), + **random_slices_kwargs, + ) def get_channel_distances(recording): @@ -718,7 +404,7 @@ def get_noise_levels( force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor random_slices_kwargs : dict - Options transmited to get_random_recording_slices(), please read documentation from this + Options transmitted to get_random_sample_slices(), please read documentation from this function for more details. {} @@ -753,7 +439,7 @@ def get_noise_levels( msg = ( "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" - "Please read get_random_recording_slices() documentation for more options." + "Please read get_random_sample_slices() documentation for more options." ) # if the user use both the old and the new behavior then an error is raised assert len(random_slices_kwargs) == 0, msg @@ -762,7 +448,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + slices = get_random_sample_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] @@ -772,7 +458,7 @@ def append_noise_chunk(res): func = _noise_level_chunk init_func = _noise_level_chunk_init init_args = (recording, return_in_uV, method) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, @@ -782,7 +468,7 @@ def append_noise_chunk(res): gather_func=append_noise_chunk, **job_kwargs, ) - executor.run(recording_slices=recording_slices) + executor.run(slices=slices) noise_levels_chunks = np.stack(noise_levels_chunks) noise_levels = np.mean(noise_levels_chunks, axis=0) @@ -795,130 +481,6 @@ def append_noise_chunk(res): get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) -def get_chunk_with_margin( - rec_segment, - start_frame, - end_frame, - channel_indices, - margin, - add_zeros=False, - add_reflect_padding=False, - window_on_margin=False, - dtype=None, -): - """ - Helper to get chunk with margin - - The margin is extracted from the recording when possible. If - at the edge of the recording, no margin is used unless one - of `add_zeros` or `add_reflect_padding` is True. In the first - case zero padding is used, in the second case np.pad is called - with mod="reflect". - """ - length = int(rec_segment.get_num_samples()) - - if channel_indices is None: - channel_indices = slice(None) - - if not (add_zeros or add_reflect_padding): - if window_on_margin and not add_zeros: - raise ValueError("window_on_margin requires add_zeros=True") - - if start_frame is None: - left_margin = 0 - start_frame = 0 - elif start_frame < margin: - left_margin = start_frame - else: - left_margin = margin - - if end_frame is None: - right_margin = 0 - end_frame = length - elif end_frame > (length - margin): - right_margin = length - end_frame - else: - right_margin = margin - - traces_chunk = rec_segment.get_traces( - start_frame - left_margin, - end_frame + right_margin, - channel_indices, - ) - - else: - # either add_zeros or reflect_padding - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = length - - chunk_size = end_frame - start_frame - full_size = chunk_size + 2 * margin - - if start_frame < margin: - start_frame2 = 0 - left_pad = margin - start_frame - else: - start_frame2 = start_frame - margin - left_pad = 0 - - if end_frame > (length - margin): - end_frame2 = length - right_pad = end_frame + margin - length - else: - end_frame2 = end_frame + margin - right_pad = 0 - - traces_chunk = rec_segment.get_traces(start_frame2, end_frame2, channel_indices) - - if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0: - need_copy = True - else: - need_copy = False - - left_margin = margin - right_margin = margin - - if need_copy: - if dtype is None: - dtype = traces_chunk.dtype - - left_margin = margin - if end_frame < (length + margin): - right_margin = margin - else: - right_margin = end_frame + margin - length - - if add_zeros: - traces_chunk2 = np.zeros((full_size, traces_chunk.shape[1]), dtype=dtype) - i0 = left_pad - i1 = left_pad + traces_chunk.shape[0] - traces_chunk2[i0:i1, :] = traces_chunk - if window_on_margin: - # apply inplace taper on border - taper = (1 - np.cos(np.arange(margin) / margin * np.pi)) / 2 - taper = taper[:, np.newaxis] - traces_chunk2[:margin] *= taper - traces_chunk2[-margin:] *= taper[::-1] - # enforce non writable when original was not - # (this help numba to have the same signature and not compile twice) - traces_chunk2.flags.writeable = traces_chunk.flags.writeable - traces_chunk = traces_chunk2 - elif add_reflect_padding: - # in this case, we don't want to taper - traces_chunk = np.pad( - traces_chunk.astype(dtype, copy=False), - [(left_pad, right_pad), (0, 0)], - mode="reflect", - ) - else: - # we need a copy to change the dtype - traces_chunk = np.asarray(traces_chunk, dtype=dtype) - - return traces_chunk, left_margin, right_margin - - def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), flip=False): """ Order channels by depth, by first ordering the x-axis, and then the y-axis. diff --git a/src/spikeinterface/core/tests/test_chunkable_tools.py b/src/spikeinterface/core/tests/test_chunkable_tools.py new file mode 100644 index 0000000000..3d686166e7 --- /dev/null +++ b/src/spikeinterface/core/tests/test_chunkable_tools.py @@ -0,0 +1,174 @@ +import numpy as np + +from spikeinterface.core import generate_recording + +from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor +from spikeinterface.core.generate import NoiseGeneratorRecording + + +from spikeinterface.core.chunkable_tools import ( + write_binary, + write_memory, + get_random_sample_slices, + get_chunks, +) + + +def test_write_binary(tmp_path): + # Test write_binary() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=1) + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_offset(tmp_path): + # Test write_binary() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=1) + byte_offset = 125 + write_binary(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + file_offset=byte_offset, + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_parallel(tmp_path): + # Test write_binary() with parallel processing (n_jobs=2) + + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + dtype=dtype, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_binary_multiple_segment(tmp_path): + # Test write_binary() with multiple segments (n_jobs=2) + # Setup + sampling_frequency = 30_000 + num_channels = 10 + dtype = "float32" + + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_memory_recording(): + # 2 segments + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + recording = recording.save() + + # write with loop + traces_list, shms = write_memory(recording, dtype=None, verbose=True, n_jobs=1) + + traces_list, shms = write_memory( + recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True + ) + + # write parallel + traces_list, shms = write_memory(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + # need to clean the buffer + del traces_list + for shm in shms: + shm.unlink() + + +def test_get_random_sample_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_sample_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + +def test_get_chunks(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + chunks = get_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) + assert chunks.shape == (50000, 1) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7a2accc887..7e635922db 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -9,10 +9,10 @@ divide_segment_into_chunks, ensure_n_jobs, ensure_chunk_size, - ChunkRecordingExecutor, + ChunkExecutor, fix_job_kwargs, split_job_kwargs, - divide_recording_into_chunks, + divide_chunkable_into_chunks, ) @@ -71,7 +71,7 @@ def test_ensure_chunk_size(): # Test edge case to define single chunk for n_jobs=1 chunk_size = ensure_chunk_size(recording, n_jobs=1, chunk_size=None) - chunks = divide_recording_into_chunks(recording, chunk_size) + chunks = divide_chunkable_into_chunks(recording, chunk_size) assert len(chunks) == recording.get_num_segments() for chunk in chunks: segment_index, start_frame, end_frame = chunk @@ -96,13 +96,13 @@ def init_func(arg1, arg2, arg3): return worker_dict -def test_ChunkRecordingExecutor(): +def test_ChunkExecutor(): recording = generate_recording(num_channels=2) init_args = "a", 120, "yep" # no chunk - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_size=None ) processor.run() @@ -113,7 +113,7 @@ def gathering_result(res): pass # chunk + loop + gather_func - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -139,7 +139,7 @@ def __call__(self, res): gathering_func2 = GatherClass() # process + gather_func - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -153,12 +153,12 @@ def __call__(self, res): job_name="job_name", ) processor.run() - num_chunks = len(divide_recording_into_chunks(recording, processor.chunk_size)) + num_chunks = len(divide_chunkable_into_chunks(recording, processor.chunk_size)) assert gathering_func2.pos == num_chunks # process spawn - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -174,7 +174,7 @@ def __call__(self, res): processor.run() # thread - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -258,7 +258,7 @@ def test_worker_index(): # making this 2 times ensure to test that global variables are correctly reset for pool_engine in ("process", "thread"): # print(pool_engine) - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func2, init_func2, @@ -322,7 +322,7 @@ def test_get_best_job_kwargs(): # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - # test_ChunkRecordingExecutor() + # test_ChunkExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() test_worker_index() diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4f8e600a3f..90e430a0db 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -5,7 +5,7 @@ from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype -from spikeinterface.core.job_tools import divide_recording_into_chunks +from spikeinterface.core.job_tools import divide_chunkable_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -27,12 +27,12 @@ def __init__(self, recording, parents=None, return_output=True, param0=5.5): def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, chunk, peaks): amps = np.zeros(peaks.size, dtype=self._dtype) amps["abs_amplitude"] = np.abs(peaks["amplitude"]) return amps - def get_trace_margin(self): + def get_data_margin(self): return 5 @@ -44,7 +44,7 @@ def __init__(self, recording, return_output=True, parents=None): def get_dtype(self): return np.dtype("float32") - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): kernel = np.array([0.1, 0.8, 0.1]) denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=waveforms) return denoised_waveforms @@ -57,7 +57,7 @@ def __init__(self, recording, return_output=True, parents=None): def get_dtype(self): return np.dtype("float32") - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): rms_by_channels = np.sum(waveforms**2, axis=1) return rms_by_channels @@ -220,11 +220,9 @@ def test_skip_after_n_peaks_and_recording_slices(): assert some_amplitudes.size < spikes.size # slices : 1 every 4 - recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = divide_chunkable_into_chunks(recording, 10_000) recording_slices = recording_slices[::4] - some_amplitudes = run_node_pipeline( - recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices - ) + some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", slices=recording_slices) tolerance = 1.2 assert some_amplitudes.size < (spikes.size // 4) * tolerance diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 405a2ecccf..7a327ea44f 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -11,7 +11,6 @@ from spikeinterface.core.recording_tools import ( write_binary_recording, write_memory_recording, - get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -168,17 +167,6 @@ def test_write_memory_recording(): shm.unlink() -def test_get_random_recording_slices(): - rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - rec_slices = get_random_recording_slices( - rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 - ) - assert len(rec_slices) == 40 - for seg_ind, start, stop in rec_slices: - assert stop - start == 500 - assert seg_ind in (0, 1) - - def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -366,9 +354,8 @@ def test_do_recording_attributes_match(): # test_write_binary_recording(tmp_path) # test_write_memory_recording() - test_get_random_recording_slices() # test_get_random_data_chunks() # test_get_closest_channels() # test_get_noise_levels() # test_get_noise_levels_output() - # test_order_channels_by_depth() + test_order_channels_by_depth() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 58aac7faf2..f69e640efa 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -17,7 +17,7 @@ from spikeinterface.core.baserecording import BaseRecording from .baserecording import BaseRecording -from .job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc +from .job_tools import ChunkExecutor, _shared_job_kwargs_doc from .core_tools import make_shared_array from .job_tools import fix_job_kwargs @@ -294,16 +294,14 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs - ) + processor = ChunkExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) processor.run() distribute_waveforms_to_buffers.__doc__ = distribute_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_in_uV, inds_by_unit, mode, sparsity_mask ): @@ -350,7 +348,7 @@ def _init_worker_distribute_buffers( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] @@ -563,7 +561,7 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs ) processor.run() @@ -620,7 +618,7 @@ def _init_worker_distribute_single_buffer( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] @@ -948,7 +946,7 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name=job_name, verbose=verbose, need_worker_index=True, **job_kwargs ) processor.run() @@ -1035,7 +1033,7 @@ def _init_worker_estimate_templates( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1ef5d76e5a..a51063af3e 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -569,7 +569,7 @@ def add_traces_to_zarr( from .job_tools import ( ensure_chunk_size, fix_job_kwargs, - ChunkRecordingExecutor, + ChunkExecutor, ) assert dataset_paths is not None, "Provide 'file_path'" @@ -606,13 +606,13 @@ def add_traces_to_zarr( func = _write_zarr_chunk init_func = _init_zarr_worker init_args = (recording, zarr_datasets, dtype) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs ) executor.run() -# used by write_zarr_recording + ChunkRecordingExecutor +# used by write_zarr_recording + ChunkExecutor def _init_zarr_worker(recording, zarr_datasets, dtype): import zarr @@ -625,7 +625,7 @@ def _init_zarr_worker(recording, zarr_datasets, dtype): return worker_ctx -# used by write_zarr_recording + ChunkRecordingExecutor +# used by write_zarr_recording + ChunkExecutor def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): import gc diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 8f18536daa..faef0d9560 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -7,7 +7,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks -from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc +from spikeinterface.core.job_tools import fix_job_kwargs, ChunkExecutor, _shared_job_kwargs_doc from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -258,7 +258,7 @@ def compute_rms( func = _compute_rms_chunk init_func = _init_rms_worker init_args = (recording,) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 310be8cceb..b78e60e94e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -230,9 +230,10 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, chunk, peaks): from scipy.stats import linregress + traces = chunk # scale traces with margin to match scaling of templates if self._gains is not None: traces = traces.astype("float32") * self._gains + self._offsets @@ -330,7 +331,7 @@ def compute(self, traces, peaks): # TODO: switch to collision mask and return that (to use concatenation) return (scalings, spike_collision_mask) - def get_trace_margin(self): + def get_data_margin(self): return self._margin diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index bb48a08e64..bce6c8e6a4 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -11,7 +11,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs +from spikeinterface.core.job_tools import ChunkExecutor, _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.core.analyzer_extension_core import _inplace_sparse_realign_waveforms @@ -412,7 +412,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs ) processor.run() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 0495e2c56e..c837e52f58 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -90,11 +90,13 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, chunk, peaks): sample_indices = peaks["sample_index"].copy() unit_index = peaks["unit_index"] chan_inds = peaks["channel_index"] + traces = chunk + # apply shifts per spike sample_indices += self._peak_shifts[unit_index] @@ -110,5 +112,5 @@ def compute(self, traces, peaks): return amplitudes - def get_trace_margin(self): + def get_data_margin(self): return self._margin diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 92b07b8f35..d40e043e20 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -114,7 +114,7 @@ def __init__( else: self.diff_threshold_unscaled = None - def get_trace_margin(self) -> int: + def get_data_margin(self) -> int: """Return the number of margin samples required on each side of a chunk.""" return 0 @@ -326,7 +326,7 @@ def __init__( # internal dtype self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) - def get_trace_margin(self) -> int: + def get_data_margin(self) -> int: """Return the number of margin samples required on each side of a chunk.""" return 0 diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index b262d36dab..b091734c1c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -26,7 +26,7 @@ def find_clusters_from_peaks( verbose : Bool, default: False If True, output is verbose job_kwargs : dict - Parameters for ChunkRecordingExecutor + Parameters for ChunkExecutor {method_doc} diff --git a/src/spikeinterface/sortingcomponents/matching/base.py b/src/spikeinterface/sortingcomponents/matching/base.py index 88a6522148..483d6f867d 100644 --- a/src/spikeinterface/sortingcomponents/matching/base.py +++ b/src/spikeinterface/sortingcomponents/matching/base.py @@ -20,21 +20,22 @@ def __init__(self, recording, templates, return_output=True): templates, Templates ), f"The templates supplied is of type {type(templates)} and must be a Templates" self.templates = templates + self.recording = recording PeakDetector.__init__(self, recording, return_output=return_output, parents=None) def get_dtype(self): return np.dtype(_base_matching_dtype) - def get_trace_margin(self): + def get_data_margin(self): raise NotImplementedError - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - spikes = self.compute_matching(traces, start_frame, end_frame, segment_index) + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + spikes = self.compute_matching(chunk, start_frame, end_frame, segment_index) spikes["segment_index"] = segment_index - margin = self.get_trace_margin() + margin = self.get_data_margin() if margin > 0 and spikes.size > 0: - keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) + keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (chunk.shape[0] - margin)) spikes = spikes[keep] # node pipeline need to return a tuple diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8a19ad458b..96c1d725ba 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -322,7 +322,7 @@ def get_extra_outputs(self): output[key] = getattr(self, key) return output - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): @@ -709,7 +709,7 @@ def _prepare_templates(self): self.circus_templates = templates_array - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 116aff416e..2ec8b0a0fa 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -39,7 +39,7 @@ def find_spikes_from_templates( verbose : Bool, default: False If True, output is verbose job_kwargs : dict - Parameters for ChunkRecordingExecutor + Parameters for ChunkExecutor {method_doc} diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3e0eb0b632..15b5327219 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -94,7 +94,7 @@ def __init__( self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i]) self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i]) - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): @@ -191,7 +191,7 @@ def __init__( projected_temporal_templates = self.svd_model.transform(temporal_templates) self.svd_templates = from_temporal_representation(projected_temporal_templates, self.num_channels) - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 947eaf391f..f870fed243 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -318,12 +318,12 @@ def __init__( # noise_levels=None, ) - self.detector_margin0 = self.fast_spike_detector.get_trace_margin() - self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 + self.detector_margin0 = self.fast_spike_detector.get_data_margin() + self.detector_margin1 = self.fine_spike_detector.get_data_margin() if use_fine_detector else 0 self.peeler_margin = max(self.nbefore, self.nafter) * 2 self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): @@ -505,7 +505,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le peak_detector = self.fast_spike_detector # print('peak_detector', peak_detector) - detector_margin = peak_detector.get_trace_margin() + detector_margin = peak_detector.get_data_margin() if self.peeler_margin > detector_margin: margin_shift = self.peeler_margin - detector_margin diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 5245f3230d..9e5402b399 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -492,7 +492,7 @@ def _push_to_torch(self): self.template_data.compressed_templates = (temporal, singular, spatial, temporal_jittered) self.is_pushed = True - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 7c4c4b166e..a433eeb643 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -223,7 +223,7 @@ def interpolate_motion_on_traces( # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing - # in ChunkRecordingExecutor) + # in ChunkExecutor) np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index diff --git a/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py b/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py index 732ada21bc..37ccffad3b 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py @@ -56,11 +56,11 @@ def __init__( self.peak_sign = peak_sign self.detect_threshold = detect_threshold - def get_trace_margin(self): + def get_data_margin(self): return self.exclude_sweep_size - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + traces = chunk traces_center = traces[self.exclude_sweep_size : -self.exclude_sweep_size, :] length = traces_center.shape[0] @@ -162,7 +162,8 @@ def __init__( self.device = device self.return_tensor = return_tensor - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + traces = chunk peak_sample_ind, peak_chan_ind, peak_amplitude = _torch_detect_peaks( traces, self.peak_sign, self.abs_thresholds, self.exclude_sweep_size, None, self.device ) diff --git a/src/spikeinterface/sortingcomponents/peak_detection/iterative.py b/src/spikeinterface/sortingcomponents/peak_detection/iterative.py index 63bce6d921..4547319934 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/iterative.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/iterative.py @@ -56,7 +56,7 @@ def __init__( self.num_iterations = num_iterations self.tresholds = tresholds - def get_trace_margin(self) -> int: + def get_data_margin(self) -> int: """ Calculate the maximum trace margin from the internal pipeline. Using the strategy as use by the Node pipeline @@ -68,10 +68,10 @@ def get_trace_margin(self) -> int: The maximum trace margin. """ internal_pipeline = (self.peak_detector_node, self.waveform_extraction_node, self.waveform_denoising_node) - pipeline_margin = (node.get_trace_margin() for node in internal_pipeline if hasattr(node, "get_trace_margin")) + pipeline_margin = [node.get_data_margin() for node in internal_pipeline] return max(pipeline_margin) - def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margin) -> Tuple[np.ndarray, np.ndarray]: + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin) -> Tuple[np.ndarray, np.ndarray]: """ Perform the iterative peak detection, waveform extraction, and denoising. @@ -94,7 +94,7 @@ def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margi A tuple containing a single ndarray with the detected peaks. """ - traces_chunk = np.array(traces_chunk, copy=True, dtype="float32") + traces_chunk = np.array(chunk, copy=True, dtype="float32") local_peaks_list = [] all_waveforms = [] @@ -110,7 +110,7 @@ def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margi ) (local_peaks,) = self.peak_detector_node.compute( - traces=traces_chunk, + traces_chunk, start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, @@ -124,9 +124,9 @@ def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margi if local_peaks.size == 0: break - waveforms = self.waveform_extraction_node.compute(traces=traces_chunk, peaks=local_peaks) + waveforms = self.waveform_extraction_node.compute(traces_chunk, peaks=local_peaks) denoised_waveforms = self.waveform_denoising_node.compute( - traces=traces_chunk, peaks=local_peaks, waveforms=waveforms + traces_chunk, peaks=local_peaks, waveforms=waveforms ) self.substract_waveforms_from_traces( diff --git a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py index e0c8a29cfd..ff075092e6 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py @@ -64,6 +64,7 @@ def __init__( assert peak_sign in ("both", "neg", "pos") assert noise_levels is not None + self.recording = recording self.noise_levels = noise_levels self.abs_thresholds = self.noise_levels * detect_threshold @@ -83,13 +84,13 @@ def __init__( self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= radius_um - def get_trace_margin(self): + def get_data_margin(self): # the +1 in the border is important because we need peak in the border return self.exclude_sweep_size + 1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" - + traces = chunk peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( traces, self.peak_sign, self.abs_thresholds, self.exclude_sweep_size, self.neighbours_mask ) @@ -238,12 +239,14 @@ def __init__( for i, neigh in enumerate(self.neighbour_indices_by_chan): self.neighbours_idxs[i, : len(neigh)] = neigh - def get_trace_margin(self): + def get_data_margin(self): return self.exclude_sweep_size - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): from .by_channel import _torch_detect_peaks + traces = chunk + peak_sample_ind, peak_chan_ind, peak_amplitude = _torch_detect_peaks( traces, self.peak_sign, self.abs_thresholds, self.exclude_sweep_size, self.neighbours_idxs, self.device ) @@ -291,7 +294,9 @@ def __init__( self.abs_thresholds, self.exclude_sweep_size, self.neighbours_mask, self.peak_sign, **opencl_context_kwargs ) - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + traces = chunk + peak_sample_ind, peak_chan_ind = self.executor.detect_peak(traces) peak_sample_ind += self.exclude_sweep_size peak_amplitude = traces[peak_sample_ind, peak_chan_ind] diff --git a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py index 509c3f76f8..e7d9081929 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py @@ -98,12 +98,12 @@ def __init__( def get_dtype(self): return self._dtype - def get_trace_margin(self): + def get_data_margin(self): return self.exclude_sweep_size + self.conv_margin + 1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" + traces = chunk conv_traces = self.get_convolved_traces(traces) conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1]) diff --git a/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py index e0d2a49b9d..1a5afa9612 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py @@ -152,7 +152,7 @@ def test_iterative_peak_detection(recording, job_kwargs, pca_model_folder_path, return_output=(True, True), ) - peaks, waveforms = run_node_pipeline(recording=recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) + peaks, waveforms = run_node_pipeline(recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) # Assert there is a field call iteration in structured array peaks assert "iteration" in peaks.dtype.names assert peaks.shape[0] == waveforms.shape[0] @@ -197,7 +197,7 @@ def test_iterative_peak_detection_sparse(recording, job_kwargs, pca_model_folder return_output=(True, True), ) - peaks, waveforms = run_node_pipeline(recording=recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) + peaks, waveforms = run_node_pipeline(recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) # Assert there is a field call iteration in structured array peaks assert "iteration" in peaks.dtype.names assert peaks.shape[0] == waveforms.shape[0] @@ -239,7 +239,7 @@ def test_iterative_peak_detection_thresholds(recording, job_kwargs, pca_model_fo tresholds=tresholds, ) - peaks, waveforms = run_node_pipeline(recording=recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) + peaks, waveforms = run_node_pipeline(recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) # Assert there is a field call iteration in structured array peaks assert "iteration" in peaks.dtype.names assert peaks.shape[0] == waveforms.shape[0] @@ -435,15 +435,15 @@ def test_peak_sign_consistency(recording, job_kwargs, detection_class): kwargs["peak_sign"] = "neg" peak_detection_node = detection_class(**kwargs) - negative_peaks = run_node_pipeline(recording=recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) + negative_peaks = run_node_pipeline(recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) kwargs["peak_sign"] = "pos" peak_detection_node = detection_class(**kwargs) - positive_peaks = run_node_pipeline(recording=recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) + positive_peaks = run_node_pipeline(recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) kwargs["peak_sign"] = "both" peak_detection_node = detection_class(**kwargs) - all_peaks = run_node_pipeline(recording=recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) + all_peaks = run_node_pipeline(recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) # To account for exclusion of positive peaks that are to close to negative peaks. # This should be excluded by the detection method when is exclusive so using peak_sign="both" should diff --git a/src/spikeinterface/sortingcomponents/peak_localization/base.py b/src/spikeinterface/sortingcomponents/peak_localization/base.py index 17853c85aa..5e6a16fe4c 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/base.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/base.py @@ -9,7 +9,7 @@ class LocalizeBase(PipelineNode): def __init__(self, recording, parents, return_output=True, radius_um=75.0): PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) - + self.recording = recording self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) diff --git a/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py b/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py index 9f868c3cd7..ffa97c144f 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py @@ -42,7 +42,7 @@ def __init__(self, recording, parents, return_output=True, radius_um=75.0, featu self.nbefore = waveform_extractor.nbefore self._kwargs.update(dict(feature=feature)) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): peak_locations = np.zeros(peaks.size, dtype=self._dtype) for main_chan in np.unique(peaks["channel_index"]): diff --git a/src/spikeinterface/sortingcomponents/peak_localization/grid.py b/src/spikeinterface/sortingcomponents/peak_localization/grid.py index e39773d66e..43d1324a9e 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/grid.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/grid.py @@ -61,7 +61,7 @@ def __init__( peak_sign="neg", weight_method={}, ): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) + LocalizeBase.__init__(self, recording, return_output=return_output, parents=parents) self.radius_um = radius_um self.margin_um = margin_um @@ -120,7 +120,7 @@ def __init__( ) ) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): peak_locations = np.zeros(peaks.size, dtype=self._dtype) nb_weights = self.weights.shape[0] diff --git a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py index 8840a5a00d..d9b75e265a 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py @@ -80,7 +80,7 @@ def __init__( self._dtype = np.dtype(dtype_localize_by_method["monopolar_triangulation"]) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): peak_locations = np.zeros(peaks.size, dtype=self._dtype) for i, peak in enumerate(peaks): diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 73a14bdee7..67850f241f 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -191,23 +191,15 @@ def get_prototype_and_waveforms_from_recording( nodes = [node0, node1] - recording_slices = get_shuffled_recording_slices(recording, job_kwargs=job_kwargs, seed=seed) - # res = detect_peaks( - # recording, - # pipeline_nodes=pipeline_nodes, - # skip_after_n_peaks=n_peaks, - # recording_slices=recording_slices, - # method="locally_exclusive", - # method_kwargs=detection_kwargs, - # job_kwargs=job_kwargs, - # ) + slices = get_shuffled_recording_slices(recording, job_kwargs=job_kwargs, seed=seed) + res = run_node_pipeline( recording, nodes, job_kwargs, job_name="get protoype waveforms", skip_after_n_peaks=n_peaks, - recording_slices=recording_slices, + slices=slices, ) rng = np.random.default_rng(seed) diff --git a/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py b/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py index 1695f2cc59..ab546f7460 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py @@ -97,7 +97,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): if self.all_channels: if self.peak_sign == "neg": amplitudes = np.min(waveforms, axis=1) @@ -131,7 +131,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): if self.all_channels: all_ptps = np.ptp(waveforms, axis=1) else: @@ -182,7 +182,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) for main_chan in np.unique(peaks["channel_index"]): diff --git a/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py index c6d1070e6d..ecb2edba39 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py +++ b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py @@ -42,6 +42,6 @@ def __init__( self.hanning = hanning[:, None] self._kwargs.update(dict()) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): denoised_waveforms = waveforms * self.hanning return denoised_waveforms diff --git a/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py index d094bae3e0..257c35f860 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py @@ -91,7 +91,7 @@ def load_model(self): return denoiser - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): num_channels = waveforms.shape[2] # Collapse channels and transform to torch tensor diff --git a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py index 1ed9e4bffa..b9cb6ec1ab 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py @@ -49,7 +49,7 @@ def __init__( self.order = min(self.order, self.window_length - 1) self._kwargs.update(dict(order=order, window_length_ms=window_length_ms)) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): # Denoise import scipy.signal diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 0170038c96..9b642126b6 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -216,7 +216,7 @@ def __init__( self.n_components = self.pca_model.n_components self.dtype = np.dtype(dtype) - def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: + def compute(self, chunk: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: """ Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. @@ -285,7 +285,7 @@ def __init__( model_folder_path=model_folder_path, ) - def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: + def compute(self, chunk: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: """ Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. @@ -374,7 +374,7 @@ def __init__( # this is the final sparse channel count self.out_num_channels = max(np.sum(self.final_sparsity_mask, axis=1)) - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peaks, waveforms) -> np.ndarray: + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, peaks, waveforms) -> np.ndarray: """ Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py index 286c103d22..dc01ab3431 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py @@ -179,7 +179,7 @@ def test_pca_projection_sparsity(generated_recording, detected_peaks, model_path def test_initialization_with_wrong_parents_failure(generated_recording, model_path_of_trained_pca): recording = generated_recording model_folder_path = model_path_of_trained_pca - dummy_parent = PipelineNode(recording=recording) + dummy_parent = PipelineNode(recording) extract_waveforms = ExtractSparseWaveforms( recording=recording, ms_before=1, ms_after=1, radius_um=40, return_output=True ) diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index ec223d0047..2191c6dd2a 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -74,7 +74,7 @@ def __init__( dict(feature=feature, threshold=threshold, operator=operator, noise_levels=self.noise_levels) ) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): if self.feature == "ptp": wf_data = np.ptp(waveforms, axis=1) / self.noise_levels elif self.feature == "mean":