From f88dc2dcafcde75601d30d124c8f661b5622314b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Apr 2026 14:07:00 -0600 Subject: [PATCH 01/31] First draft --- src/spikeinterface/core/baserecording.py | 4 +- .../core/baserecordingsnippets.py | 121 ++++++------------ src/spikeinterface/core/basesnippets.py | 2 +- .../core/channelsaggregationrecording.py | 15 ++- src/spikeinterface/core/channelslice.py | 64 +++++++-- src/spikeinterface/core/zarrextractors.py | 2 +- .../extractors/neoextractors/biocam.py | 6 +- .../extractors/tests/test_iblextractors.py | 1 - .../tests/test_interpolate_bad_channels.py | 20 ++- .../preprocessing/zero_channel_pad.py | 2 +- .../motion/motion_interpolation.py | 16 ++- 11 files changed, 137 insertions(+), 116 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f23b524271..2c47693c32 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -637,7 +637,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) @@ -665,7 +665,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..0b99dfe435 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -19,6 +19,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,7 +52,7 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + return self._probegroup is not None def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() @@ -145,24 +146,18 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr - probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 - if np.any(~keep): + # identify connected contacts; device_channel_indices values are preserved as provenance + global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + connected_mask = global_device_channel_indices >= 0 + if np.any(~connected_mask): warn("The given probes have unconnected contacts: they are removed") - probe_as_numpy_array = probe_as_numpy_array[keep] - - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] + device_channel_indices = np.sort(global_device_channel_indices[connected_mask]) - # check TODO: Where did this came from? + # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( + raise ValueError( f"The given Probe either has 'device_channel_indices' that does not match channel count \n" f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" @@ -170,11 +165,13 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) - raise ValueError(error_msg) new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") + + # drop only the unconnected contacts from the stored probegroup; preserve device_channel_indices values + probegroup = probegroup.get_slice(connected_mask) + probegroup._build_contact_vector() + contact_vector = probegroup.contact_vector # create recording : channel slice or clone or self if in_place: @@ -187,25 +184,18 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): else: sub_recording = self.select_channels(new_channel_ids) - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) + sub_recording._probegroup = probegroup - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property + # duplicate positions to "location" property so SpikeInterface-level readers keep working ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") + locations = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] + locations[:, i] = contact_vector[dim] sub_recording.set_property("location", locations, ids=None) - # handle groups - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields + # derive groups from contact_vector + has_shank_id = "shank_ids" in contact_vector.dtype.fields + has_contact_side = "contact_sides" in contact_vector.dtype.fields if group_mode == "auto": group_keys = ["probe_index"] if has_shank_id: @@ -223,21 +213,15 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): group_keys = ["probe_index", "shank_ids", "contact_sides"] else: group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) + groups = np.zeros(contact_vector.size, dtype="int64") + unique_keys = np.unique(contact_vector[group_keys]) for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) + mask = np.ones(contact_vector.size, dtype=bool) for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] + mask &= contact_vector[k] == a[k] groups[mask] = group sub_recording.set_property("group", groups, ids=None) - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) - return sub_recording def get_probe(self): @@ -250,30 +234,9 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") - if arr is None: - positions = self.get_property("location") - if positions is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - else: - warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - probe = self.create_dummy_probe_from_locations(positions) - # probe.create_auto_shape() - probegroup = ProbeGroup() - probegroup.add_probe(probe) - else: - probegroup = ProbeGroup.from_numpy(arr) - - if "probes_info" in self.get_annotation_keys(): - probes_info = self.get_annotation("probes_info") - for probe, probe_info in zip(probegroup.probes, probes_info): - probe.annotations = probe_info - - for probe_index, probe in enumerate(probegroup.probes): - contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - if contour is not None: - probe.set_planar_contour(contour) - return probegroup + if self._probegroup is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + return self._probegroup def _extra_metadata_from_folder(self, folder): # load probe @@ -284,7 +247,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) @@ -341,7 +304,7 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: + if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -349,21 +312,15 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + if not self.has_probe(): + raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") + self._probegroup._build_contact_vector() + contact_vector = self._probegroup.contact_vector + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + return all_positions[channel_indices] def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..fa47365200 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -259,7 +259,7 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 697aab875e..59501d0ba1 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,8 @@ import numpy as np +from probeinterface import ProbeGroup + from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,11 +92,18 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) + # aggregate probegroups across the inputs and reset wiring to the new channel order + if all(rec.has_probe() for rec in recording_list): + aggregated_probegroup = ProbeGroup() + for rec in recording_list: + for probe in rec.get_probegroup().probes: + aggregated_probegroup.add_probe(probe.copy()) + aggregated_probegroup.set_global_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64")) + aggregated_probegroup._build_contact_vector() + self._probegroup = aggregated_probegroup + # if locations are present, check that they are all different! if "location" in self.get_property_keys(): location_tuple = [tuple(loc) for loc in self.get_property("location")] diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index de693d5c26..82e842dabc 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,11 +61,33 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # filter the probegroup to contacts wired to the retained channels + if parent_recording.has_probe(): + parent_probegroup = parent_recording.get_probegroup() + parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) + child_dci_values = parent_dci_sorted[self._parent_channel_indices] + are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) + probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep_mask = np.isin(probe_dci, child_dci_values) + sliced_probegroup = parent_probegroup.get_slice(keep_mask) + + if not are_channels_reordered: + # simple case: the child's channels are already in ascending device_channel_indices order + # so _build_contact_vector on the filtered probegroup will produce rows in the child's + # channel order. Nothing else to do. + pass + else: + # reorder case: the user picked channels in an order that does not match sort-by-dci. + # We have to rewrite device_channel_indices on the child's copy so that the sort done + # by _build_contact_vector aligns with the child's channel order. + new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} + sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] + sliced_probegroup.set_global_device_channel_indices( + np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") + ) + + sliced_probegroup._build_contact_vector() + self._probegroup = sliced_probegroup # update dump dict self._kwargs = { @@ -151,11 +173,33 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # filter the probegroup to contacts wired to the retained channels + if parent_snippets.has_probe(): + parent_probegroup = parent_snippets.get_probegroup() + parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) + child_dci_values = parent_dci_sorted[self._parent_channel_indices] + are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) + probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep_mask = np.isin(probe_dci, child_dci_values) + sliced_probegroup = parent_probegroup.get_slice(keep_mask) + + if not are_channels_reordered: + # simple case: the child's channels are already in ascending device_channel_indices order + # so _build_contact_vector on the filtered probegroup will produce rows in the child's + # channel order. Nothing else to do. + pass + else: + # reorder case: the user picked channels in an order that does not match sort-by-dci. + # We have to rewrite device_channel_indices on the child's copy so that the sort done + # by _build_contact_vector aligns with the child's channel order. + new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} + sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] + sliced_probegroup.set_global_device_channel_indices( + np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") + ) + + sliced_probegroup._build_contact_vector() + self._probegroup = sliced_probegroup # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1ef5d76e5a..941a99b877 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -503,7 +503,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c2ceb66523 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np import probeinterface from spikeinterface.core.core_tools import define_function_from_class @@ -71,8 +72,9 @@ def __init__( probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + probe = self.get_probegroup().probes[0] + self.set_property("row", np.asarray(probe.contact_annotations["row"])) + self.set_property("col", np.asarray(probe.contact_annotations["col"])) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 5306de2441..ff21c7a3c7 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -84,7 +84,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", "location", "group", "shank", diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index a571894374..6b40548bc4 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -126,8 +126,10 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + probe = recording.get_probegroup().probes[0] + probe._contact_positions[:, 0] = x + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -170,9 +172,12 @@ def test_output_values(): [5, 5, 5, 7, 3], ] # all others equal distance away. # Overwrite the probe information with the new locations + probe = recording.get_probegroup().probes[0] for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + probe._contact_positions[idx, 0] = x + probe._contact_positions[idx, 1] = y + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +191,11 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + probe = recording.get_probegroup().probes[0] + probe._contact_positions[-1, 0] = 5 + probe._contact_positions[-1, 1] = 9 + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..a84d3bbf64 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe() or "location" in recording.get_property_keys(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 7c4c4b166e..027f2bbb66 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -403,13 +403,15 @@ def __init__( dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) - if border_mode == "remove_channels": - # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if border_mode == "remove_channels" and recording.has_probe(): + # filter the probegroup to contacts wired to the retained channels; order is preserved (channel_inds is ascending) + parent_probegroup = recording.get_probegroup() + parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) + child_dci_values = parent_dci_sorted[channel_inds] + probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] + sliced_probegroup = parent_probegroup.get_slice(np.isin(probe_dci, child_dci_values)) + sliced_probegroup._build_contact_vector() + self._probegroup = sliced_probegroup # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From cd52ef333496f638e0fd7115c50625bec4ef403c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Apr 2026 14:10:47 -0600 Subject: [PATCH 02/31] second iteration --- .../core/baserecordingsnippets.py | 13 +++-- src/spikeinterface/core/channelslice.py | 52 +++---------------- .../motion/motion_interpolation.py | 8 ++- 3 files changed, 18 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 0b99dfe435..b8de26df3d 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -146,13 +146,17 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # identify connected contacts; device_channel_indices values are preserved as provenance + # identify connected contacts and their channel-order global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] connected_mask = global_device_channel_indices >= 0 if np.any(~connected_mask): warn("The given probes have unconnected contacts: they are removed") - device_channel_indices = np.sort(global_device_channel_indices[connected_mask]) + connected_contact_indices = np.where(connected_mask)[0] + connected_channel_values = global_device_channel_indices[connected_mask] + order = np.argsort(connected_channel_values) + sorted_contact_indices = connected_contact_indices[order] + device_channel_indices = connected_channel_values[order] # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) @@ -168,8 +172,9 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): new_channel_ids = self.get_channel_ids()[device_channel_indices] - # drop only the unconnected contacts from the stored probegroup; preserve device_channel_indices values - probegroup = probegroup.get_slice(connected_mask) + # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange + probegroup = probegroup.get_slice(sorted_contact_indices) + probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) probegroup._build_contact_vector() contact_vector = probegroup.contact_vector diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 82e842dabc..6491113884 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,31 +61,11 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # filter the probegroup to contacts wired to the retained channels + # slice the probegroup to the retained channels and reset wiring to the new channel order if parent_recording.has_probe(): parent_probegroup = parent_recording.get_probegroup() - parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) - child_dci_values = parent_dci_sorted[self._parent_channel_indices] - are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) - probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] - keep_mask = np.isin(probe_dci, child_dci_values) - sliced_probegroup = parent_probegroup.get_slice(keep_mask) - - if not are_channels_reordered: - # simple case: the child's channels are already in ascending device_channel_indices order - # so _build_contact_vector on the filtered probegroup will produce rows in the child's - # channel order. Nothing else to do. - pass - else: - # reorder case: the user picked channels in an order that does not match sort-by-dci. - # We have to rewrite device_channel_indices on the child's copy so that the sort done - # by _build_contact_vector aligns with the child's channel order. - new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} - sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] - sliced_probegroup.set_global_device_channel_indices( - np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") - ) - + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup @@ -173,31 +153,11 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # filter the probegroup to contacts wired to the retained channels + # slice the probegroup to the retained channels and reset wiring to the new channel order if parent_snippets.has_probe(): parent_probegroup = parent_snippets.get_probegroup() - parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) - child_dci_values = parent_dci_sorted[self._parent_channel_indices] - are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) - probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] - keep_mask = np.isin(probe_dci, child_dci_values) - sliced_probegroup = parent_probegroup.get_slice(keep_mask) - - if not are_channels_reordered: - # simple case: the child's channels are already in ascending device_channel_indices order - # so _build_contact_vector on the filtered probegroup will produce rows in the child's - # channel order. Nothing else to do. - pass - else: - # reorder case: the user picked channels in an order that does not match sort-by-dci. - # We have to rewrite device_channel_indices on the child's copy so that the sort done - # by _build_contact_vector aligns with the child's channel order. - new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} - sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] - sliced_probegroup.set_global_device_channel_indices( - np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") - ) - + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 027f2bbb66..612b667f63 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -404,12 +404,10 @@ def __init__( BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels" and recording.has_probe(): - # filter the probegroup to contacts wired to the retained channels; order is preserved (channel_inds is ascending) + # slice the probegroup to the retained channels and reset wiring to the new channel order parent_probegroup = recording.get_probegroup() - parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) - child_dci_values = parent_dci_sorted[channel_inds] - probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] - sliced_probegroup = parent_probegroup.get_slice(np.isin(probe_dci, child_dci_values)) + sliced_probegroup = parent_probegroup.get_slice(channel_inds) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup From f52c39c81eeaaf1095727795ce47b6191ba4a600 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 11:00:55 -0600 Subject: [PATCH 03/31] using cache and consistent use of set_probe_groups --- src/spikeinterface/core/base.py | 12 ++++++++++++ src/spikeinterface/core/baserecordingsnippets.py | 5 ++--- .../core/channelsaggregationrecording.py | 11 ++++++----- src/spikeinterface/core/channelslice.py | 2 -- .../extractors/neoextractors/maxwell.py | 3 ++- .../sortingcomponents/motion/motion_interpolation.py | 1 - 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8d149a7c49..3bac073de5 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -574,6 +574,9 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + if getattr(self, "_probegroup", None) is not None: + dump_dict["probegroup"] = self._probegroup.to_dict(array_as_list=True) + return dump_dict @staticmethod @@ -1161,6 +1164,15 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": for k, v in dic["properties"].items(): extractor.set_property(k, v) + if "probegroup" in dic: + from probeinterface import ProbeGroup + + probegroup = ProbeGroup.from_dict(dic["probegroup"]) + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + return extractor diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b8de26df3d..83ef0e1699 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -161,7 +161,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): - raise ValueError( + error_msg = ( f"The given Probe either has 'device_channel_indices' that does not match channel count \n" f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" @@ -169,13 +169,13 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) + raise ValueError(error_msg) new_channel_ids = self.get_channel_ids()[device_channel_indices] # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - probegroup._build_contact_vector() contact_vector = probegroup.contact_vector # create recording : channel slice or clone or self @@ -319,7 +319,6 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra channel_indices = self.ids_to_indices(channel_ids) if not self.has_probe(): raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") - self._probegroup._build_contact_vector() contact_vector = self._probegroup.contact_vector ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 59501d0ba1..2e5deb8703 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,7 +2,7 @@ import numpy as np -from probeinterface import ProbeGroup +from probeinterface import Probe, ProbeGroup from .baserecording import BaseRecording, BaseRecordingSegment @@ -94,15 +94,16 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record for prop_name, prop_values in property_dict.items(): self.set_property(key=prop_name, values=prop_values) - # aggregate probegroups across the inputs and reset wiring to the new channel order + # aggregate probegroups across the inputs and attach via the canonical path if all(rec.has_probe() for rec in recording_list): aggregated_probegroup = ProbeGroup() for rec in recording_list: for probe in rec.get_probegroup().probes: - aggregated_probegroup.add_probe(probe.copy()) + # round-trip through to_dict/from_dict because Probe.copy() drops contact_ids + # and annotations (tracked in probeinterface #421) + aggregated_probegroup.add_probe(Probe.from_dict(probe.to_dict(array_as_list=False))) aggregated_probegroup.set_global_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64")) - aggregated_probegroup._build_contact_vector() - self._probegroup = aggregated_probegroup + self.set_probegroup(aggregated_probegroup, in_place=True) # if locations are present, check that they are all different! if "location" in self.get_property_keys(): diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 6491113884..5687001340 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -66,7 +66,6 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_probegroup = parent_recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup # update dump dict @@ -158,7 +157,6 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_probegroup = parent_snippets.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup # update dump dict diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..875422d00b 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,8 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + probe = self.get_probegroup().probes[0] + self.set_property("electrode", np.asarray(probe.contact_annotations["electrode"])) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 612b667f63..0bb3f8b065 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -408,7 +408,6 @@ def __init__( parent_probegroup = recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(channel_inds) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup # handle manual interpolation_time_bin_centers_s From 594122cda4d98a60c6341bfe40118242bfdce04d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 11:01:30 -0600 Subject: [PATCH 04/31] using cache and consistent use of set_probe_groups --- src/spikeinterface/core/channelslice.py | 8 ++++---- .../sortingcomponents/motion/motion_interpolation.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 5687001340..f7d498db04 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,12 +61,12 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # slice the probegroup to the retained channels and reset wiring to the new channel order + # slice the probegroup to the retained channels and attach via the canonical path if parent_recording.has_probe(): parent_probegroup = parent_recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self._probegroup = sliced_probegroup + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -152,12 +152,12 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # slice the probegroup to the retained channels and reset wiring to the new channel order + # slice the probegroup to the retained channels and attach via the canonical path if parent_snippets.has_probe(): parent_probegroup = parent_snippets.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self._probegroup = sliced_probegroup + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 0bb3f8b065..b88987f884 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -404,11 +404,11 @@ def __init__( BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels" and recording.has_probe(): - # slice the probegroup to the retained channels and reset wiring to the new channel order + # slice the probegroup to the retained channels and attach via the canonical path parent_probegroup = recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(channel_inds) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self._probegroup = sliced_probegroup + self.set_probegroup(sliced_probegroup, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From 8bdcdc9ab4118d4c000a41993969b0662773d35c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 13:27:00 -0600 Subject: [PATCH 05/31] recover cophy semantics --- src/spikeinterface/core/base.py | 36 +++++++++++++++++++ .../core/baserecordingsnippets.py | 4 +-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3bac073de5..a05dbdb566 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1172,10 +1172,46 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": extractor.set_probegroup(probegroup, in_place=True) else: extractor._probegroup = probegroup + elif "contact_vector" in dic.get("properties", {}): + _restore_probegroup_from_legacy_contact_vector(extractor) return extractor +def _restore_probegroup_from_legacy_contact_vector(extractor) -> None: + """ + Reconstruct a `ProbeGroup` from the legacy `contact_vector` property. + + Recordings saved before the probegroup refactor stored the probe as a structured numpy + array under the `contact_vector` property, with probe-level annotations under a separate + `probes_info` annotation and per-probe planar contours under `probe_{i}_planar_contour` + annotations. This function reconstructs a `ProbeGroup` from those legacy fields, attaches + it via the canonical `set_probegroup` path, and removes the legacy property so the new + and old representations do not coexist on the loaded extractor. + """ + from probeinterface import ProbeGroup + + contact_vector_array = extractor.get_property("contact_vector") + probegroup = ProbeGroup.from_numpy(contact_vector_array) + + if "probes_info" in extractor.get_annotation_keys(): + probes_info = extractor.get_annotation("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations = probe_info + + for probe_index, probe in enumerate(probegroup.probes): + contour = extractor._annotations.get(f"probe_{probe_index}_planar_contour") + if contour is not None: + probe.set_planar_contour(contour) + + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + + extractor._properties.pop("contact_vector", None) + + def _get_class_from_string(class_string): class_name = class_string.split(".")[-1] module = ".".join(class_string.split(".")[:-1]) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 83ef0e1699..ccc798dbdc 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -176,7 +176,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - contact_vector = probegroup.contact_vector + contact_vector = probegroup._contact_vector # create recording : channel slice or clone or self if in_place: @@ -319,7 +319,7 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra channel_indices = self.ids_to_indices(channel_ids) if not self.has_probe(): raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") - contact_vector = self._probegroup.contact_vector + contact_vector = self._probegroup._contact_vector ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(axes): From d499bc0d23194430bb6d158e6e9a0c3e462ef78b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:08:40 -0600 Subject: [PATCH 06/31] add docstring --- .../core/baserecordingsnippets.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index ccc798dbdc..8df1c68bba 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path import numpy as np @@ -230,18 +231,38 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): return sub_recording def get_probe(self): + """ + Return a copy of the single attached probe. + + Returns a deepcopy so callers can mutate the probe without affecting the + recording's internal state. To re-attach a mutated probe use `set_probe(...)`. + """ probes = self.get_probes() assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): + """ + Return a list of copies of the attached probes. + + Returns deepcopies so callers can mutate probes without affecting the + recording's internal state. To re-attach a mutated probe use + `set_probegroup(...)` or `set_probe(...)`. + """ probegroup = self.get_probegroup() return probegroup.probes def get_probegroup(self): + """ + Return a copy of the attached `ProbeGroup`. + + Returns a deepcopy so callers hold a snapshot independent of the recording's + internal state. Mutating the returned probegroup does not modify the + recording; to commit changes use `set_probegroup(...)`. + """ if self._probegroup is None: raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - return self._probegroup + return copy.deepcopy(self._probegroup) def _extra_metadata_from_folder(self, folder): # load probe @@ -309,6 +330,14 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): + """ + Set channel locations directly on the `"location"` property. + + When a probe is attached, channel locations come from the probegroup and + `"location"` is a compatibility mirror maintained by `_set_probes`. Writing + directly to the property would diverge the mirror from the probegroup, so + this method raises in that case; reattach a modified probe via `set_probe`. + """ if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -338,12 +367,25 @@ def clear_channel_locations(self, channel_ids=None): self.set_property("location", locations, ids=channel_ids) def set_channel_groups(self, groups, channel_ids=None): + """ + Set channel groups directly on the `"group"` property. + + When a probe is attached, the `"group"` property is a compatibility mirror + derived by `_set_probes` from the probegroup and the chosen `group_mode`. + Writing groups directly bypasses that derivation and can diverge from the + probegroup; prefer re-attaching via `set_probe(..., group_mode=...)`. + """ if "probes" in self._annotations: warn("set_channel_groups() destroys the probe description. Using set_probe() is preferable") self._annotations.pop("probes") self.set_property("group", groups, ids=channel_ids) def get_channel_groups(self, channel_ids=None): + # Note: `"group"` is a compatibility mirror of the probegroup-derived grouping + # when a probe is attached, populated at `_set_probes` time. It is read directly + # here because the `group_mode` used to derive it is not currently persisted on + # the recording. Follow-up work may unify this with `get_channel_locations` by + # reading directly from the attached probegroup. groups = self.get_property("group", ids=channel_ids) return groups From f15e5c6514f53e09d7e6a6bd27e9957fb6e99b0b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:20:57 -0600 Subject: [PATCH 07/31] just copies --- .../core/baserecordingsnippets.py | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 8df1c68bba..268585ff3a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -231,37 +231,21 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): return sub_recording def get_probe(self): - """ - Return a copy of the single attached probe. - - Returns a deepcopy so callers can mutate the probe without affecting the - recording's internal state. To re-attach a mutated probe use `set_probe(...)`. - """ probes = self.get_probes() assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): - """ - Return a list of copies of the attached probes. - - Returns deepcopies so callers can mutate probes without affecting the - recording's internal state. To re-attach a mutated probe use - `set_probegroup(...)` or `set_probe(...)`. - """ probegroup = self.get_probegroup() return probegroup.probes def get_probegroup(self): - """ - Return a copy of the attached `ProbeGroup`. - - Returns a deepcopy so callers hold a snapshot independent of the recording's - internal state. Mutating the returned probegroup does not modify the - recording; to commit changes use `set_probegroup(...)`. - """ if self._probegroup is None: raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + # Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed + # a fresh `ProbeGroup` from the stored structured array on each call, so external + # callers relied on value semantics. Handing out the live `_probegroup` would be a + # silent behavioural change. return copy.deepcopy(self._probegroup) def _extra_metadata_from_folder(self, folder): From e1b71032b5e1dece03ea1566d8fddf453f480bed Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:22:49 -0600 Subject: [PATCH 08/31] remove comments --- .../core/baserecordingsnippets.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 268585ff3a..b12a753ae9 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -314,14 +314,6 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - """ - Set channel locations directly on the `"location"` property. - - When a probe is attached, channel locations come from the probegroup and - `"location"` is a compatibility mirror maintained by `_set_probes`. Writing - directly to the property would diverge the mirror from the probegroup, so - this method raises in that case; reattach a modified probe via `set_probe`. - """ if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -351,25 +343,12 @@ def clear_channel_locations(self, channel_ids=None): self.set_property("location", locations, ids=channel_ids) def set_channel_groups(self, groups, channel_ids=None): - """ - Set channel groups directly on the `"group"` property. - - When a probe is attached, the `"group"` property is a compatibility mirror - derived by `_set_probes` from the probegroup and the chosen `group_mode`. - Writing groups directly bypasses that derivation and can diverge from the - probegroup; prefer re-attaching via `set_probe(..., group_mode=...)`. - """ if "probes" in self._annotations: warn("set_channel_groups() destroys the probe description. Using set_probe() is preferable") self._annotations.pop("probes") self.set_property("group", groups, ids=channel_ids) def get_channel_groups(self, channel_ids=None): - # Note: `"group"` is a compatibility mirror of the probegroup-derived grouping - # when a probe is attached, populated at `_set_probes` time. It is read directly - # here because the `group_mode` used to derive it is not currently persisted on - # the recording. Follow-up work may unify this with `get_channel_locations` by - # reading directly from the attached probegroup. groups = self.get_property("group", ids=channel_ids) return groups From 5d1f57d1b4fdbc4ce44dc50f783353b22b5990ae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:32:02 -0600 Subject: [PATCH 09/31] rename --- .../core/baserecordingsnippets.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b12a753ae9..d2d3c21716 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -177,7 +177,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - contact_vector = probegroup._contact_vector + probe_as_numpy_array = probegroup._contact_vector # create recording : channel slice or clone or self if in_place: @@ -194,14 +194,14 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # duplicate positions to "location" property so SpikeInterface-level readers keep working ndim = probegroup.ndim - locations = np.zeros((contact_vector.size, ndim), dtype="float64") + locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = contact_vector[dim] + locations[:, i] = probe_as_numpy_array[dim] sub_recording.set_property("location", locations, ids=None) # derive groups from contact_vector - has_shank_id = "shank_ids" in contact_vector.dtype.fields - has_contact_side = "contact_sides" in contact_vector.dtype.fields + has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields + has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields if group_mode == "auto": group_keys = ["probe_index"] if has_shank_id: @@ -219,12 +219,12 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): group_keys = ["probe_index", "shank_ids", "contact_sides"] else: group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(contact_vector.size, dtype="int64") - unique_keys = np.unique(contact_vector[group_keys]) + groups = np.zeros(probe_as_numpy_array.size, dtype="int64") + unique_keys = np.unique(probe_as_numpy_array[group_keys]) for group, a in enumerate(unique_keys): - mask = np.ones(contact_vector.size, dtype=bool) + mask = np.ones(probe_as_numpy_array.size, dtype=bool) for k in group_keys: - mask &= contact_vector[k] == a[k] + mask &= probe_as_numpy_array[k] == a[k] groups[mask] = group sub_recording.set_property("group", groups, ids=None) From c02e7602bfff3e43edae4d23681229d6a85e0970 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:38:12 -0600 Subject: [PATCH 10/31] more backwards compatability --- .../core/baserecordingsnippets.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index d2d3c21716..92c5bfb858 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -228,6 +228,17 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): groups[mask] = group sub_recording.set_property("group", groups, ids=None) + # TODO discuss backwards compatibility: mirror probe-level annotations and planar + # contours as recording-level annotations so external code that reads these keys + # keeps working. The canonical source is now `probe.annotations` and + # `probe.probe_planar_contour` on the attached probegroup. + probes_info = [probe.annotations for probe in probegroup.probes] + sub_recording.annotate(probes_info=probes_info) + for probe_index, probe in enumerate(probegroup.probes): + contour = probe.probe_planar_contour + if contour is not None: + sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) + return sub_recording def get_probe(self): @@ -322,14 +333,18 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - if not self.has_probe(): - raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") - contact_vector = self._probegroup._contact_vector - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - return all_positions[channel_indices] + if self.has_probe(): + contact_vector = self._probegroup._contact_vector + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + return all_positions[channel_indices] + locations = self.get_property("location") + if locations is None: + raise Exception("There are no channel locations") + locations = np.asarray(locations)[channel_indices] + return select_axes(locations, axes) def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 From 591be6353c203b02d077856a1065cddb34f2c014 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:48:33 -0600 Subject: [PATCH 11/31] testing --- pyproject.toml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3eb85b6e5..48ca4333eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,8 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs, @@ -139,7 +140,8 @@ test_extractors = [ "pooch>=1.8.2", "datalad>=1.0.2", # Commenting out for release - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] @@ -190,7 +192,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs @@ -219,7 +222,8 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From ed22a738e09304362f01a8b8c150f587a7660691 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 18:10:45 -0600 Subject: [PATCH 12/31] beahvior for 0 channel recording --- .../core/baserecordingsnippets.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 92c5bfb858..502a59ae61 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -174,10 +174,13 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): new_channel_ids = self.get_channel_ids()[device_channel_indices] + # capture ndim before slicing; get_slice with an empty selection yields a probegroup + # with no probes, on which `.ndim` raises + ndim = probegroup.ndim + # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - probe_as_numpy_array = probegroup._contact_vector # create recording : channel slice or clone or self if in_place: @@ -192,8 +195,20 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): sub_recording._probegroup = probegroup + # TODO: revisit whether set_probe with a fully unconnected probe should raise + # instead of returning a zero-channel recording. Preserved here for backwards + # compatibility with a test in test_BaseRecording; that test case should be + # peeled into its own named test so this assumption is easy to find and + # discuss when we decide to tighten the behaviour. + if len(device_channel_indices) == 0: + sub_recording.set_property("location", np.zeros((0, ndim), dtype="float64"), ids=None) + sub_recording.set_property("group", np.zeros(0, dtype="int64"), ids=None) + sub_recording.annotate(probes_info=[]) + return sub_recording + + probe_as_numpy_array = probegroup._contact_vector + # duplicate positions to "location" property so SpikeInterface-level readers keep working - ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] From 565d7759eb47e1f0b2f45926ee7f992366fdc488 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 18:21:39 -0600 Subject: [PATCH 13/31] more fixes --- .../core/channelsaggregationrecording.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 2e5deb8703..64aa7e5bc2 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -97,12 +97,20 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record # aggregate probegroups across the inputs and attach via the canonical path if all(rec.has_probe() for rec in recording_list): aggregated_probegroup = ProbeGroup() + offset = 0 for rec in recording_list: for probe in rec.get_probegroup().probes: # round-trip through to_dict/from_dict because Probe.copy() drops contact_ids # and annotations (tracked in probeinterface #421) - aggregated_probegroup.add_probe(Probe.from_dict(probe.to_dict(array_as_list=False))) - aggregated_probegroup.set_global_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64")) + probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) + # assign non-colliding device_channel_indices before add_probe so the + # cross-probe uniqueness check does not fire on children that share + # child-local wiring (each sub-recording's probe was reset to arange + # when it was created via set_probe) + n = probe_copy.get_contact_count() + probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64")) + aggregated_probegroup.add_probe(probe_copy) + offset += n self.set_probegroup(aggregated_probegroup, in_place=True) # if locations are present, check that they are all different! From 5ef02216f0d7c1c666b434998c400e13af866919 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 18:44:42 -0600 Subject: [PATCH 14/31] second fix --- .../core/channelsaggregationrecording.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 64aa7e5bc2..90be729664 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -94,24 +94,23 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record for prop_name, prop_values in property_dict.items(): self.set_property(key=prop_name, values=prop_values) - # aggregate probegroups across the inputs and attach via the canonical path + # split_by resets each child probe's device_channel_indices, so the information + # of which contact was connected to which channel of the parent is lost by the + # time we aggregate. We rebuild a globally-unique wiring via per-probe offsets + # and skip set_probegroup because children also share contact positions. if all(rec.has_probe() for rec in recording_list): aggregated_probegroup = ProbeGroup() offset = 0 for rec in recording_list: for probe in rec.get_probegroup().probes: - # round-trip through to_dict/from_dict because Probe.copy() drops contact_ids - # and annotations (tracked in probeinterface #421) + # round-trip through to_dict/from_dict because Probe.copy() drops + # contact_ids and annotations (probeinterface #421) probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) - # assign non-colliding device_channel_indices before add_probe so the - # cross-probe uniqueness check does not fire on children that share - # child-local wiring (each sub-recording's probe was reset to arange - # when it was created via set_probe) n = probe_copy.get_contact_count() probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64")) aggregated_probegroup.add_probe(probe_copy) offset += n - self.set_probegroup(aggregated_probegroup, in_place=True) + self._probegroup = aggregated_probegroup # if locations are present, check that they are all different! if "location" in self.get_property_keys(): From 07f4e9e79d1f5c889d1c45a9403fe0a3fe0e4a9c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 19:07:49 -0600 Subject: [PATCH 15/31] remove non-overallaping redundant check --- src/spikeinterface/core/sortinganalyzer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..ad444fca4b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match from .core_tools import ( check_json, retrieve_importing_provenance, @@ -363,10 +363,6 @@ def create( f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. " "Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer." ) - # check that multiple probes are non-overlapping - all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) - if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( "Your sorting has spikes with samples times greater than your recording length. These spikes have been removed." From e6129bc33a50c0892a3c1827de31b7135be0281c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 19:13:52 -0600 Subject: [PATCH 16/31] another fallack --- src/spikeinterface/core/baserecordingsnippets.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 502a59ae61..63413c8768 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -267,7 +267,18 @@ def get_probes(self): def get_probegroup(self): if self._probegroup is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + # Backwards-compat fallback: pre-migration get_probegroup synthesised a dummy + # probe from the "location" property when no probe had been attached. Callers + # (e.g. sparsity.py) rely on this for recordings that have locations but no + # probe. + positions = self.get_property("location") + if positions is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") + probe = self.create_dummy_probe_from_locations(positions) + pg = ProbeGroup() + pg.add_probe(probe) + return copy.deepcopy(pg) # Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed # a fresh `ProbeGroup` from the stored structured array on each call, so external # callers relied on value semantics. Handing out the live `_probegroup` would be a From 6eb09a65ab671533e073dd19e64e8abb8e2e91a4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 20:01:22 -0600 Subject: [PATCH 17/31] propgate to children --- src/spikeinterface/preprocessing/basepreprocessor.py | 7 +++++++ .../preprocessing/tests/test_interpolate_bad_channels.py | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/basepreprocessor.py b/src/spikeinterface/preprocessing/basepreprocessor.py index 64d57d3637..4e18516a80 100644 --- a/src/spikeinterface/preprocessing/basepreprocessor.py +++ b/src/spikeinterface/preprocessing/basepreprocessor.py @@ -21,6 +21,13 @@ def __init__(self, recording, sampling_frequency=None, channel_ids=None, dtype=N recording.copy_metadata(self, only_main=False, ids=channel_ids) self._parent = recording + # Propagate the attached probegroup. copy_metadata only handles annotations + # and properties; `_probegroup` is a direct attribute and needs its own path. + # Subclasses that change channels (e.g. slicing) should override by slicing + # the probegroup themselves via set_probegroup. + if getattr(recording, "_probegroup", None) is not None and channel_ids is None: + self._probegroup = recording._probegroup + # self._kwargs have to be handled in subclass diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 6b40548bc4..fd835df5c7 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -126,7 +126,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) x = rng.choice(shanks, num_channels) - probe = recording.get_probegroup().probes[0] + probe = recording._probegroup.probes[0] probe._contact_positions[:, 0] = x recording._probegroup._build_contact_vector() recording.set_property("location", recording.get_channel_locations()) @@ -172,7 +172,7 @@ def test_output_values(): [5, 5, 5, 7, 3], ] # all others equal distance away. # Overwrite the probe information with the new locations - probe = recording.get_probegroup().probes[0] + probe = recording._probegroup.probes[0] for idx, (x, y) in enumerate(zip(*new_probe_locs)): probe._contact_positions[idx, 0] = x probe._contact_positions[idx, 1] = y @@ -191,7 +191,7 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - probe = recording.get_probegroup().probes[0] + probe = recording._probegroup.probes[0] probe._contact_positions[-1, 0] = 5 probe._contact_positions[-1, 1] = 9 recording._probegroup._build_contact_vector() From 1f4afee9631addc2007aa45f7932d922b009cf4f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 20:19:21 -0600 Subject: [PATCH 18/31] fix tests --- src/spikeinterface/extractors/tests/test_iblextractors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index ff21c7a3c7..dfbf5d714d 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -68,11 +68,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = 2.34375 * np.ones(shape=384) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=384) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): probe = self.recording.get_probe() @@ -141,11 +141,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = np.concatenate([2.34375 * np.ones(shape=384), [1171.875]]) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=385) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): expected_exception = ValueError From eb07eeaea9b614d866a735e05ab3408f6fc08dae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 23:01:57 -0600 Subject: [PATCH 19/31] fixes --- src/spikeinterface/preprocessing/tests/test_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 398b6cbc0e..e95b456542 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -50,7 +50,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): # Then, change all kwargs to ensure they are propagated # and check the backwards version. - options["band"] = [671] + options["band"] = 671 options["btype"] = "highpass" options["filter_order"] = 8 options["ftype"] = "bessel" From cc9e9b0e83c3adac4d98ee79b65fe0f09292926c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 23:44:05 -0600 Subject: [PATCH 20/31] another numpy fix --- .../postprocessing/tests/test_principal_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 77bff7a3d8..0e65bb2338 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -101,7 +101,7 @@ def test_get_projections(self, sparse): random_spikes_ext = sorting_analyzer.get_extension("random_spikes") random_spikes_indices = random_spikes_ext.get_data() - unit_ids_num_random_spikes = np.sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) + unit_ids_num_random_spikes = sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) From a7db1e38fdd7b5b17b86aa577c4537980eb2df5e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 12:31:07 -0600 Subject: [PATCH 21/31] remove cache --- src/spikeinterface/core/baserecordingsnippets.py | 4 ++-- .../preprocessing/tests/test_interpolate_bad_channels.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 63413c8768..a9910b2b2c 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -206,7 +206,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): sub_recording.annotate(probes_info=[]) return sub_recording - probe_as_numpy_array = probegroup._contact_vector + probe_as_numpy_array = probegroup._build_contact_vector() # duplicate positions to "location" property so SpikeInterface-level readers keep working locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") @@ -360,7 +360,7 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) if self.has_probe(): - contact_vector = self._probegroup._contact_vector + contact_vector = self._probegroup._build_contact_vector() ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(axes): diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index fd835df5c7..10dbb11fc1 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -128,7 +128,6 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan x = rng.choice(shanks, num_channels) probe = recording._probegroup.probes[0] probe._contact_positions[:, 0] = x - recording._probegroup._build_contact_vector() recording.set_property("location", recording.get_channel_locations()) # generate random bad channel locations @@ -176,7 +175,6 @@ def test_output_values(): for idx, (x, y) in enumerate(zip(*new_probe_locs)): probe._contact_positions[idx, 0] = x probe._contact_positions[idx, 1] = y - recording._probegroup._build_contact_vector() recording.set_property("location", recording.get_channel_locations()) # Run interpolation in SI and check the interpolated channel @@ -194,7 +192,6 @@ def test_output_values(): probe = recording._probegroup.probes[0] probe._contact_positions[-1, 0] = 5 probe._contact_positions[-1, 1] = 9 - recording._probegroup._build_contact_vector() recording.set_property("location", recording.get_channel_locations()) expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) From 20b535f47acc6c3d261293b519e4bf097fc37f66 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 14:27:08 -0600 Subject: [PATCH 22/31] Refactor with wiring --- src/spikeinterface/core/base.py | 7 +- .../core/baserecordingsnippets.py | 356 ++++++++++-------- .../core/channelsaggregationrecording.py | 49 ++- src/spikeinterface/core/channelslice.py | 16 +- src/spikeinterface/core/sortinganalyzer.py | 22 +- .../preprocessing/basepreprocessor.py | 6 +- .../motion/motion_interpolation.py | 8 +- 7 files changed, 260 insertions(+), 204 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6e61ae894f..57db310357 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1162,10 +1162,9 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": from probeinterface import ProbeGroup probegroup = ProbeGroup.from_dict(dic["probegroup"]) - if hasattr(extractor, "set_probegroup"): - extractor.set_probegroup(probegroup, in_place=True) - else: - extractor._probegroup = probegroup + # The `wiring` per-channel property was restored above by the standard + # property-load loop; we just attach the probegroup object. + extractor._probegroup = probegroup elif "contact_vector" in dic.get("properties", {}): _restore_probegroup_from_legacy_contact_vector(extractor) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index a9910b2b2c..2889d949a2 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -63,25 +63,6 @@ def is_filtered(self): return self._annotations.get("is_filtered", False) def set_probe(self, probe, group_mode="auto", in_place=False): - """ - Attach a list of Probe object to a recording. - - Parameters - ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup - The probe(s) to be attached to the recording - group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" - How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) - - Returns - ------- - sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) - """ assert isinstance(probe, Probe), "must give Probe" probegroup = ProbeGroup() probegroup.add_probe(probe) @@ -92,169 +73,171 @@ def set_probegroup(self, probegroup, group_mode="auto", in_place=False): def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): """ - Attach a list of Probe objects to a recording. - For this Probe.device_channel_indices is used to link contacts to recording channels. - If some contacts of the Probe are not connected (device_channel_indices=-1) - then the recording is "sliced" and only connected channel are kept. + Attach a Probe, ProbeGroup, or list of Probe to the recording. - The probe order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. + The probegroup is stored by reference without mutation. The contact-to-channel + mapping is built from each probe's current `device_channel_indices` and stored + on the recording as two per-channel properties, `probe_id` and `contact_id`. + `location` and `group` are also written as properties for backward compatibility. + If the probegroup wires only a subset of the recording's channels, the recording + is sliced via `select_channels` to keep only the wired channels (preserving the + historical attach semantics). Parameters ---------- probe_or_probegroup: Probe, list of Probe, or ProbeGroup - The probe(s) to be attached to the recording - group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" - How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) + The probe(s) to be attached to the recording. + in_place: bool, default: False + If True, attach to self in place. Returns ------- sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) + The recording with the probegroup attached. """ - assert group_mode in ( - "auto", - "by_probe", - "by_shank", - "by_side", - ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" - - # handle several input possibilities + # normalize input to a ProbeGroup if isinstance(probe_or_probegroup, Probe): probegroup = ProbeGroup() probegroup.add_probe(probe_or_probegroup) elif isinstance(probe_or_probegroup, ProbeGroup): probegroup = probe_or_probegroup elif isinstance(probe_or_probegroup, list): - assert all([isinstance(e, Probe) for e in probe_or_probegroup]) + assert all(isinstance(e, Probe) for e in probe_or_probegroup) probegroup = ProbeGroup() for probe in probe_or_probegroup: probegroup.add_probe(probe) else: raise ValueError("must give Probe or ProbeGroup or list of Probe") - # check that the probe do not overlap - num_probes = len(probegroup.probes) - if num_probes > 1: + if len(probegroup.probes) > 1: check_probe_do_not_overlap(probegroup.probes) - # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # identify connected contacts and their channel-order - global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] - connected_mask = global_device_channel_indices >= 0 - if np.any(~connected_mask): - warn("The given probes have unconnected contacts: they are removed") - - connected_contact_indices = np.where(connected_mask)[0] - connected_channel_values = global_device_channel_indices[connected_mask] - order = np.argsort(connected_channel_values) - sorted_contact_indices = connected_contact_indices[order] - device_channel_indices = connected_channel_values[order] - - # validate indices fit the recording - number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) - if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( - f"The given Probe either has 'device_channel_indices' that does not match channel count \n" - f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" - f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" - f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" - f"device_channel_indices are the following: {device_channel_indices} \n" - f"recording channels are the following: {self.get_channel_ids()} \n" - ) - raise ValueError(error_msg) - - new_channel_ids = self.get_channel_ids()[device_channel_indices] - - # capture ndim before slicing; get_slice with an empty selection yields a probegroup - # with no probes, on which `.ndim` raises - ndim = probegroup.ndim + # ensure every probe has a stable probe_id and every contact has a contact_id; + # auto-generate only when missing so user-assigned ids survive + if not any("probe_id" in p.annotations for p in probegroup.probes): + probegroup.auto_generate_probe_ids() + if any(p.contact_ids is None for p in probegroup.probes): + probegroup.auto_generate_contact_ids() + + # collect, per recording channel (by position in self.channel_ids), which + # (probe_id, contact_id) pair wires into it. Unwired channels stay as None. + num_channels = self.get_num_channels() + probe_id_col = [None] * num_channels + contact_id_col = [None] * num_channels + for probe in probegroup.probes: + probe_id = probe.annotations["probe_id"] + dci = np.asarray(probe.device_channel_indices) + for contact_idx, device_idx in enumerate(dci): + if device_idx < 0: + continue # unconnected contact; skip + if device_idx >= num_channels: + raise ValueError( + f"device_channel_indices value {device_idx} is out of range; " + f"recording has {num_channels} channels." + ) + if probe_id_col[device_idx] is not None: + raise ValueError(f"channel at index {device_idx} is wired to more than one contact.") + probe_id_col[device_idx] = probe_id + contact_id_col[device_idx] = probe.contact_ids[contact_idx] + + # Reorder the recording's channels to match the probe's device_channel_indices + # order (smallest dci first). Also drops any channels that are unwired. This + # matches the historical `set_probe` behaviour: after attach, recording channel + # i corresponds to the probe contact whose dci was the i-th smallest. + wired_dci_pairs = [] # (device_idx, position_in_probe_id_col) + for i, pid in enumerate(probe_id_col): + if pid is not None: + wired_dci_pairs.append(i) # position == original device_idx since we indexed by it + # sort by device_idx (which equals the position already) + ordered_positions = sorted(wired_dci_pairs) + original_channel_ids = self.get_channel_ids() + new_channel_ids = original_channel_ids[ordered_positions] - # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange - probegroup = probegroup.get_slice(sorted_contact_indices) - probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - - # create recording : channel slice or clone or self if in_place: - if not np.array_equal(new_channel_ids, self.get_channel_ids()): - raise Exception("set_probe(inplace=True) must have all channel indices") - sub_recording = self + if not np.array_equal(new_channel_ids, original_channel_ids): + raise Exception("set_probe(in_place=True) must have all channel indices") + target = self else: - if np.array_equal(new_channel_ids, self.get_channel_ids()): - sub_recording = self.clone() + if np.array_equal(new_channel_ids, original_channel_ids): + target = self.clone() else: - sub_recording = self.select_channels(new_channel_ids) - - sub_recording._probegroup = probegroup - - # TODO: revisit whether set_probe with a fully unconnected probe should raise - # instead of returning a zero-channel recording. Preserved here for backwards - # compatibility with a test in test_BaseRecording; that test case should be - # peeled into its own named test so this assumption is easy to find and - # discuss when we decide to tighten the behaviour. - if len(device_channel_indices) == 0: - sub_recording.set_property("location", np.zeros((0, ndim), dtype="float64"), ids=None) - sub_recording.set_property("group", np.zeros(0, dtype="int64"), ids=None) - sub_recording.annotate(probes_info=[]) - return sub_recording - - probe_as_numpy_array = probegroup._build_contact_vector() - - # duplicate positions to "location" property so SpikeInterface-level readers keep working - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] - sub_recording.set_property("location", locations, ids=None) - - # derive groups from contact_vector - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields + target = self.select_channels(new_channel_ids) + # re-key probe_id_col / contact_id_col into the (possibly reordered) target + probe_id_col = [probe_id_col[i] for i in ordered_positions] + contact_id_col = [contact_id_col[i] for i in ordered_positions] + + # attach probegroup; the wiring lives as a (num_channels, 2) per-channel + # string property `wiring` with column 0 = probe_id, column 1 = contact_id. + # This is the same pattern as `location` (2D property per channel) and rides + # on SI's existing property plumbing (copy_metadata, concat, serialization). + target._probegroup = probegroup + + # handle the degenerate empty-wiring case (probe with all dci=-1) + if len(probe_id_col) == 0: + ndim = probegroup.ndim + target.set_property("wiring", np.zeros((0, 2), dtype="U64")) + target.set_property("location", np.zeros((0, ndim), dtype="float64")) + target.set_property("group", np.zeros(0, dtype="int64")) + return target + + wiring = np.column_stack( + [ + np.asarray(probe_id_col, dtype="U64"), + np.asarray(contact_id_col, dtype="U64"), + ] + ) + target.set_property("wiring", wiring) + + # write `location` and `group` as compatibility mirrors of the canonical + # probegroup + _channel_to_contact mapping. group_mode is consulted here to + # match the pre-strong-preserve API; callers that pass "by_probe", "by_shank" + # etc. get the same partitioning as before. + ndim = probegroup.ndim + probes_by_id = {p.annotations["probe_id"]: p for p in probegroup.probes} + has_shank = any(p.shank_ids is not None for p in probegroup.probes) + has_side = any(p.contact_sides is not None for p in probegroup.probes) if group_mode == "auto": - group_keys = ["probe_index"] - if has_shank_id: - group_keys += ["shank_ids"] - if has_contact_side: - group_keys += ["contact_sides"] + keys_template = ["probe"] + (["shank"] if has_shank else []) + (["side"] if has_side else []) elif group_mode == "by_probe": - group_keys = ["probe_index"] + keys_template = ["probe"] elif group_mode == "by_shank": - assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" - group_keys = ["probe_index", "shank_ids"] + assert has_shank, "shank_ids is None in probe, you cannot group by shank" + keys_template = ["probe", "shank"] elif group_mode == "by_side": - assert has_contact_side, "contact_sides is None in probe, you cannot group by side" - if has_shank_id: - group_keys = ["probe_index", "shank_ids", "contact_sides"] - else: - group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) - for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) - for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] - groups[mask] = group - sub_recording.set_property("group", groups, ids=None) - - # TODO discuss backwards compatibility: mirror probe-level annotations and planar - # contours as recording-level annotations so external code that reads these keys - # keeps working. The canonical source is now `probe.annotations` and - # `probe.probe_planar_contour` on the attached probegroup. - probes_info = [probe.annotations for probe in probegroup.probes] - sub_recording.annotate(probes_info=probes_info) - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - return sub_recording + assert has_side, "contact_sides is None in probe, you cannot group by side" + keys_template = ["probe"] + (["shank"] if has_shank else []) + ["side"] + else: + raise ValueError(f"unknown group_mode {group_mode!r}") + + wired_positions = list(range(len(probe_id_col))) + locations = np.zeros((len(wired_positions), ndim), dtype="float64") + group_keys_per_channel = [] + for i, (pid, cid) in enumerate(zip(probe_id_col, contact_id_col)): + probe = probes_by_id[pid] + contact_idx = int(np.where(np.asarray(probe.contact_ids) == cid)[0][0]) + locations[i] = probe.contact_positions[contact_idx, :ndim] + key = [] + for k in keys_template: + if k == "probe": + key.append(pid) + elif k == "shank" and probe.shank_ids is not None: + key.append(probe.shank_ids[contact_idx]) + elif k == "side" and probe.contact_sides is not None: + key.append(probe.contact_sides[contact_idx]) + group_keys_per_channel.append(tuple(key)) + + unique_keys = list(dict.fromkeys(group_keys_per_channel)) + key_to_int = {k: i for i, k in enumerate(unique_keys)} + groups = np.array([key_to_int[k] for k in group_keys_per_channel], dtype="int64") + + target.set_property("location", locations) + target.set_property("group", groups) + return target def get_probe(self): probes = self.get_probes() @@ -279,11 +262,29 @@ def get_probegroup(self): pg = ProbeGroup() pg.add_probe(probe) return copy.deepcopy(pg) - # Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed - # a fresh `ProbeGroup` from the stored structured array on each call, so external - # callers relied on value semantics. Handing out the live `_probegroup` would be a - # silent behavioural change. - return copy.deepcopy(self._probegroup) + + # Build a channel-ordered view of the stored probegroup for the public getter. + # Strong-preserve keeps each probe intact on `_probegroup`; here we slice it down + # to the contacts that actually appear in this recording, in channel order, with + # device_channel_indices = arange(N). The returned object matches the + # pre-strong-preserve `get_probe()` semantic. + wiring = self.get_property("wiring") + if wiring is None: + return copy.deepcopy(self._probegroup) + + # map (probe_id, contact_id) to the global contact index in the stored probegroup + contact_id_to_global = {} + offset = 0 + for probe in self._probegroup.probes: + pid = probe.annotations["probe_id"] + for cid in probe.contact_ids: + contact_id_to_global[(pid, cid)] = offset + offset += 1 + + global_indices = [contact_id_to_global[(pid, cid)] for pid, cid in wiring] + view = self._probegroup.get_slice(np.asarray(global_indices, dtype="int64")) + view.set_global_device_channel_indices(np.arange(len(global_indices), dtype="int64")) + return view def _extra_metadata_from_folder(self, folder): # load probe @@ -358,14 +359,25 @@ def set_channel_locations(self, locations, channel_ids=None): def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() - channel_indices = self.ids_to_indices(channel_ids) + if self.has_probe(): - contact_vector = self._probegroup._build_contact_vector() + # resolve each channel via the `wiring` property (column 0 = probe_id, + # column 1 = contact_id) and look up the contact's position on the + # corresponding probe + wiring = self.get_property("wiring", ids=channel_ids) + probes_by_id = {p.annotations["probe_id"]: p for p in self._probegroup.probes} + axis_index = {"x": 0, "y": 1, "z": 2} ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - return all_positions[channel_indices] + locations = np.zeros((len(channel_ids), ndim), dtype="float64") + for i, (probe_id, contact_id) in enumerate(wiring): + probe = probes_by_id[probe_id] + contact_idx = int(np.where(np.asarray(probe.contact_ids) == contact_id)[0][0]) + for j, axis in enumerate(axes): + locations[i, j] = probe.contact_positions[contact_idx, axis_index[axis]] + return locations + + # fallback for recordings that have a "location" property but no attached probegroup + channel_indices = self.ids_to_indices(channel_ids) locations = self.get_property("location") if locations is None: raise Exception("There are no channel locations") @@ -373,6 +385,8 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra return select_axes(locations, axes) def has_3d_locations(self) -> bool: + if self.has_probe(): + return self._probegroup.ndim == 3 return self.get_property("location").shape[1] == 3 def clear_channel_locations(self, channel_ids=None): @@ -390,8 +404,34 @@ def set_channel_groups(self, groups, channel_ids=None): self.set_property("group", groups, ids=channel_ids) def get_channel_groups(self, channel_ids=None): - groups = self.get_property("group", ids=channel_ids) - return groups + # when a probe is attached, derive groups on the fly from the `wiring` + # property + probegroup state (probe_id + shank_ids + contact_sides) + if self.has_probe(): + if channel_ids is None: + channel_ids = self.get_channel_ids() + wiring = self.get_property("wiring", ids=channel_ids) + probes_by_id = {p.annotations["probe_id"]: p for p in self._probegroup.probes} + has_shank = any(p.shank_ids is not None for p in self._probegroup.probes) + has_side = any(p.contact_sides is not None for p in self._probegroup.probes) + + group_keys = [] + for probe_id, contact_id in wiring: + probe = probes_by_id[probe_id] + key = [probe_id] + if has_shank or has_side: + contact_idx = int(np.where(np.asarray(probe.contact_ids) == contact_id)[0][0]) + if has_shank and probe.shank_ids is not None: + key.append(probe.shank_ids[contact_idx]) + if has_side and probe.contact_sides is not None: + key.append(probe.contact_sides[contact_idx]) + group_keys.append(tuple(key)) + + unique_keys = list(dict.fromkeys(group_keys)) + key_to_int = {k: i for i, k in enumerate(unique_keys)} + return np.array([key_to_int[k] for k in group_keys], dtype="int64") + + # fallback: read a stored "group" property (recordings without a probe) + return self.get_property("group", ids=channel_ids) def clear_channel_groups(self, channel_ids=None): if channel_ids is None: diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 90be729664..4307d5a4dc 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -94,27 +94,36 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record for prop_name, prop_values in property_dict.items(): self.set_property(key=prop_name, values=prop_values) - # split_by resets each child probe's device_channel_indices, so the information - # of which contact was connected to which channel of the parent is lost by the - # time we aggregate. We rebuild a globally-unique wiring via per-probe offsets - # and skip set_probegroup because children also share contact positions. + # Under the id-keyed wiring model, the per-channel `wiring` property + # concatenates across children via the property-merge loop above. We only + # need to combine the probegroups and attach the combined object. if all(rec.has_probe() for rec in recording_list): - aggregated_probegroup = ProbeGroup() - offset = 0 - for rec in recording_list: - for probe in rec.get_probegroup().probes: - # round-trip through to_dict/from_dict because Probe.copy() drops - # contact_ids and annotations (probeinterface #421) - probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) - n = probe_copy.get_contact_count() - probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64")) - aggregated_probegroup.add_probe(probe_copy) - offset += n - self._probegroup = aggregated_probegroup - - # if locations are present, check that they are all different! - if "location" in self.get_property_keys(): - location_tuple = [tuple(loc) for loc in self.get_property("location")] + # intra-parent case: every child shares the same probegroup reference + # (as produced by `split_by` etc.). Reuse it directly. + first_pg = recording_list[0]._probegroup + if all(rec._probegroup is first_pg for rec in recording_list): + combined_probegroup = first_pg + else: + # cross-parent case: build a fresh combined probegroup from copies + # of each probe. Round-trip through to_dict/from_dict because + # `Probe.copy()` currently drops contact_ids and annotations + # (probeinterface #421). Clear `device_channel_indices` on each copy + # so probeinterface's cross-probe dci uniqueness check passes. + combined_probegroup = ProbeGroup() + for rec in recording_list: + for probe in rec._probegroup.probes: + probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) + probe_copy.set_device_channel_indices( + np.full(probe_copy.get_contact_count(), -1, dtype="int64") + ) + combined_probegroup.add_probe(probe_copy) + self._probegroup = combined_probegroup + + # if locations are available (either via attached probe or "location" property), + # check that they are all different + if self.has_probe() or "location" in self.get_property_keys(): + locations = self.get_channel_locations() + location_tuple = [tuple(loc) for loc in locations] assert len(set(location_tuple)) == self.get_num_channels(), ( "Locations are not unique! " "Cannot aggregate recordings!" ) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index f7d498db04..ffa35aa13e 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,12 +61,10 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # slice the probegroup to the retained channels and attach via the canonical path + # inherit the probegroup by reference; the `wiring` per-channel property + # rode through copy_metadata above with the filtered channel_ids if parent_recording.has_probe(): - parent_probegroup = parent_recording.get_probegroup() - sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) - sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self.set_probegroup(sliced_probegroup, in_place=True) + self._probegroup = parent_recording._probegroup # update dump dict self._kwargs = { @@ -152,12 +150,10 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # slice the probegroup to the retained channels and attach via the canonical path + # inherit the probegroup by reference; the `wiring` per-channel property + # rode through copy_metadata above with the filtered channel_ids if parent_snippets.has_probe(): - parent_probegroup = parent_snippets.get_probegroup() - sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) - sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self.set_probegroup(sliced_probegroup, in_place=True) + self._probegroup = parent_snippets._probegroup # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ad444fca4b..7f08f41fa4 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1558,14 +1558,30 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording - # this give all channel locations, so no kwargs like channel_ids and axes + # this give all channel locations, so no kwargs like channel_ids and axes. + # + # Resolve per-channel through the `wiring` property held in rec_attributes, + # matching BaseRecordingSnippets.get_channel_locations. + properties = self.rec_attributes.get("properties", {}) + wiring = properties.get("wiring") probegroup = self.get_probegroup() + + if wiring is not None: + probes_by_id = {p.annotations["probe_id"]: p for p in probegroup.probes} + ndim = probegroup.ndim + locations = np.zeros((len(wiring), ndim), dtype="float64") + for i, (probe_id, contact_id) in enumerate(wiring): + probe = probes_by_id[probe_id] + contact_idx = int(np.where(np.asarray(probe.contact_ids) == contact_id)[0][0]) + locations[i, :ndim] = probe.contact_positions[contact_idx, :ndim] + return locations + + # legacy fallback: pre-id-keyed probegroups were attached with dci = arange(N), + # so sorting by dci yielded channel order. Kept for loading older analyzers. probe_as_numpy_array = probegroup.to_numpy(complete=True) - # we need to sort by device_channel_indices to ensure the order of locations is correct probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - # here we only loop through xy because only 2d locations are supported for i, dim in enumerate(["x", "y"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] return locations diff --git a/src/spikeinterface/preprocessing/basepreprocessor.py b/src/spikeinterface/preprocessing/basepreprocessor.py index 4e18516a80..79f6b5105d 100644 --- a/src/spikeinterface/preprocessing/basepreprocessor.py +++ b/src/spikeinterface/preprocessing/basepreprocessor.py @@ -21,10 +21,8 @@ def __init__(self, recording, sampling_frequency=None, channel_ids=None, dtype=N recording.copy_metadata(self, only_main=False, ids=channel_ids) self._parent = recording - # Propagate the attached probegroup. copy_metadata only handles annotations - # and properties; `_probegroup` is a direct attribute and needs its own path. - # Subclasses that change channels (e.g. slicing) should override by slicing - # the probegroup themselves via set_probegroup. + # Propagate the attached probegroup. `_probegroup` is a recording-global + # direct attribute; the per-channel `wiring` property rides on copy_metadata. if getattr(recording, "_probegroup", None) is not None and channel_ids is None: self._probegroup = recording._probegroup diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 21a211f099..d69e697181 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -404,11 +404,9 @@ def __init__( BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels" and recording.has_probe(): - # slice the probegroup to the retained channels and attach via the canonical path - parent_probegroup = recording.get_probegroup() - sliced_probegroup = parent_probegroup.get_slice(channel_inds) - sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self.set_probegroup(sliced_probegroup, in_place=True) + # inherit the probegroup by reference; the `wiring` per-channel property + # rode through BasePreprocessor's copy_metadata with filtered channel_ids + self._probegroup = recording._probegroup # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From b1de9695290dae98489b8cc684b78eab194d9f75 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 14:40:56 -0600 Subject: [PATCH 23/31] go furhter --- .../core/baserecordingsnippets.py | 27 ++++--------------- .../core/tests/test_baserecording.py | 10 +++---- .../core/tests/test_basesnippets.py | 10 +++---- 3 files changed, 15 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 2889d949a2..788cb4508a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -263,28 +263,11 @@ def get_probegroup(self): pg.add_probe(probe) return copy.deepcopy(pg) - # Build a channel-ordered view of the stored probegroup for the public getter. - # Strong-preserve keeps each probe intact on `_probegroup`; here we slice it down - # to the contacts that actually appear in this recording, in channel order, with - # device_channel_indices = arange(N). The returned object matches the - # pre-strong-preserve `get_probe()` semantic. - wiring = self.get_property("wiring") - if wiring is None: - return copy.deepcopy(self._probegroup) - - # map (probe_id, contact_id) to the global contact index in the stored probegroup - contact_id_to_global = {} - offset = 0 - for probe in self._probegroup.probes: - pid = probe.annotations["probe_id"] - for cid in probe.contact_ids: - contact_id_to_global[(pid, cid)] = offset - offset += 1 - - global_indices = [contact_id_to_global[(pid, cid)] for pid, cid in wiring] - view = self._probegroup.get_slice(np.asarray(global_indices, dtype="int64")) - view.set_global_device_channel_indices(np.arange(len(global_indices), dtype="int64")) - return view + # Strong-preserve: return the stored probegroup as-is. The probe objects carry + # the user's original `device_channel_indices` and the full set of physical + # contacts, not a channel-aligned view. Callers that want channel-ordered + # geometry should use `get_channel_locations()`. + return copy.deepcopy(self._probegroup) def _extra_metadata_from_folder(self, folder): # load probe diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 1ebeb677c6..70d3d7c519 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -201,18 +201,18 @@ def test_BaseRecording(create_cache_folder): positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + # under strong-preserve, get_probe() returns the full preserved probe + # (all physical contacts, user's original device_channel_indices). probe2 = rec_p.get_probe() - positions3 = probe2.contact_positions - assert np.array_equal(positions2, positions3) - - assert np.array_equal(probe2.device_channel_indices, [0, 1]) + assert probe2.get_contact_count() == 6 + assert np.array_equal(probe2.device_channel_indices, [2, -1, 0, -1, -1, -1]) # test save with probe folder = cache_folder / "simple_recording3" rec2 = rec_p.save(folder=folder, chunk_size=10, n_jobs=2) rec2 = load(folder) probe2 = rec2.get_probe() - assert np.array_equal(probe2.contact_positions, [[0, 30.0], [0.0, 0.0]]) + assert probe2.get_contact_count() == 6 positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) traces2 = rec2.get_traces(segment_index=0) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 751a03460c..199ef12321 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -147,18 +147,18 @@ def test_BaseSnippets(create_cache_folder): positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + # under strong-preserve, get_probe() returns the full preserved probe + # (all physical contacts, user's original device_channel_indices). probe2 = snippets_p.get_probe() - positions3 = probe2.contact_positions - assert np.array_equal(positions2, positions3) - - assert np.array_equal(probe2.device_channel_indices, [0, 1]) + assert probe2.get_contact_count() == 6 + assert np.array_equal(probe2.device_channel_indices, [2, -1, 0, -1, -1, -1]) # test save with probe folder = cache_folder / "simple_snippets3" snippets2 = snippets_p.save(folder=folder) snippets2 = load(folder) probe2 = snippets2.get_probe() - assert np.array_equal(probe2.contact_positions, [[0, 30.0], [0.0, 0.0]]) + assert probe2.get_contact_count() == 6 positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) wavefroms2 = snippets2.get_snippets(segment_index=0) From 56761c86d00f0e54e02b28dd96fb029889c71ce6 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 14:48:28 -0600 Subject: [PATCH 24/31] drop deep cpy --- src/spikeinterface/core/baserecordingsnippets.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 788cb4508a..c2183894e5 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,3 @@ -import copy from pathlib import Path import numpy as np @@ -261,13 +260,9 @@ def get_probegroup(self): probe = self.create_dummy_probe_from_locations(positions) pg = ProbeGroup() pg.add_probe(probe) - return copy.deepcopy(pg) + return pg - # Strong-preserve: return the stored probegroup as-is. The probe objects carry - # the user's original `device_channel_indices` and the full set of physical - # contacts, not a channel-aligned view. Callers that want channel-ordered - # geometry should use `get_channel_locations()`. - return copy.deepcopy(self._probegroup) + return self._probegroup def _extra_metadata_from_folder(self, folder): # load probe From 999a5c4bf884780e242946205a4f302c711031db Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 16:12:11 -0600 Subject: [PATCH 25/31] fix --- pyproject.toml | 12 ++++-------- src/spikeinterface/core/base.py | 9 +++++---- src/spikeinterface/core/baserecording.py | 11 +++++++++-- .../core/baserecordingsnippets.py | 9 ++++++++- src/spikeinterface/core/basesnippets.py | 2 +- .../core/channelsaggregationrecording.py | 19 +++++++++++++++---- .../core/tests/test_basesnippets.py | 6 +++--- .../core/tests/test_channelslicerecording.py | 4 +--- 8 files changed, 46 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1a92a109da..17022d1892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,8 +128,7 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge - "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs, @@ -141,8 +140,7 @@ test_extractors = [ "pooch>=1.8.2", "datalad>=1.0.2", # Commenting out for release - # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge - "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] @@ -193,8 +191,7 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge - "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs @@ -223,8 +220,7 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge - "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 57db310357..881ab272de 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -610,12 +610,11 @@ def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> "BaseE return extractor def load_metadata_from_folder(self, folder_metadata): - # hack to load probe for recording folder_metadata = Path(folder_metadata) - self._extra_metadata_from_folder(folder_metadata) - - # load properties + # load properties first so that `_extra_metadata_from_folder` can see + # restored state like the `wiring` property and skip re-running + # `set_probegroup` when the mapping is already in place. prop_folder = folder_metadata / "properties" if prop_folder.is_dir(): for prop_file in prop_folder.iterdir(): @@ -624,6 +623,8 @@ def load_metadata_from_folder(self, folder_metadata): key = prop_file.stem self.set_property(key, values) + self._extra_metadata_from_folder(folder_metadata) + def save_metadata_to_folder(self, folder_metadata): self._extra_metadata_to_folder(folder_metadata) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index ad4f97d25c..268d1232e1 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -392,7 +392,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.has_probe(): + if self.has_probe() and not cached.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) @@ -403,7 +403,14 @@ def _extra_metadata_from_folder(self, folder): folder = Path(folder) if (folder / "probe.json").is_file(): probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + if "wiring" in self.get_property_keys(): + # wiring was restored via the property-load loop; the stored + # probegroup's dci refers to the parent's channel space, so + # re-running `_set_probes` would fail for sliced children. + # Attach the probegroup object directly. + self._probegroup = probegroup + else: + self.set_probegroup(probegroup, in_place=True) # load time vector if any for segment_index, rs in enumerate(self.segments): diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index c2183894e5..f55a3f4013 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -269,7 +269,14 @@ def _extra_metadata_from_folder(self, folder): folder = Path(folder) if (folder / "probe.json").is_file(): probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + if "wiring" in self.get_property_keys(): + # wiring was restored via the property-load loop; the stored + # probegroup's dci refers to the parent's channel space, so + # re-running `_set_probes` would fail for sliced children. + # Attach the probegroup object directly. + self._probegroup = probegroup + else: + self.set_probegroup(probegroup, in_place=True) def _extra_metadata_to_folder(self, folder): # save probe diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index fa47365200..631949259a 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -259,7 +259,7 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.has_probe(): + if self.has_probe() and not cached.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 4307d5a4dc..3bcf134218 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -105,14 +105,25 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record combined_probegroup = first_pg else: # cross-parent case: build a fresh combined probegroup from copies - # of each probe. Round-trip through to_dict/from_dict because - # `Probe.copy()` currently drops contact_ids and annotations - # (probeinterface #421). Clear `device_channel_indices` on each copy - # so probeinterface's cross-probe dci uniqueness check passes. + # of each probe. combined_probegroup = ProbeGroup() for rec in recording_list: for probe in rec._probegroup.probes: + # Round-trip through to_dict/from_dict because `Probe.copy()` + # currently drops contact_ids and annotations (probeinterface + # #421). Once that is fixed we can switch to `probe.copy()`. probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) + # Clear `device_channel_indices` so probeinterface's + # `ProbeGroup.add_probe` cross-probe dci uniqueness check + # passes: each parent's probe originally had dci in its own + # 0..N-1 channel space, and those ranges collide when + # combined. This does not drop provenance: under the + # id-keyed wiring model dci on the probe is a local index, + # not the canonical mapping. The (channel -> probe_id, + # contact_id) mapping is carried by the recording's + # `wiring` property, and everything that identifies the + # physical probe (geometry, contact_ids, annotations, + # planar_contour) survives the to_dict round-trip. probe_copy.set_device_channel_indices( np.full(probe_copy.get_contact_count(), -1, dtype="int64") ) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 199ef12321..40cafed641 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -150,15 +150,15 @@ def test_BaseSnippets(create_cache_folder): # under strong-preserve, get_probe() returns the full preserved probe # (all physical contacts, user's original device_channel_indices). probe2 = snippets_p.get_probe() - assert probe2.get_contact_count() == 6 - assert np.array_equal(probe2.device_channel_indices, [2, -1, 0, -1, -1, -1]) + assert probe2.get_contact_count() == 3 + assert np.array_equal(probe2.device_channel_indices, [2, -1, 0]) # test save with probe folder = cache_folder / "simple_snippets3" snippets2 = snippets_p.save(folder=folder) snippets2 = load(folder) probe2 = snippets2.get_probe() - assert probe2.get_contact_count() == 6 + assert probe2.get_contact_count() == 3 positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) wavefroms2 = snippets2.get_snippets(segment_index=0) diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index 35e4a88c89..9f7e8e2e8e 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -61,11 +61,9 @@ def test_ChannelSliceRecording(create_cache_folder): probe.set_device_channel_indices(np.arange(num_chan)) rec_p = rec.set_probe(probe) rec_sliced3 = ChannelSliceRecording(rec_p, channel_ids=[0, 2], renamed_channel_ids=[3, 4]) - probe3 = rec_sliced3.get_probe() - locations3 = probe3.contact_positions + locations3 = rec_sliced3.get_channel_locations() folder = cache_folder / "sliced_recording" rec_saved = rec_sliced3.save(folder=folder, chunk_size=10, n_jobs=2) - probe = rec_saved.get_probe() assert np.array_equal(locations3, rec_saved.get_channel_locations()) traces3 = rec_saved.get_traces(segment_index=0) assert np.all(traces3[:, 0] == 0) From 5119d5394bd65afb0ffd419afd2743cfa89c07dc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 16:28:55 -0600 Subject: [PATCH 26/31] refactor --- src/spikeinterface/extractors/tests/test_iblextractors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index dfbf5d714d..90ff378575 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -86,6 +86,7 @@ def test_property_keys(self): "offset_to_uV", "location", "group", + "wiring", "shank", "shank_row", "shank_col", From bdf9d24784ac6d5d6cdc18d40816b6e390950330 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 17:09:48 -0600 Subject: [PATCH 27/31] serialization fix --- src/spikeinterface/core/baserecording.py | 5 +++- .../core/baserecordingsnippets.py | 26 +++++++++++++++++++ src/spikeinterface/core/basesnippets.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 12 +++++++-- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 268d1232e1..dfc6e2a0ad 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -394,7 +394,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if self.has_probe() and not cached.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + cached.set_probegroup(probegroup, in_place=True) return cached @@ -409,6 +409,9 @@ def _extra_metadata_from_folder(self, folder): # re-running `_set_probes` would fail for sliced children. # Attach the probegroup object directly. self._probegroup = probegroup + from .baserecordingsnippets import _restore_probe_ids_from_wiring + + _restore_probe_ids_from_wiring(self._probegroup, self.get_property("wiring")) else: self.set_probegroup(probegroup, in_place=True) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index f55a3f4013..5426dfa291 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -10,6 +10,31 @@ from warnings import warn +def _restore_probe_ids_from_wiring(probegroup, wiring): + """ + Re-attach `probe_id` annotations to probes from the `wiring` property when + they have been stripped (e.g. by a reconstruction path that rebuilds the + ProbeGroup without preserving annotations). `wiring` has shape `(N, 2)` + with probe_ids in column 0; unique probe_ids in order of first appearance + correspond directly to the stored probe order, which is how `_set_probes` + builds wiring in the first place. No-op when annotations are intact or + when the counts don't match. + """ + if probegroup is None or wiring is None or len(probegroup.probes) == 0: + return + if all("probe_id" in p.annotations for p in probegroup.probes): + return + seen = [] + for pid in np.asarray(wiring)[:, 0]: + if pid not in seen: + seen.append(pid) + if len(seen) != len(probegroup.probes): + return # cannot reconstruct safely + for probe, pid in zip(probegroup.probes, seen): + if "probe_id" not in probe.annotations: + probe.annotate(probe_id=pid) + + class BaseRecordingSnippets(BaseExtractor): """ Mixin that handles all probe and channel operations @@ -275,6 +300,7 @@ def _extra_metadata_from_folder(self, folder): # re-running `_set_probes` would fail for sliced children. # Attach the probegroup object directly. self._probegroup = probegroup + _restore_probe_ids_from_wiring(self._probegroup, self.get_property("wiring")) else: self.set_probegroup(probegroup, in_place=True) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 631949259a..15533e1606 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -261,7 +261,7 @@ def _save(self, format="npy", **save_kwargs): if self.has_probe() and not cached.has_probe(): probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + cached.set_probegroup(probegroup, in_place=True) return cached diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7f08f41fa4..698cd2dd3c 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -560,7 +560,11 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): probegroup_file = folder / "recording_info" / "probegroup.json" if probegroup_file.is_file(): - rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) + probegroup = probeinterface.read_probeinterface(probegroup_file) + from .baserecordingsnippets import _restore_probe_ids_from_wiring + + _restore_probe_ids_from_wiring(probegroup, rec_attributes.get("properties", {}).get("wiring")) + rec_attributes["probegroup"] = probegroup else: rec_attributes["probegroup"] = None @@ -742,7 +746,11 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes = zarr_root["recording_info"].attrs["recording_attributes"] if "probegroup" in zarr_root["recording_info"].attrs: probegroup_dict = zarr_root["recording_info"].attrs["probegroup"] - rec_attributes["probegroup"] = probeinterface.ProbeGroup.from_dict(probegroup_dict) + probegroup = probeinterface.ProbeGroup.from_dict(probegroup_dict) + from .baserecordingsnippets import _restore_probe_ids_from_wiring + + _restore_probe_ids_from_wiring(probegroup, rec_attributes.get("properties", {}).get("wiring")) + rec_attributes["probegroup"] = probegroup else: rec_attributes["probegroup"] = None From 8b642d9b4b4b04ddff2393b266477dbabaaf48fb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 17:30:52 -0600 Subject: [PATCH 28/31] more fixes --- src/spikeinterface/core/baserecording.py | 9 +++++++-- src/spikeinterface/core/basesnippets.py | 6 ++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index dfc6e2a0ad..220e3bf637 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -393,8 +393,13 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): raise ValueError(f"format {format} not supported") if self.has_probe() and not cached.has_probe(): - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup, in_place=True) + # Share the probegroup by reference. We deliberately skip + # `set_probegroup` (which re-runs _set_probes and validates dci) + # because a child of `split_by` references the parent's full + # probegroup whose dci values can exceed the child's channel + # count. Wiring/location/group properties are carried over by + # the caller's `copy_metadata` step. + cached._probegroup = self._probegroup return cached diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 15533e1606..181f913149 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -260,8 +260,10 @@ def _save(self, format="npy", **save_kwargs): raise ValueError(f"format {format} not supported") if self.has_probe() and not cached.has_probe(): - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup, in_place=True) + # Share the probegroup by reference; see BaseRecording._save for + # the rationale (avoids re-running _set_probes validation on a + # child whose parent's dci exceeds the child's channel count). + cached._probegroup = self._probegroup return cached From e801124746221d601d3af0a817c3cad4c32fad8c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 18:02:19 -0600 Subject: [PATCH 29/31] add to main properties --- src/spikeinterface/core/baserecording.py | 1 + src/spikeinterface/core/basesnippets.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 220e3bf637..e8bcb232b5 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -21,6 +21,7 @@ class BaseRecording(BaseRecordingSnippets, ChunkableMixin): _main_properties = [ "group", "location", + "wiring", "gain_to_uV", "offset_to_uV", "gain_to_physical_unit", diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 181f913149..41d34173b0 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -11,7 +11,7 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] + _main_properties = ["group", "location", "wiring", "gain_to_uV", "offset_to_uV"] _main_features = [] def __init__(self, sampling_frequency: float, nbefore: int | None, snippet_len: int, channel_ids: list, dtype): From ffee07c80805ecd6cfc86e29eb9d1bd34ded272f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 19:06:43 -0600 Subject: [PATCH 30/31] cleanup --- src/spikeinterface/core/baserecording.py | 3 - .../core/baserecordingsnippets.py | 62 +++++++++++-------- src/spikeinterface/core/sortinganalyzer.py | 12 +--- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e8bcb232b5..bec1b5441e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -415,9 +415,6 @@ def _extra_metadata_from_folder(self, folder): # re-running `_set_probes` would fail for sliced children. # Attach the probegroup object directly. self._probegroup = probegroup - from .baserecordingsnippets import _restore_probe_ids_from_wiring - - _restore_probe_ids_from_wiring(self._probegroup, self.get_property("wiring")) else: self.set_probegroup(probegroup, in_place=True) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 5426dfa291..9869e50b9f 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -10,31 +10,6 @@ from warnings import warn -def _restore_probe_ids_from_wiring(probegroup, wiring): - """ - Re-attach `probe_id` annotations to probes from the `wiring` property when - they have been stripped (e.g. by a reconstruction path that rebuilds the - ProbeGroup without preserving annotations). `wiring` has shape `(N, 2)` - with probe_ids in column 0; unique probe_ids in order of first appearance - correspond directly to the stored probe order, which is how `_set_probes` - builds wiring in the first place. No-op when annotations are intact or - when the counts don't match. - """ - if probegroup is None or wiring is None or len(probegroup.probes) == 0: - return - if all("probe_id" in p.annotations for p in probegroup.probes): - return - seen = [] - for pid in np.asarray(wiring)[:, 0]: - if pid not in seen: - seen.append(pid) - if len(seen) != len(probegroup.probes): - return # cannot reconstruct safely - for probe, pid in zip(probegroup.probes, seen): - if "probe_id" not in probe.annotations: - probe.annotate(probe_id=pid) - - class BaseRecordingSnippets(BaseExtractor): """ Mixin that handles all probe and channel operations @@ -87,12 +62,48 @@ def is_filtered(self): return self._annotations.get("is_filtered", False) def set_probe(self, probe, group_mode="auto", in_place=False): + """ + Attach a Probe to a recording. + + Parameters + ---------- + probe: Probe + The probe to be attached to the recording. + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to derive the "group" property mirror. + "auto" is the best splitting possible across multiple probes, multiple shanks, and two sides. + in_place: bool, default: False + If True, attach to self in place (only allowed when all channels are wired). + Useful internally when an extractor calls ``self.set_probe(probe)`` on itself. + + Returns + ------- + sub_recording: BaseRecording + A view of the recording (ChannelSlice or clone or itself) with the probe attached. + """ assert isinstance(probe, Probe), "must give Probe" probegroup = ProbeGroup() probegroup.add_probe(probe) return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) def set_probegroup(self, probegroup, group_mode="auto", in_place=False): + """ + Attach a ProbeGroup to a recording. + + Parameters + ---------- + probegroup: ProbeGroup + The probegroup to be attached to the recording. + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to derive the "group" property mirror. + in_place: bool, default: False + If True, attach to self in place (only allowed when all channels are wired). + + Returns + ------- + sub_recording: BaseRecording + A view of the recording (ChannelSlice or clone or itself) with the probegroup attached. + """ return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): @@ -300,7 +311,6 @@ def _extra_metadata_from_folder(self, folder): # re-running `_set_probes` would fail for sliced children. # Attach the probegroup object directly. self._probegroup = probegroup - _restore_probe_ids_from_wiring(self._probegroup, self.get_property("wiring")) else: self.set_probegroup(probegroup, in_place=True) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 698cd2dd3c..7f08f41fa4 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -560,11 +560,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): probegroup_file = folder / "recording_info" / "probegroup.json" if probegroup_file.is_file(): - probegroup = probeinterface.read_probeinterface(probegroup_file) - from .baserecordingsnippets import _restore_probe_ids_from_wiring - - _restore_probe_ids_from_wiring(probegroup, rec_attributes.get("properties", {}).get("wiring")) - rec_attributes["probegroup"] = probegroup + rec_attributes["probegroup"] = probeinterface.read_probeinterface(probegroup_file) else: rec_attributes["probegroup"] = None @@ -746,11 +742,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes = zarr_root["recording_info"].attrs["recording_attributes"] if "probegroup" in zarr_root["recording_info"].attrs: probegroup_dict = zarr_root["recording_info"].attrs["probegroup"] - probegroup = probeinterface.ProbeGroup.from_dict(probegroup_dict) - from .baserecordingsnippets import _restore_probe_ids_from_wiring - - _restore_probe_ids_from_wiring(probegroup, rec_attributes.get("properties", {}).get("wiring")) - rec_attributes["probegroup"] = probegroup + rec_attributes["probegroup"] = probeinterface.ProbeGroup.from_dict(probegroup_dict) else: rec_attributes["probegroup"] = None From f48b88db2afe09e706da2c773c9087626f6c151c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 19:49:59 -0600 Subject: [PATCH 31/31] fix doccstring --- .../core/baserecordingsnippets.py | 35 +++++-------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 9869e50b9f..ae2071aa3c 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -63,23 +63,23 @@ def is_filtered(self): def set_probe(self, probe, group_mode="auto", in_place=False): """ - Attach a Probe to a recording. + Attach a list of Probe object to a recording. Parameters ---------- - probe: Probe - The probe to be attached to the recording. + probe_or_probegroup: Probe, list of Probe, or ProbeGroup + The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" - How to derive the "group" property mirror. - "auto" is the best splitting possible across multiple probes, multiple shanks, and two sides. - in_place: bool, default: False - If True, attach to self in place (only allowed when all channels are wired). - Useful internally when an extractor calls ``self.set_probe(probe)`` on itself. + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. + in_place: bool + False by default. + Useful internally when extractor do self.set_probegroup(probe) Returns ------- sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) with the probe attached. + A view of the recording (ChannelSlice or clone or itself) """ assert isinstance(probe, Probe), "must give Probe" probegroup = ProbeGroup() @@ -87,23 +87,6 @@ def set_probe(self, probe, group_mode="auto", in_place=False): return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) def set_probegroup(self, probegroup, group_mode="auto", in_place=False): - """ - Attach a ProbeGroup to a recording. - - Parameters - ---------- - probegroup: ProbeGroup - The probegroup to be attached to the recording. - group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" - How to derive the "group" property mirror. - in_place: bool, default: False - If True, attach to self in place (only allowed when all channels are wired). - - Returns - ------- - sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) with the probegroup attached. - """ return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):