diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..3044eb3f10 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -41,7 +41,7 @@ from .numpyextractors import NumpySorting from .sparsity import ChannelSparsity, estimate_sparsity from .sortingfolder import NumpyFolderSorting -from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open +from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open, _write_object_array from .node_pipeline import run_node_pipeline @@ -617,7 +617,6 @@ def _get_zarr_root(self, mode="r+"): def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): # used by create and save_as import zarr - import numcodecs from .zarrextractors import add_sorting_to_zarr_group if is_path_remote(folder): @@ -646,13 +645,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att if recording is not None: rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) if recording.check_serializability("json"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) - zarr_rec = np.array([check_json(rec_dict)], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) + _write_object_array(zarr_root, "recording", check_json(rec_dict), codec="json") elif recording.check_serializability("pickle"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) - zarr_rec = np.array([rec_dict], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + _write_object_array(zarr_root, "recording", rec_dict, codec="pickle") else: warnings.warn("The Recording is not serializable! The recording link will be lost for future load") else: @@ -662,11 +657,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att # sorting provenance sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): - zarr_sort = np.array([check_json(sort_dict)], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) + _write_object_array(zarr_root, "sorting_provenance", check_json(sort_dict), codec="json") elif sorting.check_serializability("pickle"): - zarr_sort = np.array([sort_dict], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) + _write_object_array(zarr_root, "sorting_provenance", sort_dict, codec="pickle") else: warnings.warn( "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" @@ -2703,8 +2696,6 @@ def _save_data(self): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - import numcodecs - saving_options = self.sorting_analyzer._backend_options.get("saving_options", {}) extension_group = self._get_zarr_extension_group(mode="r+") @@ -2717,9 +2708,7 @@ def _save_data(self): del extension_group[ext_data_name] if isinstance(ext_data, (dict, list)): ext_data_ = check_json(ext_data) - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data_], dtype=object), object_codec=numcodecs.JSON() - ) + _write_object_array(extension_group, ext_data_name, ext_data_, codec="json") extension_group[ext_data_name].attrs["dict"] = True elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) @@ -2739,9 +2728,7 @@ def _save_data(self): else: # any object try: - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() - ) + _write_object_array(extension_group, ext_data_name, ext_data, codec="pickle") except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1ef5d76e5a..3f4c01d4bc 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -382,6 +382,51 @@ def resolve_zarr_path(folder_path: str | Path): return folder_path, folder_path_kwarg +def _write_object_array( + group, + name: str, + data, + codec: str = "json", + overwrite: bool = True, +): + """ + Write a length-1 object-dtype array holding a Python dict/list/object. + + Centralizes the v2/v3 codec-placement difference for object blobs: under zarr-v2 + the object codec goes in ``object_codec=``; under zarr-v3 it goes in ``filters=`` + (wrapped via ``numcodecs.zarr3.*``). The helper picks the right path automatically. + + Parameters + ---------- + group : zarr.Group + The zarr group to write into. + name : str + Name of the array inside ``group``. + data : Any + The Python object to store. Wrapped into ``np.array([data], dtype=object)``. + codec : {"json", "pickle"}, default: "json" + Which object codec to use. + overwrite : bool, default: True + Whether to overwrite an existing array with the same name. + """ + import numcodecs + + if codec == "json": + codec_instance = numcodecs.JSON() + elif codec == "pickle": + codec_instance = numcodecs.Pickle() + else: + raise ValueError(f"codec must be 'json' or 'pickle', got {codec!r}") + + arr = np.array([data], dtype=object) + return group.create_dataset( + name=name, + data=arr, + object_codec=codec_instance, + overwrite=overwrite, + ) + + def get_default_zarr_compressor(clevel: int = 5): """ Return default Zarr compressor object for good preformance in int16