From 085219e23001803d65ae5d4fe11161318d61b9a9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 23 Apr 2026 20:12:47 -0600 Subject: [PATCH] Improve Probe.copy() --- src/probeinterface/probe.py | 26 +++++++++------------ tests/test_probe.py | 45 +++++++++++++++++++++++++++++++++++++ tests/test_probegroup.py | 14 ++++++++---- 3 files changed, 66 insertions(+), 19 deletions(-) diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index c19ddd65..4ab8c9d0 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -1,4 +1,5 @@ import numpy as np +from copy import deepcopy from typing import Literal from pathlib import Path @@ -662,24 +663,19 @@ def __eq__(self, other): return True - def copy(self): + def copy(self) -> "Probe": """ - Copy to another Probe instance. + Identity-preserving deep copy of the Probe. - Note: device_channel_indices are not copied - and contact_ids are not copied + Preserves contacts, contact_ids, shank_ids, contact_sides, annotations + (name, model_name, manufacturer, serial_number, description), and + contact_annotations. Does not copy ``device_channel_indices`` because + wiring is attached by the caller at use time, not part of the probe's + identity. """ - other = Probe() - other.set_contacts( - positions=self.contact_positions.copy(), - plane_axes=self.contact_plane_axes.copy(), - shapes=self.contact_shapes.copy(), - shape_params=self.contact_shape_params.copy(), - ) - if self.probe_planar_contour is not None: - other.set_planar_contour(self.probe_planar_contour.copy()) - # channel_indices are not copied - return other + d = deepcopy(self.to_dict()) + d.pop("device_channel_indices", None) + return Probe.from_dict(d) def to_3d(self, axes: Literal["xy", "yz", "xz"] = "xz"): """ diff --git a/tests/test_probe.py b/tests/test_probe.py index 631b6d3f..3ed8a2f1 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -228,6 +228,51 @@ def test_double_side_probe(): assert probe4 == probe +def _annotated_probe(): + probe = generate_dummy_probe() + n = probe.get_contact_count() + probe.set_contact_ids([f"c{i}" for i in range(n)]) + probe.set_shank_ids(np.array(["s0"] * (n // 2) + ["s1"] * (n - n // 2))) + probe.set_device_channel_indices(np.arange(n)[::-1]) + probe.annotate(name="dummy", manufacturer="acme", model_name="x1", serial_number="sn-42") + probe.annotate_contacts(impedance=np.linspace(1.0, 2.0, n)) + return probe + + +def test_copy_preserves_identity(): + probe = _annotated_probe() + probe2 = probe.copy() + + assert probe2 is not probe + np.testing.assert_array_equal(probe2.contact_ids, probe.contact_ids) + np.testing.assert_array_equal(probe2.shank_ids, probe.shank_ids) + assert probe2.annotations == probe.annotations + assert probe2.contact_annotations.keys() == probe.contact_annotations.keys() + for key in probe.contact_annotations: + np.testing.assert_array_equal(probe2.contact_annotations[key], probe.contact_annotations[key]) + + +def test_copy_drops_device_channel_indices(): + probe = _annotated_probe() + probe2 = probe.copy() + + assert probe2.device_channel_indices is None + + +def test_copy_is_independent(): + probe = _annotated_probe() + probe2 = probe.copy() + + probe2.annotations["manufacturer"] = "mutated" + probe2.contact_annotations["impedance"][0] = 999.0 + probe2.move([999, 999]) + probe2._contact_ids[0] = "zzz" + + assert probe.annotations["manufacturer"] == "acme" + assert probe.contact_annotations["impedance"][0] != 999.0 + assert probe.contact_ids[0] == "c0" + + if __name__ == "__main__": import tempfile diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index c9421908..ddd332d4 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -179,11 +179,17 @@ def test_copy_preserves_device_channel_indices(probegroup): ) -def test_copy_does_not_preserve_contact_ids(probegroup): - """Probe.copy() intentionally does not copy contact_ids.""" +def test_copy_preserves_contact_ids(probegroup): + """Probe.copy() preserves contact_ids when they are set on the probe.""" + for index, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + probe.set_contact_ids([f"p{index}-c{i}" for i in range(n)]) + pg_copy = probegroup.copy() - # All contact_ids should be empty strings after copy - assert all(cid == "" for cid in pg_copy.get_global_contact_ids()) + + original_ids = probegroup.get_global_contact_ids() + copied_ids = pg_copy.get_global_contact_ids() + np.testing.assert_array_equal(copied_ids, original_ids) def test_copy_is_independent(probegroup):