diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1470473fe2..4fa68ebec0 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1072,6 +1072,22 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol return spike_train +def synthesize_amplitude_factor( + num_spikes: int, + amplitude_factor: np.ndarray | None = None, + amplitude_std: float | None = None, + seed: np.random.Generator | int | None = None, +): + if amplitude_factor is not None: + assert amplitude_factor.shape == (num_spikes,) + return amplitude_factor + elif amplitude_std: + rng = np.random.default_rng(seed) + return rng.normal(loc=1, scale=amplitude_std, size=num_spikes) + else: + return None + + from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 5529c56a4f..f15d290864 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -18,6 +18,7 @@ generate_unit_locations, generate_sorting, generate_templates, + synthesize_amplitude_factor, _ensure_unit_params, _ensure_seed, ) @@ -366,6 +367,8 @@ def generate_drifting_recording( generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0), noise=None, generate_noise_kwargs=dict(noise_levels=(6.0, 8.0), spatial_decay=25.0), + amplitude_std: float | None = None, + amplitude_factor: np.ndarray | None = None, extra_outputs=False, seed=None, ): @@ -405,6 +408,11 @@ def generate_drifting_recording( Noise generator used to generate background noise generate_noise_kwargs : dict Parameters given to generate_noise() if no noise is None + amplitude_std : float, default: 0.05 + The standard deviation of the modulation to apply to the spikes when injecting them + into the recording. + amplitude_factor: np.ndarray, optional + Optional fixed per-spike amplitude modulation extra_outputs : bool, default False Return optionaly a dict with more variables. seed : None ot int @@ -559,6 +567,13 @@ def generate_drifting_recording( assert noise.probe.get_contact_count() == probe.get_contact_count(), "Noise num channels mismatch" assert noise.get_total_duration() == duration, "Noise duration should be the same as the recording duration" + amplitude_factor = synthesize_amplitude_factor( + num_spikes=sorting.count_total_num_spikes(), + amplitude_factor=amplitude_factor, + amplitude_std=amplitude_std, + seed=seed, + ) + static_recording = InjectDriftingTemplatesRecording( sorting=sorting, parent_recording=noise, @@ -567,7 +582,7 @@ def generate_drifting_recording( displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=np.zeros_like(displacement_unit_factor), num_samples=[int(duration * sampling_frequency)], - amplitude_factor=None, + amplitude_factor=amplitude_factor, ) drifting_recording = InjectDriftingTemplatesRecording( @@ -578,7 +593,7 @@ def generate_drifting_recording( displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, num_samples=[int(duration * sampling_frequency)], - amplitude_factor=None, + amplitude_factor=amplitude_factor, ) if extra_outputs: diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 2476bc6336..244575aff2 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -10,6 +10,7 @@ generate_sorting, InjectTemplatesRecording, _ensure_seed, + synthesize_amplitude_factor, ) from spikeinterface.core.template_tools import get_template_extremum_channel @@ -327,6 +328,7 @@ def generate_hybrid_recording( upsample_factor: int | None = None, upsample_vector: np.ndarray | None = None, amplitude_std: float = 0.05, + amplitude_factor: np.ndarray | None = None, generate_sorting_kwargs: dict = dict(num_units=10, firing_rates=15, refractory_period_ms=4.0, seed=2205), generate_unit_locations_kwargs: dict = dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs: dict = dict(ms_before=1.0, ms_after=3.0), @@ -499,10 +501,12 @@ def generate_hybrid_recording( upsample_factor = templates_array.shape[3] upsample_vector = rng.integers(0, upsample_factor, size=num_spikes) - if amplitude_std is not None: - amplitude_factor = rng.normal(loc=1, scale=amplitude_std, size=num_spikes) - else: - amplitude_factor = None + amplitude_factor = synthesize_amplitude_factor( + num_spikes, + amplitude_factor=amplitude_factor, + amplitude_std=amplitude_std, + seed=rng, + ) if motion is not None: assert num_segments == motion.num_segments, "recording and motion should have the same number of segments"