Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .numpyextractors import NumpySorting
from .sparsity import ChannelSparsity, estimate_sparsity
from .sortingfolder import NumpyFolderSorting
from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open
from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open, _write_object_array
from .node_pipeline import run_node_pipeline


Expand Down Expand Up @@ -617,7 +617,6 @@ def _get_zarr_root(self, mode="r+"):
def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options):
# used by create and save_as
import zarr
import numcodecs
from .zarrextractors import add_sorting_to_zarr_group

if is_path_remote(folder):
Expand Down Expand Up @@ -646,13 +645,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
if recording is not None:
rec_dict = recording.to_dict(relative_to=relative_to, recursive=True)
if recording.check_serializability("json"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON())
zarr_rec = np.array([check_json(rec_dict)], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON())
_write_object_array(zarr_root, "recording", check_json(rec_dict), codec="json")
elif recording.check_serializability("pickle"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle())
zarr_rec = np.array([rec_dict], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle())
_write_object_array(zarr_root, "recording", rec_dict, codec="pickle")
else:
warnings.warn("The Recording is not serializable! The recording link will be lost for future load")
else:
Expand All @@ -662,11 +657,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
# sorting provenance
sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True)
if sorting.check_serializability("json"):
zarr_sort = np.array([check_json(sort_dict)], dtype=object)
zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON())
_write_object_array(zarr_root, "sorting_provenance", check_json(sort_dict), codec="json")
elif sorting.check_serializability("pickle"):
zarr_sort = np.array([sort_dict], dtype=object)
zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle())
_write_object_array(zarr_root, "sorting_provenance", sort_dict, codec="pickle")
else:
warnings.warn(
"The sorting provenance is not serializable! The sorting provenance link will be lost for future load"
Expand Down Expand Up @@ -2703,8 +2696,6 @@ def _save_data(self):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
elif self.format == "zarr":
import numcodecs

saving_options = self.sorting_analyzer._backend_options.get("saving_options", {})
extension_group = self._get_zarr_extension_group(mode="r+")

Expand All @@ -2717,9 +2708,7 @@ def _save_data(self):
del extension_group[ext_data_name]
if isinstance(ext_data, (dict, list)):
ext_data_ = check_json(ext_data)
extension_group.create_dataset(
name=ext_data_name, data=np.array([ext_data_], dtype=object), object_codec=numcodecs.JSON()
)
_write_object_array(extension_group, ext_data_name, ext_data_, codec="json")
extension_group[ext_data_name].attrs["dict"] = True
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options)
Expand All @@ -2739,9 +2728,7 @@ def _save_data(self):
else:
# any object
try:
extension_group.create_dataset(
name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle()
)
_write_object_array(extension_group, ext_data_name, ext_data, codec="pickle")
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
Expand Down
45 changes: 45 additions & 0 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,51 @@ def resolve_zarr_path(folder_path: str | Path):
return folder_path, folder_path_kwarg


def _write_object_array(
group,
name: str,
data,
codec: str = "json",
overwrite: bool = True,
):
"""
Write a length-1 object-dtype array holding a Python dict/list/object.

Centralizes the v2/v3 codec-placement difference for object blobs: under zarr-v2
the object codec goes in ``object_codec=``; under zarr-v3 it goes in ``filters=``
(wrapped via ``numcodecs.zarr3.*``). The helper picks the right path automatically.

Parameters
----------
group : zarr.Group
The zarr group to write into.
name : str
Name of the array inside ``group``.
data : Any
The Python object to store. Wrapped into ``np.array([data], dtype=object)``.
codec : {"json", "pickle"}, default: "json"
Which object codec to use.
overwrite : bool, default: True
Whether to overwrite an existing array with the same name.
"""
import numcodecs

if codec == "json":
codec_instance = numcodecs.JSON()
elif codec == "pickle":
codec_instance = numcodecs.Pickle()
else:
raise ValueError(f"codec must be 'json' or 'pickle', got {codec!r}")

arr = np.array([data], dtype=object)
return group.create_dataset(
name=name,
data=arr,
object_codec=codec_instance,
overwrite=overwrite,
)


def get_default_zarr_compressor(clevel: int = 5):
"""
Return default Zarr compressor object for good preformance in int16
Expand Down
Loading