From 6c4e53f84d9efc9fa08870b58f2b67276b0d3e95 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 24 Oct 2025 21:31:05 +0200 Subject: [PATCH 01/12] add global concurrency limit instead of per-routine concurrency limits --- src/zarr/abc/codec.py | 4 -- src/zarr/abc/store.py | 4 +- src/zarr/core/array.py | 4 -- src/zarr/core/codec_pipeline.py | 5 -- src/zarr/core/common.py | 115 +++++++++++++++++++++++++++++++- src/zarr/core/group.py | 10 +-- src/zarr/storage/_local.py | 3 +- src/zarr/storage/_memory.py | 3 +- src/zarr/storage/_obstore.py | 8 +-- 9 files changed, 128 insertions(+), 28 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..b5f7819a91 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -9,7 +9,6 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import NamedConfig, concurrent_map -from zarr.core.config import config if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -228,7 +227,6 @@ async def decode_partial( return await concurrent_map( list(batch_info), self._decode_partial_single, - config.get("async.concurrency"), ) @@ -265,7 +263,6 @@ async def encode_partial( await concurrent_map( list(batch_info), self._encode_partial_single, - config.get("async.concurrency"), ) @@ -467,7 +464,6 @@ async def _batching_helper( return await concurrent_map( list(batch_info), _noop_for_none(func), - config.get("async.concurrency"), ) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..30602edf34 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -464,11 +464,9 @@ async def getsize_prefix(self, prefix: str) -> int: # avoid circular import from zarr.core.common import concurrent_map - from zarr.core.config import config keys = [(x,) async for x in self.list_prefix(prefix)] - limit = config.get("async.concurrency") - sizes = await concurrent_map(keys, self.getsize, limit=limit) + sizes = await concurrent_map(keys, self.getsize) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 8bd8be40b2..2f42836fc2 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -22,7 +22,6 @@ import numpy as np from typing_extensions import deprecated -import zarr from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.numcodec import Numcodec, _is_numcodec from zarr.codecs._v2 import V2Codec @@ -1853,7 +1852,6 @@ async def _delete_key(key: str) -> None: for chunk_coords in old_chunk_coords.difference(new_chunk_coords) ], _delete_key, - zarr_config.get("async.concurrency"), ) # Write new metadata @@ -4530,7 +4528,6 @@ async def _copy_array_region( await concurrent_map( [(region, data) for region in result._iter_shard_regions()], _copy_array_region, - zarr.core.config.config.get("async.concurrency"), ) else: @@ -4541,7 +4538,6 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non await concurrent_map( [(region, data) for region in result._iter_shard_regions()], _copy_arraylike_region, - zarr.core.config.config.get("async.concurrency"), ) return result diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 63fcda7065..e6864c607e 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -270,7 +270,6 @@ async def read_batch( chunk_bytes_batch = await concurrent_map( [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], lambda byte_getter, prototype: byte_getter.get(prototype), - config.get("async.concurrency"), ) chunk_array_batch = await self.decode_batch( [ @@ -375,7 +374,6 @@ async def _read_key( for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info ], _read_key, - config.get("async.concurrency"), ) chunk_array_decoded = await self.decode_batch( [ @@ -441,7 +439,6 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non ) ], _write_key, - config.get("async.concurrency"), ) async def decode( @@ -474,7 +471,6 @@ async def read( for single_batch_info in batched(batch_info, self.batch_size) ], self.read_batch, - config.get("async.concurrency"), ) async def write( @@ -489,7 +485,6 @@ async def write( for single_batch_info in batched(batch_info, self.batch_size) ], self.write_batch, - config.get("async.concurrency"), ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 651ebd72f3..8f6f899cf8 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -4,7 +4,9 @@ import functools import math import operator +import threading import warnings +import weakref from collections.abc import Iterable, Mapping, Sequence from enum import Enum from itertools import starmap @@ -82,15 +84,126 @@ def ceildiv(a: float, b: float) -> int: V = TypeVar("V") +# Global semaphore management for per-process concurrency limiting +# Use WeakKeyDictionary to automatically clean up semaphores when event loops are garbage collected +_global_semaphores: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore] = ( + weakref.WeakKeyDictionary() +) +# Use threading.Lock instead of asyncio.Lock to coordinate across event loops +_global_semaphore_lock = threading.Lock() + + +def get_global_semaphore() -> asyncio.Semaphore: + """ + Get the global semaphore for the current event loop. + + This ensures that all concurrent operations across the process share the same + concurrency limit, preventing excessive concurrent task creation when multiple + arrays or operations are running simultaneously. + + The semaphore is lazily created per event loop and uses the configured + `async.concurrency` value from zarr config. The semaphore is cached per event + loop, so subsequent calls return the same semaphore instance. + + Note: Config changes after the first call will not affect the semaphore limit. + To apply new config values, use :func:`reset_global_semaphores` to clear the cache. + + Returns + ------- + asyncio.Semaphore + The global semaphore for this event loop. + + Raises + ------ + RuntimeError + If called outside of an async context (no running event loop). + + See Also + -------- + reset_global_semaphores : Clear the global semaphore cache + """ + loop = asyncio.get_running_loop() + + # Acquire lock FIRST to prevent TOCTOU race condition + with _global_semaphore_lock: + if loop not in _global_semaphores: + limit = zarr_config.get("async.concurrency") + _global_semaphores[loop] = asyncio.Semaphore(limit) + return _global_semaphores[loop] + + +def reset_global_semaphores() -> None: + """ + Clear all cached global semaphores. + + This is useful when you want config changes to take effect, or for testing. + The next call to :func:`get_global_semaphore` will create a new semaphore + using the current configuration. + + Warning: This should only be called when no async operations are in progress, + as it will invalidate all existing semaphore references. + + Examples + -------- + >>> import zarr + >>> zarr.config.set({"async.concurrency": 50}) + >>> reset_global_semaphores() # Apply new config + """ + with _global_semaphore_lock: + _global_semaphores.clear() + + async def concurrent_map( items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None, + *, + use_global_semaphore: bool = True, ) -> list[V]: - if limit is None: + """ + Execute an async function concurrently over multiple items with concurrency limiting. + + Parameters + ---------- + items : Iterable[T] + Items to process, where each item is a tuple of arguments to pass to func. + func : Callable[..., Awaitable[V]] + Async function to execute for each item. + limit : int | None, optional + If provided and use_global_semaphore is False, creates a local semaphore + with this limit. If None, no concurrency limiting is applied. + use_global_semaphore : bool, default True + If True, uses the global per-process semaphore for concurrency limiting, + ensuring all concurrent operations share the same limit. If False, uses + the `limit` parameter for local limiting (legacy behavior). + + Returns + ------- + list[V] + Results from executing func on all items. + """ + if use_global_semaphore: + if limit is not None: + raise ValueError( + "Cannot specify both use_global_semaphore=True and a limit value. " + "Either use the global semaphore (use_global_semaphore=True, limit=None) " + "or specify a local limit (use_global_semaphore=False, limit=)." + ) + # Use the global semaphore for process-wide concurrency limiting + sem = get_global_semaphore() + + async def run(item: tuple[Any]) -> V: + async with sem: + return await func(*item) + + return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) + + elif limit is None: + # No concurrency limiting return await asyncio.gather(*list(starmap(func, items))) else: + # Legacy mode: create local semaphore with specified limit sem = asyncio.Semaphore(limit) async def run(item: tuple[Any]) -> V: diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 26aed4fd60..2f381431a3 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -44,6 +44,7 @@ NodeType, ShapeLike, ZarrFormat, + get_global_semaphore, parse_shapelike, ) from zarr.core.config import config @@ -1441,8 +1442,8 @@ async def _members( ) raise ValueError(msg) - # enforce a concurrency limit by passing a semaphore to all the recursive functions - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() async for member in _iter_members_deep( self, max_depth=max_depth, @@ -3338,9 +3339,8 @@ async def create_nodes( The created nodes in the order they are created. """ - # Note: the only way to alter this value is via the config. If that's undesirable for some reason, - # then we should consider adding a keyword argument this this function - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..d6f10be862 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -217,7 +217,8 @@ async def get_partial_values( assert isinstance(key, str) path = self.root / key args.append((_get, path, prototype, byte_range)) - return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit + # Use global semaphore to limit concurrent thread spawning + return await concurrent_map(args, asyncio.to_thread) async def set(self, key: str, value: Buffer) -> None: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index a3fd058680..12d7424185 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -104,7 +104,8 @@ async def get_partial_values( async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: return await self.get(key, prototype=prototype, byte_range=byte_range) - return await concurrent_map(key_ranges, _get, limit=None) + # In-memory operations are fast and don't benefit from concurrency limiting + return await concurrent_map(key_ranges, _get, use_global_semaphore=False) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..e1d1bde672 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -13,8 +13,7 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map -from zarr.core.config import config +from zarr.core.common import concurrent_map, get_global_semaphore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -209,7 +208,7 @@ async def delete_dir(self, prefix: str) -> None: metas = await obs.list(self.store, prefix).collect_async() keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete, limit=config.get("async.concurrency")) + await concurrent_map(keys, self.delete) @property def supports_listing(self) -> bool: @@ -485,7 +484,8 @@ async def _get_partial_values( else: raise ValueError(f"Unsupported range input: {byte_range}") - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() futs: list[Coroutine[Any, Any, list[_Response]]] = [] for path, bounded_ranges in per_file_bounded_requests.items(): From e98e6c0a1f4f731bfb76eaa8ffdbe050b8b2c7e6 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 24 Oct 2025 21:32:17 +0200 Subject: [PATCH 02/12] add test --- tests/test_global_concurrency.py | 330 +++++++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 tests/test_global_concurrency.py diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py new file mode 100644 index 0000000000..5df1d68a39 --- /dev/null +++ b/tests/test_global_concurrency.py @@ -0,0 +1,330 @@ +""" +Tests for global per-process concurrency limiting. +""" + +import asyncio +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import zarr +from zarr.core.common import get_global_semaphore, reset_global_semaphores +from zarr.core.config import config + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class TestGlobalSemaphore: + """Tests for the global semaphore management.""" + + async def test_get_global_semaphore_creates_per_loop(self) -> None: + """Test that each event loop gets its own semaphore.""" + sem1 = get_global_semaphore() + assert sem1 is not None + assert isinstance(sem1, asyncio.Semaphore) + + # Getting it again should return the same instance + sem2 = get_global_semaphore() + assert sem1 is sem2 + + async def test_global_semaphore_uses_config_limit(self) -> None: + """Test that the global semaphore respects the configured limit.""" + # Set a custom concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores to force recreation + reset_global_semaphores() + + sem = get_global_semaphore() + + # The semaphore should have the configured limit + # We can verify this by acquiring all tokens and checking the semaphore is locked + for i in range(5): + await sem.acquire() + if i < 4: + assert not sem.locked() # Should still have capacity + else: + assert sem.locked() # All tokens acquired, semaphore is now locked + + # Release all tokens + for _ in range(5): + sem.release() + + finally: + # Restore original config + config.set({"async.concurrency": original_limit}) + # Clear semaphores again to reset state + reset_global_semaphores() + + async def test_global_semaphore_shared_across_operations(self) -> None: + """Test that multiple concurrent operations share the same semaphore.""" + # Track the maximum number of concurrent tasks + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_operation() -> None: + """An operation that tracks concurrency.""" + nonlocal max_concurrent, current_concurrent + + async with lock: + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + + # Small delay to ensure overlap + await asyncio.sleep(0.01) + + async with lock: + current_concurrent -= 1 + + # Set a low concurrency limit to make the test observable + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores + reset_global_semaphores() + + # Get the global semaphore + sem = get_global_semaphore() + + # Create many tasks that use the semaphore + async def task_with_semaphore() -> None: + async with sem: + await tracked_operation() + + # Launch 20 tasks (4x the limit) + tasks = [task_with_semaphore() for _ in range(20)] + await asyncio.gather(*tasks) + + # Maximum concurrent should respect the limit + assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5" + assert max_concurrent >= 3, ( + f"Max concurrent was {max_concurrent}, expected some concurrency" + ) + + finally: + config.set({"async.concurrency": original_limit}) + reset_global_semaphores() + + async def test_semaphore_reuse_across_calls(self) -> None: + """Test that repeated calls to get_global_semaphore return the same instance.""" + reset_global_semaphores() + + # Call multiple times and verify we get the same instance + sem1 = get_global_semaphore() + sem2 = get_global_semaphore() + sem3 = get_global_semaphore() + + assert sem1 is sem2 is sem3, "Should return same semaphore instance on repeated calls" + + # Verify it's still the same after using it + async with sem1: + sem4 = get_global_semaphore() + assert sem1 is sem4 + + def test_config_change_after_creation(self) -> None: + """Test and document that config changes don't affect existing semaphores.""" + original_limit: Any = config.get("async.concurrency") + try: + # Set initial config + config.set({"async.concurrency": 5}) + + async def check_limit() -> None: + reset_global_semaphores() + + # Create semaphore with limit=5 + sem1 = get_global_semaphore() + initial_capacity: int = sem1._value + + # Change config + config.set({"async.concurrency": 50}) + + # Get semaphore again - should be same instance with old limit + sem2 = get_global_semaphore() + assert sem1 is sem2, "Should return same semaphore instance" + assert sem2._value == initial_capacity, ( + f"Semaphore limit changed from {initial_capacity} to {sem2._value}. " + "Config changes should not affect existing semaphores." + ) + + # Clean up + reset_global_semaphores() + + asyncio.run(check_limit()) + + finally: + config.set({"async.concurrency": original_limit}) + + +class TestArrayConcurrency: + """Tests that array operations use global concurrency limiting.""" + + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + async def test_multiple_arrays_share_concurrency_limit(self) -> None: + """Test that reading from multiple arrays shares the global concurrency limit.""" + from zarr.core.common import concurrent_map + + # Track concurrent task executions + max_concurrent_tasks = 0 + current_concurrent_tasks = 0 + task_lock = asyncio.Lock() + + async def tracked_chunk_operation(chunk_id: int) -> int: + """Simulate a chunk operation with tracking.""" + nonlocal max_concurrent_tasks, current_concurrent_tasks + + async with task_lock: + current_concurrent_tasks += 1 + max_concurrent_tasks = max(max_concurrent_tasks, current_concurrent_tasks) + + # Small delay to simulate I/O + await asyncio.sleep(0.001) + + async with task_lock: + current_concurrent_tasks -= 1 + + return chunk_id + + # Set a low concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 10}) + + # Clear existing semaphores + reset_global_semaphores() + + # Simulate reading many chunks using concurrent_map (which uses the global semaphore) + # This simulates what happens when reading from multiple arrays + chunk_ids = [(i,) for i in range(100)] + await concurrent_map(chunk_ids, tracked_chunk_operation) + + # The maximum concurrent tasks should respect the global limit + assert max_concurrent_tasks <= 10, ( + f"Max concurrent tasks was {max_concurrent_tasks}, expected <= 10" + ) + + assert max_concurrent_tasks >= 5, ( + f"Max concurrent tasks was {max_concurrent_tasks}, " + f"expected at least some concurrency" + ) + + finally: + config.set({"async.concurrency": original_limit}) + # Note: We don't reset_global_semaphores() here because doing so while + # many tasks are still cleaning up can trigger ResourceWarnings from + # asyncio internals. The semaphore will be reused by subsequent tests. + + def test_sync_api_uses_global_concurrency(self) -> None: + """Test that synchronous API also benefits from global concurrency limiting.""" + # This test verifies that the sync API (which wraps async) uses global limiting + + # Set a low concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 8}) + + # Create a small array - the key is that zarr internally uses + # concurrent_map which now uses the global semaphore + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + # Read data (synchronously) + data: NDArray[Any] = arr[:] + + # Verify we got the right data + assert np.all(data == 42) + + # The test passes if no errors occurred + # The concurrency limiting is happening under the hood + + finally: + config.set({"async.concurrency": original_limit}) + + +class TestConcurrentMapGlobal: + """Tests for concurrent_map using global semaphore.""" + + async def test_concurrent_map_uses_global_by_default(self) -> None: + """Test that concurrent_map uses global semaphore by default.""" + from zarr.core.common import concurrent_map + + # Track concurrent executions + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_task(x: int) -> int: + nonlocal max_concurrent, current_concurrent + + async with lock: + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + + await asyncio.sleep(0.01) + + async with lock: + current_concurrent -= 1 + + return x * 2 + + # Set a low limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores + reset_global_semaphores() + + # Use concurrent_map with default settings (use_global_semaphore=True) + items = [(i,) for i in range(20)] + results = await concurrent_map(items, tracked_task) + + assert len(results) == 20 + assert max_concurrent <= 5 + assert max_concurrent >= 3 # Should have some concurrency + + finally: + config.set({"async.concurrency": original_limit}) + reset_global_semaphores() + + async def test_concurrent_map_legacy_mode(self) -> None: + """Test that concurrent_map legacy mode still works.""" + from zarr.core.common import concurrent_map + + async def simple_task(x: int) -> int: + await asyncio.sleep(0.001) + return x * 2 + + # Use legacy mode with local limit + items = [(i,) for i in range(10)] + results = await concurrent_map(items, simple_task, limit=3, use_global_semaphore=False) + + assert len(results) == 10 + assert results == [i * 2 for i in range(10)] + + async def test_concurrent_map_parameter_validation(self) -> None: + """Test that concurrent_map validates conflicting parameters.""" + from zarr.core.common import concurrent_map + + async def simple_task(x: int) -> int: + return x * 2 + + items = [(i,) for i in range(10)] + + # Should raise ValueError when both limit and use_global_semaphore=True + with pytest.raises( + ValueError, match="Cannot specify both use_global_semaphore=True and a limit" + ): + await concurrent_map(items, simple_task, limit=5, use_global_semaphore=True) From 735ee8e68232da2ddeeb74c89ea902926869ae3c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 24 Oct 2025 21:51:14 +0200 Subject: [PATCH 03/12] lint --- tests/test_global_concurrency.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py index 5df1d68a39..f6366e3c53 100644 --- a/tests/test_global_concurrency.py +++ b/tests/test_global_concurrency.py @@ -3,7 +3,7 @@ """ import asyncio -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np import pytest @@ -12,9 +12,6 @@ from zarr.core.common import get_global_semaphore, reset_global_semaphores from zarr.core.config import config -if TYPE_CHECKING: - from numpy.typing import NDArray - class TestGlobalSemaphore: """Tests for the global semaphore management.""" @@ -241,7 +238,7 @@ def test_sync_api_uses_global_concurrency(self) -> None: arr[:] = 42 # Read data (synchronously) - data: NDArray[Any] = arr[:] + data = arr[:] # Verify we got the right data assert np.all(data == 42) From 3ef6cfba55f2e7203551424792e55005230cee90 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 27 Oct 2025 16:31:44 +0100 Subject: [PATCH 04/12] changelog --- changes/3547.misc.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3547.misc.md diff --git a/changes/3547.misc.md b/changes/3547.misc.md new file mode 100644 index 0000000000..771bfe8861 --- /dev/null +++ b/changes/3547.misc.md @@ -0,0 +1 @@ +Moved concurrency limits to a global per-event loop setting instead of per-array call. \ No newline at end of file From 229e3b3eeab054de6fa506facd4466f3343fa8b4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 18 Dec 2025 10:04:06 +0100 Subject: [PATCH 05/12] move concurrency limiting logic to stores --- src/zarr/abc/codec.py | 21 ++--- src/zarr/abc/store.py | 8 +- src/zarr/core/array.py | 24 ++--- src/zarr/core/codec_pipeline.py | 51 ++++++----- src/zarr/storage/_fsspec.py | 34 +++++++ src/zarr/storage/_local.py | 45 ++++++++-- src/zarr/storage/_memory.py | 13 ++- src/zarr/storage/_obstore.py | 154 +++++++++++++++++++++++--------- src/zarr/storage/_utils.py | 68 +++++++++++++- tests/test_group.py | 56 ------------ 10 files changed, 304 insertions(+), 170 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index b5f7819a91..69d6c3082e 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -8,7 +9,7 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import NamedConfig, concurrent_map +from zarr.core.common import NamedConfig if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -224,10 +225,8 @@ async def decode_partial( ------- Iterable[NDBuffer | None] """ - return await concurrent_map( - list(batch_info), - self._decode_partial_single, - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[self._decode_partial_single(*info) for info in batch_info]) class ArrayBytesCodecPartialEncodeMixin: @@ -260,10 +259,8 @@ async def encode_partial( The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data. The chunk spec contains information about the chunk. """ - await concurrent_map( - list(batch_info), - self._encode_partial_single, - ) + # Store handles concurrency limiting internally + await asyncio.gather(*[self._encode_partial_single(*info) for info in batch_info]) class CodecPipeline: @@ -461,10 +458,8 @@ async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], batch_info: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> list[CodecOutput | None]: - return await concurrent_map( - list(batch_info), - _noop_for_none(func), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[_noop_for_none(func)(chunk, spec) for chunk, spec in batch_info]) def _noop_for_none( diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 30602edf34..4ccab1877f 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from asyncio import gather from dataclasses import dataclass @@ -462,11 +463,8 @@ async def getsize_prefix(self, prefix: str) -> int: # improve tail latency and might reduce memory pressure (since not all keys # would be in memory at once). - # avoid circular import - from zarr.core.common import concurrent_map - - keys = [(x,) async for x in self.list_prefix(prefix)] - sizes = await concurrent_map(keys, self.getsize) + keys = [x async for x in self.list_prefix(prefix)] + sizes = await asyncio.gather(*[self.getsize(key) for key in keys]) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index d036cd7974..01ff74f38f 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from asyncio import gather @@ -59,7 +60,6 @@ _default_zarr_format, _warn_order_kwarg, ceildiv, - concurrent_map, parse_shapelike, product, ) @@ -1847,12 +1847,12 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) async def _delete_key(key: str) -> None: await (self.store_path / key).delete() - await concurrent_map( - [ - (self.metadata.encode_chunk_key(chunk_coords),) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _delete_key(self.metadata.encode_chunk_key(chunk_coords)) for chunk_coords in old_chunk_coords.difference(new_chunk_coords) - ], - _delete_key, + ] ) # Write new metadata @@ -4533,9 +4533,9 @@ async def _copy_array_region( await result.setitem(chunk_coords, arr) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_array_region, + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_array_region(region, data) for region in result._iter_shard_regions()] ) else: @@ -4543,9 +4543,9 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non await result.setitem(chunk_coords, _data[chunk_coords]) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_arraylike_region, + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_arraylike_region(region, data) for region in result._iter_shard_regions()] ) return result diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index e99080acec..0f8350f7ea 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from itertools import islice, pairwise from typing import TYPE_CHECKING, Any, TypeVar @@ -14,7 +15,6 @@ Codec, CodecPipeline, ) -from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning @@ -267,9 +267,12 @@ async def read_batch( else: out[out_selection] = fill_value_or_default(chunk_spec) else: - chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], - lambda byte_getter, prototype: byte_getter.get(prototype), + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + byte_getter.get(array_spec.prototype) + for byte_getter, array_spec, *_ in batch_info + ] ) chunk_array_batch = await self.decode_batch( [ @@ -367,15 +370,15 @@ async def _read_key( return await byte_setter.get(prototype=prototype) chunk_bytes_batch: Iterable[Buffer | None] - chunk_bytes_batch = await concurrent_map( - [ - ( + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + _read_key( None if is_complete_chunk else byte_setter, chunk_spec.prototype, ) for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info - ], - _read_key, + ] ) chunk_array_decoded = await self.decode_batch( [ @@ -433,14 +436,14 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non else: await byte_setter.set(chunk_bytes) - await concurrent_map( - [ - (byte_setter, chunk_bytes) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _write_key(byte_setter, chunk_bytes) for chunk_bytes, (byte_setter, *_) in zip( chunk_bytes_batch, batch_info, strict=False ) - ], - _write_key, + ] ) async def decode( @@ -467,12 +470,12 @@ async def read( out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, out, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.read_batch(single_batch_info, out, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.read_batch, + ] ) async def write( @@ -481,12 +484,12 @@ async def write( value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, value, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.write_batch(single_batch_info, value, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.write_batch, + ] ) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 7945fba467..e1ca718784 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from contextlib import suppress @@ -17,6 +18,7 @@ from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -82,6 +84,9 @@ class FsspecStore(Store): filesystem scheme. allowed_exceptions : tuple[type[Exception], ...] When fetching data, these cases will be deemed to correspond to missing keys. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Attributes ---------- @@ -117,18 +122,24 @@ class FsspecStore(Store): fs: AsyncFileSystem allowed_exceptions: tuple[type[Exception], ...] path: str + _semaphore: asyncio.Semaphore | None def __init__( self, fs: AsyncFileSystem, + *, read_only: bool = False, path: str = "/", allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + concurrency_limit: int | None = 50, ) -> None: super().__init__(read_only=read_only) self.fs = fs self.path = path self.allowed_exceptions = allowed_exceptions + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) if not self.fs.async_impl: raise TypeError("Filesystem needs to support async operations.") @@ -273,6 +284,7 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + @with_concurrency_limit() async def get( self, key: str, @@ -315,6 +327,7 @@ async def get( else: return value + @with_concurrency_limit() async def set( self, key: str, @@ -335,6 +348,27 @@ async def set( raise NotImplementedError await self.fs._pipe_file(path, value.to_bytes()) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + if not self._is_open: + await self._open() + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + if not isinstance(value, Buffer): + raise TypeError( + f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + path = _dereference_path(self.path, key) + if self._semaphore: + async with self._semaphore: + await self.fs._pipe_file(path, value.to_bytes()) + else: + await self.fs._pipe_file(path, value.to_bytes()) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index d6f10be862..ea48c756d3 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,12 +19,13 @@ ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator from zarr.core.buffer import BufferPrototype + from zarr.core.common import AccessModeLiteral def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: @@ -95,6 +96,9 @@ class LocalStore(Store): Directory to use as root of store. read_only : bool Whether the store is read-only + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 100. + Set to None for unlimited concurrency. Attributes ---------- @@ -109,8 +113,15 @@ class LocalStore(Store): supports_listing: bool = True root: Path + _semaphore: asyncio.Semaphore | None - def __init__(self, root: Path | str, *, read_only: bool = False) -> None: + def __init__( + self, + root: Path | str, + *, + read_only: bool = False, + concurrency_limit: int | None = 100, + ) -> None: super().__init__(read_only=read_only) if isinstance(root, str): root = Path(root) @@ -119,12 +130,17 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( root=self.root, read_only=read_only, + concurrency_limit=concurrency_limit, ) @classmethod @@ -187,6 +203,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + @with_concurrency_limit() async def get( self, key: str, @@ -212,13 +229,23 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - args = [] - for key, byte_range in key_ranges: - assert isinstance(key, str) + # Note: We directly call the I/O functions here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key - args.append((_get, path, prototype, byte_range)) - # Use global semaphore to limit concurrent thread spawning - return await concurrent_map(args, asyncio.to_thread) + try: + if self._semaphore: + async with self._semaphore: + return await asyncio.to_thread(_get, path, prototype, byte_range) + else: + return await asyncio.to_thread(_get, path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) async def set(self, key: str, value: Buffer) -> None: # docstring inherited @@ -231,6 +258,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: except FileExistsError: pass + @with_concurrency_limit() async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() @@ -243,6 +271,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + @with_concurrency_limit() async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e968e3cd26..be222c96b7 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,12 +1,12 @@ from __future__ import annotations +import asyncio from logging import getLogger from typing import TYPE_CHECKING, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: @@ -102,13 +102,10 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - - # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: - return await self.get(key, prototype=prototype, byte_range=byte_range) - - # In-memory operations are fast and don't benefit from concurrency limiting - return await concurrent_map(key_ranges, _get, use_global_semaphore=False) + # In-memory operations are fast and don't need concurrency limiting + return await asyncio.gather( + *[self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + ) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index e1d1bde672..223142d371 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -13,7 +13,8 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map, get_global_semaphore +from zarr.core.common import get_global_semaphore +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -46,6 +47,9 @@ class ObjectStore(Store, Generic[T_Store]): An obstore store instance that is set up with the proper credentials. read_only : bool Whether to open the store in read-only mode. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Warnings -------- @@ -55,6 +59,7 @@ class ObjectStore(Store, Generic[T_Store]): store: T_Store """The underlying obstore instance.""" + _semaphore: asyncio.Semaphore | None def __eq__(self, value: object) -> bool: if not isinstance(value, ObjectStore): @@ -65,17 +70,28 @@ def __eq__(self, value: object) -> bool: return self.store == value.store # type: ignore[no-any-return] - def __init__(self, store: T_Store, *, read_only: bool = False) -> None: + def __init__( + self, + store: T_Store, + *, + read_only: bool = False, + concurrency_limit: int | None = 50, + ) -> None: if not store.__class__.__module__.startswith("obstore"): raise TypeError(f"expected ObjectStore class, got {store!r}") super().__init__(read_only=read_only) self.store = store + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( store=self.store, read_only=read_only, + concurrency_limit=concurrency_limit, ) def __str__(self) -> str: @@ -93,6 +109,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + @with_concurrency_limit() async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: @@ -100,41 +117,7 @@ async def get( import obstore as obs try: - if byte_range is None: - resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, RangeByteRequest): - bytes = await obs.get_range_async( - self.store, key, start=byte_range.start, end=byte_range.end - ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] - elif isinstance(byte_range, OffsetByteRequest): - resp = await obs.get_async( - self.store, key, options={"range": {"offset": byte_range.offset}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, SuffixByteRequest): - # some object stores (Azure) don't support suffix requests. In this - # case, our workaround is to first get the length of the object and then - # manually request the byte range at the end. - try: - resp = await obs.get_async( - self.store, key, options={"range": {"suffix": byte_range.suffix}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(self.store, key) - file_size = head_resp["size"] - suffix_len = byte_range.suffix - buffer = await obs.get_range_async( - self.store, - key, - start=file_size - suffix_len, - length=suffix_len, - ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] - else: - raise ValueError(f"Unexpected byte_range, got {byte_range}") + return await self._get_impl(key, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: return None @@ -144,7 +127,60 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) + # Note: We directly call obs operations here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + import obstore as obs + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: + try: + if self._semaphore: + async with self._semaphore: + return await self._get_impl(key, prototype, byte_range, obs) + else: + return await self._get_impl(key, prototype, byte_range, obs) + except _ALLOWED_EXCEPTIONS: + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) + + async def _get_impl( + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any + ) -> Buffer: + """Implementation of get without semaphore decoration.""" + if byte_range is None: + resp = await obs.get_async(self.store, key) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, RangeByteRequest): + bytes = await obs.get_range_async( + self.store, key, start=byte_range.start, end=byte_range.end + ) + return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + elif isinstance(byte_range, OffsetByteRequest): + resp = await obs.get_async( + self.store, key, options={"range": {"offset": byte_range.offset}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, SuffixByteRequest): + try: + resp = await obs.get_async( + self.store, key, options={"range": {"suffix": byte_range.suffix}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + except obs.exceptions.NotSupportedError: + head_resp = await obs.head_async(self.store, key) + file_size = head_resp["size"] + suffix_len = byte_range.suffix + buffer = await obs.get_range_async( + self.store, + key, + start=file_size - suffix_len, + length=suffix_len, + ) + return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}") async def exists(self, key: str) -> bool: # docstring inherited @@ -162,6 +198,7 @@ def supports_writes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def set(self, key: str, value: Buffer) -> None: # docstring inherited import obstore as obs @@ -171,20 +208,43 @@ async def set(self, key: str, value: Buffer) -> None: buf = value.as_buffer_like() await obs.put_async(self.store, key, buf) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + import obstore as obs + + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + buf = value.as_buffer_like() + if self._semaphore: + async with self._semaphore: + await obs.put_async(self.store, key, buf) + else: + await obs.put_async(self.store, key, buf) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited + # Note: Not decorated to avoid deadlock when called in batch via gather() import obstore as obs self._check_writable() buf = value.as_buffer_like() - with contextlib.suppress(obs.exceptions.AlreadyExistsError): - await obs.put_async(self.store, key, buf, mode="create") + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") + else: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") @property def supports_deletes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited import obstore as obs @@ -207,8 +267,18 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() - keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete) + + # Delete with semaphore limiting to avoid deadlock + async def _delete_with_limit(path: str) -> None: + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + else: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + + await asyncio.gather(*[_delete_with_limit(m["path"]) for m in metas]) @property def supports_listing(self) -> bool: diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 39c28d44c3..9ce01c2d99 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -1,17 +1,81 @@ from __future__ import annotations +import functools import re from pathlib import Path -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + import asyncio + from collections.abc import Callable, Coroutine, Iterable, Mapping from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + + +def with_concurrency_limit( + semaphore_attr: str = "_semaphore", +) -> Callable[[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]]: + """ + Decorator that applies a semaphore-based concurrency limit to an async method. + + This decorator is designed for Store methods that need to limit concurrent operations. + The store instance should have a `_semaphore` attribute (or custom attribute name) + that is either an asyncio.Semaphore or None (for unlimited concurrency). + + Parameters + ---------- + semaphore_attr : str, optional + Name of the semaphore attribute on the class instance. Default is "_semaphore". + + Returns + ------- + Callable + The decorated async function with concurrency limiting applied. + + Examples + -------- + ```python + class MyStore(Store): + def __init__(self, concurrency_limit: int = 100): + self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None + + @with_concurrency_limit() + async def get(self, key: str) -> Buffer | None: + # This will only run when semaphore permits + return await expensive_io_operation(key) + ``` + """ + + def decorator( + func: Callable[P, Coroutine[Any, Any, T_co]], + ) -> Callable[P, Coroutine[Any, Any, T_co]]: + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + # First arg should be 'self' + if not args: + raise TypeError(f"{func.__name__} requires at least one argument (self)") + + self = args[0] + semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr, None) + + if semaphore is None: + # No concurrency limit - run directly + return await func(*args, **kwargs) + else: + # Apply concurrency limit + async with semaphore: + return await func(*args, **kwargs) + + return wrapper + + return decorator + def normalize_path(path: str | bytes | Path | None) -> str: if path is None: diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..9f25036298 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -23,7 +23,6 @@ from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 from zarr.core.group import ( @@ -1738,29 +1737,6 @@ async def test_create_nodes( assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of create_nodes can be constrained by the async concurrency - configuration setting. - """ - set_latency = 0.02 - num_groups = 10 - groups = {str(idx): GroupMetadata() for idx in range(num_groups)} - - latency_store = LatencyStore(store, set_latency=set_latency) - - # check how long it takes to iterate over the groups - # if create_nodes is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) - elapsed = time.time() - start - assert elapsed > num_groups * set_latency - - @pytest.mark.parametrize( ("a_func", "b_func"), [ @@ -2250,38 +2226,6 @@ def test_group_members_performance(store: Store) -> None: assert elapsed < (num_groups * get_latency) -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_members_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of Group.members can be constrained by the async concurrency - configuration setting. - """ - get_latency = 0.02 - - # use the input store to create some groups - group_create = zarr.group(store=store) - num_groups = 10 - - # Create some groups - for i in range(num_groups): - group_create.create_group(f"group{i}") - - latency_store = LatencyStore(store, get_latency=get_latency) - # create a group with some latency on get operations - group_read = zarr.group(store=latency_store) - - # check how long it takes to iterate over the groups - # if .members is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = group_read.members() - elapsed = time.time() - start - - assert elapsed > num_groups * get_latency - - @pytest.mark.parametrize("option", ["array", "group", "invalid"]) def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: """ From 05d191ae3672cbf0f61bb7b08b1f56cc045eee05 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 19 Dec 2025 15:59:54 +0100 Subject: [PATCH 06/12] add store concurrency tests --- src/zarr/storage/_utils.py | 18 +- src/zarr/testing/store_concurrency.py | 247 ++++++++++++++++++++++++++ tests/test_store/test_local.py | 13 ++ tests/test_store/test_memory.py | 13 ++ 4 files changed, 284 insertions(+), 7 deletions(-) create mode 100644 src/zarr/testing/store_concurrency.py diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 9ce01c2d99..d156a06891 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -55,6 +55,13 @@ async def get(self, key: str) -> Buffer | None: def decorator( func: Callable[P, Coroutine[Any, Any, T_co]], ) -> Callable[P, Coroutine[Any, Any, T_co]]: + """ + This decorator wraps the invocation of `func` in an `async with semaphore` context manager. + The semaphore object is resolved by getting the `semaphor_attr` attribute from the first + argument to func. When this decorator is used on a method of a class, that first argument + is a reference to the class instance (`self`). + """ + @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: # First arg should be 'self' @@ -62,15 +69,12 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: raise TypeError(f"{func.__name__} requires at least one argument (self)") self = args[0] - semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr, None) - if semaphore is None: - # No concurrency limit - run directly + semaphore: asyncio.Semaphore = getattr(self, semaphore_attr) + + # Apply concurrency limit + async with semaphore: return await func(*args, **kwargs) - else: - # Apply concurrency limit - async with semaphore: - return await func(*args, **kwargs) return wrapper diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py new file mode 100644 index 0000000000..06cf23857d --- /dev/null +++ b/src/zarr/testing/store_concurrency.py @@ -0,0 +1,247 @@ +"""Base test class for store concurrency limiting behavior.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Generic, TypeVar + +import pytest + +from zarr.core.buffer import Buffer, default_buffer_prototype + +if TYPE_CHECKING: + from zarr.abc.store import Store + +__all__ = ["StoreConcurrencyTests"] + + +S = TypeVar("S", bound="Store") +B = TypeVar("B", bound="Buffer") + + +class StoreConcurrencyTests(Generic[S, B]): + """Base class for testing store concurrency limiting behavior. + + This mixin provides tests for verifying that stores correctly implement + concurrency limiting. + + Subclasses should set: + - store_cls: The store class being tested + - buffer_cls: The buffer class to use (e.g., cpu.Buffer) + - expected_concurrency_limit: Expected default concurrency limit (or None for unlimited) + """ + + store_cls: type[S] + buffer_cls: type[B] + expected_concurrency_limit: int | None + + @pytest.fixture + async def store(self, store_kwargs: dict) -> S: + """Create and open a store instance.""" + return await self.store_cls.open(**store_kwargs) + + def test_concurrency_limit_default(self, store: S) -> None: + """Test that store has the expected default concurrency limit.""" + if hasattr(store, "_semaphore"): + if self.expected_concurrency_limit is None: + assert store._semaphore is None, "Expected no concurrency limit" + else: + assert store._semaphore is not None, "Expected concurrency limit to be set" + assert store._semaphore._value == self.expected_concurrency_limit, ( + f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" + ) + + def test_concurrency_limit_custom(self, store_kwargs: dict) -> None: + """Test that custom concurrency limits can be set.""" + if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: + pytest.skip("Store does not support custom concurrency limits") + + # Test with custom limit + store = self.store_cls(**store_kwargs, concurrency_limit=42) + if hasattr(store, "_semaphore"): + assert store._semaphore is not None + assert store._semaphore._value == 42 + + # Test with None (unlimited) + store = self.store_cls(**store_kwargs, concurrency_limit=None) + if hasattr(store, "_semaphore"): + assert store._semaphore is None + + async def test_concurrency_limit_enforced(self, store: S) -> None: + """Test that the concurrency limit is actually enforced during execution. + + This test verifies that when many operations are submitted concurrently, + only up to the concurrency limit are actually executing at once. + """ + if not hasattr(store, "_semaphore") or store._semaphore is None: + pytest.skip("Store has no concurrency limit") + + limit = store._semaphore._value + + # We'll monitor the semaphore's available count + # When it reaches 0, that means `limit` operations are running + min_available = limit + + async def monitored_operation(key: str, value: B) -> None: + nonlocal min_available + # Check semaphore state right after we're scheduled + await asyncio.sleep(0) # Yield to ensure we're in the queue + available = store._semaphore._value + min_available = min(min_available, available) + + # Now do the actual operation (which will acquire the semaphore) + await store.set(key, value) + + # Launch more operations than the limit to ensure contention + num_ops = limit * 2 + items = [ + (f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode())) + for i in range(num_ops) + ] + + await asyncio.gather(*[monitored_operation(k, v) for k, v in items]) + + # The semaphore should have been fully utilized (reached 0 or close to it) + # This indicates that `limit` operations were running concurrently + assert min_available < limit, ( + f"Semaphore was never fully utilized. " + f"Min available: {min_available}, Limit: {limit}. " + f"This suggests operations aren't running concurrently." + ) + + # Ideally it should reach 0, but allow some slack for timing + assert min_available <= 5, ( + f"Semaphore only reached {min_available} available slots. " + f"Expected close to 0 with limit {limit}." + ) + + async def test_batch_write_no_deadlock(self, store: S) -> None: + """Test that batch writes don't deadlock when exceeding concurrency limit.""" + # Create more items than any reasonable concurrency limit + num_items = 200 + items = [ + (f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode())) + for i in range(num_items) + ] + + # This should complete without deadlock, even if num_items > concurrency_limit + await asyncio.wait_for(store._set_many(items), timeout=30.0) + + # Verify all items were written correctly + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_read_no_deadlock(self, store: S) -> None: + """Test that batch reads don't deadlock when exceeding concurrency limit.""" + # Write test data + num_items = 200 + test_data = { + f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode()) + for i in range(num_items) + } + + for key, value in test_data.items(): + await store.set(key, value) + + # Read all items concurrently - should not deadlock + keys_and_ranges = [(key, None) for key in test_data] + results = await asyncio.wait_for( + store.get_partial_values(default_buffer_prototype(), keys_and_ranges), + timeout=30.0, + ) + + # Verify results + assert len(results) == num_items + for result, (key, expected_value) in zip(results, test_data.items()): + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_delete_no_deadlock(self, store: S) -> None: + """Test that batch deletes don't deadlock when exceeding concurrency limit.""" + if not store.supports_deletes: + pytest.skip("Store does not support deletes") + + # Write test data + num_items = 200 + keys = [f"test_key_{i}" for i in range(num_items)] + for key in keys: + await store.set(key, self.buffer_cls.from_bytes(b"test_value")) + + # Delete all items concurrently - should not deadlock + await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0) + + # Verify all items were deleted + for key in keys: + result = await store.get(key, default_buffer_prototype()) + assert result is None + + async def test_concurrent_operations_correctness(self, store: S) -> None: + """Test that concurrent operations produce correct results.""" + num_operations = 100 + + # Mix of reads and writes + write_keys = [f"write_key_{i}" for i in range(num_operations)] + write_values = [ + self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations) + ] + + # Write all concurrently + await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)]) + + # Read all concurrently + results = await asyncio.gather( + *[store.get(k, default_buffer_prototype()) for k in write_keys] + ) + + # Verify correctness + for result, expected in zip(results, write_values): + assert result is not None + assert result.to_bytes() == expected.to_bytes() + + @pytest.mark.parametrize("batch_size", [1, 10, 50, 100]) + async def test_various_batch_sizes(self, store: S, batch_size: int) -> None: + """Test that various batch sizes work correctly.""" + items = [ + (f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode())) + for i in range(batch_size) + ] + + # Should complete without issues for any batch size + await asyncio.wait_for(store._set_many(items), timeout=10.0) + + # Verify + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_empty_batch_operations(self, store: S) -> None: + """Test that empty batch operations don't cause issues.""" + # Empty batch should not raise + await store._set_many([]) + + # Empty read batch + results = await store.get_partial_values(default_buffer_prototype(), []) + assert results == [] + + async def test_mixed_success_failure_batch(self, store: S) -> None: + """Test batch operations with mix of successful and failing items.""" + # Write some initial data + await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value")) + + # Try to read mix of existing and non-existing keys + key_ranges = [ + ("existing_key", None), + ("non_existing_key_1", None), + ("non_existing_key_2", None), + ] + + results = await store.get_partial_values(default_buffer_prototype(), key_ranges) + + # First should exist, others should be None + assert results[0] is not None + assert results[0].to_bytes() == b"existing_value" + assert results[1] is None + assert results[2] is None diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..73eec991f8 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -12,6 +12,7 @@ from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import assert_bytes_equal @@ -150,3 +151,15 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +class TestLocalStoreConcurrency(StoreConcurrencyTests[LocalStore, cpu.Buffer]): + """Test LocalStore concurrency limiting behavior.""" + + store_cls = LocalStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = 100 # LocalStore default + + @pytest.fixture + def store_kwargs(self, tmpdir: str) -> dict[str, str]: + return {"root": str(tmpdir)} diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..2222905745 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -12,6 +12,7 @@ from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: @@ -130,3 +131,15 @@ def test_from_dict(self) -> None: result = GpuMemoryStore.from_dict(d) for v in result._store_dict.values(): assert type(v) is gpu.Buffer + + +class TestMemoryStoreConcurrency(StoreConcurrencyTests[MemoryStore, cpu.Buffer]): + """Test MemoryStore concurrency limiting behavior.""" + + store_cls = MemoryStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = None # MemoryStore has no limit (fast in-memory ops) + + @pytest.fixture + def store_kwargs(self) -> dict[str, Any]: + return {"store_dict": None} From 21fc1d547a31c067800227c6dabb095ee6e9ee94 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 8 Feb 2026 20:52:32 +0100 Subject: [PATCH 07/12] remove more references to the global concurrency limit --- src/zarr/core/common.py | 139 +--------- src/zarr/core/group.py | 58 +---- src/zarr/storage/_obstore.py | 299 ++++------------------ src/zarr/storage/_utils.py | 8 +- src/zarr/testing/store_concurrency.py | 8 +- tests/test_common.py | 4 - tests/test_global_concurrency.py | 354 +++----------------------- 7 files changed, 103 insertions(+), 767 deletions(-) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 88f1388091..e45c256310 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -1,15 +1,11 @@ from __future__ import annotations -import asyncio import functools import math import operator -import threading import warnings -import weakref from collections.abc import Iterable, Mapping, Sequence from enum import Enum -from itertools import starmap from typing import ( TYPE_CHECKING, Any, @@ -29,7 +25,7 @@ from zarr.errors import ZarrRuntimeWarning if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Iterator + from collections.abc import Iterator ZARR_JSON = "zarr.json" @@ -96,139 +92,6 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) -T = TypeVar("T", bound=tuple[Any, ...]) -V = TypeVar("V") - - -# Global semaphore management for per-process concurrency limiting -# Use WeakKeyDictionary to automatically clean up semaphores when event loops are garbage collected -_global_semaphores: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore] = ( - weakref.WeakKeyDictionary() -) -# Use threading.Lock instead of asyncio.Lock to coordinate across event loops -_global_semaphore_lock = threading.Lock() - - -def get_global_semaphore() -> asyncio.Semaphore: - """ - Get the global semaphore for the current event loop. - - This ensures that all concurrent operations across the process share the same - concurrency limit, preventing excessive concurrent task creation when multiple - arrays or operations are running simultaneously. - - The semaphore is lazily created per event loop and uses the configured - `async.concurrency` value from zarr config. The semaphore is cached per event - loop, so subsequent calls return the same semaphore instance. - - Note: Config changes after the first call will not affect the semaphore limit. - To apply new config values, use :func:`reset_global_semaphores` to clear the cache. - - Returns - ------- - asyncio.Semaphore - The global semaphore for this event loop. - - Raises - ------ - RuntimeError - If called outside of an async context (no running event loop). - - See Also - -------- - reset_global_semaphores : Clear the global semaphore cache - """ - loop = asyncio.get_running_loop() - - # Acquire lock FIRST to prevent TOCTOU race condition - with _global_semaphore_lock: - if loop not in _global_semaphores: - limit = zarr_config.get("async.concurrency") - _global_semaphores[loop] = asyncio.Semaphore(limit) - return _global_semaphores[loop] - - -def reset_global_semaphores() -> None: - """ - Clear all cached global semaphores. - - This is useful when you want config changes to take effect, or for testing. - The next call to :func:`get_global_semaphore` will create a new semaphore - using the current configuration. - - Warning: This should only be called when no async operations are in progress, - as it will invalidate all existing semaphore references. - - Examples - -------- - >>> import zarr - >>> zarr.config.set({"async.concurrency": 50}) - >>> reset_global_semaphores() # Apply new config - """ - with _global_semaphore_lock: - _global_semaphores.clear() - - -async def concurrent_map( - items: Iterable[T], - func: Callable[..., Awaitable[V]], - limit: int | None = None, - *, - use_global_semaphore: bool = True, -) -> list[V]: - """ - Execute an async function concurrently over multiple items with concurrency limiting. - - Parameters - ---------- - items : Iterable[T] - Items to process, where each item is a tuple of arguments to pass to func. - func : Callable[..., Awaitable[V]] - Async function to execute for each item. - limit : int | None, optional - If provided and use_global_semaphore is False, creates a local semaphore - with this limit. If None, no concurrency limiting is applied. - use_global_semaphore : bool, default True - If True, uses the global per-process semaphore for concurrency limiting, - ensuring all concurrent operations share the same limit. If False, uses - the `limit` parameter for local limiting (legacy behavior). - - Returns - ------- - list[V] - Results from executing func on all items. - """ - if use_global_semaphore: - if limit is not None: - raise ValueError( - "Cannot specify both use_global_semaphore=True and a limit value. " - "Either use the global semaphore (use_global_semaphore=True, limit=None) " - "or specify a local limit (use_global_semaphore=False, limit=)." - ) - # Use the global semaphore for process-wide concurrency limiting - sem = get_global_semaphore() - - async def run(item: tuple[Any]) -> V: - async with sem: - return await func(*item) - - return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) - - elif limit is None: - # No concurrency limiting - return await asyncio.gather(*list(starmap(func, items))) - - else: - # Legacy mode: create local semaphore with specified limit - sem = asyncio.Semaphore(limit) - - async def run(item: tuple[Any]) -> V: - async with sem: - return await func(*item) - - return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) - - E = TypeVar("E", bound=Enum) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 50b57a569f..658de7ef81 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -44,7 +44,6 @@ NodeType, ShapeLike, ZarrFormat, - get_global_semaphore, parse_shapelike, ) from zarr.core.config import config @@ -1441,13 +1440,10 @@ async def _members( ) raise ValueError(msg) - # Use global semaphore for process-wide concurrency limiting - semaphore = get_global_semaphore() async for member in _iter_members_deep( self, max_depth=max_depth, skip_keys=skip_keys, - semaphore=semaphore, use_consolidated_for_children=use_consolidated_for_children, ): yield member @@ -3324,13 +3320,11 @@ async def create_nodes( The created nodes in the order they are created. """ - # Use global semaphore for process-wide concurrency limiting - semaphore = get_global_semaphore() create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): # make the key absolute - create_tasks.extend(_persist_metadata(store, key, value, semaphore=semaphore)) + create_tasks.extend(_persist_metadata(store, key, value)) created_object_keys = [] @@ -3476,28 +3470,16 @@ def _ensure_consistent_zarr_format( ) -async def _getitem_semaphore( - node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None -) -> AnyAsyncArray | AsyncGroup: +async def _getitem(node: AsyncGroup, key: str) -> AnyAsyncArray | AsyncGroup: """ - Wrap Group.getitem with an optional semaphore. - - If the semaphore parameter is an - asyncio.Semaphore instance, then the getitem operation is performed inside an async context - manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked - without a context manager. + Fetch a child node from a group by key. """ - if semaphore is not None: - async with semaphore: - return await node.getitem(key) - else: - return await node.getitem(key) + return await node.getitem(key) async def _iter_members( node: AsyncGroup, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ Iterate over the arrays and groups contained in a group. @@ -3508,8 +3490,6 @@ async def _iter_members( The group to traverse. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. Yields ------ @@ -3520,10 +3500,7 @@ async def _iter_members( keys = [key async for key in node.store.list_dir(node.path)] keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) - node_tasks = tuple( - asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key) - for key in keys_filtered - ) + node_tasks = tuple(asyncio.create_task(_getitem(node, key), name=key) for key in keys_filtered) for fetched_node_coro in asyncio.as_completed(node_tasks): try: @@ -3550,7 +3527,6 @@ async def _iter_members_deep( *, max_depth: int | None, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None = None, use_consolidated_for_children: bool = True, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ @@ -3565,8 +3541,6 @@ async def _iter_members_deep( The maximum depth of recursion. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. use_consolidated_for_children : bool, default True Whether to use the consolidated metadata of child groups loaded from the store. Note that this only affects groups loaded from the @@ -3585,7 +3559,7 @@ async def _iter_members_deep( new_depth = None else: new_depth = max_depth - 1 - async for name, node in _iter_members(group, skip_keys=skip_keys, semaphore=semaphore): + async for name, node in _iter_members(group, skip_keys=skip_keys): is_group = isinstance(node, AsyncGroup) if ( is_group @@ -3599,9 +3573,7 @@ async def _iter_members_deep( yield name, node if is_group and do_recursion: node = cast("AsyncGroup", node) - to_recurse[name] = _iter_members_deep( - node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore - ) + to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys) for prefix, subgroup_iter in to_recurse.items(): async for name, node in subgroup_iter: @@ -3811,9 +3783,7 @@ async def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> AnyAsync raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover -async def _set_return_key( - *, store: Store, key: str, value: Buffer, semaphore: asyncio.Semaphore | None = None -) -> str: +async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: """ Write a value to storage at the given key. The key is returned. Useful when saving values via routines that return results in execution order, @@ -3828,15 +3798,8 @@ async def _set_return_key( The key to save the value to. value : Buffer The value to save. - semaphore : asyncio.Semaphore | None - An optional semaphore to use to limit the number of concurrent writes. """ - - if semaphore is not None: - async with semaphore: - await store.set(key, value) - else: - await store.set(key, value) + await store.set(key, value) return key @@ -3844,7 +3807,6 @@ def _persist_metadata( store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata, - semaphore: asyncio.Semaphore | None = None, ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. @@ -3852,7 +3814,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) + _set_return_key(store=store, key=_join_paths([path, key]), value=value) for key, value in to_save.items() ) diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 223142d371..697f51ddb0 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -4,7 +4,7 @@ import contextlib import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from zarr.abc.store import ( ByteRequest, @@ -13,14 +13,13 @@ Store, SuffixByteRequest, ) -from zarr.core.common import get_global_semaphore from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence from typing import Any - from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange + from obstore import ListResult, ListStream, ObjectMeta from obstore.store import ObjectStore as _UpstreamObjectStore from zarr.core.buffer import Buffer, BufferPrototype @@ -127,23 +126,59 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - # Note: We directly call obs operations here, wrapped with semaphore - # to avoid deadlock from calling the decorated get() method + # We override to: + # 1. Avoid deadlock from calling the decorated get() method + # 2. Batch RangeByteRequests per-file using get_ranges_async for performance import obstore as obs - async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: + key_ranges = list(key_ranges) + # Group bounded range requests by path for batched fetching + per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list) + other_requests: list[tuple[int, str, ByteRequest | None]] = [] + + for idx, (path, byte_range) in enumerate(key_ranges): + if isinstance(byte_range, RangeByteRequest): + per_file_bounded[path].append((idx, byte_range)) + else: + other_requests.append((idx, path, byte_range)) + + buffers: list[Buffer | None] = [None] * len(key_ranges) + + async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) -> None: + """Batch multiple range requests for the same file using get_ranges_async.""" + starts = [r.start for _, r in requests] + ends = [r.end for _, r in requests] + if self._semaphore: + async with self._semaphore: + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + else: + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + for (idx, _), response in zip(requests, responses, strict=True): + buffers[idx] = prototype.buffer.from_bytes(response) # type: ignore[arg-type] + + async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None: + """Fetch a single non-range request with semaphore limiting.""" try: if self._semaphore: async with self._semaphore: - return await self._get_impl(key, prototype, byte_range, obs) + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) else: - return await self._get_impl(key, prototype, byte_range, obs) + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: - return None + pass # buffers[idx] stays None - return await asyncio.gather( - *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] - ) + futs: list[Coroutine[Any, Any, None]] = [] + for path, requests in per_file_bounded.items(): + futs.append(_fetch_ranges(path, requests)) + for idx, path, byte_range in other_requests: + futs.append(_fetch_one(idx, path, byte_range)) + + await asyncio.gather(*futs) + return buffers async def _get_impl( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any @@ -336,243 +371,3 @@ async def _transform_list_dir( objects = [obj["path"].removeprefix(prefix).lstrip("/") for obj in list_result["objects"]] for item in prefixes + objects: yield item - - -class _BoundedRequest(TypedDict): - """Range request with a known start and end byte. - - These requests can be multiplexed natively on the Rust side with - `obstore.get_ranges_async`. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - start: int - """Start byte offset.""" - - end: int - """End byte offset.""" - - -class _OtherRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: OffsetRange | None - # Note: suffix requests are handled separately because some object stores (Azure) - # don't support them - """The range request type.""" - - -class _SuffixRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: SuffixRange - """The suffix range.""" - - -class _Response(TypedDict): - """A response buffer associated with the original index that it should be restored to.""" - - original_request_index: int - """The positional index in the original key_ranges input""" - - buffer: Buffer - """The buffer returned from obstore's range request.""" - - -async def _make_bounded_requests( - store: _UpstreamObjectStore, - path: str, - requests: list[_BoundedRequest], - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make all bounded requests for a specific file. - - `obstore.get_ranges_async` allows for making concurrent requests for multiple ranges - within a single file, and will e.g. merge concurrent requests. This only uses one - single Python coroutine. - """ - import obstore as obs - - starts = [r["start"] for r in requests] - ends = [r["end"] for r in requests] - async with semaphore: - responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends) - - buffer_responses: list[_Response] = [] - for request, response in zip(requests, responses, strict=True): - buffer_responses.append( - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(response), # type: ignore[arg-type] - } - ) - - return buffer_responses - - -async def _make_other_request( - store: _UpstreamObjectStore, - request: _OtherRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make offset or full-file requests. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - if request["range"] is None: - resp = await obs.get_async(store, request["path"]) - else: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _make_suffix_request( - store: _UpstreamObjectStore, - request: _SuffixRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make suffix requests. - - This is separated out from `_make_other_request` because some object stores (Azure) - don't support suffix requests. In this case, our workaround is to first get the - length of the object and then manually request the byte range at the end. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - try: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(store, request["path"]) - file_size = head_resp["size"] - suffix_len = request["range"]["suffix"] - buffer = await obs.get_range_async( - store, - request["path"], - start=file_size - suffix_len, - length=suffix_len, - ) - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _get_partial_values( - store: _UpstreamObjectStore, - prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRequest | None]], -) -> list[Buffer | None]: - """Make multiple range requests. - - ObjectStore has a `get_ranges` method that will additionally merge nearby ranges, - but it's _per_ file. So we need to split these key_ranges into **per-file** key - ranges, and then reassemble the results in the original order. - - We separate into different requests: - - - One call to `obstore.get_ranges_async` **per target file** - - One call to `obstore.get_async` for each other request. - """ - key_ranges = list(key_ranges) - per_file_bounded_requests: dict[str, list[_BoundedRequest]] = defaultdict(list) - other_requests: list[_OtherRequest] = [] - suffix_requests: list[_SuffixRequest] = [] - - for idx, (path, byte_range) in enumerate(key_ranges): - if byte_range is None: - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": None, - } - ) - elif isinstance(byte_range, RangeByteRequest): - per_file_bounded_requests[path].append( - {"original_request_index": idx, "start": byte_range.start, "end": byte_range.end} - ) - elif isinstance(byte_range, OffsetByteRequest): - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"offset": byte_range.offset}, - } - ) - elif isinstance(byte_range, SuffixByteRequest): - suffix_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"suffix": byte_range.suffix}, - } - ) - else: - raise ValueError(f"Unsupported range input: {byte_range}") - - # Use global semaphore for process-wide concurrency limiting - semaphore = get_global_semaphore() - - futs: list[Coroutine[Any, Any, list[_Response]]] = [] - for path, bounded_ranges in per_file_bounded_requests.items(): - futs.append( - _make_bounded_requests(store, path, bounded_ranges, prototype, semaphore=semaphore) - ) - - for request in other_requests: - futs.append(_make_other_request(store, request, prototype, semaphore=semaphore)) # noqa: PERF401 - - for suffix_request in suffix_requests: - futs.append(_make_suffix_request(store, suffix_request, prototype, semaphore=semaphore)) # noqa: PERF401 - - buffers: list[Buffer | None] = [None] * len(key_ranges) - - for responses in await asyncio.gather(*futs): - for resp in responses: - buffers[resp["original_request_index"]] = resp["buffer"] - - return buffers diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 0deac52dd9..80bce250e9 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -70,10 +70,12 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: self = args[0] - semaphore: asyncio.Semaphore = getattr(self, semaphore_attr) + semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr) - # Apply concurrency limit - async with semaphore: + if semaphore is not None: + async with semaphore: + return await func(*args, **kwargs) + else: return await func(*args, **kwargs) return wrapper diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py index 06cf23857d..0dd6dcff17 100644 --- a/src/zarr/testing/store_concurrency.py +++ b/src/zarr/testing/store_concurrency.py @@ -154,7 +154,7 @@ async def test_batch_read_no_deadlock(self, store: S) -> None: # Verify results assert len(results) == num_items - for result, (key, expected_value) in zip(results, test_data.items()): + for result, (_key, expected_value) in zip(results, test_data.items(), strict=True): assert result is not None assert result.to_bytes() == expected_value.to_bytes() @@ -188,7 +188,9 @@ async def test_concurrent_operations_correctness(self, store: S) -> None: ] # Write all concurrently - await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)]) + await asyncio.gather( + *[store.set(k, v) for k, v in zip(write_keys, write_values, strict=True)] + ) # Read all concurrently results = await asyncio.gather( @@ -196,7 +198,7 @@ async def test_concurrent_operations_correctness(self, store: S) -> None: ) # Verify correctness - for result, expected in zip(results, write_values): + for result, expected in zip(results, write_values, strict=True): assert result is not None assert result.to_bytes() == expected.to_bytes() diff --git a/tests/test_common.py b/tests/test_common.py index 0944c3375a..9484d15ca3 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,10 +31,6 @@ def test_access_modes() -> None: assert set(ANY_ACCESS_MODE) == set(get_args(AccessModeLiteral)) -# todo: test -def test_concurrent_map() -> None: ... - - # todo: test def test_to_thread() -> None: ... diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py index f6366e3c53..3cfca5052c 100644 --- a/tests/test_global_concurrency.py +++ b/tests/test_global_concurrency.py @@ -1,327 +1,43 @@ """ -Tests for global per-process concurrency limiting. +Tests for store-level concurrency limiting through the array API. """ -import asyncio -from typing import Any - import numpy as np -import pytest import zarr -from zarr.core.common import get_global_semaphore, reset_global_semaphores -from zarr.core.config import config - - -class TestGlobalSemaphore: - """Tests for the global semaphore management.""" - - async def test_get_global_semaphore_creates_per_loop(self) -> None: - """Test that each event loop gets its own semaphore.""" - sem1 = get_global_semaphore() - assert sem1 is not None - assert isinstance(sem1, asyncio.Semaphore) - - # Getting it again should return the same instance - sem2 = get_global_semaphore() - assert sem1 is sem2 - - async def test_global_semaphore_uses_config_limit(self) -> None: - """Test that the global semaphore respects the configured limit.""" - # Set a custom concurrency limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 5}) - - # Clear existing semaphores to force recreation - reset_global_semaphores() - - sem = get_global_semaphore() - - # The semaphore should have the configured limit - # We can verify this by acquiring all tokens and checking the semaphore is locked - for i in range(5): - await sem.acquire() - if i < 4: - assert not sem.locked() # Should still have capacity - else: - assert sem.locked() # All tokens acquired, semaphore is now locked - - # Release all tokens - for _ in range(5): - sem.release() - - finally: - # Restore original config - config.set({"async.concurrency": original_limit}) - # Clear semaphores again to reset state - reset_global_semaphores() - - async def test_global_semaphore_shared_across_operations(self) -> None: - """Test that multiple concurrent operations share the same semaphore.""" - # Track the maximum number of concurrent tasks - max_concurrent = 0 - current_concurrent = 0 - lock = asyncio.Lock() - - async def tracked_operation() -> None: - """An operation that tracks concurrency.""" - nonlocal max_concurrent, current_concurrent - - async with lock: - current_concurrent += 1 - max_concurrent = max(max_concurrent, current_concurrent) - - # Small delay to ensure overlap - await asyncio.sleep(0.01) - - async with lock: - current_concurrent -= 1 - - # Set a low concurrency limit to make the test observable - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 5}) - - # Clear existing semaphores - reset_global_semaphores() - - # Get the global semaphore - sem = get_global_semaphore() - - # Create many tasks that use the semaphore - async def task_with_semaphore() -> None: - async with sem: - await tracked_operation() - - # Launch 20 tasks (4x the limit) - tasks = [task_with_semaphore() for _ in range(20)] - await asyncio.gather(*tasks) - - # Maximum concurrent should respect the limit - assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5" - assert max_concurrent >= 3, ( - f"Max concurrent was {max_concurrent}, expected some concurrency" - ) - - finally: - config.set({"async.concurrency": original_limit}) - reset_global_semaphores() - - async def test_semaphore_reuse_across_calls(self) -> None: - """Test that repeated calls to get_global_semaphore return the same instance.""" - reset_global_semaphores() - - # Call multiple times and verify we get the same instance - sem1 = get_global_semaphore() - sem2 = get_global_semaphore() - sem3 = get_global_semaphore() - - assert sem1 is sem2 is sem3, "Should return same semaphore instance on repeated calls" - - # Verify it's still the same after using it - async with sem1: - sem4 = get_global_semaphore() - assert sem1 is sem4 - - def test_config_change_after_creation(self) -> None: - """Test and document that config changes don't affect existing semaphores.""" - original_limit: Any = config.get("async.concurrency") - try: - # Set initial config - config.set({"async.concurrency": 5}) - - async def check_limit() -> None: - reset_global_semaphores() - - # Create semaphore with limit=5 - sem1 = get_global_semaphore() - initial_capacity: int = sem1._value - - # Change config - config.set({"async.concurrency": 50}) - - # Get semaphore again - should be same instance with old limit - sem2 = get_global_semaphore() - assert sem1 is sem2, "Should return same semaphore instance" - assert sem2._value == initial_capacity, ( - f"Semaphore limit changed from {initial_capacity} to {sem2._value}. " - "Config changes should not affect existing semaphores." - ) - - # Clean up - reset_global_semaphores() - - asyncio.run(check_limit()) - - finally: - config.set({"async.concurrency": original_limit}) - - -class TestArrayConcurrency: - """Tests that array operations use global concurrency limiting.""" - - @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") - async def test_multiple_arrays_share_concurrency_limit(self) -> None: - """Test that reading from multiple arrays shares the global concurrency limit.""" - from zarr.core.common import concurrent_map - - # Track concurrent task executions - max_concurrent_tasks = 0 - current_concurrent_tasks = 0 - task_lock = asyncio.Lock() - - async def tracked_chunk_operation(chunk_id: int) -> int: - """Simulate a chunk operation with tracking.""" - nonlocal max_concurrent_tasks, current_concurrent_tasks - - async with task_lock: - current_concurrent_tasks += 1 - max_concurrent_tasks = max(max_concurrent_tasks, current_concurrent_tasks) - - # Small delay to simulate I/O - await asyncio.sleep(0.001) - - async with task_lock: - current_concurrent_tasks -= 1 - - return chunk_id - - # Set a low concurrency limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 10}) - - # Clear existing semaphores - reset_global_semaphores() - - # Simulate reading many chunks using concurrent_map (which uses the global semaphore) - # This simulates what happens when reading from multiple arrays - chunk_ids = [(i,) for i in range(100)] - await concurrent_map(chunk_ids, tracked_chunk_operation) - - # The maximum concurrent tasks should respect the global limit - assert max_concurrent_tasks <= 10, ( - f"Max concurrent tasks was {max_concurrent_tasks}, expected <= 10" - ) - - assert max_concurrent_tasks >= 5, ( - f"Max concurrent tasks was {max_concurrent_tasks}, " - f"expected at least some concurrency" - ) - - finally: - config.set({"async.concurrency": original_limit}) - # Note: We don't reset_global_semaphores() here because doing so while - # many tasks are still cleaning up can trigger ResourceWarnings from - # asyncio internals. The semaphore will be reused by subsequent tests. - - def test_sync_api_uses_global_concurrency(self) -> None: - """Test that synchronous API also benefits from global concurrency limiting.""" - # This test verifies that the sync API (which wraps async) uses global limiting - - # Set a low concurrency limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 8}) - - # Create a small array - the key is that zarr internally uses - # concurrent_map which now uses the global semaphore - store = zarr.storage.MemoryStore() - arr = zarr.create( - shape=(20, 20), - chunks=(10, 10), - dtype="i4", - store=store, - zarr_format=3, - ) - arr[:] = 42 - - # Read data (synchronously) - data = arr[:] - - # Verify we got the right data - assert np.all(data == 42) - - # The test passes if no errors occurred - # The concurrency limiting is happening under the hood - - finally: - config.set({"async.concurrency": original_limit}) - - -class TestConcurrentMapGlobal: - """Tests for concurrent_map using global semaphore.""" - - async def test_concurrent_map_uses_global_by_default(self) -> None: - """Test that concurrent_map uses global semaphore by default.""" - from zarr.core.common import concurrent_map - - # Track concurrent executions - max_concurrent = 0 - current_concurrent = 0 - lock = asyncio.Lock() - - async def tracked_task(x: int) -> int: - nonlocal max_concurrent, current_concurrent - - async with lock: - current_concurrent += 1 - max_concurrent = max(max_concurrent, current_concurrent) - - await asyncio.sleep(0.01) - - async with lock: - current_concurrent -= 1 - - return x * 2 - - # Set a low limit - original_limit: Any = config.get("async.concurrency") - try: - config.set({"async.concurrency": 5}) - - # Clear existing semaphores - reset_global_semaphores() - - # Use concurrent_map with default settings (use_global_semaphore=True) - items = [(i,) for i in range(20)] - results = await concurrent_map(items, tracked_task) - - assert len(results) == 20 - assert max_concurrent <= 5 - assert max_concurrent >= 3 # Should have some concurrency - - finally: - config.set({"async.concurrency": original_limit}) - reset_global_semaphores() - - async def test_concurrent_map_legacy_mode(self) -> None: - """Test that concurrent_map legacy mode still works.""" - from zarr.core.common import concurrent_map - - async def simple_task(x: int) -> int: - await asyncio.sleep(0.001) - return x * 2 - - # Use legacy mode with local limit - items = [(i,) for i in range(10)] - results = await concurrent_map(items, simple_task, limit=3, use_global_semaphore=False) - - assert len(results) == 10 - assert results == [i * 2 for i in range(10)] - - async def test_concurrent_map_parameter_validation(self) -> None: - """Test that concurrent_map validates conflicting parameters.""" - from zarr.core.common import concurrent_map - - async def simple_task(x: int) -> int: - return x * 2 - items = [(i,) for i in range(10)] - # Should raise ValueError when both limit and use_global_semaphore=True - with pytest.raises( - ValueError, match="Cannot specify both use_global_semaphore=True and a limit" - ): - await concurrent_map(items, simple_task, limit=5, use_global_semaphore=True) +class TestStoreConcurrencyThroughArrayAPI: + """Tests that store-level concurrency limiting works through the array API.""" + + def test_array_operations_with_store_concurrency(self) -> None: + """Test that array read/write works correctly with store-level concurrency limits.""" + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) + + def test_array_operations_with_local_store_concurrency(self, tmp_path: object) -> None: + """Test that array read/write works correctly with LocalStore concurrency limits.""" + store = zarr.storage.LocalStore(str(tmp_path), concurrency_limit=10) + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) From 9cec350fc2118cf3dfb14df0ba63fbc2109b0266 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 19 Feb 2026 22:09:33 +0100 Subject: [PATCH 08/12] lint --- .github/dependabot.yml | 11 - README.md | 49 +- changes/3655.bugfix.md | 1 + changes/3657.bugfix.md | 1 + changes/3695.bugfix.md | 1 + changes/3700.bugfix.md | 1 + changes/3702.bugfix.md | 1 + changes/3704.misc.md | 1 + changes/3705.bugfix.md | 1 + changes/3706.misc.md | 1 + changes/3708.misc.md | 1 + changes/3712.misc.md | 1 + docs/contributing.md | 16 +- docs/quick-start.md | 4 +- docs/user-guide/arrays.md | 4 +- docs/user-guide/extending.md | 2 +- docs/user-guide/groups.md | 2 +- docs/user-guide/storage.md | 4 +- docs/user-guide/v3_migration.md | 2 +- src/zarr/codecs/bytes.py | 11 +- src/zarr/codecs/sharding.py | 13 +- src/zarr/core/array.py | 962 +++++++++++++++----- src/zarr/core/common.py | 16 +- src/zarr/core/dtype/npy/string.py | 37 + src/zarr/core/dtype/registry.py | 4 + src/zarr/core/indexing.py | 109 ++- src/zarr/experimental/cache_store.py | 143 +-- src/zarr/storage/_logging.py | 3 + src/zarr/storage/_obstore.py | 23 +- src/zarr/storage/_wrapper.py | 13 +- src/zarr/testing/store.py | 45 +- src/zarr/testing/store_concurrency.py | 10 +- tests/benchmarks/test_indexing.py | 215 +++++ tests/test_array.py | 67 ++ tests/test_codecs/test_codecs.py | 66 +- tests/test_codecs/test_sharding.py | 47 +- tests/test_common.py | 14 +- tests/test_dtype_registry.py | 12 + tests/test_experimental/test_cache_store.py | 102 ++- tests/test_store/test_latency.py | 57 ++ tests/test_store/test_logging.py | 40 + 41 files changed, 1658 insertions(+), 455 deletions(-) create mode 100644 changes/3655.bugfix.md create mode 100644 changes/3657.bugfix.md create mode 100644 changes/3695.bugfix.md create mode 100644 changes/3700.bugfix.md create mode 100644 changes/3702.bugfix.md create mode 100644 changes/3704.misc.md create mode 100644 changes/3705.bugfix.md create mode 100644 changes/3706.misc.md create mode 100644 changes/3708.misc.md create mode 100644 changes/3712.misc.md create mode 100644 tests/test_store/test_latency.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 469b6a4d19..82419a5143 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,17 +10,6 @@ updates: actions: patterns: - "*" - - # Updates for support/v2 branch - - package-ecosystem: "pip" - directory: "/" - target-branch: "support/v2" - schedule: - interval: "weekly" - groups: - requirements: - patterns: - - "*" - package-ecosystem: "github-actions" directory: "/" target-branch: "support/v2" diff --git a/README.md b/README.md index 97f5617934..3911ba17b8 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ + @@ -23,9 +24,9 @@ Package Status - - status - + + status + @@ -47,17 +48,16 @@ Pre-commit Status - + pre-commit status - Coverage - coverage + coverage @@ -70,27 +70,28 @@ - Developer Chat - - - - - + Developer Chat + + + + + - Funding - - - CZI's Essential Open Source Software for Science - - + Funding + + + CZI's Essential Open Source Software for Science + + - Citation - - - DOI - - + + Citation + + + DOI + + diff --git a/changes/3655.bugfix.md b/changes/3655.bugfix.md new file mode 100644 index 0000000000..67d384f00d --- /dev/null +++ b/changes/3655.bugfix.md @@ -0,0 +1 @@ +Fixed a bug in the sharding codec that prevented nested shard reads in certain cases. \ No newline at end of file diff --git a/changes/3657.bugfix.md b/changes/3657.bugfix.md new file mode 100644 index 0000000000..1411704674 --- /dev/null +++ b/changes/3657.bugfix.md @@ -0,0 +1 @@ +Fix obstore _transform_list_dir implementation to correctly relativize paths (removing lstrip usage). \ No newline at end of file diff --git a/changes/3695.bugfix.md b/changes/3695.bugfix.md new file mode 100644 index 0000000000..a7d847e4f1 --- /dev/null +++ b/changes/3695.bugfix.md @@ -0,0 +1 @@ +Raise error when trying to encode :class:`numpy.dtypes.StringDType` with `na_object` set. \ No newline at end of file diff --git a/changes/3700.bugfix.md b/changes/3700.bugfix.md new file mode 100644 index 0000000000..86acb71d0e --- /dev/null +++ b/changes/3700.bugfix.md @@ -0,0 +1 @@ +CacheStore, LoggingStore and LatencyStore now support with_read_only. \ No newline at end of file diff --git a/changes/3702.bugfix.md b/changes/3702.bugfix.md new file mode 100644 index 0000000000..94a2902567 --- /dev/null +++ b/changes/3702.bugfix.md @@ -0,0 +1 @@ +Skip chunk coordinate enumeration in resize when the array is only growing, avoiding unbounded memory usage for large arrays. \ No newline at end of file diff --git a/changes/3704.misc.md b/changes/3704.misc.md new file mode 100644 index 0000000000..d15d4924e0 --- /dev/null +++ b/changes/3704.misc.md @@ -0,0 +1 @@ +Remove an expensive `isinstance` check from the bytes codec decoding routine. \ No newline at end of file diff --git a/changes/3705.bugfix.md b/changes/3705.bugfix.md new file mode 100644 index 0000000000..2abcb4ee7c --- /dev/null +++ b/changes/3705.bugfix.md @@ -0,0 +1 @@ +Fix a performance bug in morton curve generation. \ No newline at end of file diff --git a/changes/3706.misc.md b/changes/3706.misc.md new file mode 100644 index 0000000000..70a0e44c58 --- /dev/null +++ b/changes/3706.misc.md @@ -0,0 +1 @@ +Allow NumPy ints as input when declaring a shape. \ No newline at end of file diff --git a/changes/3708.misc.md b/changes/3708.misc.md new file mode 100644 index 0000000000..dce7546c97 --- /dev/null +++ b/changes/3708.misc.md @@ -0,0 +1 @@ +Optimize Morton order computation with hypercube optimization, vectorized decoding, and singleton dimension removal, providing 10-45x speedup for typical chunk shapes. diff --git a/changes/3712.misc.md b/changes/3712.misc.md new file mode 100644 index 0000000000..8fa2f2d2f7 --- /dev/null +++ b/changes/3712.misc.md @@ -0,0 +1 @@ +Added benchmarks for Morton order computation in sharded arrays. diff --git a/docs/contributing.md b/docs/contributing.md index c330504536..e4f341a8b3 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -131,7 +131,7 @@ The hooks can be installed locally by running: prek install ``` -This would run the checks every time a commit is created locally. The checks will by default only run on the files modified by a commit, but the checks can be triggered for all the files by running: +This will run the checks every time a commit is created locally. The checks will by default only run on the files modified by a commit, but the checks can be triggered for all the files by running: ```bash prek run --all-files @@ -249,13 +249,13 @@ Pull requests submitted by an external contributor should be reviewed and approv Pull requests should not be merged until all CI checks have passed (GitHub Actions, Codecov) against code that has had the latest main merged in. -Before merging the milestone must be set either to decide whether a PR will be in the next patch, minor, or major release. The next section explains which types of changes go in each release. +Before merging, the milestone must be set to decide whether a PR will be in the next patch, minor, or major release. The next section explains which types of changes go in each release. ## Compatibility and versioning policies ### Versioning -Versions of this library are identified by a triplet of integers with the form `..`, for example `3.0.4`. A release of `zarr-python` is associated with a new version identifier. That new identifier is generated by incrementing exactly one of the components of the previous version identifier by 1. When incrementing the `major` component of the version identifier, the `minor` and `patch` components is reset to 0. When incrementing the minor component, the patch component is reset to 0. +Versions of this library are identified by a triplet of integers with the form `..`, for example `3.0.4`. A release of `zarr-python` is associated with a new version identifier. That new identifier is generated by incrementing exactly one of the components of the previous version identifier by 1. When incrementing the `major` component of the version identifier, the `minor` and `patch` components are reset to 0. When incrementing the minor component, the patch component is reset to 0. Releases are classified by the library changes contained in that release. This classification determines which component of the version identifier is incremented on release. @@ -263,7 +263,7 @@ Releases are classified by the library changes contained in that release. This c Users and downstream projects should carefully consider the impact of a major release before adopting it. In advance of a major release, developers should communicate the scope of the upcoming changes, and help users prepare for them. -* **minor** releases (for example, `3.0.0` -> `3.1.0`) are for changes that do not require significant effort from most users or downstream downstream projects to respond to. API changes are possible in minor releases if the burden on users imposed by those changes is sufficiently small. +* **minor** releases (for example, `3.0.0` -> `3.1.0`) are for changes that do not require significant effort from most users or downstream projects to respond to. API changes are possible in minor releases if the burden on users imposed by those changes is sufficiently small. For example, a recently released API may need fixes or refinements that are breaking, but low impact due to the recency of the feature. Such API changes are permitted in a minor release. @@ -271,11 +271,11 @@ Releases are classified by the library changes contained in that release. This c * **patch** releases (for example, `3.1.0` -> `3.1.1`) are for changes that contain no breaking or behaviour changes for downstream projects or users. Examples of changes suitable for a patch release are bugfixes and documentation improvements. - Users should always feel safe upgrading to a the latest patch release. + Users should always feel safe upgrading to the latest patch release. Note that this versioning scheme is not consistent with [Semantic Versioning](https://semver.org/). Contrary to SemVer, the Zarr library may release breaking changes in `minor` releases, or even `patch` releases under exceptional circumstances. But we should strive to avoid doing so. -A better model for our versioning scheme is [Intended Effort Versioning](https://jacobtomlinson.dev/effver/), or "EffVer". The guiding principle off EffVer is to categorize releases based on the *expected effort required to upgrade to that release*. +A better model for our versioning scheme is [Intended Effort Versioning](https://jacobtomlinson.dev/effver/), or "EffVer". The guiding principle of EffVer is to categorize releases based on the *expected effort required to upgrade to that release*. Zarr developers should make changes as smooth as possible for users. This means making backwards-compatible changes wherever possible. When a backwards-incompatible change is necessary, users should be notified well in advance, e.g. via informative deprecation warnings. @@ -288,12 +288,12 @@ If an existing Zarr format version changes, or a new version of the Zarr format ## Release procedure Open an issue on GitHub announcing the release using the release checklist template: -[https://github.com/zarr-developers/zarr-python/issues/new?template=release-checklist.md](https://github.com/zarr-developers/zarr-python/issues/new?template=release-checklist.md>). The release checklist includes all steps necessary for the release. +[https://github.com/zarr-developers/zarr-python/issues/new?template=release-checklist.md](https://github.com/zarr-developers/zarr-python/issues/new?template=release-checklist.md). The release checklist includes all steps necessary for the release. ## Benchmarks Zarr uses [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/latest/) for running -performance benchmarks as part of our test suite. The benchmarks can be are found in `tests/benchmarks`. +performance benchmarks as part of our test suite. The benchmarks are found in `tests/benchmarks`. By default pytest is configured to run these benchmarks as plain tests (i.e., no benchmarking). To run a benchmark with timing measurements, use the `--benchmark-enable` when invoking `pytest`. diff --git a/docs/quick-start.md b/docs/quick-start.md index 42ac95d169..bb7a556b96 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -1,4 +1,4 @@ -This section will help you get up and running with +This section will help you get up and running with the Zarr library in Python to efficiently manage and analyze multi-dimensional arrays. ### Creating an Array @@ -92,7 +92,7 @@ spam[:] = np.arange(10) print(root.tree()) ``` -This creates a group with two datasets: `foo` and `bar`. +This creates a group hierarchy with a group (`foo`) and two arrays (`bar` and `spam`). #### Batch Hierarchy Creation diff --git a/docs/user-guide/arrays.md b/docs/user-guide/arrays.md index f63f5bc6b2..cd6a93cac9 100644 --- a/docs/user-guide/arrays.md +++ b/docs/user-guide/arrays.md @@ -72,7 +72,7 @@ print(z[:, 0]) print(z[:]) ``` -Read more about NumPy-style indexing can be found in the +More information about NumPy-style indexing can be found in the [NumPy documentation](https://numpy.org/doc/stable/user/basics.indexing.html). ## Persistent arrays @@ -297,7 +297,7 @@ array without loading the entire array into memory. Note that although this functionality is similar to some of the advanced indexing capabilities available on NumPy arrays and on h5py datasets, **the Zarr API for advanced indexing is different from both NumPy and h5py**, so please -read this section carefully. For a complete description of the indexing API, +read this section carefully. For a complete description of the indexing API, see the documentation for the [`zarr.Array`][] class. ### Indexing with coordinate arrays diff --git a/docs/user-guide/extending.md b/docs/user-guide/extending.md index d857fa3356..39444135df 100644 --- a/docs/user-guide/extending.md +++ b/docs/user-guide/extending.md @@ -29,7 +29,7 @@ of the array data. Examples include compression codecs, such as Custom codecs for Zarr are implemented by subclassing the relevant base class, see [`zarr.abc.codec.ArrayArrayCodec`][], [`zarr.abc.codec.ArrayBytesCodec`][] and -[`zarr.abc.codec.BytesBytesCodec`][]. Most custom codecs should implemented the +[`zarr.abc.codec.BytesBytesCodec`][]. Most custom codecs should implement the `_encode_single` and `_decode_single` methods. These methods operate on single chunks of the array data. Alternatively, custom codecs can implement the `encode` and `decode` methods, which operate on batches of chunks, in case the codec is intended to implement diff --git a/docs/user-guide/groups.md b/docs/user-guide/groups.md index 57201216b6..e093590dfe 100644 --- a/docs/user-guide/groups.md +++ b/docs/user-guide/groups.md @@ -13,7 +13,7 @@ root = zarr.create_group(store=store) print(root) ``` -Groups have a similar API to the Group class from [h5py](https://www.h5py.org/). For example, groups can contain other groups: +Groups have a similar API to the Group class from [h5py](https://www.h5py.org/). For example, groups can contain other groups: ```python exec="true" session="groups" source="above" foo = root.create_group('foo') diff --git a/docs/user-guide/storage.md b/docs/user-guide/storage.md index 82b576b889..e75cd21381 100644 --- a/docs/user-guide/storage.md +++ b/docs/user-guide/storage.md @@ -91,8 +91,8 @@ print(group) ## Explicit Store Creation -In some cases, it may be helpful to create a store instance directly. Zarr-Python offers four -built-in store: [`zarr.storage.LocalStore`][], [`zarr.storage.FsspecStore`][], +In some cases, it may be helpful to create a store instance directly. Zarr-Python offers +built-in stores: [`zarr.storage.LocalStore`][], [`zarr.storage.FsspecStore`][], [`zarr.storage.ZipStore`][], [`zarr.storage.MemoryStore`][], and [`zarr.storage.ObjectStore`][]. ### Local Store diff --git a/docs/user-guide/v3_migration.md b/docs/user-guide/v3_migration.md index 15425de27a..d5a8067a88 100644 --- a/docs/user-guide/v3_migration.md +++ b/docs/user-guide/v3_migration.md @@ -20,7 +20,7 @@ so we can improve this guide. The goals described above necessitated some breaking changes to the API (hence the major version update), but where possible we have maintained backwards compatibility -in the most widely used parts of the API. This in the [`zarr.Array`][] and +in the most widely used parts of the API. This includes the [`zarr.Array`][] and [`zarr.Group`][] classes and the "top-level API" (e.g. [`zarr.open_array`][] and [`zarr.open_group`][]). diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 39c26bd4a8..1fbdeef497 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -5,10 +5,8 @@ from enum import Enum from typing import TYPE_CHECKING -import numpy as np - from zarr.abc.codec import ArrayBytesCodec -from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer +from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import JSON, parse_enum, parse_named_configuration from zarr.core.dtype.common import HasEndianness @@ -72,7 +70,6 @@ async def _decode_single( chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> NDBuffer: - assert isinstance(chunk_bytes, Buffer) # TODO: remove endianness enum in favor of literal union endian_str = self.endian.value if self.endian is not None else None if isinstance(chunk_spec.dtype, HasEndianness): @@ -80,12 +77,8 @@ async def _decode_single( else: dtype = chunk_spec.dtype.to_native_dtype() as_array_like = chunk_bytes.as_array_like() - if isinstance(as_array_like, NDArrayLike): - as_nd_array_like = as_array_like - else: - as_nd_array_like = np.asanyarray(as_array_like) chunk_array = chunk_spec.prototype.nd_buffer.from_ndarray_like( - as_nd_array_like.view(dtype=dtype) + as_array_like.view(dtype=dtype) # type: ignore[attr-defined] ) # ensure correct chunk shape diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 8124ea44ea..b54b3c2257 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -52,6 +52,7 @@ ) from zarr.core.metadata.v3 import parse_codecs from zarr.registry import get_ndbuffer_class, get_pipeline_class +from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: from collections.abc import Iterator @@ -86,11 +87,16 @@ class _ShardingByteGetter(ByteGetter): async def get( self, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: - assert byte_range is None, "byte_range is not supported within shards" assert prototype == default_buffer_prototype(), ( f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" ) - return self.shard_dict.get(self.chunk_coords) + value = self.shard_dict.get(self.chunk_coords) + if value is None: + return None + if byte_range is None: + return value + start, stop = _normalize_byte_range_index(value, byte_range) + return value[start:stop] @dataclass(frozen=True) @@ -597,7 +603,8 @@ async def _decode_shard_index( ) ) ) - assert index_array is not None + # This cannot be None because we have the bytes already + index_array = cast(NDBuffer, index_array) return _ShardIndex(index_array.as_numpy_array()) async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 64f79e68cd..3974a4d8b4 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1370,13 +1370,7 @@ async def example(): result = asyncio.run(example()) ``` """ - if self.shards is None: - chunks_per_shard = 1 - else: - chunks_per_shard = product( - tuple(a // b for a, b in zip(self.shards, self.chunks, strict=True)) - ) - return (await self._nshards_initialized()) * chunks_per_shard + return await _nchunks_initialized(self) async def _nshards_initialized(self) -> int: """ @@ -1414,10 +1408,10 @@ async def example(): result = asyncio.run(example()) ``` """ - return len(await _shards_initialized(self)) + return await _nshards_initialized(self) async def nbytes_stored(self) -> int: - return await self.store_path.store.getsize_prefix(self.store_path.path) + return await _nbytes_stored(self.store_path) def _iter_chunk_coords( self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None @@ -1582,49 +1576,16 @@ async def _get_selection( out: NDBuffer | None = None, fields: Fields | None = None, ) -> NDArrayLikeOrScalar: - # check fields are sensible - out_dtype = check_fields(fields, self.dtype) - - # setup output buffer - if out is not None: - if isinstance(out, NDBuffer): - out_buffer = out - else: - raise TypeError(f"out argument needs to be an NDBuffer. Got {type(out)!r}") - if out_buffer.shape != indexer.shape: - raise ValueError( - f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}" - ) - else: - out_buffer = prototype.nd_buffer.empty( - shape=indexer.shape, - dtype=out_dtype, - order=self.order, - ) - if product(indexer.shape) > 0: - # need to use the order from the metadata for v2 - _config = self.config - if self.metadata.zarr_format == 2: - _config = replace(_config, order=self.order) - - # reading chunks and decoding them - await self.codec_pipeline.read( - [ - ( - self.store_path / self.metadata.encode_chunk_key(chunk_coords), - self.metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), - chunk_selection, - out_selection, - is_complete_chunk, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer - ], - out_buffer, - drop_axes=indexer.drop_axes, - ) - if isinstance(indexer, BasicIndexer) and indexer.shape == (): - return out_buffer.as_scalar() - return out_buffer.as_ndarray_like() + return await _get_selection( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, + indexer, + prototype=prototype, + out=out, + fields=fields, + ) async def getitem( self, @@ -1669,14 +1630,14 @@ async def example(): value = asyncio.run(example()) ``` """ - if prototype is None: - prototype = default_buffer_prototype() - indexer = BasicIndexer( + return await _getitem( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, selection, - shape=self.metadata.shape, - chunk_grid=self.metadata.chunk_grid, + prototype=prototype, ) - return await self._get_selection(indexer, prototype=prototype) async def get_orthogonal_selection( self, @@ -1686,11 +1647,15 @@ async def get_orthogonal_selection( fields: Fields | None = None, prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: - if prototype is None: - prototype = default_buffer_prototype() - indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) - return await self._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + return await _get_orthogonal_selection( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, + selection, + out=out, + fields=fields, + prototype=prototype, ) async def get_mask_selection( @@ -1701,11 +1666,15 @@ async def get_mask_selection( fields: Fields | None = None, prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: - if prototype is None: - prototype = default_buffer_prototype() - indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) - return await self._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + return await _get_mask_selection( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, + mask, + out=out, + fields=fields, + prototype=prototype, ) async def get_coordinate_selection( @@ -1716,18 +1685,17 @@ async def get_coordinate_selection( fields: Fields | None = None, prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: - if prototype is None: - prototype = default_buffer_prototype() - indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) - out_array = await self._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + return await _get_coordinate_selection( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, + selection, + out=out, + fields=fields, + prototype=prototype, ) - if hasattr(out_array, "shape"): - # restore shape - out_array = np.array(out_array).reshape(indexer.sel_shape) - return out_array - async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: """ Asynchronously save the array metadata. @@ -1742,56 +1710,15 @@ async def _set_selection( prototype: BufferPrototype, fields: Fields | None = None, ) -> None: - # check fields are sensible - check_fields(fields, self.dtype) - fields = check_no_multi_fields(fields) - - # check value shape - if np.isscalar(value): - array_like = prototype.buffer.create_zero_length().as_array_like() - if isinstance(array_like, np._typing._SupportsArrayFunc): - # TODO: need to handle array types that don't support __array_function__ - # like PyTorch and JAX - array_like_ = cast("np._typing._SupportsArrayFunc", array_like) - value = np.asanyarray(value, dtype=self.dtype, like=array_like_) - else: - if not hasattr(value, "shape"): - value = np.asarray(value, self.dtype) - # assert ( - # value.shape == indexer.shape - # ), f"shape of value doesn't match indexer shape. Expected {indexer.shape}, got {value.shape}" - if not hasattr(value, "dtype") or value.dtype.name != self.dtype.name: - if hasattr(value, "astype"): - # Handle things that are already NDArrayLike more efficiently - value = value.astype(dtype=self.dtype, order="A") - else: - value = np.array(value, dtype=self.dtype, order="A") - value = cast("NDArrayLike", value) - - # We accept any ndarray like object from the user and convert it - # to an NDBuffer (or subclass). From this point onwards, we only pass - # Buffer and NDBuffer between components. - value_buffer = prototype.nd_buffer.from_ndarray_like(value) - - # need to use the order from the metadata for v2 - _config = self.config - if self.metadata.zarr_format == 2: - _config = replace(_config, order=self.metadata.order) - - # merging with existing data and encoding chunks - await self.codec_pipeline.write( - [ - ( - self.store_path / self.metadata.encode_chunk_key(chunk_coords), - self.metadata.get_chunk_spec(chunk_coords, _config, prototype), - chunk_selection, - out_selection, - is_complete_chunk, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer - ], - value_buffer, - drop_axes=indexer.drop_axes, + return await _set_selection( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, + indexer, + value, + prototype=prototype, + fields=fields, ) async def setitem( @@ -1833,14 +1760,15 @@ async def setitem( - This method is asynchronous and should be awaited. - Supports basic indexing, where the selection is contiguous and does not involve advanced indexing. """ - if prototype is None: - prototype = default_buffer_prototype() - indexer = BasicIndexer( + return await _setitem( + self.store_path, + self.metadata, + self.codec_pipeline, + self.config, selection, - shape=self.metadata.shape, - chunk_grid=self.metadata.chunk_grid, + value, + prototype=prototype, ) - return await self._set_selection(indexer, value, prototype=prototype) @property def oindex(self) -> AsyncOIndex[T_ArrayMetadata]: @@ -1882,31 +1810,7 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) ----- - This method is asynchronous and should be awaited. """ - new_shape = parse_shapelike(new_shape) - assert len(new_shape) == len(self.metadata.shape) - new_metadata = self.metadata.update_shape(new_shape) - - if delete_outside_chunks: - # Remove all chunks outside of the new shape - old_chunk_coords = set(self.metadata.chunk_grid.all_chunk_coords(self.metadata.shape)) - new_chunk_coords = set(self.metadata.chunk_grid.all_chunk_coords(new_shape)) - - async def _delete_key(key: str) -> None: - await (self.store_path / key).delete() - - # Store handles concurrency limiting internally - await asyncio.gather( - *[ - _delete_key(self.metadata.encode_chunk_key(chunk_coords)) - for chunk_coords in old_chunk_coords.difference(new_chunk_coords) - ] - ) - - # Write new metadata - await self._save_metadata(new_metadata) - - # Update metadata (in place) - object.__setattr__(self, "metadata", new_metadata) + return await _resize(self, new_shape, delete_outside_chunks) async def append(self, data: npt.ArrayLike, axis: int = 0) -> tuple[int, ...]: """Append `data` to `axis`. @@ -1927,40 +1831,7 @@ async def append(self, data: npt.ArrayLike, axis: int = 0) -> tuple[int, ...]: The size of all dimensions other than `axis` must match between this array and `data`. """ - # ensure data is array-like - if not hasattr(data, "shape"): - data = np.asanyarray(data) - - self_shape_preserved = tuple(s for i, s in enumerate(self.shape) if i != axis) - data_shape_preserved = tuple(s for i, s in enumerate(data.shape) if i != axis) - if self_shape_preserved != data_shape_preserved: - raise ValueError( - f"shape of data to append is not compatible with the array. " - f"The shape of the data is ({data_shape_preserved})" - f"and the shape of the array is ({self_shape_preserved})." - "All dimensions must match except for the dimension being " - "appended." - ) - # remember old shape - old_shape = self.shape - - # determine new shape - new_shape = tuple( - self.shape[i] if i != axis else self.shape[i] + data.shape[i] - for i in range(len(self.shape)) - ) - - # resize - await self.resize(new_shape) - - # store data - append_selection = tuple( - slice(None) if i != axis else slice(old_shape[i], new_shape[i]) - for i in range(len(self.shape)) - ) - await self.setitem(append_selection, data) - - return new_shape + return await _append(self, data, axis) async def update_attributes(self, new_attributes: dict[str, JSON]) -> Self: """ @@ -1988,11 +1859,7 @@ async def update_attributes(self, new_attributes: dict[str, JSON]) -> Self: - The updated attributes will be merged with existing attributes, and any conflicts will be overwritten by the new values. """ - self.metadata.attributes.update(new_attributes) - - # Write new metadata - await self._save_metadata(self.metadata) - + await _update_attributes(self, new_attributes) return self def __repr__(self) -> str: @@ -2049,10 +1916,7 @@ async def info_complete(self) -> Any: ------- [zarr.AsyncArray.info][] - A property giving just the statically known information about an array. """ - return self._info( - await self._nshards_initialized(), - await self.store_path.store.getsize_prefix(self.store_path.path), - ) + return await _info_complete(self) def _info( self, count_chunks_initialized: int | None = None, count_bytes_stored: int | None = None @@ -4616,9 +4480,7 @@ async def from_array( if write_data: if isinstance(data, Array): - async def _copy_array_region( - chunk_coords: tuple[int, ...] | slice, _data: AnyArray - ) -> None: + async def _copy_array_region(chunk_coords: tuple[slice, ...], _data: AnyArray) -> None: arr = await _data.async_array.getitem(chunk_coords) await result.setitem(chunk_coords, arr) @@ -4629,8 +4491,10 @@ async def _copy_array_region( ) else: - async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> None: - await result.setitem(chunk_coords, _data[chunk_coords]) + async def _copy_arraylike_region( + chunk_coords: tuple[slice, ...], _data: npt.ArrayLike + ) -> None: + await result.setitem(chunk_coords, _data[chunk_coords]) # type: ignore[call-overload, index] # Stream data from the source array to the new array # Store handles concurrency limiting internally @@ -5584,3 +5448,687 @@ def _iter_chunk_regions( return _iter_regions( array.shape, array.chunks, origin=origin, selection_shape=selection_shape, trim_excess=True ) + + +async def _nchunks_initialized( + array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], +) -> int: + """ + Calculate the number of chunks that have been initialized in storage. + + This value is calculated as the product of the number of initialized shards and the number + of chunks per shard. For arrays that do not use sharding, the number of chunks per shard is + effectively 1, and in that case the number of chunks initialized is the same as the number + of stored objects associated with an array. + + Parameters + ---------- + array : AsyncArray + The array to inspect. + + Returns + ------- + nchunks_initialized : int + The number of chunks that have been initialized. + """ + if array.shards is None: + chunks_per_shard = 1 + else: + chunks_per_shard = product( + tuple(a // b for a, b in zip(array.shards, array.chunks, strict=True)) + ) + return (await _nshards_initialized(array)) * chunks_per_shard + + +async def _nshards_initialized( + array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], +) -> int: + """ + Calculate the number of shards that have been initialized in storage. + + This is the number of shards that have been persisted to the storage backend. + + Parameters + ---------- + array : AsyncArray + The array to inspect. + + Returns + ------- + nshards_initialized : int + The number of shards that have been initialized. + """ + return len(await _shards_initialized(array)) + + +async def _nbytes_stored( + store_path: StorePath, +) -> int: + """ + Calculate the number of bytes stored for an array. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + + Returns + ------- + nbytes_stored : int + The number of bytes stored. + """ + return await store_path.store.getsize_prefix(store_path.path) + + +async def _get_selection( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + indexer: Indexer, + *, + prototype: BufferPrototype, + out: NDBuffer | None = None, + fields: Fields | None = None, +) -> NDArrayLikeOrScalar: + """ + Get a selection from an array. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + indexer : Indexer + The indexer specifying the selection. + prototype : BufferPrototype + A buffer prototype to use for the retrieved data. + out : NDBuffer | None, optional + An output buffer to write the data to. + fields : Fields | None, optional + Fields to select from structured arrays. + + Returns + ------- + NDArrayLikeOrScalar + The selected data. + """ + # Get dtype from metadata + if metadata.zarr_format == 2: + zdtype = metadata.dtype + else: + zdtype = metadata.data_type + dtype = zdtype.to_native_dtype() + + # Determine memory order + if metadata.zarr_format == 2: + order = metadata.order + else: + order = config.order + + # check fields are sensible + out_dtype = check_fields(fields, dtype) + + # setup output buffer + if out is not None: + if isinstance(out, NDBuffer): + out_buffer = out + else: + raise TypeError(f"out argument needs to be an NDBuffer. Got {type(out)!r}") + if out_buffer.shape != indexer.shape: + raise ValueError( + f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}" + ) + else: + out_buffer = prototype.nd_buffer.empty( + shape=indexer.shape, + dtype=out_dtype, + order=order, + ) + if product(indexer.shape) > 0: + # need to use the order from the metadata for v2 + _config = config + if metadata.zarr_format == 2: + _config = replace(_config, order=order) + + # reading chunks and decoding them + await codec_pipeline.read( + [ + ( + store_path / metadata.encode_chunk_key(chunk_coords), + metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), + chunk_selection, + out_selection, + is_complete_chunk, + ) + for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer + ], + out_buffer, + drop_axes=indexer.drop_axes, + ) + if isinstance(indexer, BasicIndexer) and indexer.shape == (): + return out_buffer.as_scalar() + return out_buffer.as_ndarray_like() + + +async def _getitem( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + selection: BasicSelection, + *, + prototype: BufferPrototype | None = None, +) -> NDArrayLikeOrScalar: + """ + Retrieve a subset of the array's data based on the provided selection. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + selection : BasicSelection + A selection object specifying the subset of data to retrieve. + prototype : BufferPrototype, optional + A buffer prototype to use for the retrieved data (default is None). + + Returns + ------- + NDArrayLikeOrScalar + The retrieved subset of the array's data. + """ + if prototype is None: + prototype = default_buffer_prototype() + indexer = BasicIndexer( + selection, + shape=metadata.shape, + chunk_grid=metadata.chunk_grid, + ) + return await _get_selection( + store_path, metadata, codec_pipeline, config, indexer, prototype=prototype + ) + + +async def _get_orthogonal_selection( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + selection: OrthogonalSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, +) -> NDArrayLikeOrScalar: + """ + Get an orthogonal selection from the array. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + selection : OrthogonalSelection + The orthogonal selection specification. + out : NDBuffer | None, optional + An output buffer to write the data to. + fields : Fields | None, optional + Fields to select from structured arrays. + prototype : BufferPrototype | None, optional + A buffer prototype to use for the retrieved data. + + Returns + ------- + NDArrayLikeOrScalar + The selected data. + """ + if prototype is None: + prototype = default_buffer_prototype() + indexer = OrthogonalIndexer(selection, metadata.shape, metadata.chunk_grid) + return await _get_selection( + store_path, + metadata, + codec_pipeline, + config, + indexer=indexer, + out=out, + fields=fields, + prototype=prototype, + ) + + +async def _get_mask_selection( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + mask: MaskSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, +) -> NDArrayLikeOrScalar: + """ + Get a mask selection from the array. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + mask : MaskSelection + The boolean mask specifying the selection. + out : NDBuffer | None, optional + An output buffer to write the data to. + fields : Fields | None, optional + Fields to select from structured arrays. + prototype : BufferPrototype | None, optional + A buffer prototype to use for the retrieved data. + + Returns + ------- + NDArrayLikeOrScalar + The selected data. + """ + if prototype is None: + prototype = default_buffer_prototype() + indexer = MaskIndexer(mask, metadata.shape, metadata.chunk_grid) + return await _get_selection( + store_path, + metadata, + codec_pipeline, + config, + indexer=indexer, + out=out, + fields=fields, + prototype=prototype, + ) + + +async def _get_coordinate_selection( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + selection: CoordinateSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, +) -> NDArrayLikeOrScalar: + """ + Get a coordinate selection from the array. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + selection : CoordinateSelection + The coordinate selection specification. + out : NDBuffer | None, optional + An output buffer to write the data to. + fields : Fields | None, optional + Fields to select from structured arrays. + prototype : BufferPrototype | None, optional + A buffer prototype to use for the retrieved data. + + Returns + ------- + NDArrayLikeOrScalar + The selected data. + """ + if prototype is None: + prototype = default_buffer_prototype() + indexer = CoordinateIndexer(selection, metadata.shape, metadata.chunk_grid) + out_array = await _get_selection( + store_path, + metadata, + codec_pipeline, + config, + indexer=indexer, + out=out, + fields=fields, + prototype=prototype, + ) + + if hasattr(out_array, "shape"): + # restore shape + out_array = np.array(out_array).reshape(indexer.sel_shape) + return out_array + + +async def _set_selection( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + indexer: Indexer, + value: npt.ArrayLike, + *, + prototype: BufferPrototype, + fields: Fields | None = None, +) -> None: + """ + Set a selection in an array. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + indexer : Indexer + The indexer specifying the selection. + value : npt.ArrayLike + The values to write. + prototype : BufferPrototype + A buffer prototype to use. + fields : Fields | None, optional + Fields to select from structured arrays. + """ + # Get dtype from metadata + if metadata.zarr_format == 2: + zdtype = metadata.dtype + else: + zdtype = metadata.data_type + dtype = zdtype.to_native_dtype() + + # check fields are sensible + check_fields(fields, dtype) + fields = check_no_multi_fields(fields) + + # check value shape + if np.isscalar(value): + array_like = prototype.buffer.create_zero_length().as_array_like() + if isinstance(array_like, np._typing._SupportsArrayFunc): + # TODO: need to handle array types that don't support __array_function__ + # like PyTorch and JAX + array_like_ = cast("np._typing._SupportsArrayFunc", array_like) + value = np.asanyarray(value, dtype=dtype, like=array_like_) + else: + if not hasattr(value, "shape"): + value = np.asarray(value, dtype) + # assert ( + # value.shape == indexer.shape + # ), f"shape of value doesn't match indexer shape. Expected {indexer.shape}, got {value.shape}" + if not hasattr(value, "dtype") or value.dtype.name != dtype.name: + if hasattr(value, "astype"): + # Handle things that are already NDArrayLike more efficiently + value = value.astype(dtype=dtype, order="A") + else: + value = np.array(value, dtype=dtype, order="A") + value = cast("NDArrayLike", value) + + # We accept any ndarray like object from the user and convert it + # to an NDBuffer (or subclass). From this point onwards, we only pass + # Buffer and NDBuffer between components. + value_buffer = prototype.nd_buffer.from_ndarray_like(value) + + # Determine memory order + if metadata.zarr_format == 2: + order = metadata.order + else: + order = config.order + + # need to use the order from the metadata for v2 + _config = config + if metadata.zarr_format == 2: + _config = replace(_config, order=order) + + # merging with existing data and encoding chunks + await codec_pipeline.write( + [ + ( + store_path / metadata.encode_chunk_key(chunk_coords), + metadata.get_chunk_spec(chunk_coords, _config, prototype), + chunk_selection, + out_selection, + is_complete_chunk, + ) + for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer + ], + value_buffer, + drop_axes=indexer.drop_axes, + ) + + +async def _setitem( + store_path: StorePath, + metadata: ArrayMetadata, + codec_pipeline: CodecPipeline, + config: ArrayConfig, + selection: BasicSelection, + value: npt.ArrayLike, + prototype: BufferPrototype | None = None, +) -> None: + """ + Set values in the array using basic indexing. + + Parameters + ---------- + store_path : StorePath + The store path of the array. + metadata : ArrayMetadata + The array metadata. + codec_pipeline : CodecPipeline + The codec pipeline for encoding/decoding. + config : ArrayConfig + The array configuration. + selection : BasicSelection + The selection defining the region of the array to set. + value : npt.ArrayLike + The values to be written into the selected region of the array. + prototype : BufferPrototype or None, optional + A prototype buffer that defines the structure and properties of the array chunks being modified. + If None, the default buffer prototype is used. + """ + if prototype is None: + prototype = default_buffer_prototype() + indexer = BasicIndexer( + selection, + shape=metadata.shape, + chunk_grid=metadata.chunk_grid, + ) + return await _set_selection( + store_path, metadata, codec_pipeline, config, indexer, value, prototype=prototype + ) + + +async def _resize( + array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], + new_shape: ShapeLike, + delete_outside_chunks: bool = True, +) -> None: + """ + Resize an array to a new shape. + + Parameters + ---------- + array : AsyncArray + The array to resize. + new_shape : ShapeLike + The desired new shape of the array. + delete_outside_chunks : bool, optional + If True (default), chunks that fall outside the new shape will be deleted. + If False, the data in those chunks will be preserved. + """ + new_shape = parse_shapelike(new_shape) + assert len(new_shape) == len(array.metadata.shape) + new_metadata = array.metadata.update_shape(new_shape) + + # ensure deletion is only run if array is shrinking as the delete_outside_chunks path is unbounded in memory + only_growing = all(new >= old for new, old in zip(new_shape, array.metadata.shape, strict=True)) + + if delete_outside_chunks and not only_growing: + # Remove all chunks outside of the new shape + old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(array.metadata.shape)) + new_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(new_shape)) + + async def _delete_key(key: str) -> None: + await (array.store_path / key).delete() + + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _delete_key(array.metadata.encode_chunk_key(chunk_coords)) + for chunk_coords in old_chunk_coords.difference(new_chunk_coords) + ] + ) + + # Write new metadata + await save_metadata(array.store_path, new_metadata) + + # Update metadata (in place) + object.__setattr__(array, "metadata", new_metadata) + + +async def _append( + array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], + data: npt.ArrayLike, + axis: int = 0, +) -> tuple[int, ...]: + """ + Append data to an array along the specified axis. + + Parameters + ---------- + array : AsyncArray + The array to append to. + data : npt.ArrayLike + Data to be appended. + axis : int + Axis along which to append. + + Returns + ------- + new_shape : tuple[int, ...] + The new shape of the array after appending. + + Notes + ----- + The size of all dimensions other than `axis` must match between the + array and `data`. + """ + # ensure data is array-like + if not hasattr(data, "shape"): + data = np.asanyarray(data) + + self_shape_preserved = tuple(s for i, s in enumerate(array.shape) if i != axis) + data_shape_preserved = tuple(s for i, s in enumerate(data.shape) if i != axis) + if self_shape_preserved != data_shape_preserved: + raise ValueError( + f"shape of data to append is not compatible with the array. " + f"The shape of the data is ({data_shape_preserved})" + f"and the shape of the array is ({self_shape_preserved})." + "All dimensions must match except for the dimension being " + "appended." + ) + # remember old shape + old_shape = array.shape + + # determine new shape + new_shape = tuple( + array.shape[i] if i != axis else array.shape[i] + data.shape[i] + for i in range(len(array.shape)) + ) + + # resize + await _resize(array, new_shape) + + # store data + append_selection = tuple( + slice(None) if i != axis else slice(old_shape[i], new_shape[i]) + for i in range(len(array.shape)) + ) + await _setitem( + array.store_path, + array.metadata, + array.codec_pipeline, + array.config, + append_selection, + data, + ) + + return new_shape + + +async def _update_attributes( + array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], + new_attributes: dict[str, JSON], +) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: + """ + Update the array's attributes. + + Parameters + ---------- + array : AsyncArray + The array whose attributes to update. + new_attributes : dict[str, JSON] + A dictionary of new attributes to update or add to the array. + + Returns + ------- + AsyncArray + The array with the updated attributes. + """ + array.metadata.attributes.update(new_attributes) + + # Write new metadata + await save_metadata(array.store_path, array.metadata) + + return array + + +async def _info_complete( + array: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], +) -> Any: + """ + Return all the information for an array, including dynamic information like storage size. + + Parameters + ---------- + array : AsyncArray + The array to get info for. + + Returns + ------- + ArrayInfo + Complete information about the array including: + - The count of chunks initialized + - The sum of the bytes written + """ + return array._info( + await _nshards_initialized(array), + await array.store_path.store.getsize_prefix(array.store_path.path), + ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index e45c256310..cc1c8d0ecf 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -19,6 +19,7 @@ overload, ) +import numpy as np from typing_extensions import ReadOnly from zarr.core.config import config as zarr_config @@ -35,7 +36,7 @@ ZMETADATA_V2_JSON = ".zmetadata" BytesLike = bytes | bytearray | memoryview -ShapeLike = Iterable[int] | int +ShapeLike = Iterable[int | np.integer[Any]] | int | np.integer[Any] # For backwards compatibility ChunkCoords = tuple[int, ...] ZarrFormat = Literal[2, 3] @@ -161,23 +162,28 @@ def parse_named_configuration( def parse_shapelike(data: ShapeLike) -> tuple[int, ...]: - if isinstance(data, int): + """ + Parse a shape-like input into an explicit shape. + """ + if isinstance(data, int | np.integer): if data < 0: raise ValueError(f"Expected a non-negative integer. Got {data} instead") - return (data,) + return (int(data),) try: data_tuple = tuple(data) except TypeError as e: msg = f"Expected an integer or an iterable of integers. Got {data} instead." raise TypeError(msg) from e - if not all(isinstance(v, int) for v in data_tuple): + if not all(isinstance(v, int | np.integer) for v in data_tuple): msg = f"Expected an iterable of integers. Got {data} instead." raise TypeError(msg) if not all(v > -1 for v in data_tuple): msg = f"Expected all values to be non-negative. Got {data} instead." raise ValueError(msg) - return data_tuple + + # cast NumPy scalars to plain python ints + return tuple(int(x) for x in data_tuple) def parse_fill_value(data: Any) -> Any: diff --git a/src/zarr/core/dtype/npy/string.py b/src/zarr/core/dtype/npy/string.py index 41d3a60078..904280a330 100644 --- a/src/zarr/core/dtype/npy/string.py +++ b/src/zarr/core/dtype/npy/string.py @@ -742,6 +742,43 @@ class VariableLengthUTF8(UTF8Base[np.dtypes.StringDType]): # type: ignore[type- dtype_cls = np.dtypes.StringDType + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + """ + Create an instance of this data type from a compatible NumPy data type. + We reject NumPy StringDType instances that have the `na_object` field set, + because this is not representable by the Zarr `string` data type. + + Parameters + ---------- + dtype : TBaseDType + The native data type. + + Returns + ------- + Self + An instance of this data type. + + Raises + ------ + DataTypeValidationError + If the input is not compatible with this data type. + ValueError + If the input is `numpy.dtypes.StringDType` and has `na_object` set. + """ + if cls._check_native_dtype(dtype): + if hasattr(dtype, "na_object"): + msg = ( + f"Zarr data type resolution from {dtype} failed. " + "Attempted to resolve a zarr data type from a `numpy.dtypes.StringDType` " + "with `na_object` set, which is not supported." + ) + raise ValueError(msg) + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + def to_native_dtype(self) -> np.dtypes.StringDType: """ Create a NumPy string dtype from this VariableLengthUTF8 ZDType. diff --git a/src/zarr/core/dtype/registry.py b/src/zarr/core/dtype/registry.py index cb9ab50044..315945cf4e 100644 --- a/src/zarr/core/dtype/registry.py +++ b/src/zarr/core/dtype/registry.py @@ -161,6 +161,10 @@ def match_dtype(self, dtype: TBaseDType) -> ZDType[TBaseDType, TBaseScalar]: raise ValueError(msg) matched: list[ZDType[TBaseDType, TBaseScalar]] = [] for val in self.contents.values(): + # DataTypeValidationError means "this dtype doesn't match me", which is + # expected and suppressed. Other exceptions (e.g. ValueError for a dtype + # that matches the type but has an invalid configuration) are propagated + # to the caller. with contextlib.suppress(DataTypeValidationError): matched.append(val.from_native_dtype(dtype)) if len(matched) == 1: diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 7f704bf2b7..df79728a85 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -7,7 +7,7 @@ from collections.abc import Iterator, Sequence from dataclasses import dataclass from enum import Enum -from functools import reduce +from functools import lru_cache, reduce from types import EllipsisType from typing import ( TYPE_CHECKING, @@ -1452,7 +1452,7 @@ def make_slice_selection(selection: Any) -> list[slice]: def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]: # Inspired by compressed morton code as implemented in Neuroglancer # https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code - bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape) + bits = tuple((c - 1).bit_length() for c in chunk_shape) max_coords_bits = max(bits) input_bit = 0 input_value = z @@ -1467,16 +1467,107 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]: return tuple(out) -def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - i = 0 - order: list[tuple[int, ...]] = [] - while len(order) < product(chunk_shape): +def decode_morton_vectorized( + z: npt.NDArray[np.intp], chunk_shape: tuple[int, ...] +) -> npt.NDArray[np.intp]: + """Vectorized Morton code decoding for multiple z values. + + Parameters + ---------- + z : ndarray + 1D array of Morton codes to decode. + chunk_shape : tuple of int + Shape defining the coordinate space. + + Returns + ------- + ndarray + 2D array of shape (len(z), len(chunk_shape)) containing decoded coordinates. + """ + n_dims = len(chunk_shape) + bits = tuple((c - 1).bit_length() for c in chunk_shape) + + max_coords_bits = max(bits) if bits else 0 + out = np.zeros((len(z), n_dims), dtype=np.intp) + + input_bit = 0 + for coord_bit in range(max_coords_bits): + for dim in range(n_dims): + if coord_bit < bits[dim]: + # Extract bit at position input_bit from all z values + bit_values = (z >> input_bit) & 1 + # Place bit at coord_bit position in dimension dim + out[:, dim] |= bit_values << coord_bit + input_bit += 1 + + return out + + +@lru_cache(maxsize=16) +def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + n_total = product(chunk_shape) + if n_total == 0: + return () + + # Optimization: Remove singleton dimensions to enable magic number usage + # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand. + singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1) + if singleton_dims: + squeezed_shape = tuple(s for s in chunk_shape if s != 1) + if squeezed_shape: + # Compute Morton order on squeezed shape + squeezed_order = _morton_order(squeezed_shape) + # Expand coordinates to include singleton dimensions (always 0) + expanded: list[tuple[int, ...]] = [] + for coord in squeezed_order: + full_coord: list[int] = [] + squeezed_idx = 0 + for i in range(len(chunk_shape)): + if chunk_shape[i] == 1: + full_coord.append(0) + else: + full_coord.append(coord[squeezed_idx]) + squeezed_idx += 1 + expanded.append(tuple(full_coord)) + return tuple(expanded) + else: + # All dimensions are singletons, just return the single point + return ((0,) * len(chunk_shape),) + + n_dims = len(chunk_shape) + + # Find the largest power-of-2 hypercube that fits within chunk_shape. + # Within this hypercube, Morton codes are guaranteed to be in bounds. + min_dim = min(chunk_shape) + if min_dim >= 1: + power = min_dim.bit_length() - 1 # floor(log2(min_dim)) + hypercube_size = 1 << power # 2^power + n_hypercube = hypercube_size**n_dims + else: + n_hypercube = 0 + + # Within the hypercube, no bounds checking needed - use vectorized decoding + order: list[tuple[int, ...]] + if n_hypercube > 0: + z_values = np.arange(n_hypercube, dtype=np.intp) + hypercube_coords = decode_morton_vectorized(z_values, chunk_shape) + order = [tuple(row) for row in hypercube_coords] + else: + order = [] + + # For remaining elements, bounds checking is needed + i = n_hypercube + while len(order) < n_total: m = decode_morton(i, chunk_shape) - if m not in order and all(x < y for x, y in zip(m, chunk_shape, strict=False)): + if all(x < y for x, y in zip(m, chunk_shape, strict=False)): order.append(m) i += 1 - for j in range(product(chunk_shape)): - yield order[j] + + return tuple(order) + + +def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: + return iter(_morton_order(tuple(chunk_shape))) def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index 3456c94320..87adc90c83 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -4,7 +4,8 @@ import logging import time from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Literal +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Self from zarr.abc.store import ByteRequest, Store from zarr.storage._wrapper import WrapperStore @@ -15,6 +16,18 @@ from zarr.core.buffer.core import Buffer, BufferPrototype +@dataclass(slots=True) +class _CacheState: + cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict) + current_size: int = 0 + key_sizes: dict[str, int] = field(default_factory=dict) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + hits: int = 0 + misses: int = 0 + evictions: int = 0 + key_insert_times: dict[str, float] = field(default_factory=dict) + + class CacheStore(WrapperStore[Store]): """ A dual-store caching implementation for Zarr stores. @@ -36,9 +49,6 @@ class CacheStore(WrapperStore[Store]): Maximum size of the cache in bytes. When exceeded, least recently used items are evicted. None means unlimited size. Default is None. Note: Individual values larger than max_size will not be cached. - key_insert_times : dict[str, float] | None, optional - Dictionary to track insertion times (using monotonic time). - Primarily for internal use. Default is None (creates new dict). cache_set_data : bool, optional Whether to cache data when it's written to the store. Default is True. @@ -69,15 +79,8 @@ class CacheStore(WrapperStore[Store]): _cache: Store max_age_seconds: int | Literal["infinity"] max_size: int | None - key_insert_times: dict[str, float] cache_set_data: bool - _cache_order: OrderedDict[str, None] # Track access order for LRU - _current_size: int # Track current cache size - _key_sizes: dict[str, int] # Track size of each cached key - _lock: asyncio.Lock - _hits: int # Cache hit counter - _misses: int # Cache miss counter - _evictions: int # Cache eviction counter + _state: _CacheState def __init__( self, @@ -86,7 +89,6 @@ def __init__( cache_store: Store, max_age_seconds: int | str = "infinity", max_size: int | None = None, - key_insert_times: dict[str, float] | None = None, cache_set_data: bool = True, ) -> None: super().__init__(store) @@ -107,18 +109,25 @@ def __init__( else: self.max_age_seconds = max_age_seconds self.max_size = max_size - if key_insert_times is None: - self.key_insert_times = {} - else: - self.key_insert_times = key_insert_times self.cache_set_data = cache_set_data - self._cache_order = OrderedDict() - self._current_size = 0 - self._key_sizes = {} - self._lock = asyncio.Lock() - self._hits = 0 - self._misses = 0 - self._evictions = 0 + self._state = _CacheState() + + def _with_store(self, store: Store) -> Self: + # Cannot support this operation because it would share a cache, but have a new store + # So cache keys would conflict + raise NotImplementedError("CacheStore does not support this operation.") + + def with_read_only(self, read_only: bool = False) -> Self: + # Create a new cache store that shares the same cache and mutable state + store = type(self)( + store=self._store.with_read_only(read_only), + cache_store=self._cache, + max_age_seconds=self.max_age_seconds, + max_size=self.max_size, + cache_set_data=self.cache_set_data, + ) + store._state = self._state + return store def _is_key_fresh(self, key: str) -> bool: """Check if a cached key is still fresh based on max_age_seconds. @@ -128,7 +137,7 @@ def _is_key_fresh(self, key: str) -> bool: if self.max_age_seconds == "infinity": return True now = time.monotonic() - elapsed = now - self.key_insert_times.get(key, 0) + elapsed = now - self._state.key_insert_times.get(key, 0) return elapsed < self.max_age_seconds async def _accommodate_value(self, value_size: int) -> None: @@ -140,9 +149,9 @@ async def _accommodate_value(self, value_size: int) -> None: return # Remove least recently used items until we have enough space - while self._current_size + value_size > self.max_size and self._cache_order: + while self._state.current_size + value_size > self.max_size and self._state.cache_order: # Get the least recently used key (first in OrderedDict) - lru_key = next(iter(self._cache_order)) + lru_key = next(iter(self._state.cache_order)) await self._evict_key(lru_key) async def _evict_key(self, key: str) -> None: @@ -152,15 +161,15 @@ async def _evict_key(self, key: str) -> None: Updates size tracking atomically with deletion. """ try: - key_size = self._key_sizes.get(key, 0) + key_size = self._state.key_sizes.get(key, 0) # Delete from cache store await self._cache.delete(key) # Update tracking after successful deletion self._remove_from_tracking(key) - self._current_size = max(0, self._current_size - key_size) - self._evictions += 1 + self._state.current_size = max(0, self._state.current_size - key_size) + self._state.evictions += 1 logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size) except Exception: @@ -183,39 +192,39 @@ async def _cache_value(self, key: str, value: Buffer) -> None: ) return - async with self._lock: + async with self._state.lock: # If key already exists, subtract old size first - if key in self._key_sizes: - old_size = self._key_sizes[key] - self._current_size -= old_size + if key in self._state.key_sizes: + old_size = self._state.key_sizes[key] + self._state.current_size -= old_size logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size) # Make room for the new value (this calls _evict_key_locked internally) await self._accommodate_value(value_size) # Update tracking atomically - self._cache_order[key] = None # OrderedDict to track access order - self._current_size += value_size - self._key_sizes[key] = value_size - self.key_insert_times[key] = time.monotonic() + self._state.cache_order[key] = None # OrderedDict to track access order + self._state.current_size += value_size + self._state.key_sizes[key] = value_size + self._state.key_insert_times[key] = time.monotonic() logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size) async def _update_access_order(self, key: str) -> None: """Update the access order for LRU tracking.""" - if key in self._cache_order: - async with self._lock: + if key in self._state.cache_order: + async with self._state.lock: # Move to end (most recently used) - self._cache_order.move_to_end(key) + self._state.cache_order.move_to_end(key) def _remove_from_tracking(self, key: str) -> None: """Remove a key from all tracking structures. - Must be called while holding self._lock. + Must be called while holding self._state.lock. """ - self._cache_order.pop(key, None) - self.key_insert_times.pop(key, None) - self._key_sizes.pop(key, None) + self._state.cache_order.pop(key, None) + self._state.key_insert_times.pop(key, None) + self._state.key_sizes.pop(key, None) async def _get_try_cache( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None @@ -224,7 +233,7 @@ async def _get_try_cache( maybe_cached_result = await self._cache.get(key, prototype, byte_range) if maybe_cached_result is not None: logger.debug("_get_try_cache: key %s found in cache (HIT)", key) - self._hits += 1 + self._state.hits += 1 # Update access order for LRU await self._update_access_order(key) return maybe_cached_result @@ -232,12 +241,12 @@ async def _get_try_cache( logger.debug( "_get_try_cache: key %s not found in cache (MISS), fetching from store", key ) - self._misses += 1 + self._state.misses += 1 maybe_fresh_result = await super().get(key, prototype, byte_range) if maybe_fresh_result is None: # Key doesn't exist in source store await self._cache.delete(key) - async with self._lock: + async with self._state.lock: self._remove_from_tracking(key) else: # Cache the newly fetched value @@ -249,12 +258,12 @@ async def _get_no_cache( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: """Get data directly from source store and update cache.""" - self._misses += 1 + self._state.misses += 1 maybe_fresh_result = await super().get(key, prototype, byte_range) if maybe_fresh_result is None: # Key doesn't exist in source, remove from cache and tracking await self._cache.delete(key) - async with self._lock: + async with self._state.lock: self._remove_from_tracking(key) else: logger.debug("_get_no_cache: key %s found in store, setting in cache", key) @@ -312,7 +321,7 @@ async def set(self, key: str, value: Buffer) -> None: else: logger.debug("set: deleting key %s from cache", key) await self._cache.delete(key) - async with self._lock: + async with self._state.lock: self._remove_from_tracking(key) async def delete(self, key: str) -> None: @@ -328,7 +337,7 @@ async def delete(self, key: str) -> None: await super().delete(key) logger.debug("delete: deleting key %s from cache", key) await self._cache.delete(key) - async with self._lock: + async with self._state.lock: self._remove_from_tracking(key) def cache_info(self) -> dict[str, Any]: @@ -339,20 +348,20 @@ def cache_info(self) -> dict[str, Any]: if self.max_age_seconds == "infinity" else self.max_age_seconds, "max_size": self.max_size, - "current_size": self._current_size, + "current_size": self._state.current_size, "cache_set_data": self.cache_set_data, - "tracked_keys": len(self.key_insert_times), - "cached_keys": len(self._cache_order), + "tracked_keys": len(self._state.key_insert_times), + "cached_keys": len(self._state.cache_order), } def cache_stats(self) -> dict[str, Any]: """Return cache performance statistics.""" - total_requests = self._hits + self._misses - hit_rate = self._hits / total_requests if total_requests > 0 else 0.0 + total_requests = self._state.hits + self._state.misses + hit_rate = self._state.hits / total_requests if total_requests > 0 else 0.0 return { - "hits": self._hits, - "misses": self._misses, - "evictions": self._evictions, + "hits": self._state.hits, + "misses": self._state.misses, + "evictions": self._state.evictions, "total_requests": total_requests, "hit_rate": hit_rate, } @@ -364,11 +373,11 @@ async def clear_cache(self) -> None: await self._cache.clear() # Reset tracking - async with self._lock: - self.key_insert_times.clear() - self._cache_order.clear() - self._key_sizes.clear() - self._current_size = 0 + async with self._state.lock: + self._state.key_insert_times.clear() + self._state.cache_order.clear() + self._state.key_sizes.clear() + self._state.current_size = 0 logger.debug("clear_cache: cleared all cache data") def __repr__(self) -> str: @@ -379,6 +388,6 @@ def __repr__(self) -> str: f"cache_store={self._cache!r}, " f"max_age_seconds={self.max_age_seconds}, " f"max_size={self.max_size}, " - f"current_size={self._current_size}, " - f"cached_keys={len(self._cache_order)})" + f"current_size={self._state.current_size}, " + f"cached_keys={len(self._state.cache_order)})" ) diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index dd20d49ae5..98dca6b23d 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -77,6 +77,9 @@ def _default_handler(self) -> logging.Handler: ) return handler + def _with_store(self, store: T_Store) -> Self: + return type(self)(store=store, log_level=self.log_level, log_handler=self.log_handler) + @contextmanager def log(self, hint: Any = "") -> Generator[None, None, None]: """Context manager to log method calls diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 697f51ddb0..ea2c8d91fe 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -4,6 +4,8 @@ import contextlib import pickle from collections import defaultdict +from itertools import chain +from operator import itemgetter from typing import TYPE_CHECKING, Generic, Self, TypeVar from zarr.abc.store import ( @@ -13,7 +15,7 @@ Store, SuffixByteRequest, ) -from zarr.storage._utils import with_concurrency_limit +from zarr.storage._utils import _relativize_path, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -186,23 +188,23 @@ async def _get_impl( """Implementation of get without semaphore decoration.""" if byte_range is None: resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return prototype.buffer.from_bytes(await resp.bytes_async()) elif isinstance(byte_range, RangeByteRequest): bytes = await obs.get_range_async( self.store, key, start=byte_range.start, end=byte_range.end ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + return prototype.buffer.from_bytes(bytes) elif isinstance(byte_range, OffsetByteRequest): resp = await obs.get_async( self.store, key, options={"range": {"offset": byte_range.offset}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return prototype.buffer.from_bytes(await resp.bytes_async()) elif isinstance(byte_range, SuffixByteRequest): try: resp = await obs.get_async( self.store, key, options={"range": {"suffix": byte_range.suffix}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return prototype.buffer.from_bytes(await resp.bytes_async()) except obs.exceptions.NotSupportedError: head_resp = await obs.head_async(self.store, key) file_size = head_resp["size"] @@ -213,7 +215,7 @@ async def _get_impl( start=file_size - suffix_len, length=suffix_len, ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + return prototype.buffer.from_bytes(buffer) else: raise ValueError(f"Unexpected byte_range, got {byte_range}") @@ -367,7 +369,8 @@ async def _transform_list_dir( # We assume that the underlying object-store implementation correctly handles the # prefix, so we don't double-check that the returned results actually start with the # given prefix. - prefixes = [obj.lstrip(prefix).lstrip("/") for obj in list_result["common_prefixes"]] - objects = [obj["path"].removeprefix(prefix).lstrip("/") for obj in list_result["objects"]] - for item in prefixes + objects: - yield item + prefix = prefix.rstrip("/") + for path in chain( + list_result["common_prefixes"], map(itemgetter("path"), list_result["objects"]) + ): + yield _relativize_path(path=path, prefix=prefix) diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index 64a5b2d83c..e8a2859abc 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar, cast if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable @@ -31,14 +31,23 @@ class WrapperStore(Store, Generic[T_Store]): def __init__(self, store: T_Store) -> None: self._store = store + def _with_store(self, store: T_Store) -> Self: + """ + Constructs a new instance of the wrapper store with the same details but a new store. + """ + return type(self)(store=store) + @classmethod async def open(cls: type[Self], store_cls: type[T_Store], *args: Any, **kwargs: Any) -> Self: store = store_cls(*args, **kwargs) await store._open() return cls(store=store) + def with_read_only(self, read_only: bool = False) -> Self: + return self._with_store(cast(T_Store, self._store.with_read_only(read_only))) + def __enter__(self) -> Self: - return type(self)(self._store.__enter__()) + return self._with_store(self._store.__enter__()) def __exit__( self, diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 5daf8284eb..1b8e85ed98 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -4,7 +4,7 @@ import json import pickle from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from zarr.storage import WrapperStore @@ -492,24 +492,36 @@ async def test_list_empty_path(self, store: S) -> None: assert observed_prefix_sorted == expected_prefix_sorted async def test_list_dir(self, store: S) -> None: - root = "foo" - store_dict = { - root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), - root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), - } + roots_and_keys: list[tuple[str, dict[str, Buffer]]] = [ + ( + "foo", + { + "foo/zarr.json": self.buffer_cls.from_bytes(b"bar"), + "foo/c/1": self.buffer_cls.from_bytes(b"\x01"), + }, + ), + ( + "foo/bar", + { + "foo/bar/foobar_first_child": self.buffer_cls.from_bytes(b"1"), + "foo/bar/foobar_second_child/zarr.json": self.buffer_cls.from_bytes(b"2"), + }, + ), + ] assert await _collect_aiterator(store.list_dir("")) == () - assert await _collect_aiterator(store.list_dir(root)) == () - await store._set_many(store_dict.items()) + for root, store_dict in roots_and_keys: + assert await _collect_aiterator(store.list_dir(root)) == () - keys_observed = await _collect_aiterator(store.list_dir(root)) - keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict} + await store._set_many(store_dict.items()) - assert sorted(keys_observed) == sorted(keys_expected) + keys_observed = await _collect_aiterator(store.list_dir(root)) + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict} + assert sorted(keys_observed) == sorted(keys_expected) - keys_observed = await _collect_aiterator(store.list_dir(root + "/")) - assert sorted(keys_expected) == sorted(keys_observed) + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) + assert sorted(keys_expected) == sorted(keys_observed) async def test_set_if_not_exists(self, store: S) -> None: key = "k" @@ -578,10 +590,13 @@ class LatencyStore(WrapperStore[Store]): get_latency: float set_latency: float - def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None: + def __init__(self, store: Store, *, get_latency: float = 0, set_latency: float = 0) -> None: self.get_latency = float(get_latency) self.set_latency = float(set_latency) - self._store = cls + self._store = store + + def _with_store(self, store: Store) -> Self: + return type(self)(store, get_latency=self.get_latency, set_latency=self.set_latency) async def set(self, key: str, value: Buffer) -> None: """ diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py index 0dd6dcff17..f057808305 100644 --- a/src/zarr/testing/store_concurrency.py +++ b/src/zarr/testing/store_concurrency.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import pytest @@ -36,7 +36,7 @@ class StoreConcurrencyTests(Generic[S, B]): expected_concurrency_limit: int | None @pytest.fixture - async def store(self, store_kwargs: dict) -> S: + async def store(self, store_kwargs: dict[str, Any]) -> S: """Create and open a store instance.""" return await self.store_cls.open(**store_kwargs) @@ -51,19 +51,19 @@ def test_concurrency_limit_default(self, store: S) -> None: f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" ) - def test_concurrency_limit_custom(self, store_kwargs: dict) -> None: + def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None: """Test that custom concurrency limits can be set.""" if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: pytest.skip("Store does not support custom concurrency limits") # Test with custom limit - store = self.store_cls(**store_kwargs, concurrency_limit=42) + store = self.store_cls(**{**store_kwargs, "concurrency_limit": 42}) if hasattr(store, "_semaphore"): assert store._semaphore is not None assert store._semaphore._value == 42 # Test with None (unlimited) - store = self.store_cls(**store_kwargs, concurrency_limit=None) + store = self.store_cls(**{**store_kwargs, "concurrency_limit": None}) if hasattr(store, "_semaphore"): assert store._semaphore is None diff --git a/tests/benchmarks/test_indexing.py b/tests/benchmarks/test_indexing.py index 9ca0d8e1af..dff2269dcb 100644 --- a/tests/benchmarks/test_indexing.py +++ b/tests/benchmarks/test_indexing.py @@ -50,3 +50,218 @@ def test_slice_indexing( data[:] = 1 benchmark(getitem, data, indexer) + + +# Benchmark for Morton order optimization with power-of-2 shards +# Morton order is used internally by sharding codec for chunk iteration +morton_shards = ( + (16,) * 3, # With 2x2x2 chunks: 8x8x8 = 512 chunks per shard + (32,) * 3, # With 2x2x2 chunks: 16x16x16 = 4096 chunks per shard +) + + +@pytest.mark.parametrize("store", ["memory"], indirect=["store"]) +@pytest.mark.parametrize("shards", morton_shards, ids=str) +def test_sharded_morton_indexing( + store: Store, + shards: tuple[int, ...], + benchmark: BenchmarkFixture, +) -> None: + """Benchmark sharded array indexing with power-of-2 chunks per shard. + + This benchmark exercises the Morton order iteration path in the sharding + codec, which benefits from the hypercube and vectorization optimizations. + The Morton order cache is cleared before each iteration to measure the + full computation cost. + """ + from zarr.core.indexing import _morton_order + + # Create array where each shard contains many small chunks + # e.g., shards=(32,32,32) with chunks=(2,2,2) means 16x16x16 = 4096 chunks per shard + shape = tuple(s * 2 for s in shards) # 2 shards per dimension + chunks = (2,) * 3 # Small chunks to maximize chunks per shard + + data = create_array( + store=store, + shape=shape, + dtype="uint8", + chunks=chunks, + shards=shards, + compressors=None, + filters=None, + fill_value=0, + ) + + data[:] = 1 + # Read a sub-shard region to exercise Morton order iteration + indexer = (slice(shards[0]),) * 3 + + def read_with_cache_clear() -> None: + _morton_order.cache_clear() + getitem(data, indexer) + + benchmark(read_with_cache_clear) + + +# Benchmark with larger chunks_per_shard to make Morton order impact more visible +large_morton_shards = ( + (32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard +) + + +@pytest.mark.parametrize("store", ["memory"], indirect=["store"]) +@pytest.mark.parametrize("shards", large_morton_shards, ids=str) +def test_sharded_morton_indexing_large( + store: Store, + shards: tuple[int, ...], + benchmark: BenchmarkFixture, +) -> None: + """Benchmark sharded array indexing with large chunks_per_shard. + + Uses 1x1x1 chunks to maximize chunks_per_shard (32^3 = 32768), making + the Morton order computation a more significant portion of total time. + The Morton order cache is cleared before each iteration. + """ + from zarr.core.indexing import _morton_order + + # 1x1x1 chunks means chunks_per_shard equals shard shape + shape = tuple(s * 2 for s in shards) # 2 shards per dimension + chunks = (1,) * 3 # 1x1x1 chunks: chunks_per_shard = shards + + data = create_array( + store=store, + shape=shape, + dtype="uint8", + chunks=chunks, + shards=shards, + compressors=None, + filters=None, + fill_value=0, + ) + + data[:] = 1 + # Read one full shard + indexer = (slice(shards[0]),) * 3 + + def read_with_cache_clear() -> None: + _morton_order.cache_clear() + getitem(data, indexer) + + benchmark(read_with_cache_clear) + + +@pytest.mark.parametrize("store", ["memory"], indirect=["store"]) +@pytest.mark.parametrize("shards", large_morton_shards, ids=str) +def test_sharded_morton_single_chunk( + store: Store, + shards: tuple[int, ...], + benchmark: BenchmarkFixture, +) -> None: + """Benchmark reading a single chunk from a large shard. + + This isolates the Morton order computation overhead by minimizing I/O. + Reading one chunk from a shard with 32^3 = 32768 chunks still requires + computing the full Morton order, making the optimization impact clear. + The Morton order cache is cleared before each iteration. + """ + from zarr.core.indexing import _morton_order + + # 1x1x1 chunks means chunks_per_shard equals shard shape + shape = tuple(s * 2 for s in shards) # 2 shards per dimension + chunks = (1,) * 3 # 1x1x1 chunks: chunks_per_shard = shards + + data = create_array( + store=store, + shape=shape, + dtype="uint8", + chunks=chunks, + shards=shards, + compressors=None, + filters=None, + fill_value=0, + ) + + data[:] = 1 + # Read only a single chunk (1x1x1) from the shard + indexer = (slice(1),) * 3 + + def read_with_cache_clear() -> None: + _morton_order.cache_clear() + getitem(data, indexer) + + benchmark(read_with_cache_clear) + + +# Benchmark for morton_order_iter directly (no I/O) +morton_iter_shapes = ( + (8, 8, 8), # 512 elements + (16, 16, 16), # 4096 elements + (32, 32, 32), # 32768 elements +) + + +@pytest.mark.parametrize("shape", morton_iter_shapes, ids=str) +def test_morton_order_iter( + shape: tuple[int, ...], + benchmark: BenchmarkFixture, +) -> None: + """Benchmark morton_order_iter directly without I/O. + + This isolates the Morton order computation to measure the + optimization impact without array read/write overhead. + The cache is cleared before each iteration. + """ + from zarr.core.indexing import _morton_order, morton_order_iter + + def compute_morton_order() -> None: + _morton_order.cache_clear() + # Consume the iterator to force computation + list(morton_order_iter(shape)) + + benchmark(compute_morton_order) + + +@pytest.mark.parametrize("store", ["memory"], indirect=["store"]) +@pytest.mark.parametrize("shards", large_morton_shards, ids=str) +def test_sharded_morton_write_single_chunk( + store: Store, + shards: tuple[int, ...], + benchmark: BenchmarkFixture, +) -> None: + """Benchmark writing a single chunk to a large shard. + + This is the clearest end-to-end demonstration of Morton order optimization. + Writing a single chunk to a shard with 32^3 = 32768 chunks requires + computing the full Morton order, but minimizes I/O overhead. + + Expected improvement: ~160ms (matching Morton computation speedup of ~178ms). + The Morton order cache is cleared before each iteration. + """ + import numpy as np + + from zarr.core.indexing import _morton_order + + # 1x1x1 chunks means chunks_per_shard equals shard shape + shape = tuple(s * 2 for s in shards) # 2 shards per dimension + chunks = (1,) * 3 # 1x1x1 chunks: chunks_per_shard = shards + + data = create_array( + store=store, + shape=shape, + dtype="uint8", + chunks=chunks, + shards=shards, + compressors=None, + filters=None, + fill_value=0, + ) + + # Write data for a single chunk + write_data = np.ones((1, 1, 1), dtype="uint8") + indexer = (slice(1), slice(1), slice(1)) + + def write_with_cache_clear() -> None: + _morton_order.cache_clear() + data[indexer] = write_data + + benchmark(write_with_cache_clear) diff --git a/tests/test_array.py b/tests/test_array.py index b7d7bc723d..01a82e1938 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -781,6 +781,73 @@ def test_resize_2d(store: MemoryStore, zarr_format: ZarrFormat) -> None: assert new_shape == result.shape +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_resize_growing_skips_chunk_enumeration( + store: MemoryStore, zarr_format: ZarrFormat +) -> None: + """Growing an array should not enumerate chunk coords for deletion (#3650 mitigation).""" + from zarr.core.chunk_grids import RegularChunkGrid + + z = zarr.create( + shape=(10, 10), + chunks=(5, 5), + dtype="i4", + fill_value=0, + store=store, + zarr_format=zarr_format, + ) + z[:] = np.ones((10, 10), dtype="i4") + + # growth only - ensure no chunk coords are enumerated + with mock.patch.object( + RegularChunkGrid, + "all_chunk_coords", + wraps=z.metadata.chunk_grid.all_chunk_coords, + ) as mock_coords: + z.resize((20, 20)) + mock_coords.assert_not_called() + + assert z.shape == (20, 20) + np.testing.assert_array_equal(np.ones((10, 10), dtype="i4"), z[:10, :10]) + np.testing.assert_array_equal(np.zeros((10, 10), dtype="i4"), z[10:, 10:]) + + # shrink - ensure no regression of behaviour + with mock.patch.object( + RegularChunkGrid, + "all_chunk_coords", + wraps=z.metadata.chunk_grid.all_chunk_coords, + ) as mock_coords: + z.resize((5, 5)) + assert mock_coords.call_count > 0 + + assert z.shape == (5, 5) + np.testing.assert_array_equal(np.ones((5, 5), dtype="i4"), z[:]) + + # mixed: grow dim 0, shrink dim 1 - ensure deletion path runs + z2 = zarr.create( + shape=(10, 10), + chunks=(5, 5), + dtype="i4", + fill_value=0, + store=store, + zarr_format=zarr_format, + overwrite=True, + ) + z2[:] = np.ones((10, 10), dtype="i4") + + with mock.patch.object( + RegularChunkGrid, + "all_chunk_coords", + wraps=z2.metadata.chunk_grid.all_chunk_coords, + ) as mock_coords: + z2.resize((20, 5)) + assert mock_coords.call_count > 0 + + assert z2.shape == (20, 5) + np.testing.assert_array_equal(np.ones((10, 5), dtype="i4"), z2[:10, :]) + np.testing.assert_array_equal(np.zeros((10, 5), dtype="i4"), z2[10:, :]) + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_append_1d(store: MemoryStore, zarr_format: ZarrFormat) -> None: a = np.arange(105) diff --git a/tests/test_codecs/test_codecs.py b/tests/test_codecs/test_codecs.py index eae7168d49..fa2017876e 100644 --- a/tests/test_codecs/test_codecs.py +++ b/tests/test_codecs/test_codecs.py @@ -18,7 +18,7 @@ TransposeCodec, ) from zarr.core.buffer import default_buffer_prototype -from zarr.core.indexing import BasicSelection, morton_order_iter +from zarr.core.indexing import BasicSelection, decode_morton, morton_order_iter from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.dtype import UInt8 from zarr.errors import ZarrUserWarning @@ -171,7 +171,8 @@ def test_open(store: Store) -> None: assert a.metadata == b.metadata -def test_morton() -> None: +def test_morton_exact_order() -> None: + """Test exact morton ordering for power-of-2 shapes.""" assert list(morton_order_iter((2, 2))) == [(0, 0), (1, 0), (0, 1), (1, 1)] assert list(morton_order_iter((2, 2, 2))) == [ (0, 0, 0), @@ -206,21 +207,58 @@ def test_morton() -> None: @pytest.mark.parametrize( "shape", [ - [2, 2, 2], - [5, 2], - [2, 5], - [2, 9, 2], - [3, 2, 12], - [2, 5, 1], - [4, 3, 6, 2, 7], - [3, 2, 1, 6, 4, 5, 2], + (2, 2, 2), + (5, 2), + (2, 5), + (2, 9, 2), + (3, 2, 12), + (2, 5, 1), + (4, 3, 6, 2, 7), + (3, 2, 1, 6, 4, 5, 2), + (1,), + (1, 1), + (5, 1, 3), + (1, 4, 1, 2), ], ) -def test_morton2(shape: tuple[int, ...]) -> None: +def test_morton_is_permutation(shape: tuple[int, ...]) -> None: + """Test that morton_order_iter produces every valid coordinate exactly once.""" + import itertools + + from zarr.core.common import product + + order = list(morton_order_iter(shape)) + expected_len = product(shape) + # completeness: every valid coordinate is present + assert len(order) == expected_len + # no duplicates + assert len(set(order)) == expected_len + # all coordinates are within bounds + assert all(all(c < s for c, s in zip(coord, shape, strict=True)) for coord in order) + # the set of coordinates equals the full cartesian product + assert set(order) == set(itertools.product(*(range(s) for s in shape))) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 2), + (4, 4), + (2, 2, 2), + (4, 4, 4), + (2, 2, 2, 2), + ], +) +def test_morton_ordering(shape: tuple[int, ...]) -> None: + """Test that the iteration order matches consecutive decode_morton outputs. + + For power-of-2 shapes, every decode_morton output is in-bounds, + so the ordering should be exactly decode_morton(0), decode_morton(1), ... + """ + order = list(morton_order_iter(shape)) - for i, x in enumerate(order): - assert x not in order[:i] # no duplicates - assert all(x[j] < shape[j] for j in range(len(shape))) # all indices are within bounds + for i, coord in enumerate(order): + assert coord == decode_morton(i, shape) @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 7eb4deccbf..d0e2d09b7c 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -18,7 +18,6 @@ TransposeCodec, ) from zarr.core.buffer import NDArrayLike, default_buffer_prototype -from zarr.errors import ZarrUserWarning from zarr.storage import StorePath, ZipStore from ..conftest import ArrayRequest @@ -239,12 +238,14 @@ def test_sharding_partial_overwrite( assert np.array_equal(data, read_data) +# Zip storage raises a warning about a duplicate name, which we ignore. +@pytest.mark.filterwarnings("ignore:Duplicate name.*:UserWarning") @pytest.mark.parametrize( "array_fixture", [ - ArrayRequest(shape=(128,) * 3, dtype="uint16", order="F"), + ArrayRequest(shape=(127, 128, 129), dtype="uint16", order="F"), ], - indirect=["array_fixture"], + indirect=True, ) @pytest.mark.parametrize( "outer_index_location", @@ -263,24 +264,23 @@ def test_nested_sharding( ) -> None: data = array_fixture spath = StorePath(store) - msg = "Combining a `sharding_indexed` codec disables partial reads and writes, which may lead to inefficient performance." - with pytest.warns(ZarrUserWarning, match=msg): - a = zarr.create_array( - spath, - shape=data.shape, - chunks=(64, 64, 64), - dtype=data.dtype, - fill_value=0, - serializer=ShardingCodec( - chunk_shape=(32, 32, 32), - codecs=[ - ShardingCodec(chunk_shape=(16, 16, 16), index_location=inner_index_location) - ], - index_location=outer_index_location, - ), - ) + # compressors=None ensures no BytesBytesCodec is added, which keeps + # supports_partial_decode=True and exercises the partial decode path + a = zarr.create_array( + spath, + data=data, + chunks=(64,) * data.ndim, + compressors=None, + serializer=ShardingCodec( + chunk_shape=(32,) * data.ndim, + codecs=[ + ShardingCodec(chunk_shape=(16,) * data.ndim, index_location=inner_index_location) + ], + index_location=outer_index_location, + ), + ) - a[:, :, :] = data + a[:] = data read_data = a[0 : data.shape[0], 0 : data.shape[1], 0 : data.shape[2]] assert isinstance(read_data, NDArrayLike) @@ -326,13 +326,10 @@ def test_nested_sharding_create_array( filters=None, compressors=None, ) - print(a.metadata.to_dict()) - a[:, :, :] = data + a[:] = data - read_data = a[0 : data.shape[0], 0 : data.shape[1], 0 : data.shape[2]] - assert isinstance(read_data, NDArrayLike) - assert data.shape == read_data.shape + read_data = a[:] assert np.array_equal(data, read_data) diff --git a/tests/test_common.py b/tests/test_common.py index 9484d15ca3..09cb6df2f8 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, get_args import numpy as np @@ -15,7 +16,6 @@ from zarr.core.config import parse_indexing_order if TYPE_CHECKING: - from collections.abc import Iterable from typing import Any, Literal @@ -111,9 +111,15 @@ def test_parse_shapelike_invalid_iterable_values(data: Any) -> None: parse_shapelike(data) -@pytest.mark.parametrize("data", [range(10), [0, 1, 2, 3], (3, 4, 5), ()]) -def test_parse_shapelike_valid(data: Iterable[int]) -> None: - assert parse_shapelike(data) == tuple(data) +@pytest.mark.parametrize( + "data", [range(10), [0, 1, 2, np.uint64(3)], (3, 4, 5), (), 1, np.uint8(1)] +) +def test_parse_shapelike_valid(data: Iterable[int] | int) -> None: + if isinstance(data, Iterable): + expected = tuple(data) + else: + expected = (data,) + assert parse_shapelike(data) == expected # todo: more dtypes diff --git a/tests/test_dtype_registry.py b/tests/test_dtype_registry.py index 58b14fe07a..b7ceb502b7 100644 --- a/tests/test_dtype_registry.py +++ b/tests/test_dtype_registry.py @@ -15,9 +15,11 @@ get_data_type_from_json, ) from zarr.core.dtype.common import unpack_dtype_json +from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING from zarr.dtype import ( # type: ignore[attr-defined] Bool, FixedLengthUTF32, + VariableLengthUTF8, ZDType, data_type_registry, parse_data_type, @@ -74,6 +76,16 @@ def test_match_dtype( data_type_registry_fixture.register(wrapper_cls._zarr_v3_name, wrapper_cls) assert isinstance(data_type_registry_fixture.match_dtype(np.dtype(dtype_str)), wrapper_cls) + @pytest.mark.skipif(not _NUMPY_SUPPORTS_VLEN_STRING, reason="requires numpy with T dtype") + @staticmethod + def test_match_dtype_string_na_object_error( + data_type_registry_fixture: DataTypeRegistry, + ) -> None: + data_type_registry_fixture.register(VariableLengthUTF8._zarr_v3_name, VariableLengthUTF8) # type: ignore[arg-type] + dtype: np.dtype[Any] = np.dtypes.StringDType(na_object=None) # type: ignore[call-arg] + with pytest.raises(ValueError, match=r"Zarr data type resolution from StringDType.*failed"): + data_type_registry_fixture.match_dtype(dtype) + @staticmethod def test_unregistered_dtype(data_type_registry_fixture: DataTypeRegistry) -> None: """ diff --git a/tests/test_experimental/test_cache_store.py b/tests/test_experimental/test_cache_store.py index d4a45f78f1..50d3d9506b 100644 --- a/tests/test_experimental/test_cache_store.py +++ b/tests/test_experimental/test_cache_store.py @@ -30,7 +30,61 @@ def cache_store(self) -> MemoryStore: @pytest.fixture def cached_store(self, source_store: Store, cache_store: Store) -> CacheStore: """Create a cached store instance.""" - return CacheStore(source_store, cache_store=cache_store, key_insert_times={}) + return CacheStore(source_store, cache_store=cache_store) + + async def test_with_read_only_round_trip(self) -> None: + """ + Ensure that CacheStore.with_read_only returns another CacheStore with + the requested read_only state, shares cache state, and does not change + the original store's read_only flag. + """ + source = MemoryStore() + cache = MemoryStore() + + # Start from a read-only underlying store + source_ro = source.with_read_only(read_only=True) + cached_ro = CacheStore(store=source_ro, cache_store=cache) + assert cached_ro.read_only + + buf = CPUBuffer.from_bytes(b"0123") + + # Cannot write through the read-only cache store + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await cached_ro.set("foo", buf) + + # Create a writable cache store from the read-only one + writer = cached_ro.with_read_only(read_only=False) + assert isinstance(writer, CacheStore) + assert not writer.read_only + + # Cache configuration and state are shared + assert writer._cache is cached_ro._cache + assert writer._state is cached_ro._state + assert writer._state.key_insert_times is cached_ro._state.key_insert_times + + # Writes via the writable cache store succeed and are cached + await writer.set("foo", buf) + out = await writer.get("foo", default_buffer_prototype()) + assert out is not None + assert out.to_bytes() == buf.to_bytes() + + # The original cache store remains read-only + assert cached_ro.read_only + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await cached_ro.set("bar", buf) + + # Creating a read-only copy from the writable cache store works and is enforced + reader = writer.with_read_only(read_only=True) + assert isinstance(reader, CacheStore) + assert reader.read_only + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await reader.set("baz", buf) async def test_basic_caching(self, cached_store: CacheStore, source_store: Store) -> None: """Test basic cache functionality.""" @@ -71,7 +125,6 @@ async def test_cache_expiration(self) -> None: source_store, cache_store=cache_store, max_age_seconds=1, # 1 second expiration - key_insert_times={}, ) # Store data @@ -96,9 +149,7 @@ async def test_cache_expiration(self) -> None: async def test_cache_set_data_false(self, source_store: Store, cache_store: Store) -> None: """Test behavior when cache_set_data=False.""" - cached_store = CacheStore( - source_store, cache_store=cache_store, cache_set_data=False, key_insert_times={} - ) + cached_store = CacheStore(source_store, cache_store=cache_store, cache_set_data=False) test_data = CPUBuffer.from_bytes(b"no cache data") await cached_store.set("no_cache_key", test_data) @@ -154,9 +205,7 @@ async def test_stale_cache_refresh(self) -> None: """Test that stale cache entries are refreshed from source.""" source_store = MemoryStore() cache_store = MemoryStore() - cached_store = CacheStore( - source_store, cache_store=cache_store, max_age_seconds=1, key_insert_times={} - ) + cached_store = CacheStore(source_store, cache_store=cache_store, max_age_seconds=1) # Store initial data old_data = CPUBuffer.from_bytes(b"old data") @@ -194,14 +243,10 @@ async def test_cache_returns_cached_data_for_performance( self, cached_store: CacheStore, source_store: Store ) -> None: """Test that cache returns cached data for performance, even if not in source.""" - # Skip test if key_insert_times attribute doesn't exist - if not hasattr(cached_store, "key_insert_times"): - pytest.skip("key_insert_times attribute not implemented") - # Put data in cache but not source (simulates orphaned cache entry) test_data = CPUBuffer.from_bytes(b"orphaned data") await cached_store._cache.set("orphan_key", test_data) - cached_store.key_insert_times["orphan_key"] = time.monotonic() + cached_store._state.key_insert_times["orphan_key"] = time.monotonic() # Cache should return data for performance (no source verification) result = await cached_store.get("orphan_key", default_buffer_prototype()) @@ -210,7 +255,7 @@ async def test_cache_returns_cached_data_for_performance( # Cache entry should remain (performance optimization) assert await cached_store._cache.exists("orphan_key") - assert "orphan_key" in cached_store.key_insert_times + assert "orphan_key" in cached_store._state.key_insert_times async def test_cache_coherency_through_expiration(self) -> None: """Test that cache coherency is managed through cache expiration, not source verification.""" @@ -288,7 +333,6 @@ async def test_cache_info_with_max_size(self) -> None: cache_store=cache_store, max_size=1024, max_age_seconds=300, - key_insert_times={}, ) info = cached_store.cache_info() @@ -365,7 +409,7 @@ async def test_max_age_numeric(self) -> None: assert cached_store._is_key_fresh("test_key") # Manually set old timestamp to test expiration - cached_store.key_insert_times["test_key"] = time.monotonic() - 2 # 2 seconds ago + cached_store._state.key_insert_times["test_key"] = time.monotonic() - 2 # 2 seconds ago # Key should now be stale assert not cached_store._is_key_fresh("test_key") @@ -519,7 +563,7 @@ async def test_evict_key_exception_handling(self) -> None: # Manually corrupt the tracking to trigger exception # Remove from one structure but not others to create inconsistency - del cached_store._cache_order["test_key"] + del cached_store._state.cache_order["test_key"] # Try to evict - should handle the KeyError gracefully await cached_store._evict_key("test_key") @@ -540,16 +584,16 @@ async def test_get_no_cache_delete_tracking(self) -> None: await cached_store._cache_value("phantom_key", test_data) # Verify it's in tracking - assert "phantom_key" in cached_store._cache_order - assert "phantom_key" in cached_store.key_insert_times + assert "phantom_key" in cached_store._state.cache_order + assert "phantom_key" in cached_store._state.key_insert_times # Now try to get it - since it's not in source, should clean up tracking result = await cached_store._get_no_cache("phantom_key", default_buffer_prototype()) assert result is None # Should have cleaned up tracking - assert "phantom_key" not in cached_store._cache_order - assert "phantom_key" not in cached_store.key_insert_times + assert "phantom_key" not in cached_store._state.cache_order + assert "phantom_key" not in cached_store._state.key_insert_times async def test_accommodate_value_no_max_size(self) -> None: """Test _accommodate_value early return when max_size is None.""" @@ -609,7 +653,9 @@ async def set_large(key: str) -> None: # Size should be consistent with tracked keys assert info["current_size"] <= 200 # Might pass # But verify actual cache store size matches tracking - total_size = sum(cached_store._key_sizes.get(k, 0) for k in cached_store._cache_order) + total_size = sum( + cached_store._state.key_sizes.get(k, 0) for k in cached_store._state.cache_order + ) assert total_size == info["current_size"] # WOULD FAIL async def test_concurrent_get_and_evict(self) -> None: @@ -638,7 +684,7 @@ async def write_key() -> None: # Verify consistency info = cached_store.cache_info() assert info["current_size"] <= 100 - assert len(cached_store._cache_order) == len(cached_store._key_sizes) + assert len(cached_store._state.cache_order) == len(cached_store._state.key_sizes) async def test_eviction_actually_deletes_from_cache_store(self) -> None: """Test that eviction removes keys from cache_store, not just tracking.""" @@ -659,8 +705,8 @@ async def test_eviction_actually_deletes_from_cache_store(self) -> None: await cached_store.set("key2", data2) # Check tracking - key1 should be removed - assert "key1" not in cached_store._cache_order - assert "key1" not in cached_store._key_sizes + assert "key1" not in cached_store._state.cache_order + assert "key1" not in cached_store._state.key_sizes # CRITICAL: key1 should also be removed from cache_store assert not await cache_store.exists("key1"), ( @@ -733,13 +779,13 @@ async def test_all_tracked_keys_exist_in_cache_store(self) -> None: await cached_store.set(f"key_{i}", data) # Every key in tracking should exist in cache_store - for key in cached_store._cache_order: + for key in cached_store._state.cache_order: assert await cache_store.exists(key), ( f"Key '{key}' is tracked but doesn't exist in cache_store" ) # Every key in _key_sizes should exist in cache_store - for key in cached_store._key_sizes: + for key in cached_store._state.key_sizes: assert await cache_store.exists(key), ( f"Key '{key}' has size tracked but doesn't exist in cache_store" ) @@ -778,7 +824,7 @@ async def failing_delete(key: str) -> None: # Attempt to evict should raise the exception with pytest.raises(RuntimeError, match="Simulated cache deletion failure"): - async with cached_store._lock: + async with cached_store._state.lock: await cached_store._evict_key("test_key") async def test_cache_stats_method(self) -> None: diff --git a/tests/test_store/test_latency.py b/tests/test_store/test_latency.py new file mode 100644 index 0000000000..38ffb17dd6 --- /dev/null +++ b/tests/test_store/test_latency.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from zarr.core.buffer import default_buffer_prototype +from zarr.storage import MemoryStore +from zarr.testing.store import LatencyStore + + +async def test_latency_store_with_read_only_round_trip() -> None: + """ + Ensure that LatencyStore.with_read_only returns another LatencyStore with + the requested read_only state, preserves latency configuration, and does + not change the original wrapper. + """ + base = await MemoryStore.open() + # Start from a read-only underlying store + ro_base = base.with_read_only(read_only=True) + latency_ro = LatencyStore(ro_base, get_latency=0.01, set_latency=0.02) + + assert latency_ro.read_only + assert latency_ro.get_latency == pytest.approx(0.01) + assert latency_ro.set_latency == pytest.approx(0.02) + + buf = default_buffer_prototype().buffer.from_bytes(b"abcd") + + # Cannot write through the read-only wrapper + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await latency_ro.set("key", buf) + + # Create a writable wrapper from the read-only one + writer = latency_ro.with_read_only(read_only=False) + assert isinstance(writer, LatencyStore) + assert not writer.read_only + # Latency configuration is preserved + assert writer.get_latency == latency_ro.get_latency + assert writer.set_latency == latency_ro.set_latency + + # Writes via the writable wrapper succeed + await writer.set("key", buf) + out = await writer.get("key", prototype=default_buffer_prototype()) + assert out is not None + assert out.to_bytes() == buf.to_bytes() + + # Creating a read-only copy from the writable wrapper works and is enforced + reader = writer.with_read_only(read_only=True) + assert isinstance(reader, LatencyStore) + assert reader.read_only + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await reader.set("other", buf) + + # The original read-only wrapper remains read-only + assert latency_ro.read_only diff --git a/tests/test_store/test_logging.py b/tests/test_store/test_logging.py index fa566e45aa..96cd184938 100644 --- a/tests/test_store/test_logging.py +++ b/tests/test_store/test_logging.py @@ -86,6 +86,46 @@ def test_is_open_setter_raises(self, store: LoggingStore[LocalStore]) -> None: ): store._is_open = True + async def test_with_read_only_round_trip(self, local_store: LocalStore) -> None: + """ + Ensure that LoggingStore.with_read_only returns another LoggingStore with + the requested read_only state, preserves logging configuration, and does + not change the original store. + """ + # Start from a read-only underlying store + ro_store = local_store.with_read_only(read_only=True) + wrapped_ro = LoggingStore(store=ro_store, log_level="INFO") + assert wrapped_ro.read_only + + buf = default_buffer_prototype().buffer.from_bytes(b"0123") + + # Cannot write through the read-only wrapper + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await wrapped_ro.set("foo", buf) + + # Create a writable wrapper + writer = wrapped_ro.with_read_only(read_only=False) + assert isinstance(writer, LoggingStore) + assert not writer.read_only + # logging configuration is preserved + assert writer.log_level == wrapped_ro.log_level + assert writer.log_handler == wrapped_ro.log_handler + + # Writes via the writable wrapper succeed + await writer.set("foo", buf) + out = await writer.get("foo", prototype=default_buffer_prototype()) + assert out is not None + assert out.to_bytes() == buf.to_bytes() + + # The original wrapper remains read-only + assert wrapped_ro.read_only + with pytest.raises( + ValueError, match="store was opened in read-only mode and does not support writing" + ): + await wrapped_ro.set("bar", buf) + @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) async def test_logging_store(store: Store, caplog: pytest.LogCaptureFixture) -> None: From 1e73be9a910e716d589e25a3258f9daf351d2aae Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 19 Feb 2026 22:26:27 +0100 Subject: [PATCH 09/12] use protocol --- src/zarr/storage/_fsspec.py | 11 +++++-- src/zarr/storage/_local.py | 12 ++++++-- src/zarr/storage/_obstore.py | 30 +++++++++++++------- src/zarr/storage/_utils.py | 41 +++++++++++++-------------- src/zarr/testing/store_concurrency.py | 41 ++++++++++++++++----------- 5 files changed, 82 insertions(+), 53 deletions(-) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 21f96d87d5..7f7cb08c57 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -255,13 +255,19 @@ def from_url( return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions) + def get_semaphore(self) -> asyncio.Semaphore | None: + return self._semaphore + def with_read_only(self, read_only: bool = False) -> FsspecStore: # docstring inherited + sem = self.get_semaphore() + concurrency_limit = sem._value if sem else None return type(self)( fs=self.fs, path=self.path, allowed_exceptions=self.allowed_exceptions, read_only=read_only, + concurrency_limit=concurrency_limit, ) async def clear(self) -> None: @@ -353,6 +359,7 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: if not self._is_open: await self._open() self._check_writable() + semaphore = self.get_semaphore() async def _set_with_limit(key: str, value: Buffer) -> None: if not isinstance(value, Buffer): @@ -360,8 +367,8 @@ async def _set_with_limit(key: str, value: Buffer) -> None: f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) path = _dereference_path(self.path, key) - if self._semaphore: - async with self._semaphore: + if semaphore: + async with semaphore: await self.fs._pipe_file(path, value.to_bytes()) else: await self.fs._pipe_file(path, value.to_bytes()) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 842cab41ca..230d0b3358 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -134,9 +134,13 @@ def __init__( asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None ) + def get_semaphore(self) -> asyncio.Semaphore | None: + return self._semaphore + def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited - concurrency_limit = self._semaphore._value if self._semaphore else None + sem = self.get_semaphore() + concurrency_limit = sem._value if sem else None return type(self)( root=self.root, read_only=read_only, @@ -232,11 +236,13 @@ async def get_partial_values( # Note: We directly call the I/O functions here, wrapped with semaphore # to avoid deadlock from calling the decorated get() method + semaphore = self.get_semaphore() + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key try: - if self._semaphore: - async with self._semaphore: + if semaphore: + async with semaphore: return await asyncio.to_thread(_get, path, prototype, byte_range) else: return await asyncio.to_thread(_get, path, prototype, byte_range) diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index ea2c8d91fe..313755f0a1 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -86,9 +86,13 @@ def __init__( asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None ) + def get_semaphore(self) -> asyncio.Semaphore | None: + return self._semaphore + def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited - concurrency_limit = self._semaphore._value if self._semaphore else None + sem = self.get_semaphore() + concurrency_limit = sem._value if sem else None return type(self)( store=self.store, read_only=read_only, @@ -134,6 +138,7 @@ async def get_partial_values( import obstore as obs key_ranges = list(key_ranges) + semaphore = self.get_semaphore() # Group bounded range requests by path for batched fetching per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list) other_requests: list[tuple[int, str, ByteRequest | None]] = [] @@ -150,8 +155,8 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) """Batch multiple range requests for the same file using get_ranges_async.""" starts = [r.start for _, r in requests] ends = [r.end for _, r in requests] - if self._semaphore: - async with self._semaphore: + if semaphore: + async with semaphore: responses = await obs.get_ranges_async( self.store, path=path, starts=starts, ends=ends ) @@ -165,8 +170,8 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None: """Fetch a single non-range request with semaphore limiting.""" try: - if self._semaphore: - async with self._semaphore: + if semaphore: + async with semaphore: buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) else: buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) @@ -250,11 +255,12 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: import obstore as obs self._check_writable() + semaphore = self.get_semaphore() async def _set_with_limit(key: str, value: Buffer) -> None: buf = value.as_buffer_like() - if self._semaphore: - async with self._semaphore: + if semaphore: + async with semaphore: await obs.put_async(self.store, key, buf) else: await obs.put_async(self.store, key, buf) @@ -268,8 +274,9 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: self._check_writable() buf = value.as_buffer_like() - if self._semaphore: - async with self._semaphore: + semaphore = self.get_semaphore() + if semaphore: + async with semaphore: with contextlib.suppress(obs.exceptions.AlreadyExistsError): await obs.put_async(self.store, key, buf, mode="create") else: @@ -304,11 +311,12 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() + semaphore = self.get_semaphore() # Delete with semaphore limiting to avoid deadlock async def _delete_with_limit(path: str) -> None: - if self._semaphore: - async with self._semaphore: + if semaphore: + async with semaphore: with contextlib.suppress(FileNotFoundError): await obs.delete_async(self.store, path) else: diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 80bce250e9..92189e7d51 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -3,7 +3,7 @@ import functools import re from pathlib import Path -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, Protocol, TypeVar, runtime_checkable from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest @@ -18,20 +18,24 @@ T_co = TypeVar("T_co", covariant=True) -def with_concurrency_limit( - semaphore_attr: str = "_semaphore", -) -> Callable[[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]]: +@runtime_checkable +class HasConcurrencyLimit(Protocol): + """Protocol for stores that support concurrency limiting via a semaphore.""" + + def get_semaphore(self) -> asyncio.Semaphore | None: + """Return the semaphore used for concurrency limiting, or None for unlimited.""" + ... + + +def with_concurrency_limit() -> Callable[ + [Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]] +]: """ Decorator that applies a semaphore-based concurrency limit to an async method. - This decorator is designed for Store methods that need to limit concurrent operations. - The store instance should have a `_semaphore` attribute (or custom attribute name) - that is either an asyncio.Semaphore or None (for unlimited concurrency). - - Parameters - ---------- - semaphore_attr : str, optional - Name of the semaphore attribute on the class instance. Default is "_semaphore". + This decorator is designed for methods on classes that implement the + ``HasConcurrencyLimit`` protocol. The class must define a ``get_semaphore()`` + method returning either an ``asyncio.Semaphore`` or ``None``. Returns ------- @@ -45,6 +49,9 @@ class MyStore(Store): def __init__(self, concurrency_limit: int = 100): self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None + def get_semaphore(self) -> asyncio.Semaphore | None: + return self._semaphore + @with_concurrency_limit() async def get(self, key: str) -> Buffer | None: # This will only run when semaphore permits @@ -55,13 +62,6 @@ async def get(self, key: str) -> Buffer | None: def decorator( func: Callable[P, Coroutine[Any, Any, T_co]], ) -> Callable[P, Coroutine[Any, Any, T_co]]: - """ - This decorator wraps the invocation of `func` in an `async with semaphore` context manager. - The semaphore object is resolved by getting the `semaphor_attr` attribute from the first - argument to func. When this decorator is used on a method of a class, that first argument - is a reference to the class instance (`self`). - """ - @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: # First arg should be 'self' @@ -69,8 +69,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: raise TypeError(f"{func.__name__} requires at least one argument (self)") self = args[0] - - semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr) + semaphore: asyncio.Semaphore | None = self.get_semaphore() # type: ignore[attr-defined] if semaphore is not None: async with semaphore: diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py index f057808305..a13de4b13a 100644 --- a/src/zarr/testing/store_concurrency.py +++ b/src/zarr/testing/store_concurrency.py @@ -40,16 +40,25 @@ async def store(self, store_kwargs: dict[str, Any]) -> S: """Create and open a store instance.""" return await self.store_cls.open(**store_kwargs) + @staticmethod + def _get_semaphore(store: Store) -> asyncio.Semaphore | None: + """Get the semaphore from a store, or None if the store doesn't support concurrency limiting.""" + get_semaphore = getattr(store, "get_semaphore", None) + if get_semaphore is not None: + return get_semaphore() # type: ignore[no-any-return] + return None + def test_concurrency_limit_default(self, store: S) -> None: """Test that store has the expected default concurrency limit.""" - if hasattr(store, "_semaphore"): - if self.expected_concurrency_limit is None: - assert store._semaphore is None, "Expected no concurrency limit" - else: - assert store._semaphore is not None, "Expected concurrency limit to be set" - assert store._semaphore._value == self.expected_concurrency_limit, ( - f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" - ) + semaphore = self._get_semaphore(store) + if semaphore is None and self.expected_concurrency_limit is not None: + pytest.fail("Expected concurrency limit to be set") + if semaphore is not None and self.expected_concurrency_limit is None: + pytest.fail("Expected no concurrency limit") + if semaphore is not None and self.expected_concurrency_limit is not None: + assert semaphore._value == self.expected_concurrency_limit, ( + f"Expected limit {self.expected_concurrency_limit}, got {semaphore._value}" + ) def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None: """Test that custom concurrency limits can be set.""" @@ -58,14 +67,13 @@ def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None: # Test with custom limit store = self.store_cls(**{**store_kwargs, "concurrency_limit": 42}) - if hasattr(store, "_semaphore"): - assert store._semaphore is not None - assert store._semaphore._value == 42 + semaphore = self._get_semaphore(store) + assert semaphore is not None + assert semaphore._value == 42 # Test with None (unlimited) store = self.store_cls(**{**store_kwargs, "concurrency_limit": None}) - if hasattr(store, "_semaphore"): - assert store._semaphore is None + assert self._get_semaphore(store) is None async def test_concurrency_limit_enforced(self, store: S) -> None: """Test that the concurrency limit is actually enforced during execution. @@ -73,10 +81,11 @@ async def test_concurrency_limit_enforced(self, store: S) -> None: This test verifies that when many operations are submitted concurrently, only up to the concurrency limit are actually executing at once. """ - if not hasattr(store, "_semaphore") or store._semaphore is None: + semaphore = self._get_semaphore(store) + if semaphore is None: pytest.skip("Store has no concurrency limit") - limit = store._semaphore._value + limit = semaphore._value # We'll monitor the semaphore's available count # When it reaches 0, that means `limit` operations are running @@ -86,7 +95,7 @@ async def monitored_operation(key: str, value: B) -> None: nonlocal min_available # Check semaphore state right after we're scheduled await asyncio.sleep(0) # Yield to ensure we're in the queue - available = store._semaphore._value + available = semaphore._value min_available = min(min_available, available) # Now do the actual operation (which will acquire the semaphore) From 17c7226252117d8ca17215abe82981d038277411 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 20 Feb 2026 12:17:22 +0100 Subject: [PATCH 10/12] fix docstring --- src/zarr/storage/_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 92189e7d51..2dea7f6fe4 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -45,6 +45,14 @@ def with_concurrency_limit() -> Callable[ Examples -------- ```python + import asyncio + from zarr.abc.store import Store + from zarr.abc.buffer import Buffer + from zarr.storage._utils import with_concurrency_limit + + async def expensive_io_operation(key: str): + asyncio.sleep(10) + class MyStore(Store): def __init__(self, concurrency_limit: int = 100): self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None From f3d92f621badb01579b89a5d2d0963513a6f2003 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 20 Feb 2026 12:31:10 +0100 Subject: [PATCH 11/12] update docs, warn when config is used to set concurrency limits --- changes/3547.misc.md | 3 ++- docs/user-guide/config.md | 2 +- docs/user-guide/performance.md | 40 ++++++++++++++++++---------------- src/zarr/core/config.py | 40 +++++++++++++++++++++++++++++++++- tests/test_config.py | 17 ++++++++++++--- 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/changes/3547.misc.md b/changes/3547.misc.md index 771bfe8861..eb4d4bd507 100644 --- a/changes/3547.misc.md +++ b/changes/3547.misc.md @@ -1 +1,2 @@ -Moved concurrency limits to a global per-event loop setting instead of per-array call. \ No newline at end of file +Moved concurrency-limiting functionality to store classes. The global configuration object no longer +controls concurrency limits. Concurrency limits, if applicable, must now be specified when constructing a store. \ No newline at end of file diff --git a/docs/user-guide/config.md b/docs/user-guide/config.md index 21fe9b5def..57fed828a1 100644 --- a/docs/user-guide/config.md +++ b/docs/user-guide/config.md @@ -30,7 +30,7 @@ Configuration options include the following: - Default Zarr format `default_zarr_version` - Default array order in memory `array.order` - Whether empty chunks are written to storage `array.write_empty_chunks` -- Async and threading options, e.g. `async.concurrency` and `threading.max_workers` +- Threading options, e.g. `threading.max_workers` - Selections of implementations of codecs, codec pipelines and buffers - Enabling GPU support with `zarr.config.enable_gpu()`. See GPU support for more. diff --git a/docs/user-guide/performance.md b/docs/user-guide/performance.md index 0e0fa3cd55..56c2b4de4b 100644 --- a/docs/user-guide/performance.md +++ b/docs/user-guide/performance.md @@ -191,20 +191,18 @@ scenarios. ### Concurrent I/O operations Zarr uses asynchronous I/O internally to enable concurrent reads and writes across multiple chunks. -The level of concurrency is controlled by the `async.concurrency` configuration setting, which -determines the maximum number of concurrent I/O operations. - -The default value is 10, which is a conservative value. You may get improved performance by tuning -the concurrency limit. You can adjust this value based on your specific needs: +Concurrency is controlled at the **store level** — each store instance can have its own concurrency +limit, set via the `concurrency_limit` parameter when creating the store. ```python import zarr -# Set concurrency for the current session -zarr.config.set({'async.concurrency': 128}) +# Local filesystem store with custom concurrency limit +store = zarr.storage.LocalStore("data/my_array.zarr", concurrency_limit=64) -# Or use environment variable -# export ZARR_ASYNC_CONCURRENCY=128 +# Remote store with higher concurrency for network I/O +from obstore.store import S3Store +store = zarr.storage.ObjectStore(S3Store.from_url("s3://bucket/path"), concurrency_limit=128) ``` Higher concurrency values can improve throughput when: @@ -217,32 +215,36 @@ Lower concurrency values may be beneficial when: - Memory is constrained (each concurrent operation requires buffer space) - Using Zarr within a parallel computing framework (see below) +Set `concurrency_limit=None` to disable the concurrency limit entirely. + ### Using Zarr with Dask -[Dask](https://www.dask.org/) is a popular parallel computing library that works well with Zarr for processing large arrays. When using Zarr with Dask, it's important to consider the interaction between Dask's thread pool and Zarr's concurrency settings. +[Dask](https://www.dask.org/) is a popular parallel computing library that works well with Zarr for processing large arrays. When using Zarr with Dask, it's important to consider the interaction between Dask's thread pool and the store's concurrency limit. -**Important**: When using many Dask threads, you may need to reduce both Zarr's `async.concurrency` and `threading.max_workers` settings to avoid creating too many concurrent operations. The total number of concurrent I/O operations can be roughly estimated as: +**Important**: When using many Dask threads, you may need to reduce the store's `concurrency_limit` and Zarr's `threading.max_workers` setting to avoid creating too many concurrent operations. The total number of concurrent I/O operations can be roughly estimated as: ``` -total_concurrency ≈ dask_threads × zarr_async_concurrency +total_concurrency ≈ dask_threads × store_concurrency_limit ``` -For example, if you're running Dask with 10 threads and Zarr's default concurrency of 64, you could potentially have up to 640 concurrent operations, which may overwhelm your storage system or cause memory issues. +For example, if you're running Dask with 10 threads and a store concurrency limit of 64, you could potentially have up to 640 concurrent operations, which may overwhelm your storage system or cause memory issues. -**Recommendation**: When using Dask with many threads, configure Zarr's concurrency settings: +**Recommendation**: When using Dask with many threads, configure concurrency settings: ```python import zarr import dask.array as da -# If using Dask with many threads (e.g., 8-16), reduce Zarr's concurrency settings +# Create store with reduced concurrency limit for Dask workloads +store = zarr.storage.LocalStore("data/large_array.zarr", concurrency_limit=4) + +# Also limit Zarr's internal thread pool zarr.config.set({ - 'async.concurrency': 4, # Limit concurrent async operations 'threading.max_workers': 4, # Limit Zarr's internal thread pool }) # Open Zarr array -z = zarr.open_array('data/large_array.zarr', mode='r') +z = zarr.open_array(store=store, mode='r') # Create Dask array from Zarr array arr = da.from_array(z, chunks=z.chunks) @@ -253,8 +255,8 @@ result = arr.mean(axis=0).compute() **Configuration guidelines for Dask workloads**: -- `async.concurrency`: Controls the maximum number of concurrent async I/O operations. Start with a lower value (e.g., 4-8) when using many Dask threads. -- `threading.max_workers`: Controls Zarr's internal thread pool size for blocking operations (defaults to CPU count). Reduce this to avoid thread contention with Dask's scheduler. +- `concurrency_limit` (per-store): Controls the maximum number of concurrent async I/O operations for a given store. Start with a lower value (e.g., 4-8) when using many Dask threads. +- `threading.max_workers` (global config): Controls Zarr's internal thread pool size for blocking operations (defaults to CPU count). Reduce this to avoid thread contention with Dask's scheduler. You may need to experiment with different values to find the optimal balance for your workload. Monitor your system's resource usage and adjust these settings based on whether your storage system or CPU is the bottleneck. diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index f8f8ea4f5f..4afd70b1ab 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -29,13 +29,32 @@ from __future__ import annotations +import os +import warnings from typing import TYPE_CHECKING, Any, Literal, cast from donfig import Config as DConfig if TYPE_CHECKING: + from collections.abc import Mapping + from donfig.config_obj import ConfigSet +# Config keys that have been moved from global config to per-store parameters. +# Maps old config key to a warning message. +_warn_on_set: dict[str, str] = { + "async.concurrency": ( + "The 'async.concurrency' configuration key has no effect. " + "Concurrency limits are now set per-store via the 'concurrency_limit' " + "parameter. For example: zarr.storage.LocalStore(..., concurrency_limit=10)." + ), +} + +# Environment variable forms of the keys above (ZARR_ASYNC__CONCURRENCY -> async.concurrency) +_warn_on_set_env: dict[str, str] = { + "ZARR_ASYNC__CONCURRENCY": _warn_on_set["async.concurrency"], +} + class BadConfigError(ValueError): _msg = "bad Config: %r" @@ -55,6 +74,25 @@ class Config(DConfig): # type: ignore[misc] """ + def set(self, arg: Mapping[str, Any] | None = None, **kwargs: Any) -> ConfigSet: + # Check for keys that now belong to per-store config + if arg is not None: + for key in arg: + if key in _warn_on_set: + warnings.warn(_warn_on_set[key], UserWarning, stacklevel=2) + for key in kwargs: + normalized = key.replace("__", ".") + if normalized in _warn_on_set: + warnings.warn(_warn_on_set[normalized], UserWarning, stacklevel=2) + return super().set(arg, **kwargs) + + def refresh(self, **kwargs: Any) -> None: + # Warn if env vars are being used for removed config keys + for env_key, message in _warn_on_set_env.items(): + if env_key in os.environ: + warnings.warn(message, UserWarning, stacklevel=2) + super().refresh(**kwargs) + def reset(self) -> None: self.clear() self.refresh() @@ -98,7 +136,7 @@ def enable_gpu(self) -> ConfigSet: "write_empty_chunks": False, "target_shard_size_bytes": None, }, - "async": {"concurrency": 10, "timeout": None}, + "async": {"timeout": None}, "threading": {"max_workers": None}, "json_indent": 2, "codec_pipeline": { diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..d3c109d41a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -55,7 +55,7 @@ def test_config_defaults_set() -> None: "write_empty_chunks": False, "target_shard_size_bytes": None, }, - "async": {"concurrency": 10, "timeout": None}, + "async": {"timeout": None}, "threading": {"max_workers": None}, "json_indent": 2, "codec_pipeline": { @@ -101,7 +101,6 @@ def test_config_defaults_set() -> None: ] ) assert config.get("array.order") == "C" - assert config.get("async.concurrency") == 10 assert config.get("async.timeout") is None assert config.get("codec_pipeline.batch_size") == 1 assert config.get("json_indent") == 2 @@ -109,7 +108,7 @@ def test_config_defaults_set() -> None: @pytest.mark.parametrize( ("key", "old_val", "new_val"), - [("array.order", "C", "F"), ("async.concurrency", 10, 128), ("json_indent", 2, 0)], + [("array.order", "C", "F"), ("json_indent", 2, 0)], ) def test_config_defaults_can_be_overridden(key: str, old_val: Any, new_val: Any) -> None: assert config.get(key) == old_val @@ -347,3 +346,15 @@ def test_deprecated_config(key: str) -> None: with pytest.raises(ValueError): with zarr.config.set({key: "foo"}): pass + + +def test_async_concurrency_config_warns() -> None: + """Test that setting async.concurrency emits a warning directing users to per-store config.""" + with pytest.warns(UserWarning, match="async.concurrency.*no effect"): + with zarr.config.set({"async.concurrency": 20}): + pass + + # Also test the kwarg form + with pytest.warns(UserWarning, match="async.concurrency.*no effect"): + with zarr.config.set(async__concurrency=20): + pass From 127a024fe4b6906bb7b305fccc0014f8c8cf4bea Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 20 Feb 2026 12:57:21 +0100 Subject: [PATCH 12/12] use base class instead of protocol --- src/zarr/storage/__init__.py | 2 + src/zarr/storage/_fsspec.py | 30 +++------ src/zarr/storage/_local.py | 35 +++------- src/zarr/storage/_obstore.py | 60 ++++------------- src/zarr/storage/_utils.py | 97 ++++++++++++--------------- src/zarr/testing/store_concurrency.py | 40 +++++------ 6 files changed, 100 insertions(+), 164 deletions(-) diff --git a/src/zarr/storage/__init__.py b/src/zarr/storage/__init__.py index 00df50214f..1fde90d8c8 100644 --- a/src/zarr/storage/__init__.py +++ b/src/zarr/storage/__init__.py @@ -10,10 +10,12 @@ from zarr.storage._logging import LoggingStore from zarr.storage._memory import GpuMemoryStore, MemoryStore from zarr.storage._obstore import ObjectStore +from zarr.storage._utils import ConcurrencyLimiter from zarr.storage._wrapper import WrapperStore from zarr.storage._zip import ZipStore __all__ = [ + "ConcurrencyLimiter", "FsspecStore", "GpuMemoryStore", "LocalStore", diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 7f7cb08c57..afd9194425 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -18,7 +18,7 @@ from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path -from zarr.storage._utils import with_concurrency_limit +from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -69,7 +69,7 @@ def _make_async(fs: AbstractFileSystem) -> AsyncFileSystem: return AsyncFileSystemWrapper(fs, asynchronous=True) -class FsspecStore(Store): +class FsspecStore(Store, ConcurrencyLimiter): """ Store for remote data based on FSSpec. @@ -122,7 +122,6 @@ class FsspecStore(Store): fs: AsyncFileSystem allowed_exceptions: tuple[type[Exception], ...] path: str - _semaphore: asyncio.Semaphore | None def __init__( self, @@ -133,13 +132,11 @@ def __init__( allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, concurrency_limit: int | None = 50, ) -> None: - super().__init__(read_only=read_only) + Store.__init__(self, read_only=read_only) + ConcurrencyLimiter.__init__(self, concurrency_limit) self.fs = fs self.path = path self.allowed_exceptions = allowed_exceptions - self._semaphore = ( - asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None - ) if not self.fs.async_impl: raise TypeError("Filesystem needs to support async operations.") @@ -255,19 +252,14 @@ def from_url( return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions) - def get_semaphore(self) -> asyncio.Semaphore | None: - return self._semaphore - def with_read_only(self, read_only: bool = False) -> FsspecStore: # docstring inherited - sem = self.get_semaphore() - concurrency_limit = sem._value if sem else None return type(self)( fs=self.fs, path=self.path, allowed_exceptions=self.allowed_exceptions, read_only=read_only, - concurrency_limit=concurrency_limit, + concurrency_limit=self.concurrency_limit, ) async def clear(self) -> None: @@ -290,7 +282,7 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) - @with_concurrency_limit() + @with_concurrency_limit async def get( self, key: str, @@ -333,7 +325,7 @@ async def get( else: return value - @with_concurrency_limit() + @with_concurrency_limit async def set( self, key: str, @@ -359,7 +351,6 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: if not self._is_open: await self._open() self._check_writable() - semaphore = self.get_semaphore() async def _set_with_limit(key: str, value: Buffer) -> None: if not isinstance(value, Buffer): @@ -367,15 +358,12 @@ async def _set_with_limit(key: str, value: Buffer) -> None: f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) path = _dereference_path(self.path, key) - if semaphore: - async with semaphore: - await self.fs._pipe_file(path, value.to_bytes()) - else: + async with self._limit(): await self.fs._pipe_file(path, value.to_bytes()) await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) - @with_concurrency_limit() + @with_concurrency_limit async def delete(self, key: str) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 230d0b3358..4cea142307 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,7 +19,7 @@ ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype -from zarr.storage._utils import with_concurrency_limit +from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator @@ -86,7 +86,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) -class LocalStore(Store): +class LocalStore(Store, ConcurrencyLimiter): """ Store for the local file system. @@ -113,7 +113,6 @@ class LocalStore(Store): supports_listing: bool = True root: Path - _semaphore: asyncio.Semaphore | None def __init__( self, @@ -122,29 +121,22 @@ def __init__( read_only: bool = False, concurrency_limit: int | None = 100, ) -> None: - super().__init__(read_only=read_only) if isinstance(root, str): root = Path(root) if not isinstance(root, Path): raise TypeError( f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) + Store.__init__(self, read_only=read_only) + ConcurrencyLimiter.__init__(self, concurrency_limit) self.root = root - self._semaphore = ( - asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None - ) - - def get_semaphore(self) -> asyncio.Semaphore | None: - return self._semaphore def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited - sem = self.get_semaphore() - concurrency_limit = sem._value if sem else None return type(self)( root=self.root, read_only=read_only, - concurrency_limit=concurrency_limit, + concurrency_limit=self.concurrency_limit, ) @classmethod @@ -207,7 +199,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root - @with_concurrency_limit() + @with_concurrency_limit async def get( self, key: str, @@ -233,18 +225,13 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - # Note: We directly call the I/O functions here, wrapped with semaphore - # to avoid deadlock from calling the decorated get() method - - semaphore = self.get_semaphore() + # We directly call the I/O functions here, wrapped with the semaphore, + # to avoid deadlock from calling the decorated get() method. async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key try: - if semaphore: - async with semaphore: - return await asyncio.to_thread(_get, path, prototype, byte_range) - else: + async with self._limit(): return await asyncio.to_thread(_get, path, prototype, byte_range) except (FileNotFoundError, IsADirectoryError, NotADirectoryError): return None @@ -264,7 +251,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: except FileExistsError: pass - @with_concurrency_limit() + @with_concurrency_limit async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() @@ -277,7 +264,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) - @with_concurrency_limit() + @with_concurrency_limit async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 313755f0a1..2fa6a31295 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -15,7 +15,7 @@ Store, SuffixByteRequest, ) -from zarr.storage._utils import _relativize_path, with_concurrency_limit +from zarr.storage._utils import ConcurrencyLimiter, _relativize_path, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -38,7 +38,7 @@ T_Store = TypeVar("T_Store", bound="_UpstreamObjectStore") -class ObjectStore(Store, Generic[T_Store]): +class ObjectStore(Store, ConcurrencyLimiter, Generic[T_Store]): """ Store that uses obstore for fast read/write from AWS, GCP, Azure. @@ -60,7 +60,6 @@ class ObjectStore(Store, Generic[T_Store]): store: T_Store """The underlying obstore instance.""" - _semaphore: asyncio.Semaphore | None def __eq__(self, value: object) -> bool: if not isinstance(value, ObjectStore): @@ -80,23 +79,16 @@ def __init__( ) -> None: if not store.__class__.__module__.startswith("obstore"): raise TypeError(f"expected ObjectStore class, got {store!r}") - super().__init__(read_only=read_only) + Store.__init__(self, read_only=read_only) + ConcurrencyLimiter.__init__(self, concurrency_limit) self.store = store - self._semaphore = ( - asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None - ) - - def get_semaphore(self) -> asyncio.Semaphore | None: - return self._semaphore def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited - sem = self.get_semaphore() - concurrency_limit = sem._value if sem else None return type(self)( store=self.store, read_only=read_only, - concurrency_limit=concurrency_limit, + concurrency_limit=self.concurrency_limit, ) def __str__(self) -> str: @@ -114,7 +106,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) - @with_concurrency_limit() + @with_concurrency_limit async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: @@ -138,7 +130,6 @@ async def get_partial_values( import obstore as obs key_ranges = list(key_ranges) - semaphore = self.get_semaphore() # Group bounded range requests by path for batched fetching per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list) other_requests: list[tuple[int, str, ByteRequest | None]] = [] @@ -155,12 +146,7 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) """Batch multiple range requests for the same file using get_ranges_async.""" starts = [r.start for _, r in requests] ends = [r.end for _, r in requests] - if semaphore: - async with semaphore: - responses = await obs.get_ranges_async( - self.store, path=path, starts=starts, ends=ends - ) - else: + async with self._limit(): responses = await obs.get_ranges_async( self.store, path=path, starts=starts, ends=ends ) @@ -170,10 +156,7 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None: """Fetch a single non-range request with semaphore limiting.""" try: - if semaphore: - async with semaphore: - buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) - else: + async with self._limit(): buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: pass # buffers[idx] stays None @@ -240,7 +223,7 @@ def supports_writes(self) -> bool: # docstring inherited return True - @with_concurrency_limit() + @with_concurrency_limit async def set(self, key: str, value: Buffer) -> None: # docstring inherited import obstore as obs @@ -255,31 +238,22 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: import obstore as obs self._check_writable() - semaphore = self.get_semaphore() async def _set_with_limit(key: str, value: Buffer) -> None: buf = value.as_buffer_like() - if semaphore: - async with semaphore: - await obs.put_async(self.store, key, buf) - else: + async with self._limit(): await obs.put_async(self.store, key, buf) await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited - # Note: Not decorated to avoid deadlock when called in batch via gather() + # Not decorated to avoid deadlock when called in batch via gather() import obstore as obs self._check_writable() buf = value.as_buffer_like() - semaphore = self.get_semaphore() - if semaphore: - async with semaphore: - with contextlib.suppress(obs.exceptions.AlreadyExistsError): - await obs.put_async(self.store, key, buf, mode="create") - else: + async with self._limit(): with contextlib.suppress(obs.exceptions.AlreadyExistsError): await obs.put_async(self.store, key, buf, mode="create") @@ -288,7 +262,7 @@ def supports_deletes(self) -> bool: # docstring inherited return True - @with_concurrency_limit() + @with_concurrency_limit async def delete(self, key: str) -> None: # docstring inherited import obstore as obs @@ -311,15 +285,9 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() - semaphore = self.get_semaphore() - # Delete with semaphore limiting to avoid deadlock async def _delete_with_limit(path: str) -> None: - if semaphore: - async with semaphore: - with contextlib.suppress(FileNotFoundError): - await obs.delete_async(self.store, path) - else: + async with self._limit(): with contextlib.suppress(FileNotFoundError): await obs.delete_async(self.store, path) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 2dea7f6fe4..f257082c27 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -1,14 +1,15 @@ from __future__ import annotations +import asyncio +import contextlib import functools import re from pathlib import Path -from typing import TYPE_CHECKING, Any, ParamSpec, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: - import asyncio from collections.abc import Callable, Coroutine, Iterable, Mapping from zarr.abc.store import ByteRequest @@ -18,76 +19,64 @@ T_co = TypeVar("T_co", covariant=True) -@runtime_checkable -class HasConcurrencyLimit(Protocol): - """Protocol for stores that support concurrency limiting via a semaphore.""" +class ConcurrencyLimiter: + """Mixin that adds a semaphore-based concurrency limit to a store. - def get_semaphore(self) -> asyncio.Semaphore | None: - """Return the semaphore used for concurrency limiting, or None for unlimited.""" - ... + Stores that inherit from this class gain a ``concurrency_limit`` attribute, + a ``_semaphore`` for internal use, and a ``_limit()`` context-manager helper. + Use the ``@with_concurrency_limit`` decorator on individual async I/O methods. + """ + concurrency_limit: int | None + """The concurrency limit, or ``None`` for unlimited.""" -def with_concurrency_limit() -> Callable[ - [Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]] -]: - """ - Decorator that applies a semaphore-based concurrency limit to an async method. + _semaphore: asyncio.Semaphore | None - This decorator is designed for methods on classes that implement the - ``HasConcurrencyLimit`` protocol. The class must define a ``get_semaphore()`` - method returning either an ``asyncio.Semaphore`` or ``None``. + def __init__(self, concurrency_limit: int | None = None) -> None: + self.concurrency_limit = concurrency_limit + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) - Returns - ------- - Callable - The decorated async function with concurrency limiting applied. + def _limit(self) -> asyncio.Semaphore | contextlib.nullcontext[None]: + """Return the semaphore if set, otherwise a no-op context manager.""" + sem = self._semaphore + return sem if sem is not None else contextlib.nullcontext() + + +def with_concurrency_limit( + func: Callable[P, Coroutine[Any, Any, T_co]], +) -> Callable[P, Coroutine[Any, Any, T_co]]: + """Decorator that applies a semaphore-based concurrency limit to an async method. + + This decorator is designed for methods on classes that inherit from + :class:`ConcurrencyLimiter`. Examples -------- ```python - import asyncio from zarr.abc.store import Store - from zarr.abc.buffer import Buffer - from zarr.storage._utils import with_concurrency_limit - - async def expensive_io_operation(key: str): - asyncio.sleep(10) + from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit - class MyStore(Store): + class MyStore(Store, ConcurrencyLimiter): def __init__(self, concurrency_limit: int = 100): - self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None - - def get_semaphore(self) -> asyncio.Semaphore | None: - return self._semaphore + Store.__init__(self, read_only=False) + ConcurrencyLimiter.__init__(self, concurrency_limit) - @with_concurrency_limit() - async def get(self, key: str) -> Buffer | None: + @with_concurrency_limit + async def get(self, key, prototype, byte_range=None): # This will only run when semaphore permits - return await expensive_io_operation(key) + ... ``` """ - def decorator( - func: Callable[P, Coroutine[Any, Any, T_co]], - ) -> Callable[P, Coroutine[Any, Any, T_co]]: - @functools.wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: - # First arg should be 'self' - if not args: - raise TypeError(f"{func.__name__} requires at least one argument (self)") - - self = args[0] - semaphore: asyncio.Semaphore | None = self.get_semaphore() # type: ignore[attr-defined] - - if semaphore is not None: - async with semaphore: - return await func(*args, **kwargs) - else: - return await func(*args, **kwargs) - - return wrapper + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + self = args[0] + async with self._limit(): # type: ignore[attr-defined] + return await func(*args, **kwargs) - return decorator + return wrapper def normalize_path(path: str | bytes | Path | None) -> str: diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py index a13de4b13a..98530d7225 100644 --- a/src/zarr/testing/store_concurrency.py +++ b/src/zarr/testing/store_concurrency.py @@ -8,6 +8,7 @@ import pytest from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.storage._utils import ConcurrencyLimiter if TYPE_CHECKING: from zarr.abc.store import Store @@ -43,37 +44,36 @@ async def store(self, store_kwargs: dict[str, Any]) -> S: @staticmethod def _get_semaphore(store: Store) -> asyncio.Semaphore | None: """Get the semaphore from a store, or None if the store doesn't support concurrency limiting.""" - get_semaphore = getattr(store, "get_semaphore", None) - if get_semaphore is not None: - return get_semaphore() # type: ignore[no-any-return] + if isinstance(store, ConcurrencyLimiter): + return store._semaphore return None def test_concurrency_limit_default(self, store: S) -> None: """Test that store has the expected default concurrency limit.""" - semaphore = self._get_semaphore(store) - if semaphore is None and self.expected_concurrency_limit is not None: - pytest.fail("Expected concurrency limit to be set") - if semaphore is not None and self.expected_concurrency_limit is None: - pytest.fail("Expected no concurrency limit") - if semaphore is not None and self.expected_concurrency_limit is not None: - assert semaphore._value == self.expected_concurrency_limit, ( - f"Expected limit {self.expected_concurrency_limit}, got {semaphore._value}" - ) + # Concrete subclasses inherit from both Store and ConcurrencyLimiter, + # but S is bound to Store so mypy considers this branch unreachable. + if not isinstance(store, ConcurrencyLimiter): + assert self.expected_concurrency_limit is None + return + assert store.concurrency_limit == self.expected_concurrency_limit # type: ignore[unreachable] def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None: """Test that custom concurrency limits can be set.""" - if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: + if not issubclass(self.store_cls, ConcurrencyLimiter): pytest.skip("Store does not support custom concurrency limits") + # mypy considers this unreachable because S is bound to Store, not ConcurrencyLimiter. # Test with custom limit - store = self.store_cls(**{**store_kwargs, "concurrency_limit": 42}) - semaphore = self._get_semaphore(store) - assert semaphore is not None - assert semaphore._value == 42 + store = self.store_cls(**{**store_kwargs, "concurrency_limit": 42}) # type: ignore[unreachable] + assert isinstance(store, ConcurrencyLimiter) + assert store.concurrency_limit == 42 + assert store._semaphore is not None # Test with None (unlimited) store = self.store_cls(**{**store_kwargs, "concurrency_limit": None}) - assert self._get_semaphore(store) is None + assert isinstance(store, ConcurrencyLimiter) + assert store.concurrency_limit is None + assert store._semaphore is None async def test_concurrency_limit_enforced(self, store: S) -> None: """Test that the concurrency limit is actually enforced during execution. @@ -85,7 +85,9 @@ async def test_concurrency_limit_enforced(self, store: S) -> None: if semaphore is None: pytest.skip("Store has no concurrency limit") - limit = semaphore._value + assert isinstance(store, ConcurrencyLimiter) + assert store.concurrency_limit is not None # type: ignore[unreachable] + limit = store.concurrency_limit # We'll monitor the semaphore's available count # When it reaches 0, that means `limit` operations are running