Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions src/spikeinterface/widgets/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(
self,
spike_train_data: dict,
y_axis_data: dict,
depth_dict: dict | None = None,
sort_by_depth: bool = False,
unit_ids: list | None = None,
segment_indices: list | None = None,
durations: list | None = None,
Expand Down Expand Up @@ -144,6 +146,8 @@ def __init__(
plot_data = dict(
spike_train_data=concatenated_spike_trains,
y_axis_data=concatenated_y_axis,
depth_dict=depth_dict,
sort_by_depth=sort_by_depth,
unit_ids=unit_ids,
plot_histograms=plot_histograms,
y_lim=y_lim,
Expand Down Expand Up @@ -208,6 +212,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate]
unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate]

if dp.sort_by_depth and dp.depth_dict is not None:
ones = np.ones_like(unit_y_data)
unit_y_data = ones * dp.depth_dict[unit_id]

if dp.color_kwargs is None:
scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, color=unit_colors[unit_id])
else:
Expand Down Expand Up @@ -249,7 +257,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
x_lim = [0, np.sum(dp.durations)]
scatter_ax.set_xlim(x_lim)

if dp.y_ticks:
if dp.sort_by_depth and dp.depth_dict is not None:
scatter_ax.set_yticks(ticks=list(range(len(dp.depth_dict))), labels=list(dp.depth_dict.keys()))
elif dp.y_ticks:
scatter_ax.set_yticks(**dp.y_ticks)

scatter_ax.set_title(dp.title)
Expand Down Expand Up @@ -282,8 +292,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
plt.show()

self.unit_selector = UnitSelector(list(data_plot["spike_train_data"].keys()))
self.unit_selector.value = list(data_plot["spike_train_data"].keys())[:1]
if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
unit_list = list(data_plot["depth_dict"].keys())
else:
unit_list = list(data_plot["spike_train_data"].keys())

self.unit_selector = UnitSelector(unit_list)
self.unit_selector.value = unit_list[:1]

children = [self.unit_selector]

Expand All @@ -294,6 +309,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
)
children.append(self.checkbox_histograms)

if data_plot["depth_dict"] is not None:
self.checkbox_depth = W.Checkbox(
value=data_plot["sort_by_depth"],
description="Sort by depth",
)
children.append(self.checkbox_depth)

left_sidebar = W.VBox(
children=children,
layout=W.Layout(align_items="center", width="100%", height="100%"),
Expand All @@ -311,6 +333,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
self.unit_selector.observe(self._update_plot, names="value", type="change")
if data_plot["plot_histograms"] is not None:
self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change")
if data_plot["depth_dict"] is not None:
self.checkbox_depth.observe(self._update_plot, names="value", type="change")

if backend_kwargs["display"]:
display(self.widget)
Expand All @@ -321,6 +345,8 @@ def _full_update_plot(self, change=None):
data_plot["unit_ids"] = self.unit_selector.value
if data_plot["plot_histograms"] is not None:
data_plot["plot_histograms"] = self.checkbox_histograms.value
if data_plot["depth_dict"] is not None:
data_plot["sort_by_depth"] = self.checkbox_depth.value
data_plot["plot_legend"] = False

backend_kwargs = dict(figure=self.figure, axes=None, ax=None)
Expand All @@ -334,11 +360,24 @@ def _update_plot(self, change=None):
data_plot["unit_ids"] = self.unit_selector.value
if data_plot["plot_histograms"] is not None:
data_plot["plot_histograms"] = self.checkbox_histograms.value
if data_plot["depth_dict"] is not None:
data_plot["sort_by_depth"] = self.checkbox_depth.value

if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
unit_list = list(data_plot["depth_dict"].keys())
else:
unit_list = list(data_plot["spike_train_data"].keys())

old_value = self.unit_selector.value
self.unit_selector.unit_ids = unit_list
self.unit_selector.selector.options = unit_list
self.unit_selector.value = old_value

data_plot["unit_ids"] = self.unit_selector.value
data_plot["plot_legend"] = False

backend_kwargs = dict(figure=None, axes=self.axes, ax=None)
self.plot_matplotlib(data_plot, **backend_kwargs)

self.figure.canvas.draw()
self.figure.canvas.flush_events()

Expand All @@ -363,6 +402,8 @@ class RasterWidget(BaseRasterWidget):
A sorting object. Deprecated.
sorting_analyzer : SortingAnalyzer | None, default: None
A sorting analyzer object. Deprecated.
sort_by_depth : bool, default: False
Whether or not to sort units by depth, default: False
"""

def __init__(
Expand All @@ -375,6 +416,7 @@ def __init__(
backend: str | None = None,
sorting: BaseSorting | None = None,
sorting_analyzer: SortingAnalyzer | None = None,
sort_by_depth: bool = False,
**backend_kwargs,
):
if sorting is not None:
Expand All @@ -394,6 +436,17 @@ def __init__(
if unit_ids is None:
unit_ids = sorting.unit_ids

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
if sort_by_depth:
raise AttributeError(f"'unit_locations' necessary for `sort_by_depth=True`")
else:
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
s_args_depths = np.argsort(depths[:, 1])
depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())}
elif sort_by_depth:
raise AttributeError("`sort_by_depth=True` requires a SortingAnalyzer")

# Create dict of dicts structure
spike_train_data = {}
y_axis_data = {}
Expand Down Expand Up @@ -435,6 +488,8 @@ def __init__(
plot_data = dict(
spike_train_data=spike_train_data,
y_axis_data=y_axis_data,
depth_dict=depth_dict,
sort_by_depth=sort_by_depth,
segment_indices=segment_indices,
x_lim=time_range,
y_label="Unit id",
Expand Down
Loading