diff --git a/changes/3713.misc.md b/changes/3713.misc.md new file mode 100644 index 0000000000..9b0680dfc0 --- /dev/null +++ b/changes/3713.misc.md @@ -0,0 +1 @@ +Vectorize get_chunk_slice for faster sharded array writes. diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index b54b3c2257..85162c2f74 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -46,6 +46,8 @@ from zarr.core.indexing import ( BasicIndexer, SelectorTuple, + _morton_order, + _morton_order_keys, c_order_iter, get_indexer, morton_order_iter, @@ -144,6 +146,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No else: return (int(chunk_start), int(chunk_start + chunk_len)) + def get_chunk_slices_vectorized( + self, chunk_coords_array: npt.NDArray[np.integer[Any]] + ) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]: + """Get chunk slices for multiple coordinates at once. + + Parameters + ---------- + chunk_coords_array : ndarray of shape (n_chunks, n_dims) + Array of chunk coordinates to look up. + + Returns + ------- + starts : ndarray of shape (n_chunks,) + Start byte positions for each chunk. + ends : ndarray of shape (n_chunks,) + End byte positions for each chunk. + valid : ndarray of shape (n_chunks,) + Boolean mask indicating which chunks are non-empty. + """ + # Localize coordinates via modulo (vectorized) + shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64) + localized = chunk_coords_array.astype(np.uint64) % shard_shape + + # Build index tuple for advanced indexing + index_tuple = tuple(localized[:, i] for i in range(localized.shape[1])) + + # Fetch all offsets and lengths at once + offsets_and_lengths = self.offsets_and_lengths[index_tuple] + starts = offsets_and_lengths[:, 0] + lengths = offsets_and_lengths[:, 1] + + # Check for valid (non-empty) chunks + valid = starts != MAX_UINT_64 + + # Compute end positions + ends = starts + lengths + + return starts, ends, valid + def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None: localized_chunk = self._localize_chunk(chunk_coords) if chunk_slice is None: @@ -225,6 +266,34 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[tuple[int, ...]]: return c_order_iter(self.index.offsets_and_lengths.shape[:-1]) + def to_dict_vectorized( + self, + chunk_coords_array: npt.NDArray[np.integer[Any]], + ) -> dict[tuple[int, ...], Buffer | None]: + """Build a dict of chunk coordinates to buffers using vectorized lookup. + + Parameters + ---------- + chunk_coords_array : ndarray of shape (n_chunks, n_dims) + Array of chunk coordinates for vectorized index lookup. + + Returns + ------- + dict mapping chunk coordinate tuples to Buffer or None + """ + starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array) + chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1]) + chunk_coords_keys = _morton_order_keys(chunks_per_shard) + + result: dict[tuple[int, ...], Buffer | None] = {} + for i, coords in enumerate(chunk_coords_keys): + if valid[i]: + result[coords] = self.buf[int(starts[i]) : int(ends[i])] + else: + result[coords] = None + + return result + @dataclass(frozen=True) class ShardingCodec( @@ -511,7 +580,8 @@ async def _encode_partial_single( chunks_per_shard=chunks_per_shard, ) shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} + # Use vectorized lookup for better performance + shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard))) indexer = list( get_indexer( diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index df79728a85..454f7e2290 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1504,10 +1504,13 @@ def decode_morton_vectorized( @lru_cache(maxsize=16) -def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: +def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: n_total = product(chunk_shape) + n_dims = len(chunk_shape) if n_total == 0: - return () + out = np.empty((0, n_dims), dtype=np.intp) + out.flags.writeable = False + return out # Optimization: Remove singleton dimensions to enable magic number usage # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand. @@ -1515,26 +1518,19 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: if singleton_dims: squeezed_shape = tuple(s for s in chunk_shape if s != 1) if squeezed_shape: - # Compute Morton order on squeezed shape - squeezed_order = _morton_order(squeezed_shape) - # Expand coordinates to include singleton dimensions (always 0) - expanded: list[tuple[int, ...]] = [] - for coord in squeezed_order: - full_coord: list[int] = [] - squeezed_idx = 0 - for i in range(len(chunk_shape)): - if chunk_shape[i] == 1: - full_coord.append(0) - else: - full_coord.append(coord[squeezed_idx]) - squeezed_idx += 1 - expanded.append(tuple(full_coord)) - return tuple(expanded) + # Compute Morton order on squeezed shape, then expand singleton dims (always 0) + squeezed_order = np.asarray(_morton_order(squeezed_shape)) + out = np.zeros((n_total, n_dims), dtype=np.intp) + squeezed_col = 0 + for full_col in range(n_dims): + if chunk_shape[full_col] != 1: + out[:, full_col] = squeezed_order[:, squeezed_col] + squeezed_col += 1 else: # All dimensions are singletons, just return the single point - return ((0,) * len(chunk_shape),) - - n_dims = len(chunk_shape) + out = np.zeros((1, n_dims), dtype=np.intp) + out.flags.writeable = False + return out # Find the largest power-of-2 hypercube that fits within chunk_shape. # Within this hypercube, Morton codes are guaranteed to be in bounds. @@ -1547,27 +1543,34 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: n_hypercube = 0 # Within the hypercube, no bounds checking needed - use vectorized decoding - order: list[tuple[int, ...]] if n_hypercube > 0: z_values = np.arange(n_hypercube, dtype=np.intp) - hypercube_coords = decode_morton_vectorized(z_values, chunk_shape) - order = [tuple(row) for row in hypercube_coords] + order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape) else: - order = [] + order = np.empty((0, n_dims), dtype=np.intp) - # For remaining elements, bounds checking is needed + # For remaining elements outside the hypercube, bounds checking is needed + remaining: list[tuple[int, ...]] = [] i = n_hypercube - while len(order) < n_total: + while len(order) + len(remaining) < n_total: m = decode_morton(i, chunk_shape) if all(x < y for x, y in zip(m, chunk_shape, strict=False)): - order.append(m) + remaining.append(m) i += 1 - return tuple(order) + if remaining: + order = np.vstack([order, np.array(remaining, dtype=np.intp)]) + order.flags.writeable = False + return order + + +@lru_cache(maxsize=16) +def _morton_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + return tuple(tuple(int(x) for x in row) for row in _morton_order(chunk_shape)) def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - return iter(_morton_order(tuple(chunk_shape))) + return iter(_morton_order_keys(tuple(chunk_shape))) def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: diff --git a/tests/benchmarks/test_indexing.py b/tests/benchmarks/test_indexing.py index dff2269dcb..d30d731f0f 100644 --- a/tests/benchmarks/test_indexing.py +++ b/tests/benchmarks/test_indexing.py @@ -74,7 +74,7 @@ def test_sharded_morton_indexing( The Morton order cache is cleared before each iteration to measure the full computation cost. """ - from zarr.core.indexing import _morton_order + from zarr.core.indexing import _morton_order, _morton_order_keys # Create array where each shard contains many small chunks # e.g., shards=(32,32,32) with chunks=(2,2,2) means 16x16x16 = 4096 chunks per shard @@ -98,6 +98,7 @@ def test_sharded_morton_indexing( def read_with_cache_clear() -> None: _morton_order.cache_clear() + _morton_order_keys.cache_clear() getitem(data, indexer) benchmark(read_with_cache_clear) @@ -122,7 +123,7 @@ def test_sharded_morton_indexing_large( the Morton order computation a more significant portion of total time. The Morton order cache is cleared before each iteration. """ - from zarr.core.indexing import _morton_order + from zarr.core.indexing import _morton_order, _morton_order_keys # 1x1x1 chunks means chunks_per_shard equals shard shape shape = tuple(s * 2 for s in shards) # 2 shards per dimension @@ -145,6 +146,7 @@ def test_sharded_morton_indexing_large( def read_with_cache_clear() -> None: _morton_order.cache_clear() + _morton_order_keys.cache_clear() getitem(data, indexer) benchmark(read_with_cache_clear) @@ -164,7 +166,7 @@ def test_sharded_morton_single_chunk( computing the full Morton order, making the optimization impact clear. The Morton order cache is cleared before each iteration. """ - from zarr.core.indexing import _morton_order + from zarr.core.indexing import _morton_order, _morton_order_keys # 1x1x1 chunks means chunks_per_shard equals shard shape shape = tuple(s * 2 for s in shards) # 2 shards per dimension @@ -187,6 +189,7 @@ def test_sharded_morton_single_chunk( def read_with_cache_clear() -> None: _morton_order.cache_clear() + _morton_order_keys.cache_clear() getitem(data, indexer) benchmark(read_with_cache_clear) @@ -211,10 +214,11 @@ def test_morton_order_iter( optimization impact without array read/write overhead. The cache is cleared before each iteration. """ - from zarr.core.indexing import _morton_order, morton_order_iter + from zarr.core.indexing import _morton_order, _morton_order_keys, morton_order_iter def compute_morton_order() -> None: _morton_order.cache_clear() + _morton_order_keys.cache_clear() # Consume the iterator to force computation list(morton_order_iter(shape)) @@ -239,7 +243,7 @@ def test_sharded_morton_write_single_chunk( """ import numpy as np - from zarr.core.indexing import _morton_order + from zarr.core.indexing import _morton_order, _morton_order_keys # 1x1x1 chunks means chunks_per_shard equals shard shape shape = tuple(s * 2 for s in shards) # 2 shards per dimension @@ -262,6 +266,7 @@ def test_sharded_morton_write_single_chunk( def write_with_cache_clear() -> None: _morton_order.cache_clear() + _morton_order_keys.cache_clear() data[indexer] = write_data benchmark(write_with_cache_clear)