diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 1ddde77bf9..d08bd24249 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -63,6 +63,8 @@ def __init__( self, spike_train_data: dict, y_axis_data: dict, + depth_dict: dict | None = None, + sort_by_depth: bool = False, unit_ids: list | None = None, segment_indices: list | None = None, durations: list | None = None, @@ -144,6 +146,8 @@ def __init__( plot_data = dict( spike_train_data=concatenated_spike_trains, y_axis_data=concatenated_y_axis, + depth_dict=depth_dict, + sort_by_depth=sort_by_depth, unit_ids=unit_ids, plot_histograms=plot_histograms, y_lim=y_lim, @@ -208,6 +212,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] + if dp.sort_by_depth and dp.depth_dict is not None: + ones = np.ones_like(unit_y_data) + unit_y_data = ones * dp.depth_dict[unit_id] + if dp.color_kwargs is None: scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, color=unit_colors[unit_id]) else: @@ -249,7 +257,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): x_lim = [0, np.sum(dp.durations)] scatter_ax.set_xlim(x_lim) - if dp.y_ticks: + if dp.sort_by_depth and dp.depth_dict is not None: + scatter_ax.set_yticks(ticks=list(range(len(dp.depth_dict))), labels=list(dp.depth_dict.keys())) + elif dp.y_ticks: scatter_ax.set_yticks(**dp.y_ticks) scatter_ax.set_title(dp.title) @@ -282,8 +292,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - self.unit_selector = UnitSelector(list(data_plot["spike_train_data"].keys())) - self.unit_selector.value = list(data_plot["spike_train_data"].keys())[:1] + if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None: + unit_list = list(data_plot["depth_dict"].keys()) + else: + unit_list = list(data_plot["spike_train_data"].keys()) + + self.unit_selector = UnitSelector(unit_list) + self.unit_selector.value = unit_list[:1] children = [self.unit_selector] @@ -294,6 +309,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) children.append(self.checkbox_histograms) + if data_plot["depth_dict"] is not None: + self.checkbox_depth = W.Checkbox( + value=data_plot["sort_by_depth"], + description="Sort by depth", + ) + children.append(self.checkbox_depth) + left_sidebar = W.VBox( children=children, layout=W.Layout(align_items="center", width="100%", height="100%"), @@ -311,6 +333,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector.observe(self._update_plot, names="value", type="change") if data_plot["plot_histograms"] is not None: self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") + if data_plot["depth_dict"] is not None: + self.checkbox_depth.observe(self._update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -321,6 +345,8 @@ def _full_update_plot(self, change=None): data_plot["unit_ids"] = self.unit_selector.value if data_plot["plot_histograms"] is not None: data_plot["plot_histograms"] = self.checkbox_histograms.value + if data_plot["depth_dict"] is not None: + data_plot["sort_by_depth"] = self.checkbox_depth.value data_plot["plot_legend"] = False backend_kwargs = dict(figure=self.figure, axes=None, ax=None) @@ -334,11 +360,24 @@ def _update_plot(self, change=None): data_plot["unit_ids"] = self.unit_selector.value if data_plot["plot_histograms"] is not None: data_plot["plot_histograms"] = self.checkbox_histograms.value + if data_plot["depth_dict"] is not None: + data_plot["sort_by_depth"] = self.checkbox_depth.value + + if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None: + unit_list = list(data_plot["depth_dict"].keys()) + else: + unit_list = list(data_plot["spike_train_data"].keys()) + + old_value = self.unit_selector.value + self.unit_selector.unit_ids = unit_list + self.unit_selector.selector.options = unit_list + self.unit_selector.value = old_value + + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_legend"] = False backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) - self.figure.canvas.draw() self.figure.canvas.flush_events() @@ -363,6 +402,8 @@ class RasterWidget(BaseRasterWidget): A sorting object. Deprecated. sorting_analyzer : SortingAnalyzer | None, default: None A sorting analyzer object. Deprecated. + sort_by_depth : bool, default: False + Whether or not to sort units by depth, default: False """ def __init__( @@ -375,6 +416,7 @@ def __init__( backend: str | None = None, sorting: BaseSorting | None = None, sorting_analyzer: SortingAnalyzer | None = None, + sort_by_depth: bool = False, **backend_kwargs, ): if sorting is not None: @@ -394,6 +436,17 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + if not sorting_analyzer_or_sorting.has_extension("unit_locations"): + if sort_by_depth: + raise AttributeError(f"'unit_locations' necessary for `sort_by_depth=True`") + else: + depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy") + s_args_depths = np.argsort(depths[:, 1]) + depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())} + elif sort_by_depth: + raise AttributeError("`sort_by_depth=True` requires a SortingAnalyzer") + # Create dict of dicts structure spike_train_data = {} y_axis_data = {} @@ -435,6 +488,8 @@ def __init__( plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, + depth_dict=depth_dict, + sort_by_depth=sort_by_depth, segment_indices=segment_indices, x_lim=time_range, y_label="Unit id",