Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,22 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol
return spike_train


def synthesize_amplitude_factor(
num_spikes: int,
amplitude_factor: np.ndarray | None = None,
amplitude_std: float | None = None,
seed: np.random.Generator | int | None = None,
):
if amplitude_factor is not None:
assert amplitude_factor.shape == (num_spikes,)
return amplitude_factor
elif amplitude_std:
rng = np.random.default_rng(seed)
return rng.normal(loc=1, scale=amplitude_std, size=num_spikes)
else:
return None


from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting


Expand Down
19 changes: 17 additions & 2 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
generate_unit_locations,
generate_sorting,
generate_templates,
synthesize_amplitude_factor,
_ensure_unit_params,
_ensure_seed,
)
Expand Down Expand Up @@ -366,6 +367,8 @@ def generate_drifting_recording(
generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0),
noise=None,
generate_noise_kwargs=dict(noise_levels=(6.0, 8.0), spatial_decay=25.0),
amplitude_std: float | None = None,
amplitude_factor: np.ndarray | None = None,
extra_outputs=False,
seed=None,
):
Expand Down Expand Up @@ -405,6 +408,11 @@ def generate_drifting_recording(
Noise generator used to generate background noise
generate_noise_kwargs : dict
Parameters given to generate_noise() if no noise is None
amplitude_std : float, default: 0.05
The standard deviation of the modulation to apply to the spikes when injecting them
into the recording.
amplitude_factor: np.ndarray, optional
Optional fixed per-spike amplitude modulation
extra_outputs : bool, default False
Return optionaly a dict with more variables.
seed : None ot int
Expand Down Expand Up @@ -559,6 +567,13 @@ def generate_drifting_recording(
assert noise.probe.get_contact_count() == probe.get_contact_count(), "Noise num channels mismatch"
assert noise.get_total_duration() == duration, "Noise duration should be the same as the recording duration"

amplitude_factor = synthesize_amplitude_factor(
num_spikes=sorting.count_total_num_spikes(),
amplitude_factor=amplitude_factor,
amplitude_std=amplitude_std,
seed=seed,
)

static_recording = InjectDriftingTemplatesRecording(
sorting=sorting,
parent_recording=noise,
Expand All @@ -567,7 +582,7 @@ def generate_drifting_recording(
displacement_sampling_frequency=displacement_sampling_frequency,
displacement_unit_factor=np.zeros_like(displacement_unit_factor),
num_samples=[int(duration * sampling_frequency)],
amplitude_factor=None,
amplitude_factor=amplitude_factor,
)

drifting_recording = InjectDriftingTemplatesRecording(
Expand All @@ -578,7 +593,7 @@ def generate_drifting_recording(
displacement_sampling_frequency=displacement_sampling_frequency,
displacement_unit_factor=displacement_unit_factor,
num_samples=[int(duration * sampling_frequency)],
amplitude_factor=None,
amplitude_factor=amplitude_factor,
)

if extra_outputs:
Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
generate_sorting,
InjectTemplatesRecording,
_ensure_seed,
synthesize_amplitude_factor,
)
from spikeinterface.core.template_tools import get_template_extremum_channel

Expand Down Expand Up @@ -327,6 +328,7 @@ def generate_hybrid_recording(
upsample_factor: int | None = None,
upsample_vector: np.ndarray | None = None,
amplitude_std: float = 0.05,
amplitude_factor: np.ndarray | None = None,
generate_sorting_kwargs: dict = dict(num_units=10, firing_rates=15, refractory_period_ms=4.0, seed=2205),
generate_unit_locations_kwargs: dict = dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs: dict = dict(ms_before=1.0, ms_after=3.0),
Expand Down Expand Up @@ -499,10 +501,12 @@ def generate_hybrid_recording(
upsample_factor = templates_array.shape[3]
upsample_vector = rng.integers(0, upsample_factor, size=num_spikes)

if amplitude_std is not None:
amplitude_factor = rng.normal(loc=1, scale=amplitude_std, size=num_spikes)
else:
amplitude_factor = None
amplitude_factor = synthesize_amplitude_factor(
num_spikes,
amplitude_factor=amplitude_factor,
amplitude_std=amplitude_std,
seed=rng,
)

if motion is not None:
assert num_segments == motion.num_segments, "recording and motion should have the same number of segments"
Expand Down
Loading