From e206e32dc2cddfdca520e24dcbb5c8c4fd123c90 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 21 Apr 2026 13:56:50 +0200 Subject: [PATCH 1/7] add sort_by_depth kwargs to RasterWidget --- src/spikeinterface/widgets/rasters.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 1ddde77bf9..7de1b9a28d 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -363,6 +363,8 @@ class RasterWidget(BaseRasterWidget): A sorting object. Deprecated. sorting_analyzer : SortingAnalyzer | None, default: None A sorting analyzer object. Deprecated. + sort_by_depth: bool = False + Wether or not to sort units by depth, default: False """ def __init__( @@ -375,6 +377,7 @@ def __init__( backend: str | None = None, sorting: BaseSorting | None = None, sorting_analyzer: SortingAnalyzer | None = None, + sort_by_depth : bool = False, **backend_kwargs, ): if sorting is not None: @@ -387,6 +390,8 @@ def __init__( warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) sorting_analyzer_or_sorting = sorting_analyzer + + sorting = self.ensure_sorting(sorting_analyzer_or_sorting) segment_indices = validate_segment_indices(segment_indices, sorting) @@ -394,6 +399,14 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids + if sort_by_depth : + # print("hey") + if not sorting_analyzer_or_sorting.has_extension("unit_locations"): + raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True") + depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") + s_args_depths = np.argsort(depths[:, 1]) + unit_ids = unit_ids[s_args_depths] + # Create dict of dicts structure spike_train_data = {} y_axis_data = {} From b82bba08434165394ef2512149e7f21d7e0f8a34 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:04:20 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/rasters.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 7de1b9a28d..468bf6c3a9 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -377,7 +377,7 @@ def __init__( backend: str | None = None, sorting: BaseSorting | None = None, sorting_analyzer: SortingAnalyzer | None = None, - sort_by_depth : bool = False, + sort_by_depth: bool = False, **backend_kwargs, ): if sorting is not None: @@ -390,8 +390,6 @@ def __init__( warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) sorting_analyzer_or_sorting = sorting_analyzer - - sorting = self.ensure_sorting(sorting_analyzer_or_sorting) segment_indices = validate_segment_indices(segment_indices, sorting) @@ -399,7 +397,7 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - if sort_by_depth : + if sort_by_depth: # print("hey") if not sorting_analyzer_or_sorting.has_extension("unit_locations"): raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True") From 1919bb7eec5c480c868f573d77ea19d03a566527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20H=2E?= <62891573+tayheau@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:07:26 +0200 Subject: [PATCH 3/7] Update src/spikeinterface/widgets/rasters.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/widgets/rasters.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 468bf6c3a9..0ca7f90161 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -363,8 +363,8 @@ class RasterWidget(BaseRasterWidget): A sorting object. Deprecated. sorting_analyzer : SortingAnalyzer | None, default: None A sorting analyzer object. Deprecated. - sort_by_depth: bool = False - Wether or not to sort units by depth, default: False + sort_by_depth : bool, default: False + Whether or not to sort units by depth, default: False """ def __init__( @@ -398,7 +398,6 @@ def __init__( unit_ids = sorting.unit_ids if sort_by_depth: - # print("hey") if not sorting_analyzer_or_sorting.has_extension("unit_locations"): raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True") depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") From b1a56655be10b19d6cadc2ef17332c4c9aefd4cd Mon Sep 17 00:00:00 2001 From: tayheau Date: Thu, 23 Apr 2026 13:59:22 +0200 Subject: [PATCH 4/7] change raster depth sorting logic for widget handling --- src/spikeinterface/widgets/rasters.py | 69 ++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 0ca7f90161..1135b0f7ca 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -63,6 +63,8 @@ def __init__( self, spike_train_data: dict, y_axis_data: dict, + depth_dict: dict | None = None, + sort_by_depth: bool = False, unit_ids: list | None = None, segment_indices: list | None = None, durations: list | None = None, @@ -144,6 +146,8 @@ def __init__( plot_data = dict( spike_train_data=concatenated_spike_trains, y_axis_data=concatenated_y_axis, + depth_dict=depth_dict, + sort_by_depth=sort_by_depth, unit_ids=unit_ids, plot_histograms=plot_histograms, y_lim=y_lim, @@ -200,7 +204,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): spike_train_data = dp.spike_train_data y_axis_data = dp.y_axis_data - + for unit_id in unit_ids: if unit_id not in spike_train_data: continue # Skip this unit if not in data @@ -208,6 +212,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] + if dp.sort_by_depth and dp.depth_dict is not None: + ones = np.ones_like(unit_y_data) + unit_y_data = ones * dp.depth_dict[unit_id] + if dp.color_kwargs is None: scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, color=unit_colors[unit_id]) else: @@ -249,7 +257,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): x_lim = [0, np.sum(dp.durations)] scatter_ax.set_xlim(x_lim) - if dp.y_ticks: + if dp.sort_by_depth and dp.depth_dict is not None: + scatter_ax.set_yticks( + ticks=list(range(len(dp.depth_dict))), + labels=list(dp.depth_dict.keys()) + ) + elif dp.y_ticks: scatter_ax.set_yticks(**dp.y_ticks) scatter_ax.set_title(dp.title) @@ -282,8 +295,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - self.unit_selector = UnitSelector(list(data_plot["spike_train_data"].keys())) - self.unit_selector.value = list(data_plot["spike_train_data"].keys())[:1] + if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None: + unit_list = list(data_plot["depth_dict"].keys()) + else: + unit_list = (list(data_plot["spike_train_data"].keys())) + + self.unit_selector = UnitSelector(unit_list) + self.unit_selector.value = unit_list[:1] children = [self.unit_selector] @@ -294,6 +312,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) children.append(self.checkbox_histograms) + if data_plot["depth_dict"] is not None: + self.checkbox_depth = W.Checkbox( + value=data_plot["sort_by_depth"], + description="Sort by depth", + ) + children.append(self.checkbox_depth) + left_sidebar = W.VBox( children=children, layout=W.Layout(align_items="center", width="100%", height="100%"), @@ -311,16 +336,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector.observe(self._update_plot, names="value", type="change") if data_plot["plot_histograms"] is not None: self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") + if data_plot["depth_dict"] is not None: + self.checkbox_depth.observe(self._update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - + def _full_update_plot(self, change=None): self.figure.clear() data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value if data_plot["plot_histograms"] is not None: data_plot["plot_histograms"] = self.checkbox_histograms.value + if data_plot["depth_dict"] is not None: + data_plot["sort_by_depth"] = self.checkbox_depth.value data_plot["plot_legend"] = False backend_kwargs = dict(figure=self.figure, axes=None, ax=None) @@ -334,15 +363,27 @@ def _update_plot(self, change=None): data_plot["unit_ids"] = self.unit_selector.value if data_plot["plot_histograms"] is not None: data_plot["plot_histograms"] = self.checkbox_histograms.value + if data_plot["depth_dict"] is not None: + data_plot["sort_by_depth"] = self.checkbox_depth.value + + if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None: + unit_list = list(data_plot["depth_dict"].keys()) + else: + unit_list = list(data_plot["spike_train_data"].keys()) + + old_value = self.unit_selector.value + self.unit_selector.unit_ids = unit_list + self.unit_selector.selector.options = unit_list + self.unit_selector.value = old_value + + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_legend"] = False backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) - self.figure.canvas.draw() self.figure.canvas.flush_events() - class RasterWidget(BaseRasterWidget): """ Plots spike train rasters. @@ -397,12 +438,14 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - if sort_by_depth: - if not sorting_analyzer_or_sorting.has_extension("unit_locations"): + + if not sorting_analyzer_or_sorting.has_extension("unit_locations"): + if sort_by_depth: raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True") - depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") - s_args_depths = np.argsort(depths[:, 1]) - unit_ids = unit_ids[s_args_depths] + depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") + s_args_depths = np.argsort(depths[:, 1]) + depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())} + # unit_ids = unit_ids[s_args_depths] # Create dict of dicts structure spike_train_data = {} @@ -445,6 +488,8 @@ def __init__( plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, + depth_dict=depth_dict, + sort_by_depth=sort_by_depth, segment_indices=segment_indices, x_lim=time_range, y_label="Unit id", From 1ceac4ee3208faf659048b915aabbfe1334ebc80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:23:29 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/rasters.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 1135b0f7ca..541e41a0dd 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -204,7 +204,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): spike_train_data = dp.spike_train_data y_axis_data = dp.y_axis_data - + for unit_id in unit_ids: if unit_id not in spike_train_data: continue # Skip this unit if not in data @@ -258,10 +258,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): scatter_ax.set_xlim(x_lim) if dp.sort_by_depth and dp.depth_dict is not None: - scatter_ax.set_yticks( - ticks=list(range(len(dp.depth_dict))), - labels=list(dp.depth_dict.keys()) - ) + scatter_ax.set_yticks(ticks=list(range(len(dp.depth_dict))), labels=list(dp.depth_dict.keys())) elif dp.y_ticks: scatter_ax.set_yticks(**dp.y_ticks) @@ -298,7 +295,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None: unit_list = list(data_plot["depth_dict"].keys()) else: - unit_list = (list(data_plot["spike_train_data"].keys())) + unit_list = list(data_plot["spike_train_data"].keys()) self.unit_selector = UnitSelector(unit_list) self.unit_selector.value = unit_list[:1] @@ -341,7 +338,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - + def _full_update_plot(self, change=None): self.figure.clear() data_plot = self.next_data_plot @@ -370,7 +367,7 @@ def _update_plot(self, change=None): unit_list = list(data_plot["depth_dict"].keys()) else: unit_list = list(data_plot["spike_train_data"].keys()) - + old_value = self.unit_selector.value self.unit_selector.unit_ids = unit_list self.unit_selector.selector.options = unit_list @@ -384,6 +381,7 @@ def _update_plot(self, change=None): self.figure.canvas.draw() self.figure.canvas.flush_events() + class RasterWidget(BaseRasterWidget): """ Plots spike train rasters. @@ -438,13 +436,12 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - if not sorting_analyzer_or_sorting.has_extension("unit_locations"): if sort_by_depth: raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True") depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") s_args_depths = np.argsort(depths[:, 1]) - depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())} + depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())} # unit_ids = unit_ids[s_args_depths] # Create dict of dicts structure From 33d3fec86d3ad380402b3565d98ee001335e00ac Mon Sep 17 00:00:00 2001 From: tayheau Date: Thu, 23 Apr 2026 14:31:13 +0200 Subject: [PATCH 6/7] handle BaseSorting error case --- src/spikeinterface/widgets/rasters.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 541e41a0dd..705fdb4de3 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -436,13 +436,16 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - if not sorting_analyzer_or_sorting.has_extension("unit_locations"): - if sort_by_depth: - raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True") - depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") - s_args_depths = np.argsort(depths[:, 1]) - depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())} - # unit_ids = unit_ids[s_args_depths] + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + if not sorting_analyzer_or_sorting.has_extension("unit_locations"): + if sort_by_depth: + raise AttributeError(f"'unit_locations' necessary for `sort_by_depth=True`") + else: + depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") + s_args_depths = np.argsort(depths[:, 1]) + depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())} + elif sort_by_depth: + raise AttributeError("`sort_by_depth=True` requires a SortingAnalyzer") # Create dict of dicts structure spike_train_data = {} From f61eaee72422d1d8b1d6b90a2d6acd74d5b38998 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:37:28 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/rasters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 705fdb4de3..d08bd24249 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -443,7 +443,7 @@ def __init__( else: depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") s_args_depths = np.argsort(depths[:, 1]) - depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())} + depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())} elif sort_by_depth: raise AttributeError("`sort_by_depth=True` requires a SortingAnalyzer")