diff --git a/pyproject.toml b/pyproject.toml index 8261f863b9..17022d1892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,7 +220,7 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index fcbafdb6bf..881ab272de 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -574,6 +574,9 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + if getattr(self, "_probegroup", None) is not None: + dump_dict["probegroup"] = self._probegroup.to_dict(array_as_list=True) + return dump_dict @staticmethod @@ -607,12 +610,11 @@ def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> "BaseE return extractor def load_metadata_from_folder(self, folder_metadata): - # hack to load probe for recording folder_metadata = Path(folder_metadata) - self._extra_metadata_from_folder(folder_metadata) - - # load properties + # load properties first so that `_extra_metadata_from_folder` can see + # restored state like the `wiring` property and skip re-running + # `set_probegroup` when the mapping is already in place. prop_folder = folder_metadata / "properties" if prop_folder.is_dir(): for prop_file in prop_folder.iterdir(): @@ -621,6 +623,8 @@ def load_metadata_from_folder(self, folder_metadata): key = prop_file.stem self.set_property(key, values) + self._extra_metadata_from_folder(folder_metadata) + def save_metadata_to_folder(self, folder_metadata): self._extra_metadata_to_folder(folder_metadata) @@ -1155,9 +1159,53 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": for k, v in dic["properties"].items(): extractor.set_property(k, v) + if "probegroup" in dic: + from probeinterface import ProbeGroup + + probegroup = ProbeGroup.from_dict(dic["probegroup"]) + # The `wiring` per-channel property was restored above by the standard + # property-load loop; we just attach the probegroup object. + extractor._probegroup = probegroup + elif "contact_vector" in dic.get("properties", {}): + _restore_probegroup_from_legacy_contact_vector(extractor) + return extractor +def _restore_probegroup_from_legacy_contact_vector(extractor) -> None: + """ + Reconstruct a `ProbeGroup` from the legacy `contact_vector` property. + + Recordings saved before the probegroup refactor stored the probe as a structured numpy + array under the `contact_vector` property, with probe-level annotations under a separate + `probes_info` annotation and per-probe planar contours under `probe_{i}_planar_contour` + annotations. This function reconstructs a `ProbeGroup` from those legacy fields, attaches + it via the canonical `set_probegroup` path, and removes the legacy property so the new + and old representations do not coexist on the loaded extractor. + """ + from probeinterface import ProbeGroup + + contact_vector_array = extractor.get_property("contact_vector") + probegroup = ProbeGroup.from_numpy(contact_vector_array) + + if "probes_info" in extractor.get_annotation_keys(): + probes_info = extractor.get_annotation("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations = probe_info + + for probe_index, probe in enumerate(probegroup.probes): + contour = extractor._annotations.get(f"probe_{probe_index}_planar_contour") + if contour is not None: + probe.set_planar_contour(contour) + + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + + extractor._properties.pop("contact_vector", None) + + def _get_class_from_string(class_string): class_name = class_string.split(".")[-1] module = ".".join(class_string.split(".")[:-1]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index d94ec941ff..bec1b5441e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -21,6 +21,7 @@ class BaseRecording(BaseRecordingSnippets, ChunkableMixin): _main_properties = [ "group", "location", + "wiring", "gain_to_uV", "offset_to_uV", "gain_to_physical_unit", @@ -392,9 +393,14 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + if self.has_probe() and not cached.has_probe(): + # Share the probegroup by reference. We deliberately skip + # `set_probegroup` (which re-runs _set_probes and validates dci) + # because a child of `split_by` references the parent's full + # probegroup whose dci values can exceed the child's channel + # count. Wiring/location/group properties are carried over by + # the caller's `copy_metadata` step. + cached._probegroup = self._probegroup return cached @@ -403,7 +409,14 @@ def _extra_metadata_from_folder(self, folder): folder = Path(folder) if (folder / "probe.json").is_file(): probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + if "wiring" in self.get_property_keys(): + # wiring was restored via the property-load loop; the stored + # probegroup's dci refers to the parent's channel space, so + # re-running `_set_probes` would fail for sliced children. + # Attach the probegroup object directly. + self._probegroup = probegroup + else: + self.set_probegroup(probegroup, in_place=True) # load time vector if any for segment_index, rs in enumerate(self.segments): @@ -414,7 +427,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..ae2071aa3c 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -19,6 +19,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,7 +52,7 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + return self._probegroup is not None def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() @@ -90,155 +91,171 @@ def set_probegroup(self, probegroup, group_mode="auto", in_place=False): def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): """ - Attach a list of Probe objects to a recording. - For this Probe.device_channel_indices is used to link contacts to recording channels. - If some contacts of the Probe are not connected (device_channel_indices=-1) - then the recording is "sliced" and only connected channel are kept. + Attach a Probe, ProbeGroup, or list of Probe to the recording. - The probe order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. + The probegroup is stored by reference without mutation. The contact-to-channel + mapping is built from each probe's current `device_channel_indices` and stored + on the recording as two per-channel properties, `probe_id` and `contact_id`. + `location` and `group` are also written as properties for backward compatibility. + If the probegroup wires only a subset of the recording's channels, the recording + is sliced via `select_channels` to keep only the wired channels (preserving the + historical attach semantics). Parameters ---------- probe_or_probegroup: Probe, list of Probe, or ProbeGroup - The probe(s) to be attached to the recording - group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" - How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) + The probe(s) to be attached to the recording. + in_place: bool, default: False + If True, attach to self in place. Returns ------- sub_recording: BaseRecording - A view of the recording (ChannelSlice or clone or itself) + The recording with the probegroup attached. """ - assert group_mode in ( - "auto", - "by_probe", - "by_shank", - "by_side", - ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" - - # handle several input possibilities + # normalize input to a ProbeGroup if isinstance(probe_or_probegroup, Probe): probegroup = ProbeGroup() probegroup.add_probe(probe_or_probegroup) elif isinstance(probe_or_probegroup, ProbeGroup): probegroup = probe_or_probegroup elif isinstance(probe_or_probegroup, list): - assert all([isinstance(e, Probe) for e in probe_or_probegroup]) + assert all(isinstance(e, Probe) for e in probe_or_probegroup) probegroup = ProbeGroup() for probe in probe_or_probegroup: probegroup.add_probe(probe) else: raise ValueError("must give Probe or ProbeGroup or list of Probe") - # check that the probe do not overlap - num_probes = len(probegroup.probes) - if num_probes > 1: + if len(probegroup.probes) > 1: check_probe_do_not_overlap(probegroup.probes) - # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr - probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 - if np.any(~keep): - warn("The given probes have unconnected contacts: they are removed") - - probe_as_numpy_array = probe_as_numpy_array[keep] - - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] - - # check TODO: Where did this came from? - number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) - if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( - f"The given Probe either has 'device_channel_indices' that does not match channel count \n" - f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" - f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" - f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" - f"device_channel_indices are the following: {device_channel_indices} \n" - f"recording channels are the following: {self.get_channel_ids()} \n" - ) - raise ValueError(error_msg) - - new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") - - # create recording : channel slice or clone or self + # ensure every probe has a stable probe_id and every contact has a contact_id; + # auto-generate only when missing so user-assigned ids survive + if not any("probe_id" in p.annotations for p in probegroup.probes): + probegroup.auto_generate_probe_ids() + if any(p.contact_ids is None for p in probegroup.probes): + probegroup.auto_generate_contact_ids() + + # collect, per recording channel (by position in self.channel_ids), which + # (probe_id, contact_id) pair wires into it. Unwired channels stay as None. + num_channels = self.get_num_channels() + probe_id_col = [None] * num_channels + contact_id_col = [None] * num_channels + for probe in probegroup.probes: + probe_id = probe.annotations["probe_id"] + dci = np.asarray(probe.device_channel_indices) + for contact_idx, device_idx in enumerate(dci): + if device_idx < 0: + continue # unconnected contact; skip + if device_idx >= num_channels: + raise ValueError( + f"device_channel_indices value {device_idx} is out of range; " + f"recording has {num_channels} channels." + ) + if probe_id_col[device_idx] is not None: + raise ValueError(f"channel at index {device_idx} is wired to more than one contact.") + probe_id_col[device_idx] = probe_id + contact_id_col[device_idx] = probe.contact_ids[contact_idx] + + # Reorder the recording's channels to match the probe's device_channel_indices + # order (smallest dci first). Also drops any channels that are unwired. This + # matches the historical `set_probe` behaviour: after attach, recording channel + # i corresponds to the probe contact whose dci was the i-th smallest. + wired_dci_pairs = [] # (device_idx, position_in_probe_id_col) + for i, pid in enumerate(probe_id_col): + if pid is not None: + wired_dci_pairs.append(i) # position == original device_idx since we indexed by it + # sort by device_idx (which equals the position already) + ordered_positions = sorted(wired_dci_pairs) + original_channel_ids = self.get_channel_ids() + new_channel_ids = original_channel_ids[ordered_positions] + if in_place: - if not np.array_equal(new_channel_ids, self.get_channel_ids()): - raise Exception("set_probe(inplace=True) must have all channel indices") - sub_recording = self + if not np.array_equal(new_channel_ids, original_channel_ids): + raise Exception("set_probe(in_place=True) must have all channel indices") + target = self else: - if np.array_equal(new_channel_ids, self.get_channel_ids()): - sub_recording = self.clone() + if np.array_equal(new_channel_ids, original_channel_ids): + target = self.clone() else: - sub_recording = self.select_channels(new_channel_ids) - - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) - - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property + target = self.select_channels(new_channel_ids) + # re-key probe_id_col / contact_id_col into the (possibly reordered) target + probe_id_col = [probe_id_col[i] for i in ordered_positions] + contact_id_col = [contact_id_col[i] for i in ordered_positions] + + # attach probegroup; the wiring lives as a (num_channels, 2) per-channel + # string property `wiring` with column 0 = probe_id, column 1 = contact_id. + # This is the same pattern as `location` (2D property per channel) and rides + # on SI's existing property plumbing (copy_metadata, concat, serialization). + target._probegroup = probegroup + + # handle the degenerate empty-wiring case (probe with all dci=-1) + if len(probe_id_col) == 0: + ndim = probegroup.ndim + target.set_property("wiring", np.zeros((0, 2), dtype="U64")) + target.set_property("location", np.zeros((0, ndim), dtype="float64")) + target.set_property("group", np.zeros(0, dtype="int64")) + return target + + wiring = np.column_stack( + [ + np.asarray(probe_id_col, dtype="U64"), + np.asarray(contact_id_col, dtype="U64"), + ] + ) + target.set_property("wiring", wiring) + + # write `location` and `group` as compatibility mirrors of the canonical + # probegroup + _channel_to_contact mapping. group_mode is consulted here to + # match the pre-strong-preserve API; callers that pass "by_probe", "by_shank" + # etc. get the same partitioning as before. ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] - sub_recording.set_property("location", locations, ids=None) - - # handle groups - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields + probes_by_id = {p.annotations["probe_id"]: p for p in probegroup.probes} + has_shank = any(p.shank_ids is not None for p in probegroup.probes) + has_side = any(p.contact_sides is not None for p in probegroup.probes) if group_mode == "auto": - group_keys = ["probe_index"] - if has_shank_id: - group_keys += ["shank_ids"] - if has_contact_side: - group_keys += ["contact_sides"] + keys_template = ["probe"] + (["shank"] if has_shank else []) + (["side"] if has_side else []) elif group_mode == "by_probe": - group_keys = ["probe_index"] + keys_template = ["probe"] elif group_mode == "by_shank": - assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" - group_keys = ["probe_index", "shank_ids"] + assert has_shank, "shank_ids is None in probe, you cannot group by shank" + keys_template = ["probe", "shank"] elif group_mode == "by_side": - assert has_contact_side, "contact_sides is None in probe, you cannot group by side" - if has_shank_id: - group_keys = ["probe_index", "shank_ids", "contact_sides"] - else: - group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) - for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) - for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] - groups[mask] = group - sub_recording.set_property("group", groups, ids=None) - - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) - - return sub_recording + assert has_side, "contact_sides is None in probe, you cannot group by side" + keys_template = ["probe"] + (["shank"] if has_shank else []) + ["side"] + else: + raise ValueError(f"unknown group_mode {group_mode!r}") + + wired_positions = list(range(len(probe_id_col))) + locations = np.zeros((len(wired_positions), ndim), dtype="float64") + group_keys_per_channel = [] + for i, (pid, cid) in enumerate(zip(probe_id_col, contact_id_col)): + probe = probes_by_id[pid] + contact_idx = int(np.where(np.asarray(probe.contact_ids) == cid)[0][0]) + locations[i] = probe.contact_positions[contact_idx, :ndim] + key = [] + for k in keys_template: + if k == "probe": + key.append(pid) + elif k == "shank" and probe.shank_ids is not None: + key.append(probe.shank_ids[contact_idx]) + elif k == "side" and probe.contact_sides is not None: + key.append(probe.contact_sides[contact_idx]) + group_keys_per_channel.append(tuple(key)) + + unique_keys = list(dict.fromkeys(group_keys_per_channel)) + key_to_int = {k: i for i, k in enumerate(unique_keys)} + groups = np.array([key_to_int[k] for k in group_keys_per_channel], dtype="int64") + + target.set_property("location", locations) + target.set_property("group", groups) + return target def get_probe(self): probes = self.get_probes() @@ -250,41 +267,39 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") - if arr is None: + if self._probegroup is None: + # Backwards-compat fallback: pre-migration get_probegroup synthesised a dummy + # probe from the "location" property when no probe had been attached. Callers + # (e.g. sparsity.py) rely on this for recordings that have locations but no + # probe. positions = self.get_property("location") if positions is None: raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - else: - warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - probe = self.create_dummy_probe_from_locations(positions) - # probe.create_auto_shape() - probegroup = ProbeGroup() - probegroup.add_probe(probe) - else: - probegroup = ProbeGroup.from_numpy(arr) + warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") + probe = self.create_dummy_probe_from_locations(positions) + pg = ProbeGroup() + pg.add_probe(probe) + return pg - if "probes_info" in self.get_annotation_keys(): - probes_info = self.get_annotation("probes_info") - for probe, probe_info in zip(probegroup.probes, probes_info): - probe.annotations = probe_info - - for probe_index, probe in enumerate(probegroup.probes): - contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - if contour is not None: - probe.set_planar_contour(contour) - return probegroup + return self._probegroup def _extra_metadata_from_folder(self, folder): # load probe folder = Path(folder) if (folder / "probe.json").is_file(): probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + if "wiring" in self.get_property_keys(): + # wiring was restored via the property-load loop; the stored + # probegroup's dci refers to the parent's channel space, so + # re-running `_set_probes` would fail for sliced children. + # Attach the probegroup object directly. + self._probegroup = probegroup + else: + self.set_probegroup(probegroup, in_place=True) def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) @@ -341,31 +356,41 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: + if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() - channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup + + if self.has_probe(): + # resolve each channel via the `wiring` property (column 0 = probe_id, + # column 1 = contact_id) and look up the contact's position on the + # corresponding probe + wiring = self.get_property("wiring", ids=channel_ids) + probes_by_id = {p.annotations["probe_id"]: p for p in self._probegroup.probes} + axis_index = {"x": 0, "y": 1, "z": 2} ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + locations = np.zeros((len(channel_ids), ndim), dtype="float64") + for i, (probe_id, contact_id) in enumerate(wiring): + probe = probes_by_id[probe_id] + contact_idx = int(np.where(np.asarray(probe.contact_ids) == contact_id)[0][0]) + for j, axis in enumerate(axes): + locations[i, j] = probe.contact_positions[contact_idx, axis_index[axis]] + return locations + + # fallback for recordings that have a "location" property but no attached probegroup + channel_indices = self.ids_to_indices(channel_ids) + locations = self.get_property("location") + if locations is None: + raise Exception("There are no channel locations") + locations = np.asarray(locations)[channel_indices] + return select_axes(locations, axes) def has_3d_locations(self) -> bool: + if self.has_probe(): + return self._probegroup.ndim == 3 return self.get_property("location").shape[1] == 3 def clear_channel_locations(self, channel_ids=None): @@ -383,8 +408,34 @@ def set_channel_groups(self, groups, channel_ids=None): self.set_property("group", groups, ids=channel_ids) def get_channel_groups(self, channel_ids=None): - groups = self.get_property("group", ids=channel_ids) - return groups + # when a probe is attached, derive groups on the fly from the `wiring` + # property + probegroup state (probe_id + shank_ids + contact_sides) + if self.has_probe(): + if channel_ids is None: + channel_ids = self.get_channel_ids() + wiring = self.get_property("wiring", ids=channel_ids) + probes_by_id = {p.annotations["probe_id"]: p for p in self._probegroup.probes} + has_shank = any(p.shank_ids is not None for p in self._probegroup.probes) + has_side = any(p.contact_sides is not None for p in self._probegroup.probes) + + group_keys = [] + for probe_id, contact_id in wiring: + probe = probes_by_id[probe_id] + key = [probe_id] + if has_shank or has_side: + contact_idx = int(np.where(np.asarray(probe.contact_ids) == contact_id)[0][0]) + if has_shank and probe.shank_ids is not None: + key.append(probe.shank_ids[contact_idx]) + if has_side and probe.contact_sides is not None: + key.append(probe.contact_sides[contact_idx]) + group_keys.append(tuple(key)) + + unique_keys = list(dict.fromkeys(group_keys)) + key_to_int = {k: i for i, k in enumerate(unique_keys)} + return np.array([key_to_int[k] for k in group_keys], dtype="int64") + + # fallback: read a stored "group" property (recordings without a probe) + return self.get_property("group", ids=channel_ids) def clear_channel_groups(self, channel_ids=None): if channel_ids is None: diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..41d34173b0 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -11,7 +11,7 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] + _main_properties = ["group", "location", "wiring", "gain_to_uV", "offset_to_uV"] _main_features = [] def __init__(self, sampling_frequency: float, nbefore: int | None, snippet_len: int, channel_ids: list, dtype): @@ -259,9 +259,11 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) + if self.has_probe() and not cached.has_probe(): + # Share the probegroup by reference; see BaseRecording._save for + # the rationale (avoids re-running _set_probes validation on a + # child whose parent's dci exceeds the child's channel count). + cached._probegroup = self._probegroup return cached diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 697aab875e..3bcf134218 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,8 @@ import numpy as np +from probeinterface import Probe, ProbeGroup + from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,14 +92,49 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) - # if locations are present, check that they are all different! - if "location" in self.get_property_keys(): - location_tuple = [tuple(loc) for loc in self.get_property("location")] + # Under the id-keyed wiring model, the per-channel `wiring` property + # concatenates across children via the property-merge loop above. We only + # need to combine the probegroups and attach the combined object. + if all(rec.has_probe() for rec in recording_list): + # intra-parent case: every child shares the same probegroup reference + # (as produced by `split_by` etc.). Reuse it directly. + first_pg = recording_list[0]._probegroup + if all(rec._probegroup is first_pg for rec in recording_list): + combined_probegroup = first_pg + else: + # cross-parent case: build a fresh combined probegroup from copies + # of each probe. + combined_probegroup = ProbeGroup() + for rec in recording_list: + for probe in rec._probegroup.probes: + # Round-trip through to_dict/from_dict because `Probe.copy()` + # currently drops contact_ids and annotations (probeinterface + # #421). Once that is fixed we can switch to `probe.copy()`. + probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) + # Clear `device_channel_indices` so probeinterface's + # `ProbeGroup.add_probe` cross-probe dci uniqueness check + # passes: each parent's probe originally had dci in its own + # 0..N-1 channel space, and those ranges collide when + # combined. This does not drop provenance: under the + # id-keyed wiring model dci on the probe is a local index, + # not the canonical mapping. The (channel -> probe_id, + # contact_id) mapping is carried by the recording's + # `wiring` property, and everything that identifies the + # physical probe (geometry, contact_ids, annotations, + # planar_contour) survives the to_dict round-trip. + probe_copy.set_device_channel_indices( + np.full(probe_copy.get_contact_count(), -1, dtype="int64") + ) + combined_probegroup.add_probe(probe_copy) + self._probegroup = combined_probegroup + + # if locations are available (either via attached probe or "location" property), + # check that they are all different + if self.has_probe() or "location" in self.get_property_keys(): + locations = self.get_channel_locations() + location_tuple = [tuple(loc) for loc in locations] assert len(set(location_tuple)) == self.get_num_channels(), ( "Locations are not unique! " "Cannot aggregate recordings!" ) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index de693d5c26..ffa35aa13e 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,11 +61,10 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # inherit the probegroup by reference; the `wiring` per-channel property + # rode through copy_metadata above with the filtered channel_ids + if parent_recording.has_probe(): + self._probegroup = parent_recording._probegroup # update dump dict self._kwargs = { @@ -151,11 +150,10 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # inherit the probegroup by reference; the `wiring` per-channel property + # rode through copy_metadata above with the filtered channel_ids + if parent_snippets.has_probe(): + self._probegroup = parent_snippets._probegroup # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..7f08f41fa4 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match from .core_tools import ( check_json, retrieve_importing_provenance, @@ -363,10 +363,6 @@ def create( f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. " "Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer." ) - # check that multiple probes are non-overlapping - all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) - if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( "Your sorting has spikes with samples times greater than your recording length. These spikes have been removed." @@ -1562,14 +1558,30 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording - # this give all channel locations, so no kwargs like channel_ids and axes + # this give all channel locations, so no kwargs like channel_ids and axes. + # + # Resolve per-channel through the `wiring` property held in rec_attributes, + # matching BaseRecordingSnippets.get_channel_locations. + properties = self.rec_attributes.get("properties", {}) + wiring = properties.get("wiring") probegroup = self.get_probegroup() + + if wiring is not None: + probes_by_id = {p.annotations["probe_id"]: p for p in probegroup.probes} + ndim = probegroup.ndim + locations = np.zeros((len(wiring), ndim), dtype="float64") + for i, (probe_id, contact_id) in enumerate(wiring): + probe = probes_by_id[probe_id] + contact_idx = int(np.where(np.asarray(probe.contact_ids) == contact_id)[0][0]) + locations[i, :ndim] = probe.contact_positions[contact_idx, :ndim] + return locations + + # legacy fallback: pre-id-keyed probegroups were attached with dci = arange(N), + # so sorting by dci yielded channel order. Kept for loading older analyzers. probe_as_numpy_array = probegroup.to_numpy(complete=True) - # we need to sort by device_channel_indices to ensure the order of locations is correct probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - # here we only loop through xy because only 2d locations are supported for i, dim in enumerate(["x", "y"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] return locations diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 1ebeb677c6..70d3d7c519 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -201,18 +201,18 @@ def test_BaseRecording(create_cache_folder): positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + # under strong-preserve, get_probe() returns the full preserved probe + # (all physical contacts, user's original device_channel_indices). probe2 = rec_p.get_probe() - positions3 = probe2.contact_positions - assert np.array_equal(positions2, positions3) - - assert np.array_equal(probe2.device_channel_indices, [0, 1]) + assert probe2.get_contact_count() == 6 + assert np.array_equal(probe2.device_channel_indices, [2, -1, 0, -1, -1, -1]) # test save with probe folder = cache_folder / "simple_recording3" rec2 = rec_p.save(folder=folder, chunk_size=10, n_jobs=2) rec2 = load(folder) probe2 = rec2.get_probe() - assert np.array_equal(probe2.contact_positions, [[0, 30.0], [0.0, 0.0]]) + assert probe2.get_contact_count() == 6 positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) traces2 = rec2.get_traces(segment_index=0) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 751a03460c..40cafed641 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -147,18 +147,18 @@ def test_BaseSnippets(create_cache_folder): positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + # under strong-preserve, get_probe() returns the full preserved probe + # (all physical contacts, user's original device_channel_indices). probe2 = snippets_p.get_probe() - positions3 = probe2.contact_positions - assert np.array_equal(positions2, positions3) - - assert np.array_equal(probe2.device_channel_indices, [0, 1]) + assert probe2.get_contact_count() == 3 + assert np.array_equal(probe2.device_channel_indices, [2, -1, 0]) # test save with probe folder = cache_folder / "simple_snippets3" snippets2 = snippets_p.save(folder=folder) snippets2 = load(folder) probe2 = snippets2.get_probe() - assert np.array_equal(probe2.contact_positions, [[0, 30.0], [0.0, 0.0]]) + assert probe2.get_contact_count() == 3 positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) wavefroms2 = snippets2.get_snippets(segment_index=0) diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index 35e4a88c89..9f7e8e2e8e 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -61,11 +61,9 @@ def test_ChannelSliceRecording(create_cache_folder): probe.set_device_channel_indices(np.arange(num_chan)) rec_p = rec.set_probe(probe) rec_sliced3 = ChannelSliceRecording(rec_p, channel_ids=[0, 2], renamed_channel_ids=[3, 4]) - probe3 = rec_sliced3.get_probe() - locations3 = probe3.contact_positions + locations3 = rec_sliced3.get_channel_locations() folder = cache_folder / "sliced_recording" rec_saved = rec_sliced3.save(folder=folder, chunk_size=10, n_jobs=2) - probe = rec_saved.get_probe() assert np.array_equal(locations3, rec_saved.get_channel_locations()) traces3 = rec_saved.get_traces(segment_index=0) assert np.all(traces3[:, 0] == 0) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index a51063af3e..26c01287ad 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -503,7 +503,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c2ceb66523 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np import probeinterface from spikeinterface.core.core_tools import define_function_from_class @@ -71,8 +72,9 @@ def __init__( probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + probe = self.get_probegroup().probes[0] + self.set_property("row", np.asarray(probe.contact_annotations["row"])) + self.set_property("col", np.asarray(probe.contact_annotations["col"])) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..875422d00b 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,8 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + probe = self.get_probegroup().probes[0] + self.set_property("electrode", np.asarray(probe.contact_annotations["electrode"])) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index bd0d2184d4..90ff378575 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -84,9 +84,9 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", "location", "group", + "wiring", "shank", "shank_row", "shank_col", diff --git a/src/spikeinterface/preprocessing/basepreprocessor.py b/src/spikeinterface/preprocessing/basepreprocessor.py index 64d57d3637..79f6b5105d 100644 --- a/src/spikeinterface/preprocessing/basepreprocessor.py +++ b/src/spikeinterface/preprocessing/basepreprocessor.py @@ -21,6 +21,11 @@ def __init__(self, recording, sampling_frequency=None, channel_ids=None, dtype=N recording.copy_metadata(self, only_main=False, ids=channel_ids) self._parent = recording + # Propagate the attached probegroup. `_probegroup` is a recording-global + # direct attribute; the per-channel `wiring` property rides on copy_metadata. + if getattr(recording, "_probegroup", None) is not None and channel_ids is None: + self._probegroup = recording._probegroup + # self._kwargs have to be handled in subclass diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index a571894374..10dbb11fc1 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -126,8 +126,9 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + probe = recording._probegroup.probes[0] + probe._contact_positions[:, 0] = x + recording.set_property("location", recording.get_channel_locations()) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -170,9 +171,11 @@ def test_output_values(): [5, 5, 5, 7, 3], ] # all others equal distance away. # Overwrite the probe information with the new locations + probe = recording._probegroup.probes[0] for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + probe._contact_positions[idx, 0] = x + probe._contact_positions[idx, 1] = y + recording.set_property("location", recording.get_channel_locations()) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +189,10 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + probe = recording._probegroup.probes[0] + probe._contact_positions[-1, 0] = 5 + probe._contact_positions[-1, 1] = 9 + recording.set_property("location", recording.get_channel_locations()) expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..a84d3bbf64 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe() or "location" in recording.get_property_keys(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a433eeb643..d69e697181 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -403,13 +403,10 @@ def __init__( dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) - if border_mode == "remove_channels": - # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if border_mode == "remove_channels" and recording.has_probe(): + # inherit the probegroup by reference; the `wiring` per-channel property + # rode through BasePreprocessor's copy_metadata with filtered channel_ids + self._probegroup = recording._probegroup # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below