Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3717.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add benchmarks for Morton order computation with non-power-of-2 and near-miss shard shapes, covering both pure computation and end-to-end read/write performance.
85 changes: 39 additions & 46 deletions src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,54 +1512,47 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
out.flags.writeable = False
return out

# Optimization: Remove singleton dimensions to enable magic number usage
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
if singleton_dims:
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
if squeezed_shape:
# Compute Morton order on squeezed shape, then expand singleton dims (always 0)
squeezed_order = np.asarray(_morton_order(squeezed_shape))
out = np.zeros((n_total, n_dims), dtype=np.intp)
squeezed_col = 0
for full_col in range(n_dims):
if chunk_shape[full_col] != 1:
out[:, full_col] = squeezed_order[:, squeezed_col]
squeezed_col += 1
else:
# All dimensions are singletons, just return the single point
out = np.zeros((1, n_dims), dtype=np.intp)
out.flags.writeable = False
return out

# Find the largest power-of-2 hypercube that fits within chunk_shape.
# Within this hypercube, Morton codes are guaranteed to be in bounds.
min_dim = min(chunk_shape)
if min_dim >= 1:
power = min_dim.bit_length() - 1 # floor(log2(min_dim))
hypercube_size = 1 << power # 2^power
n_hypercube = hypercube_size**n_dims
# Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
# all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
# of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
# is the size of this hypercube.
total_bits = sum((c - 1).bit_length() for c in chunk_shape)
n_z = 1 << total_bits if total_bits > 0 else 1

# Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
# This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
# n_z=262144, n_total=35937), consider the argsort strategy below.
if n_z <= 4 * n_total:
# Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
# Works well when the overgeneration ratio n_z/n_total is small (≤4).
z_values = np.arange(n_z, dtype=np.intp)
all_coords = decode_morton_vectorized(z_values, chunk_shape)
shape_arr = np.array(chunk_shape, dtype=np.intp)
valid_mask = np.all(all_coords < shape_arr, axis=1)
order = all_coords[valid_mask]
else:
n_hypercube = 0
# Argsort strategy: enumerate all n_total valid coordinates directly,
# encode each to a Morton code, then sort by code. Avoids the 8x or
# larger overgeneration penalty for near-miss shapes like (33,33,33).
# Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
# vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij")
all_coords = np.stack([g.ravel() for g in grids], axis=1)

# Encode all coordinates to Morton codes (vectorized).
bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape)
max_coord_bits = max(bits_per_dim)
z_codes = np.zeros(n_total, dtype=np.intp)
output_bit = 0
for coord_bit in range(max_coord_bits):
for dim in range(n_dims):
if coord_bit < bits_per_dim[dim]:
z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit
output_bit += 1

sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable")
order = all_coords[sort_idx]

# Within the hypercube, no bounds checking needed - use vectorized decoding
if n_hypercube > 0:
z_values = np.arange(n_hypercube, dtype=np.intp)
order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape)
else:
order = np.empty((0, n_dims), dtype=np.intp)

# For remaining elements outside the hypercube, bounds checking is needed
remaining: list[tuple[int, ...]] = []
i = n_hypercube
while len(order) + len(remaining) < n_total:
m = decode_morton(i, chunk_shape)
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
remaining.append(m)
i += 1

if remaining:
order = np.vstack([order, np.array(remaining, dtype=np.intp)])
order.flags.writeable = False
return order

Expand Down
15 changes: 11 additions & 4 deletions tests/benchmarks/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def read_with_cache_clear() -> None:

# Benchmark with larger chunks_per_shard to make Morton order impact more visible
large_morton_shards = (
(32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard
(32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard (power-of-2)
(30,) * 3, # With 1x1x1 chunks: 30x30x30 = 27000 chunks per shard (non-power-of-2)
(33,)
* 3, # With 1x1x1 chunks: 33x33x33 = 35937 chunks per shard (near-miss: just above power-of-2)
)


Expand Down Expand Up @@ -197,9 +200,13 @@ def read_with_cache_clear() -> None:

# Benchmark for morton_order_iter directly (no I/O)
morton_iter_shapes = (
(8, 8, 8), # 512 elements
(16, 16, 16), # 4096 elements
(32, 32, 32), # 32768 elements
(8, 8, 8), # 512 elements (power-of-2)
(10, 10, 10), # 1000 elements (non-power-of-2)
(16, 16, 16), # 4096 elements (power-of-2)
(20, 20, 20), # 8000 elements (non-power-of-2)
(32, 32, 32), # 32768 elements (power-of-2)
(30, 30, 30), # 27000 elements (non-power-of-2)
(33, 33, 33), # 35937 elements (near-miss: just above power-of-2, n_z=262144)
)


Expand Down
Loading