diff --git a/changes/3547.misc.md b/changes/3547.misc.md new file mode 100644 index 0000000000..eb4d4bd507 --- /dev/null +++ b/changes/3547.misc.md @@ -0,0 +1,2 @@ +Moved concurrency-limiting functionality to store classes. The global configuration object no longer +controls concurrency limits. Concurrency limits, if applicable, must now be specified when constructing a store. \ No newline at end of file diff --git a/docs/user-guide/config.md b/docs/user-guide/config.md index 21fe9b5def..57fed828a1 100644 --- a/docs/user-guide/config.md +++ b/docs/user-guide/config.md @@ -30,7 +30,7 @@ Configuration options include the following: - Default Zarr format `default_zarr_version` - Default array order in memory `array.order` - Whether empty chunks are written to storage `array.write_empty_chunks` -- Async and threading options, e.g. `async.concurrency` and `threading.max_workers` +- Threading options, e.g. `threading.max_workers` - Selections of implementations of codecs, codec pipelines and buffers - Enabling GPU support with `zarr.config.enable_gpu()`. See GPU support for more. diff --git a/docs/user-guide/performance.md b/docs/user-guide/performance.md index 0e0fa3cd55..56c2b4de4b 100644 --- a/docs/user-guide/performance.md +++ b/docs/user-guide/performance.md @@ -191,20 +191,18 @@ scenarios. ### Concurrent I/O operations Zarr uses asynchronous I/O internally to enable concurrent reads and writes across multiple chunks. -The level of concurrency is controlled by the `async.concurrency` configuration setting, which -determines the maximum number of concurrent I/O operations. - -The default value is 10, which is a conservative value. You may get improved performance by tuning -the concurrency limit. You can adjust this value based on your specific needs: +Concurrency is controlled at the **store level** — each store instance can have its own concurrency +limit, set via the `concurrency_limit` parameter when creating the store. ```python import zarr -# Set concurrency for the current session -zarr.config.set({'async.concurrency': 128}) +# Local filesystem store with custom concurrency limit +store = zarr.storage.LocalStore("data/my_array.zarr", concurrency_limit=64) -# Or use environment variable -# export ZARR_ASYNC_CONCURRENCY=128 +# Remote store with higher concurrency for network I/O +from obstore.store import S3Store +store = zarr.storage.ObjectStore(S3Store.from_url("s3://bucket/path"), concurrency_limit=128) ``` Higher concurrency values can improve throughput when: @@ -217,32 +215,36 @@ Lower concurrency values may be beneficial when: - Memory is constrained (each concurrent operation requires buffer space) - Using Zarr within a parallel computing framework (see below) +Set `concurrency_limit=None` to disable the concurrency limit entirely. + ### Using Zarr with Dask -[Dask](https://www.dask.org/) is a popular parallel computing library that works well with Zarr for processing large arrays. When using Zarr with Dask, it's important to consider the interaction between Dask's thread pool and Zarr's concurrency settings. +[Dask](https://www.dask.org/) is a popular parallel computing library that works well with Zarr for processing large arrays. When using Zarr with Dask, it's important to consider the interaction between Dask's thread pool and the store's concurrency limit. -**Important**: When using many Dask threads, you may need to reduce both Zarr's `async.concurrency` and `threading.max_workers` settings to avoid creating too many concurrent operations. The total number of concurrent I/O operations can be roughly estimated as: +**Important**: When using many Dask threads, you may need to reduce the store's `concurrency_limit` and Zarr's `threading.max_workers` setting to avoid creating too many concurrent operations. The total number of concurrent I/O operations can be roughly estimated as: ``` -total_concurrency ≈ dask_threads × zarr_async_concurrency +total_concurrency ≈ dask_threads × store_concurrency_limit ``` -For example, if you're running Dask with 10 threads and Zarr's default concurrency of 64, you could potentially have up to 640 concurrent operations, which may overwhelm your storage system or cause memory issues. +For example, if you're running Dask with 10 threads and a store concurrency limit of 64, you could potentially have up to 640 concurrent operations, which may overwhelm your storage system or cause memory issues. -**Recommendation**: When using Dask with many threads, configure Zarr's concurrency settings: +**Recommendation**: When using Dask with many threads, configure concurrency settings: ```python import zarr import dask.array as da -# If using Dask with many threads (e.g., 8-16), reduce Zarr's concurrency settings +# Create store with reduced concurrency limit for Dask workloads +store = zarr.storage.LocalStore("data/large_array.zarr", concurrency_limit=4) + +# Also limit Zarr's internal thread pool zarr.config.set({ - 'async.concurrency': 4, # Limit concurrent async operations 'threading.max_workers': 4, # Limit Zarr's internal thread pool }) # Open Zarr array -z = zarr.open_array('data/large_array.zarr', mode='r') +z = zarr.open_array(store=store, mode='r') # Create Dask array from Zarr array arr = da.from_array(z, chunks=z.chunks) @@ -253,8 +255,8 @@ result = arr.mean(axis=0).compute() **Configuration guidelines for Dask workloads**: -- `async.concurrency`: Controls the maximum number of concurrent async I/O operations. Start with a lower value (e.g., 4-8) when using many Dask threads. -- `threading.max_workers`: Controls Zarr's internal thread pool size for blocking operations (defaults to CPU count). Reduce this to avoid thread contention with Dask's scheduler. +- `concurrency_limit` (per-store): Controls the maximum number of concurrent async I/O operations for a given store. Start with a lower value (e.g., 4-8) when using many Dask threads. +- `threading.max_workers` (global config): Controls Zarr's internal thread pool size for blocking operations (defaults to CPU count). Reduce this to avoid thread contention with Dask's scheduler. You may need to experiment with different values to find the optimal balance for your workload. Monitor your system's resource usage and adjust these settings based on whether your storage system or CPU is the bottleneck. diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..69d6c3082e 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -8,8 +9,7 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import NamedConfig, concurrent_map -from zarr.core.config import config +from zarr.core.common import NamedConfig if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -225,11 +225,8 @@ async def decode_partial( ------- Iterable[NDBuffer | None] """ - return await concurrent_map( - list(batch_info), - self._decode_partial_single, - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[self._decode_partial_single(*info) for info in batch_info]) class ArrayBytesCodecPartialEncodeMixin: @@ -262,11 +259,8 @@ async def encode_partial( The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data. The chunk spec contains information about the chunk. """ - await concurrent_map( - list(batch_info), - self._encode_partial_single, - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + await asyncio.gather(*[self._encode_partial_single(*info) for info in batch_info]) class CodecPipeline: @@ -464,11 +458,8 @@ async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], batch_info: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> list[CodecOutput | None]: - return await concurrent_map( - list(batch_info), - _noop_for_none(func), - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[_noop_for_none(func)(chunk, spec) for chunk, spec in batch_info]) def _noop_for_none( diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..49f0e90ace 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -670,13 +670,8 @@ async def getsize_prefix(self, prefix: str) -> int: # improve tail latency and might reduce memory pressure (since not all keys # would be in memory at once). - # avoid circular import - from zarr.core.common import concurrent_map - from zarr.core.config import config - - keys = [(x,) async for x in self.list_prefix(prefix)] - limit = config.get("async.concurrency") - sizes = await concurrent_map(keys, self.getsize, limit=limit) + keys = [x async for x in self.list_prefix(prefix)] + sizes = await asyncio.gather(*[self.getsize(key) for key in keys]) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 564d0e915a..3974a4d8b4 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from asyncio import gather @@ -22,7 +23,6 @@ import numpy as np from typing_extensions import deprecated -import zarr from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.numcodec import Numcodec, _is_numcodec from zarr.codecs._v2 import V2Codec @@ -60,7 +60,6 @@ _default_zarr_format, _warn_order_kwarg, ceildiv, - concurrent_map, parse_shapelike, product, ) @@ -4481,28 +4480,26 @@ async def from_array( if write_data: if isinstance(data, Array): - async def _copy_array_region( - chunk_coords: tuple[int, ...] | slice, _data: AnyArray - ) -> None: + async def _copy_array_region(chunk_coords: tuple[slice, ...], _data: AnyArray) -> None: arr = await _data.async_array.getitem(chunk_coords) await result.setitem(chunk_coords, arr) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_array_region, - zarr.core.config.config.get("async.concurrency"), + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_array_region(region, data) for region in result._iter_shard_regions()] ) else: - async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> None: - await result.setitem(chunk_coords, _data[chunk_coords]) + async def _copy_arraylike_region( + chunk_coords: tuple[slice, ...], _data: npt.ArrayLike + ) -> None: + await result.setitem(chunk_coords, _data[chunk_coords]) # type: ignore[call-overload, index] # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_arraylike_region, - zarr.core.config.config.get("async.concurrency"), + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_arraylike_region(region, data) for region in result._iter_shard_regions()] ) return result @@ -6001,13 +5998,12 @@ async def _resize( async def _delete_key(key: str) -> None: await (array.store_path / key).delete() - await concurrent_map( - [ - (array.metadata.encode_chunk_key(chunk_coords),) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _delete_key(array.metadata.encode_chunk_key(chunk_coords)) for chunk_coords in old_chunk_coords.difference(new_chunk_coords) - ], - _delete_key, - zarr_config.get("async.concurrency"), + ] ) # Write new metadata diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..0f8350f7ea 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from itertools import islice, pairwise from typing import TYPE_CHECKING, Any, TypeVar @@ -14,7 +15,6 @@ Codec, CodecPipeline, ) -from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning @@ -267,10 +267,12 @@ async def read_batch( else: out[out_selection] = fill_value_or_default(chunk_spec) else: - chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], - lambda byte_getter, prototype: byte_getter.get(prototype), - config.get("async.concurrency"), + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + byte_getter.get(array_spec.prototype) + for byte_getter, array_spec, *_ in batch_info + ] ) chunk_array_batch = await self.decode_batch( [ @@ -368,16 +370,15 @@ async def _read_key( return await byte_setter.get(prototype=prototype) chunk_bytes_batch: Iterable[Buffer | None] - chunk_bytes_batch = await concurrent_map( - [ - ( + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + _read_key( None if is_complete_chunk else byte_setter, chunk_spec.prototype, ) for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info - ], - _read_key, - config.get("async.concurrency"), + ] ) chunk_array_decoded = await self.decode_batch( [ @@ -435,15 +436,14 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non else: await byte_setter.set(chunk_bytes) - await concurrent_map( - [ - (byte_setter, chunk_bytes) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _write_key(byte_setter, chunk_bytes) for chunk_bytes, (byte_setter, *_) in zip( chunk_bytes_batch, batch_info, strict=False ) - ], - _write_key, - config.get("async.concurrency"), + ] ) async def decode( @@ -470,13 +470,12 @@ async def read( out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, out, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.read_batch(single_batch_info, out, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.read_batch, - config.get("async.concurrency"), + ] ) async def write( @@ -485,13 +484,12 @@ async def write( value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, value, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.write_batch(single_batch_info, value, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.write_batch, - config.get("async.concurrency"), + ] ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 275d062eba..cc1c8d0ecf 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -1,13 +1,11 @@ from __future__ import annotations -import asyncio import functools import math import operator import warnings from collections.abc import Iterable, Mapping, Sequence from enum import Enum -from itertools import starmap from typing import ( TYPE_CHECKING, Any, @@ -28,7 +26,7 @@ from zarr.errors import ZarrRuntimeWarning if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Iterator + from collections.abc import Iterator ZARR_JSON = "zarr.json" @@ -95,28 +93,6 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) -T = TypeVar("T", bound=tuple[Any, ...]) -V = TypeVar("V") - - -async def concurrent_map( - items: Iterable[T], - func: Callable[..., Awaitable[V]], - limit: int | None = None, -) -> list[V]: - if limit is None: - return await asyncio.gather(*list(starmap(func, items))) - - else: - sem = asyncio.Semaphore(limit) - - async def run(item: tuple[Any]) -> V: - async with sem: - return await func(*item) - - return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) - - E = TypeVar("E", bound=Enum) diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index f8f8ea4f5f..4afd70b1ab 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -29,13 +29,32 @@ from __future__ import annotations +import os +import warnings from typing import TYPE_CHECKING, Any, Literal, cast from donfig import Config as DConfig if TYPE_CHECKING: + from collections.abc import Mapping + from donfig.config_obj import ConfigSet +# Config keys that have been moved from global config to per-store parameters. +# Maps old config key to a warning message. +_warn_on_set: dict[str, str] = { + "async.concurrency": ( + "The 'async.concurrency' configuration key has no effect. " + "Concurrency limits are now set per-store via the 'concurrency_limit' " + "parameter. For example: zarr.storage.LocalStore(..., concurrency_limit=10)." + ), +} + +# Environment variable forms of the keys above (ZARR_ASYNC__CONCURRENCY -> async.concurrency) +_warn_on_set_env: dict[str, str] = { + "ZARR_ASYNC__CONCURRENCY": _warn_on_set["async.concurrency"], +} + class BadConfigError(ValueError): _msg = "bad Config: %r" @@ -55,6 +74,25 @@ class Config(DConfig): # type: ignore[misc] """ + def set(self, arg: Mapping[str, Any] | None = None, **kwargs: Any) -> ConfigSet: + # Check for keys that now belong to per-store config + if arg is not None: + for key in arg: + if key in _warn_on_set: + warnings.warn(_warn_on_set[key], UserWarning, stacklevel=2) + for key in kwargs: + normalized = key.replace("__", ".") + if normalized in _warn_on_set: + warnings.warn(_warn_on_set[normalized], UserWarning, stacklevel=2) + return super().set(arg, **kwargs) + + def refresh(self, **kwargs: Any) -> None: + # Warn if env vars are being used for removed config keys + for env_key, message in _warn_on_set_env.items(): + if env_key in os.environ: + warnings.warn(message, UserWarning, stacklevel=2) + super().refresh(**kwargs) + def reset(self) -> None: self.clear() self.refresh() @@ -98,7 +136,7 @@ def enable_gpu(self) -> ConfigSet: "write_empty_chunks": False, "target_shard_size_bytes": None, }, - "async": {"concurrency": 10, "timeout": None}, + "async": {"timeout": None}, "threading": {"max_workers": None}, "json_indent": 2, "codec_pipeline": { diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9b5fee275b..658de7ef81 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1440,13 +1440,10 @@ async def _members( ) raise ValueError(msg) - # enforce a concurrency limit by passing a semaphore to all the recursive functions - semaphore = asyncio.Semaphore(config.get("async.concurrency")) async for member in _iter_members_deep( self, max_depth=max_depth, skip_keys=skip_keys, - semaphore=semaphore, use_consolidated_for_children=use_consolidated_for_children, ): yield member @@ -3323,14 +3320,11 @@ async def create_nodes( The created nodes in the order they are created. """ - # Note: the only way to alter this value is via the config. If that's undesirable for some reason, - # then we should consider adding a keyword argument this this function - semaphore = asyncio.Semaphore(config.get("async.concurrency")) create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): # make the key absolute - create_tasks.extend(_persist_metadata(store, key, value, semaphore=semaphore)) + create_tasks.extend(_persist_metadata(store, key, value)) created_object_keys = [] @@ -3476,28 +3470,16 @@ def _ensure_consistent_zarr_format( ) -async def _getitem_semaphore( - node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None -) -> AnyAsyncArray | AsyncGroup: +async def _getitem(node: AsyncGroup, key: str) -> AnyAsyncArray | AsyncGroup: """ - Wrap Group.getitem with an optional semaphore. - - If the semaphore parameter is an - asyncio.Semaphore instance, then the getitem operation is performed inside an async context - manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked - without a context manager. + Fetch a child node from a group by key. """ - if semaphore is not None: - async with semaphore: - return await node.getitem(key) - else: - return await node.getitem(key) + return await node.getitem(key) async def _iter_members( node: AsyncGroup, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ Iterate over the arrays and groups contained in a group. @@ -3508,8 +3490,6 @@ async def _iter_members( The group to traverse. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. Yields ------ @@ -3520,10 +3500,7 @@ async def _iter_members( keys = [key async for key in node.store.list_dir(node.path)] keys_filtered = tuple(filter(lambda v: v not in skip_keys, keys)) - node_tasks = tuple( - asyncio.create_task(_getitem_semaphore(node, key, semaphore), name=key) - for key in keys_filtered - ) + node_tasks = tuple(asyncio.create_task(_getitem(node, key), name=key) for key in keys_filtered) for fetched_node_coro in asyncio.as_completed(node_tasks): try: @@ -3550,7 +3527,6 @@ async def _iter_members_deep( *, max_depth: int | None, skip_keys: tuple[str, ...], - semaphore: asyncio.Semaphore | None = None, use_consolidated_for_children: bool = True, ) -> AsyncGenerator[tuple[str, AnyAsyncArray | AsyncGroup], None]: """ @@ -3565,8 +3541,6 @@ async def _iter_members_deep( The maximum depth of recursion. skip_keys : tuple[str, ...] A tuple of keys to skip when iterating over the possible members of the group. - semaphore : asyncio.Semaphore | None - An optional semaphore to use for concurrency control. use_consolidated_for_children : bool, default True Whether to use the consolidated metadata of child groups loaded from the store. Note that this only affects groups loaded from the @@ -3585,7 +3559,7 @@ async def _iter_members_deep( new_depth = None else: new_depth = max_depth - 1 - async for name, node in _iter_members(group, skip_keys=skip_keys, semaphore=semaphore): + async for name, node in _iter_members(group, skip_keys=skip_keys): is_group = isinstance(node, AsyncGroup) if ( is_group @@ -3599,9 +3573,7 @@ async def _iter_members_deep( yield name, node if is_group and do_recursion: node = cast("AsyncGroup", node) - to_recurse[name] = _iter_members_deep( - node, max_depth=new_depth, skip_keys=skip_keys, semaphore=semaphore - ) + to_recurse[name] = _iter_members_deep(node, max_depth=new_depth, skip_keys=skip_keys) for prefix, subgroup_iter in to_recurse.items(): async for name, node in subgroup_iter: @@ -3811,9 +3783,7 @@ async def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> AnyAsync raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover -async def _set_return_key( - *, store: Store, key: str, value: Buffer, semaphore: asyncio.Semaphore | None = None -) -> str: +async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: """ Write a value to storage at the given key. The key is returned. Useful when saving values via routines that return results in execution order, @@ -3828,15 +3798,8 @@ async def _set_return_key( The key to save the value to. value : Buffer The value to save. - semaphore : asyncio.Semaphore | None - An optional semaphore to use to limit the number of concurrent writes. """ - - if semaphore is not None: - async with semaphore: - await store.set(key, value) - else: - await store.set(key, value) + await store.set(key, value) return key @@ -3844,7 +3807,6 @@ def _persist_metadata( store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata, - semaphore: asyncio.Semaphore | None = None, ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. @@ -3852,7 +3814,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) + _set_return_key(store=store, key=_join_paths([path, key]), value=value) for key, value in to_save.items() ) diff --git a/src/zarr/storage/__init__.py b/src/zarr/storage/__init__.py index 00df50214f..1fde90d8c8 100644 --- a/src/zarr/storage/__init__.py +++ b/src/zarr/storage/__init__.py @@ -10,10 +10,12 @@ from zarr.storage._logging import LoggingStore from zarr.storage._memory import GpuMemoryStore, MemoryStore from zarr.storage._obstore import ObjectStore +from zarr.storage._utils import ConcurrencyLimiter from zarr.storage._wrapper import WrapperStore from zarr.storage._zip import ZipStore __all__ = [ + "ConcurrencyLimiter", "FsspecStore", "GpuMemoryStore", "LocalStore", diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..afd9194425 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from contextlib import suppress @@ -17,6 +18,7 @@ from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path +from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -67,7 +69,7 @@ def _make_async(fs: AbstractFileSystem) -> AsyncFileSystem: return AsyncFileSystemWrapper(fs, asynchronous=True) -class FsspecStore(Store): +class FsspecStore(Store, ConcurrencyLimiter): """ Store for remote data based on FSSpec. @@ -82,6 +84,9 @@ class FsspecStore(Store): filesystem scheme. allowed_exceptions : tuple[type[Exception], ...] When fetching data, these cases will be deemed to correspond to missing keys. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Attributes ---------- @@ -121,11 +126,14 @@ class FsspecStore(Store): def __init__( self, fs: AsyncFileSystem, + *, read_only: bool = False, path: str = "/", allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + concurrency_limit: int | None = 50, ) -> None: - super().__init__(read_only=read_only) + Store.__init__(self, read_only=read_only) + ConcurrencyLimiter.__init__(self, concurrency_limit) self.fs = fs self.path = path self.allowed_exceptions = allowed_exceptions @@ -251,6 +259,7 @@ def with_read_only(self, read_only: bool = False) -> FsspecStore: path=self.path, allowed_exceptions=self.allowed_exceptions, read_only=read_only, + concurrency_limit=self.concurrency_limit, ) async def clear(self) -> None: @@ -273,6 +282,7 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + @with_concurrency_limit async def get( self, key: str, @@ -315,6 +325,7 @@ async def get( else: return value + @with_concurrency_limit async def set( self, key: str, @@ -335,6 +346,24 @@ async def set( raise NotImplementedError await self.fs._pipe_file(path, value.to_bytes()) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + if not self._is_open: + await self._open() + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + if not isinstance(value, Buffer): + raise TypeError( + f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + path = _dereference_path(self.path, key) + async with self._limit(): + await self.fs._pipe_file(path, value.to_bytes()) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + + @with_concurrency_limit async def delete(self, key: str) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..4cea142307 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,12 +19,13 @@ ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator from zarr.core.buffer import BufferPrototype + from zarr.core.common import AccessModeLiteral def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: @@ -85,7 +86,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) -class LocalStore(Store): +class LocalStore(Store, ConcurrencyLimiter): """ Store for the local file system. @@ -95,6 +96,9 @@ class LocalStore(Store): Directory to use as root of store. read_only : bool Whether the store is read-only + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 100. + Set to None for unlimited concurrency. Attributes ---------- @@ -110,14 +114,21 @@ class LocalStore(Store): root: Path - def __init__(self, root: Path | str, *, read_only: bool = False) -> None: - super().__init__(read_only=read_only) + def __init__( + self, + root: Path | str, + *, + read_only: bool = False, + concurrency_limit: int | None = 100, + ) -> None: if isinstance(root, str): root = Path(root) if not isinstance(root, Path): raise TypeError( f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) + Store.__init__(self, read_only=read_only) + ConcurrencyLimiter.__init__(self, concurrency_limit) self.root = root def with_read_only(self, read_only: bool = False) -> Self: @@ -125,6 +136,7 @@ def with_read_only(self, read_only: bool = False) -> Self: return type(self)( root=self.root, read_only=read_only, + concurrency_limit=self.concurrency_limit, ) @classmethod @@ -187,6 +199,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + @with_concurrency_limit async def get( self, key: str, @@ -212,12 +225,20 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - args = [] - for key, byte_range in key_ranges: - assert isinstance(key, str) + # We directly call the I/O functions here, wrapped with the semaphore, + # to avoid deadlock from calling the decorated get() method. + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key - args.append((_get, path, prototype, byte_range)) - return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit + try: + async with self._limit(): + return await asyncio.to_thread(_get, path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) async def set(self, key: str, value: Buffer) -> None: # docstring inherited @@ -230,6 +251,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: except FileExistsError: pass + @with_concurrency_limit async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() @@ -242,6 +264,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + @with_concurrency_limit async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..bb0f81d2f5 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,12 +1,12 @@ from __future__ import annotations +import asyncio from logging import getLogger from typing import TYPE_CHECKING, Any, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: @@ -102,12 +102,10 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - - # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: - return await self.get(key, prototype=prototype, byte_range=byte_range) - - return await concurrent_map(key_ranges, _get, limit=None) + # In-memory operations are fast and don't need concurrency limiting + return await asyncio.gather( + *[self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + ) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 6e4011da59..2fa6a31295 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -6,7 +6,7 @@ from collections import defaultdict from itertools import chain from operator import itemgetter -from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from zarr.abc.store import ( ByteRequest, @@ -15,15 +15,13 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map -from zarr.core.config import config -from zarr.storage._utils import _relativize_path +from zarr.storage._utils import ConcurrencyLimiter, _relativize_path, with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence from typing import Any - from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange + from obstore import ListResult, ListStream, ObjectMeta from obstore.store import ObjectStore as _UpstreamObjectStore from zarr.core.buffer import Buffer, BufferPrototype @@ -40,7 +38,7 @@ T_Store = TypeVar("T_Store", bound="_UpstreamObjectStore") -class ObjectStore(Store, Generic[T_Store]): +class ObjectStore(Store, ConcurrencyLimiter, Generic[T_Store]): """ Store that uses obstore for fast read/write from AWS, GCP, Azure. @@ -50,6 +48,9 @@ class ObjectStore(Store, Generic[T_Store]): An obstore store instance that is set up with the proper credentials. read_only : bool Whether to open the store in read-only mode. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Warnings -------- @@ -69,10 +70,17 @@ def __eq__(self, value: object) -> bool: return self.store == value.store # type: ignore[no-any-return] - def __init__(self, store: T_Store, *, read_only: bool = False) -> None: + def __init__( + self, + store: T_Store, + *, + read_only: bool = False, + concurrency_limit: int | None = 50, + ) -> None: if not store.__class__.__module__.startswith("obstore"): raise TypeError(f"expected ObjectStore class, got {store!r}") - super().__init__(read_only=read_only) + Store.__init__(self, read_only=read_only) + ConcurrencyLimiter.__init__(self, concurrency_limit) self.store = store def with_read_only(self, read_only: bool = False) -> Self: @@ -80,6 +88,7 @@ def with_read_only(self, read_only: bool = False) -> Self: return type(self)( store=self.store, read_only=read_only, + concurrency_limit=self.concurrency_limit, ) def __str__(self) -> str: @@ -97,6 +106,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + @with_concurrency_limit async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: @@ -104,41 +114,7 @@ async def get( import obstore as obs try: - if byte_range is None: - resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, RangeByteRequest): - bytes = await obs.get_range_async( - self.store, key, start=byte_range.start, end=byte_range.end - ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] - elif isinstance(byte_range, OffsetByteRequest): - resp = await obs.get_async( - self.store, key, options={"range": {"offset": byte_range.offset}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, SuffixByteRequest): - # some object stores (Azure) don't support suffix requests. In this - # case, our workaround is to first get the length of the object and then - # manually request the byte range at the end. - try: - resp = await obs.get_async( - self.store, key, options={"range": {"suffix": byte_range.suffix}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(self.store, key) - file_size = head_resp["size"] - suffix_len = byte_range.suffix - buffer = await obs.get_range_async( - self.store, - key, - start=file_size - suffix_len, - length=suffix_len, - ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] - else: - raise ValueError(f"Unexpected byte_range, got {byte_range}") + return await self._get_impl(key, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: return None @@ -148,7 +124,88 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) + # We override to: + # 1. Avoid deadlock from calling the decorated get() method + # 2. Batch RangeByteRequests per-file using get_ranges_async for performance + import obstore as obs + + key_ranges = list(key_ranges) + # Group bounded range requests by path for batched fetching + per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list) + other_requests: list[tuple[int, str, ByteRequest | None]] = [] + + for idx, (path, byte_range) in enumerate(key_ranges): + if isinstance(byte_range, RangeByteRequest): + per_file_bounded[path].append((idx, byte_range)) + else: + other_requests.append((idx, path, byte_range)) + + buffers: list[Buffer | None] = [None] * len(key_ranges) + + async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]]) -> None: + """Batch multiple range requests for the same file using get_ranges_async.""" + starts = [r.start for _, r in requests] + ends = [r.end for _, r in requests] + async with self._limit(): + responses = await obs.get_ranges_async( + self.store, path=path, starts=starts, ends=ends + ) + for (idx, _), response in zip(requests, responses, strict=True): + buffers[idx] = prototype.buffer.from_bytes(response) # type: ignore[arg-type] + + async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None: + """Fetch a single non-range request with semaphore limiting.""" + try: + async with self._limit(): + buffers[idx] = await self._get_impl(path, prototype, byte_range, obs) + except _ALLOWED_EXCEPTIONS: + pass # buffers[idx] stays None + + futs: list[Coroutine[Any, Any, None]] = [] + for path, requests in per_file_bounded.items(): + futs.append(_fetch_ranges(path, requests)) + for idx, path, byte_range in other_requests: + futs.append(_fetch_one(idx, path, byte_range)) + + await asyncio.gather(*futs) + return buffers + + async def _get_impl( + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any + ) -> Buffer: + """Implementation of get without semaphore decoration.""" + if byte_range is None: + resp = await obs.get_async(self.store, key) + return prototype.buffer.from_bytes(await resp.bytes_async()) + elif isinstance(byte_range, RangeByteRequest): + bytes = await obs.get_range_async( + self.store, key, start=byte_range.start, end=byte_range.end + ) + return prototype.buffer.from_bytes(bytes) + elif isinstance(byte_range, OffsetByteRequest): + resp = await obs.get_async( + self.store, key, options={"range": {"offset": byte_range.offset}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) + elif isinstance(byte_range, SuffixByteRequest): + try: + resp = await obs.get_async( + self.store, key, options={"range": {"suffix": byte_range.suffix}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) + except obs.exceptions.NotSupportedError: + head_resp = await obs.head_async(self.store, key) + file_size = head_resp["size"] + suffix_len = byte_range.suffix + buffer = await obs.get_range_async( + self.store, + key, + start=file_size - suffix_len, + length=suffix_len, + ) + return prototype.buffer.from_bytes(buffer) + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}") async def exists(self, key: str) -> bool: # docstring inherited @@ -166,6 +223,7 @@ def supports_writes(self) -> bool: # docstring inherited return True + @with_concurrency_limit async def set(self, key: str, value: Buffer) -> None: # docstring inherited import obstore as obs @@ -175,20 +233,36 @@ async def set(self, key: str, value: Buffer) -> None: buf = value.as_buffer_like() await obs.put_async(self.store, key, buf) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + import obstore as obs + + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + buf = value.as_buffer_like() + async with self._limit(): + await obs.put_async(self.store, key, buf) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited + # Not decorated to avoid deadlock when called in batch via gather() import obstore as obs self._check_writable() buf = value.as_buffer_like() - with contextlib.suppress(obs.exceptions.AlreadyExistsError): - await obs.put_async(self.store, key, buf, mode="create") + async with self._limit(): + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") @property def supports_deletes(self) -> bool: # docstring inherited return True + @with_concurrency_limit async def delete(self, key: str) -> None: # docstring inherited import obstore as obs @@ -211,8 +285,13 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() - keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete, limit=config.get("async.concurrency")) + + async def _delete_with_limit(path: str) -> None: + async with self._limit(): + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + + await asyncio.gather(*[_delete_with_limit(m["path"]) for m in metas]) @property def supports_listing(self) -> bool: @@ -271,242 +350,3 @@ async def _transform_list_dir( list_result["common_prefixes"], map(itemgetter("path"), list_result["objects"]) ): yield _relativize_path(path=path, prefix=prefix) - - -class _BoundedRequest(TypedDict): - """Range request with a known start and end byte. - - These requests can be multiplexed natively on the Rust side with - `obstore.get_ranges_async`. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - start: int - """Start byte offset.""" - - end: int - """End byte offset.""" - - -class _OtherRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: OffsetRange | None - # Note: suffix requests are handled separately because some object stores (Azure) - # don't support them - """The range request type.""" - - -class _SuffixRequest(TypedDict): - """Offset or suffix range requests. - - These requests cannot be concurrent on the Rust side, and each need their own call - to `obstore.get_async`, passing in the `range` parameter. - """ - - original_request_index: int - """The positional index in the original key_ranges input""" - - path: str - """The path to request from.""" - - range: SuffixRange - """The suffix range.""" - - -class _Response(TypedDict): - """A response buffer associated with the original index that it should be restored to.""" - - original_request_index: int - """The positional index in the original key_ranges input""" - - buffer: Buffer - """The buffer returned from obstore's range request.""" - - -async def _make_bounded_requests( - store: _UpstreamObjectStore, - path: str, - requests: list[_BoundedRequest], - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make all bounded requests for a specific file. - - `obstore.get_ranges_async` allows for making concurrent requests for multiple ranges - within a single file, and will e.g. merge concurrent requests. This only uses one - single Python coroutine. - """ - import obstore as obs - - starts = [r["start"] for r in requests] - ends = [r["end"] for r in requests] - async with semaphore: - responses = await obs.get_ranges_async(store, path=path, starts=starts, ends=ends) - - buffer_responses: list[_Response] = [] - for request, response in zip(requests, responses, strict=True): - buffer_responses.append( - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(response), # type: ignore[arg-type] - } - ) - - return buffer_responses - - -async def _make_other_request( - store: _UpstreamObjectStore, - request: _OtherRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make offset or full-file requests. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - if request["range"] is None: - resp = await obs.get_async(store, request["path"]) - else: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _make_suffix_request( - store: _UpstreamObjectStore, - request: _SuffixRequest, - prototype: BufferPrototype, - semaphore: asyncio.Semaphore, -) -> list[_Response]: - """Make suffix requests. - - This is separated out from `_make_other_request` because some object stores (Azure) - don't support suffix requests. In this case, our workaround is to first get the - length of the object and then manually request the byte range at the end. - - We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all - futures can be gathered together. - """ - import obstore as obs - - async with semaphore: - try: - resp = await obs.get_async(store, request["path"], options={"range": request["range"]}) - buffer = await resp.bytes_async() - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(store, request["path"]) - file_size = head_resp["size"] - suffix_len = request["range"]["suffix"] - buffer = await obs.get_range_async( - store, - request["path"], - start=file_size - suffix_len, - length=suffix_len, - ) - - return [ - { - "original_request_index": request["original_request_index"], - "buffer": prototype.buffer.from_bytes(buffer), # type: ignore[arg-type] - } - ] - - -async def _get_partial_values( - store: _UpstreamObjectStore, - prototype: BufferPrototype, - key_ranges: Iterable[tuple[str, ByteRequest | None]], -) -> list[Buffer | None]: - """Make multiple range requests. - - ObjectStore has a `get_ranges` method that will additionally merge nearby ranges, - but it's _per_ file. So we need to split these key_ranges into **per-file** key - ranges, and then reassemble the results in the original order. - - We separate into different requests: - - - One call to `obstore.get_ranges_async` **per target file** - - One call to `obstore.get_async` for each other request. - """ - key_ranges = list(key_ranges) - per_file_bounded_requests: dict[str, list[_BoundedRequest]] = defaultdict(list) - other_requests: list[_OtherRequest] = [] - suffix_requests: list[_SuffixRequest] = [] - - for idx, (path, byte_range) in enumerate(key_ranges): - if byte_range is None: - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": None, - } - ) - elif isinstance(byte_range, RangeByteRequest): - per_file_bounded_requests[path].append( - {"original_request_index": idx, "start": byte_range.start, "end": byte_range.end} - ) - elif isinstance(byte_range, OffsetByteRequest): - other_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"offset": byte_range.offset}, - } - ) - elif isinstance(byte_range, SuffixByteRequest): - suffix_requests.append( - { - "original_request_index": idx, - "path": path, - "range": {"suffix": byte_range.suffix}, - } - ) - else: - raise ValueError(f"Unsupported range input: {byte_range}") - - semaphore = asyncio.Semaphore(config.get("async.concurrency")) - - futs: list[Coroutine[Any, Any, list[_Response]]] = [] - for path, bounded_ranges in per_file_bounded_requests.items(): - futs.append( - _make_bounded_requests(store, path, bounded_ranges, prototype, semaphore=semaphore) - ) - - for request in other_requests: - futs.append(_make_other_request(store, request, prototype, semaphore=semaphore)) # noqa: PERF401 - - for suffix_request in suffix_requests: - futs.append(_make_suffix_request(store, suffix_request, prototype, semaphore=semaphore)) # noqa: PERF401 - - buffers: list[Buffer | None] = [None] * len(key_ranges) - - for responses in await asyncio.gather(*futs): - for resp in responses: - buffers[resp["original_request_index"]] = resp["buffer"] - - return buffers diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 10ac395b36..f257082c27 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -1,17 +1,83 @@ from __future__ import annotations +import asyncio +import contextlib +import functools import re from pathlib import Path -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Callable, Coroutine, Iterable, Mapping from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + + +class ConcurrencyLimiter: + """Mixin that adds a semaphore-based concurrency limit to a store. + + Stores that inherit from this class gain a ``concurrency_limit`` attribute, + a ``_semaphore`` for internal use, and a ``_limit()`` context-manager helper. + Use the ``@with_concurrency_limit`` decorator on individual async I/O methods. + """ + + concurrency_limit: int | None + """The concurrency limit, or ``None`` for unlimited.""" + + _semaphore: asyncio.Semaphore | None + + def __init__(self, concurrency_limit: int | None = None) -> None: + self.concurrency_limit = concurrency_limit + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) + + def _limit(self) -> asyncio.Semaphore | contextlib.nullcontext[None]: + """Return the semaphore if set, otherwise a no-op context manager.""" + sem = self._semaphore + return sem if sem is not None else contextlib.nullcontext() + + +def with_concurrency_limit( + func: Callable[P, Coroutine[Any, Any, T_co]], +) -> Callable[P, Coroutine[Any, Any, T_co]]: + """Decorator that applies a semaphore-based concurrency limit to an async method. + + This decorator is designed for methods on classes that inherit from + :class:`ConcurrencyLimiter`. + + Examples + -------- + ```python + from zarr.abc.store import Store + from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit + + class MyStore(Store, ConcurrencyLimiter): + def __init__(self, concurrency_limit: int = 100): + Store.__init__(self, read_only=False) + ConcurrencyLimiter.__init__(self, concurrency_limit) + + @with_concurrency_limit + async def get(self, key, prototype, byte_range=None): + # This will only run when semaphore permits + ... + ``` + """ + + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + self = args[0] + async with self._limit(): # type: ignore[attr-defined] + return await func(*args, **kwargs) + + return wrapper + def normalize_path(path: str | bytes | Path | None) -> str: if path is None: diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py new file mode 100644 index 0000000000..98530d7225 --- /dev/null +++ b/src/zarr/testing/store_concurrency.py @@ -0,0 +1,260 @@ +"""Base test class for store concurrency limiting behavior.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +import pytest + +from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.storage._utils import ConcurrencyLimiter + +if TYPE_CHECKING: + from zarr.abc.store import Store + +__all__ = ["StoreConcurrencyTests"] + + +S = TypeVar("S", bound="Store") +B = TypeVar("B", bound="Buffer") + + +class StoreConcurrencyTests(Generic[S, B]): + """Base class for testing store concurrency limiting behavior. + + This mixin provides tests for verifying that stores correctly implement + concurrency limiting. + + Subclasses should set: + - store_cls: The store class being tested + - buffer_cls: The buffer class to use (e.g., cpu.Buffer) + - expected_concurrency_limit: Expected default concurrency limit (or None for unlimited) + """ + + store_cls: type[S] + buffer_cls: type[B] + expected_concurrency_limit: int | None + + @pytest.fixture + async def store(self, store_kwargs: dict[str, Any]) -> S: + """Create and open a store instance.""" + return await self.store_cls.open(**store_kwargs) + + @staticmethod + def _get_semaphore(store: Store) -> asyncio.Semaphore | None: + """Get the semaphore from a store, or None if the store doesn't support concurrency limiting.""" + if isinstance(store, ConcurrencyLimiter): + return store._semaphore + return None + + def test_concurrency_limit_default(self, store: S) -> None: + """Test that store has the expected default concurrency limit.""" + # Concrete subclasses inherit from both Store and ConcurrencyLimiter, + # but S is bound to Store so mypy considers this branch unreachable. + if not isinstance(store, ConcurrencyLimiter): + assert self.expected_concurrency_limit is None + return + assert store.concurrency_limit == self.expected_concurrency_limit # type: ignore[unreachable] + + def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None: + """Test that custom concurrency limits can be set.""" + if not issubclass(self.store_cls, ConcurrencyLimiter): + pytest.skip("Store does not support custom concurrency limits") + + # mypy considers this unreachable because S is bound to Store, not ConcurrencyLimiter. + # Test with custom limit + store = self.store_cls(**{**store_kwargs, "concurrency_limit": 42}) # type: ignore[unreachable] + assert isinstance(store, ConcurrencyLimiter) + assert store.concurrency_limit == 42 + assert store._semaphore is not None + + # Test with None (unlimited) + store = self.store_cls(**{**store_kwargs, "concurrency_limit": None}) + assert isinstance(store, ConcurrencyLimiter) + assert store.concurrency_limit is None + assert store._semaphore is None + + async def test_concurrency_limit_enforced(self, store: S) -> None: + """Test that the concurrency limit is actually enforced during execution. + + This test verifies that when many operations are submitted concurrently, + only up to the concurrency limit are actually executing at once. + """ + semaphore = self._get_semaphore(store) + if semaphore is None: + pytest.skip("Store has no concurrency limit") + + assert isinstance(store, ConcurrencyLimiter) + assert store.concurrency_limit is not None # type: ignore[unreachable] + limit = store.concurrency_limit + + # We'll monitor the semaphore's available count + # When it reaches 0, that means `limit` operations are running + min_available = limit + + async def monitored_operation(key: str, value: B) -> None: + nonlocal min_available + # Check semaphore state right after we're scheduled + await asyncio.sleep(0) # Yield to ensure we're in the queue + available = semaphore._value + min_available = min(min_available, available) + + # Now do the actual operation (which will acquire the semaphore) + await store.set(key, value) + + # Launch more operations than the limit to ensure contention + num_ops = limit * 2 + items = [ + (f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode())) + for i in range(num_ops) + ] + + await asyncio.gather(*[monitored_operation(k, v) for k, v in items]) + + # The semaphore should have been fully utilized (reached 0 or close to it) + # This indicates that `limit` operations were running concurrently + assert min_available < limit, ( + f"Semaphore was never fully utilized. " + f"Min available: {min_available}, Limit: {limit}. " + f"This suggests operations aren't running concurrently." + ) + + # Ideally it should reach 0, but allow some slack for timing + assert min_available <= 5, ( + f"Semaphore only reached {min_available} available slots. " + f"Expected close to 0 with limit {limit}." + ) + + async def test_batch_write_no_deadlock(self, store: S) -> None: + """Test that batch writes don't deadlock when exceeding concurrency limit.""" + # Create more items than any reasonable concurrency limit + num_items = 200 + items = [ + (f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode())) + for i in range(num_items) + ] + + # This should complete without deadlock, even if num_items > concurrency_limit + await asyncio.wait_for(store._set_many(items), timeout=30.0) + + # Verify all items were written correctly + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_read_no_deadlock(self, store: S) -> None: + """Test that batch reads don't deadlock when exceeding concurrency limit.""" + # Write test data + num_items = 200 + test_data = { + f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode()) + for i in range(num_items) + } + + for key, value in test_data.items(): + await store.set(key, value) + + # Read all items concurrently - should not deadlock + keys_and_ranges = [(key, None) for key in test_data] + results = await asyncio.wait_for( + store.get_partial_values(default_buffer_prototype(), keys_and_ranges), + timeout=30.0, + ) + + # Verify results + assert len(results) == num_items + for result, (_key, expected_value) in zip(results, test_data.items(), strict=True): + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_delete_no_deadlock(self, store: S) -> None: + """Test that batch deletes don't deadlock when exceeding concurrency limit.""" + if not store.supports_deletes: + pytest.skip("Store does not support deletes") + + # Write test data + num_items = 200 + keys = [f"test_key_{i}" for i in range(num_items)] + for key in keys: + await store.set(key, self.buffer_cls.from_bytes(b"test_value")) + + # Delete all items concurrently - should not deadlock + await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0) + + # Verify all items were deleted + for key in keys: + result = await store.get(key, default_buffer_prototype()) + assert result is None + + async def test_concurrent_operations_correctness(self, store: S) -> None: + """Test that concurrent operations produce correct results.""" + num_operations = 100 + + # Mix of reads and writes + write_keys = [f"write_key_{i}" for i in range(num_operations)] + write_values = [ + self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations) + ] + + # Write all concurrently + await asyncio.gather( + *[store.set(k, v) for k, v in zip(write_keys, write_values, strict=True)] + ) + + # Read all concurrently + results = await asyncio.gather( + *[store.get(k, default_buffer_prototype()) for k in write_keys] + ) + + # Verify correctness + for result, expected in zip(results, write_values, strict=True): + assert result is not None + assert result.to_bytes() == expected.to_bytes() + + @pytest.mark.parametrize("batch_size", [1, 10, 50, 100]) + async def test_various_batch_sizes(self, store: S, batch_size: int) -> None: + """Test that various batch sizes work correctly.""" + items = [ + (f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode())) + for i in range(batch_size) + ] + + # Should complete without issues for any batch size + await asyncio.wait_for(store._set_many(items), timeout=10.0) + + # Verify + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_empty_batch_operations(self, store: S) -> None: + """Test that empty batch operations don't cause issues.""" + # Empty batch should not raise + await store._set_many([]) + + # Empty read batch + results = await store.get_partial_values(default_buffer_prototype(), []) + assert results == [] + + async def test_mixed_success_failure_batch(self, store: S) -> None: + """Test batch operations with mix of successful and failing items.""" + # Write some initial data + await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value")) + + # Try to read mix of existing and non-existing keys + key_ranges = [ + ("existing_key", None), + ("non_existing_key_1", None), + ("non_existing_key_2", None), + ] + + results = await store.get_partial_values(default_buffer_prototype(), key_ranges) + + # First should exist, others should be None + assert results[0] is not None + assert results[0].to_bytes() == b"existing_value" + assert results[1] is None + assert results[2] is None diff --git a/tests/test_common.py b/tests/test_common.py index 0dedde1d6b..09cb6df2f8 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,10 +31,6 @@ def test_access_modes() -> None: assert set(ANY_ACCESS_MODE) == set(get_args(AccessModeLiteral)) -# todo: test -def test_concurrent_map() -> None: ... - - # todo: test def test_to_thread() -> None: ... diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..d3c109d41a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -55,7 +55,7 @@ def test_config_defaults_set() -> None: "write_empty_chunks": False, "target_shard_size_bytes": None, }, - "async": {"concurrency": 10, "timeout": None}, + "async": {"timeout": None}, "threading": {"max_workers": None}, "json_indent": 2, "codec_pipeline": { @@ -101,7 +101,6 @@ def test_config_defaults_set() -> None: ] ) assert config.get("array.order") == "C" - assert config.get("async.concurrency") == 10 assert config.get("async.timeout") is None assert config.get("codec_pipeline.batch_size") == 1 assert config.get("json_indent") == 2 @@ -109,7 +108,7 @@ def test_config_defaults_set() -> None: @pytest.mark.parametrize( ("key", "old_val", "new_val"), - [("array.order", "C", "F"), ("async.concurrency", 10, 128), ("json_indent", 2, 0)], + [("array.order", "C", "F"), ("json_indent", 2, 0)], ) def test_config_defaults_can_be_overridden(key: str, old_val: Any, new_val: Any) -> None: assert config.get(key) == old_val @@ -347,3 +346,15 @@ def test_deprecated_config(key: str) -> None: with pytest.raises(ValueError): with zarr.config.set({key: "foo"}): pass + + +def test_async_concurrency_config_warns() -> None: + """Test that setting async.concurrency emits a warning directing users to per-store config.""" + with pytest.warns(UserWarning, match="async.concurrency.*no effect"): + with zarr.config.set({"async.concurrency": 20}): + pass + + # Also test the kwarg form + with pytest.warns(UserWarning, match="async.concurrency.*no effect"): + with zarr.config.set(async__concurrency=20): + pass diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py new file mode 100644 index 0000000000..3cfca5052c --- /dev/null +++ b/tests/test_global_concurrency.py @@ -0,0 +1,43 @@ +""" +Tests for store-level concurrency limiting through the array API. +""" + +import numpy as np + +import zarr + + +class TestStoreConcurrencyThroughArrayAPI: + """Tests that store-level concurrency limiting works through the array API.""" + + def test_array_operations_with_store_concurrency(self) -> None: + """Test that array read/write works correctly with store-level concurrency limits.""" + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) + + def test_array_operations_with_local_store_concurrency(self, tmp_path: object) -> None: + """Test that array read/write works correctly with LocalStore concurrency limits.""" + store = zarr.storage.LocalStore(str(tmp_path), concurrency_limit=10) + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + data = arr[:] + + assert np.all(data == 42) diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..9f25036298 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -23,7 +23,6 @@ from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 from zarr.core.group import ( @@ -1738,29 +1737,6 @@ async def test_create_nodes( assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of create_nodes can be constrained by the async concurrency - configuration setting. - """ - set_latency = 0.02 - num_groups = 10 - groups = {str(idx): GroupMetadata() for idx in range(num_groups)} - - latency_store = LatencyStore(store, set_latency=set_latency) - - # check how long it takes to iterate over the groups - # if create_nodes is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) - elapsed = time.time() - start - assert elapsed > num_groups * set_latency - - @pytest.mark.parametrize( ("a_func", "b_func"), [ @@ -2250,38 +2226,6 @@ def test_group_members_performance(store: Store) -> None: assert elapsed < (num_groups * get_latency) -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_members_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of Group.members can be constrained by the async concurrency - configuration setting. - """ - get_latency = 0.02 - - # use the input store to create some groups - group_create = zarr.group(store=store) - num_groups = 10 - - # Create some groups - for i in range(num_groups): - group_create.create_group(f"group{i}") - - latency_store = LatencyStore(store, get_latency=get_latency) - # create a group with some latency on get operations - group_read = zarr.group(store=latency_store) - - # check how long it takes to iterate over the groups - # if .members is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = group_read.members() - elapsed = time.time() - start - - assert elapsed > num_groups * get_latency - - @pytest.mark.parametrize("option", ["array", "group", "invalid"]) def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: """ diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index bdc9b48121..ca9759bd9a 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -15,6 +15,7 @@ from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import assert_bytes_equal if TYPE_CHECKING: @@ -204,3 +205,15 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +class TestLocalStoreConcurrency(StoreConcurrencyTests[LocalStore, cpu.Buffer]): + """Test LocalStore concurrency limiting behavior.""" + + store_cls = LocalStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = 100 # LocalStore default + + @pytest.fixture + def store_kwargs(self, tmpdir: str) -> dict[str, str]: + return {"root": str(tmpdir)} diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 03c8b24271..1004ca20bb 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -14,6 +14,7 @@ from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: @@ -181,3 +182,15 @@ def test_from_dict(self) -> None: result = GpuMemoryStore.from_dict(d) for v in result._store_dict.values(): assert type(v) is gpu.Buffer + + +class TestMemoryStoreConcurrency(StoreConcurrencyTests[MemoryStore, cpu.Buffer]): + """Test MemoryStore concurrency limiting behavior.""" + + store_cls = MemoryStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = None # MemoryStore has no limit (fast in-memory ops) + + @pytest.fixture + def store_kwargs(self) -> dict[str, Any]: + return {"store_dict": None}