diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index 286c5787..22a2a298 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -758,77 +758,102 @@ def read_spikegadgets(file: str | Path, raise_error: bool = True) -> ProbeGroup: # phase, gain, on-shank reference, shank thickness). So hardcoding NP1000 # produces correct geometry; `model_name` and `description` are cleared on # the sliced probe to avoid claiming a specific variant. - PART_NUMBER = "NP1000" header_txt = parse_spikegadgets_header(file) root = ElementTree.fromstring(header_txt) hconf = root.find("HardwareConfiguration") sconf = root.find("SpikeConfiguration") - probe_configs = [d for d in hconf if d.attrib.get("name") == "NeuroPixels1"] + # Detect devices present in the header + probe_configs = [d for d in hconf if d.attrib.get("name") in ["NeuroPixels1", "NeuroPixels2"]] n_probes = len(probe_configs) if n_probes == 0: if raise_error: - raise Exception("No Neuropixels 1.0 probes found") + raise Exception("No supported Neuropixels probes found") return None - # NeuroPixels1 SourceOptions blocks carry the per-probe AP/LF gain settings. - # They appear in the same order as the SpikeNTrode probe digits (1, 2, 3). - source_options_blocks = [s for s in hconf.findall("SourceOptions") if s.attrib.get("name") == "NeuroPixels1"] - probe_group = ProbeGroup() - for curr_probe in range(1, n_probes + 1): - # SpikeNTrode elements are the authoritative list of recorded electrodes. - # Each id is "<1-based electrode number>" for up to 960 - # electrodes on NP1.0; the catalogue uses 0-based indices, so - # catalogue_index = electrode_number - 1. The probe number is assumed - # to be a single digit (1, 2, or 3), matching the documented - # SpikeGadgets limit of three simultaneous Neuropixels probes. - electrode_to_hwchan = {} - for ntrode in sconf: - electrode_id = ntrode.attrib["id"] - if int(electrode_id[0]) == curr_probe: - catalogue_index = int(electrode_id[1:]) - 1 - hw_chan = int(ntrode[0].attrib["hwChan"]) - electrode_to_hwchan[catalogue_index] = hw_chan + for curr_probe_idx, probe_config in enumerate(probe_configs): - active_indices = np.array(sorted(electrode_to_hwchan.keys())) + device_name = probe_config.attrib["name"] - full_probe = build_neuropixels_probe(PART_NUMBER) - probe = full_probe.get_slice(active_indices) + # 1. Collect all used probeColumns for this probe index, this is needed to understand how many shanks are present + curr_probe = curr_probe_idx + 1 + used_columns = set() + for ntrode in sconf: + if int(ntrode.attrib["id"][0]) == curr_probe: + # Assuming SpikeChannel follows the structure where probeColumn is defined + for channel in ntrode.findall("SpikeChannel"): + used_columns.add(int(channel.attrib["probeColumn"])) + + # 2. Determine part number based on shank count + if device_name == "NeuroPixels1": + part_number = "NP1000" + elif device_name == "NeuroPixels2": + # NP2.0: 1 shank = columns 0-1; 4 shanks = columns 0-7 + num_shanks = 4 if max(used_columns) == 7 else 1 + part_number = "NP2000" if num_shanks == 1 else "NP2010" + + channel_data = [] + for ntrode in sconf: + electrode_id = ntrode.attrib["id"] + if int(ntrode.attrib["id"][0]) == curr_probe: + chan_data = ntrode[0].attrib + channel_data.append( + { + "hw": int(chan_data["hwChan"]), + "col": int(chan_data["probeColumn"]), + "ap": int(chan_data["coord_ap"]), + "ml": int(chan_data["coord_ml"]), + "dv": int(chan_data["coord_dv"]), + "probe_n": int(electrode_id[0]), + "channel": int(electrode_id[1:]) - 1, # trodes channels start at 1 not 0 + } + ) + + # 2. Extract indices + + device_channels = np.array([c["channel"] for c in channel_data]) - 1 # channel ids start at 1 + active_channels = np.array([c["hw"] for c in channel_data]) + + full_probe = build_neuropixels_probe(part_number) + + contact_positions = full_probe.contact_positions # shape (n_contacts, 2) + ml = contact_positions[:, 0] + dv = contact_positions[:, 1] + + sorted_order = np.lexsort((-ml, dv)) + device_channels_indexes = sorted_order[ + device_channels + ] # the ids in trodes are assigned according to whats seen in trodes, id 0 is tip of the probe (min dv) and max ml. + + probe = full_probe.get_slice(device_channels_indexes) + probe.set_device_channel_indices(active_channels) - # Clear part-number-specific metadata since we don't know the actual part number. probe.model_name = "" probe.description = "" - device_channels = np.array([electrode_to_hwchan[idx] for idx in active_indices]) - probe.set_device_channel_indices(device_channels) - - # Per-contact ADC group and sample order from the catalogue MUX table plus - # the hwChan mapping (which is the readout-channel index for each contact). + # Annotate ADC info adc_sampling_table = probe.annotations.get("adc_sampling_table") - _annotate_probe_with_adc_sampling_info(probe, adc_sampling_table) + if adc_sampling_table is not None: + _annotate_probe_with_adc_sampling_info(probe, adc_sampling_table) - # NP1.0 gain is programmable. Read APGainMode and LFPGainMode from the - # SourceOptions block matching this probe (blocks appear in probe order). - if "ap_gain" not in probe.annotations and curr_probe - 1 < len(source_options_blocks): + # Handle gain settings dynamically + source_options_blocks = [s for s in hconf.findall("SourceOptions") if s.attrib.get("name") == device_name] + if curr_probe_idx < len(source_options_blocks): custom_options = { opt.attrib["name"]: opt.attrib["data"].strip() - for opt in source_options_blocks[curr_probe - 1].findall("CustomOption") + for opt in source_options_blocks[curr_probe_idx].findall("CustomOption") } - ap_gain_str = custom_options.get("APGainMode") - if ap_gain_str: - probe.annotate(ap_gain=float(ap_gain_str)) - if probe.annotations.get("lf_sample_frequency_hz", 0) > 0: - lf_gain_str = custom_options.get("LFPGainMode") - if lf_gain_str: - probe.annotate(lf_gain=float(lf_gain_str)) - - # Shift multiple probes so they don't overlap when plotted - probe.move([250 * (curr_probe - 1), 0]) + if "APGainMode" in custom_options: + probe.annotate(ap_gain=float(custom_options["APGainMode"])) + if probe.annotations.get("lf_sample_frequency_hz", 0) > 0 and "LFPGainMode" in custom_options: + probe.annotate(lf_gain=float(custom_options["LFPGainMode"])) + # Spatial shift for multiple probes + probe.move([250 * curr_probe_idx, 0]) probe_group.add_probe(probe) return probe_group