diff --git a/changes/0000.feature.md b/changes/0000.feature.md new file mode 100644 index 0000000000..ce26a8c209 --- /dev/null +++ b/changes/0000.feature.md @@ -0,0 +1 @@ +Add `zarr.storage.FsspecStore.get_ranges` for concurrent, coalesced multi-range reads from a single key. A new keyword-only constructor argument `coalesce_options` on `FsspecStore` controls the max gap, max coalesced size, and max concurrency of the underlying requests. diff --git a/src/zarr/core/_coalesce.py b/src/zarr/core/_coalesce.py new file mode 100644 index 0000000000..40fe14bf04 --- /dev/null +++ b/src/zarr/core/_coalesce.py @@ -0,0 +1,195 @@ +# src/zarr/core/_coalesce.py +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, TypedDict + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence + + from zarr.abc.store import ByteRequest + from zarr.core.buffer import Buffer + + +class CoalesceOptions(TypedDict): + """Knobs for coalescing contiguous byte ranges into fewer I/O requests. + + All fields required. See DEFAULT_COALESCE_OPTIONS for a sensible default. + """ + + max_gap_bytes: int + """Two RangeByteRequests separated by at most this many bytes may be merged into one fetch.""" + max_coalesced_bytes: int + """Upper bound on the size of a single merged fetch (ignored for an already-oversized single request).""" + max_concurrency: int + """Maximum number of merged fetches in flight at once.""" + + +DEFAULT_COALESCE_OPTIONS: CoalesceOptions = { + "max_gap_bytes": 1 << 20, # 1 MiB + "max_coalesced_bytes": 16 << 20, # 16 MiB + "max_concurrency": 10, +} + + +async def coalesced_get( + fetch: Callable[[ByteRequest | None], Awaitable[Buffer | None]], + byte_ranges: Iterable[ByteRequest | None], + *, + options: CoalesceOptions, +) -> AsyncGenerator[Sequence[tuple[int, Buffer | None]], None]: + """Read many byte ranges through ``fetch`` with coalescing and concurrency. + + Nearby ranges are merged into a single underlying I/O (subject to + ``options``), and merged fetches are run concurrently. Each yield + corresponds to exactly one underlying I/O operation: a sequence of + ``(input_index, result)`` tuples for all input ranges served by that I/O. + Tuples within a yielded sequence are ordered by start offset. Yields across + groups are in completion order, not input order. + + Parameters + ---------- + fetch + Callable that reads one byte range and returns a ``Buffer`` (or ``None`` + if the underlying key does not exist). Typically constructed via + ``functools.partial(store.get, key, prototype)``. + byte_ranges + Input ranges. ``None`` means "the whole value". + options + Coalescing knobs. + + Yields + ------ + Sequence[tuple[int, Buffer | None]] + Per-I/O batch of ``(input_index, result)`` tuples. + + Notes + ----- + - Only ``RangeByteRequest`` inputs are coalesced. ``OffsetByteRequest``, + ``SuffixByteRequest``, and ``None`` are each treated as uncoalescable + (one fetch, one single-tuple yield per input). + - If any fetch returns ``None`` the iterator stops scheduling further fetches + and completes without yielding the missing group. Groups completed before + the miss remain observable. + - If a fetch raises, the exception propagates on the yield that produced the + failing group; earlier-completed groups remain observable. + """ + # Local import to avoid cycles at module import time. + from zarr.abc.store import RangeByteRequest + + indexed: list[tuple[int, ByteRequest | None]] = list(enumerate(byte_ranges)) + if not indexed: + return + + # Split inputs into coalescable (RangeByteRequest only) and uncoalescable (the rest). + mergeable: list[tuple[int, RangeByteRequest]] = [ + (i, r) for i, r in indexed if isinstance(r, RangeByteRequest) + ] + uncoalescable: list[tuple[int, ByteRequest | None]] = [ + (i, r) for i, r in indexed if not isinstance(r, RangeByteRequest) + ] + + # Sort mergeables by start offset, then merge. + mergeable.sort(key=lambda pair: pair[1].start) + groups: list[list[tuple[int, RangeByteRequest]]] = [] + for pair in mergeable: + _i, r = pair + if groups: + last = groups[-1] + last_end = max(x[1].end for x in last) + gap = r.start - last_end + merged_start = min(x[1].start for x in last) + prospective_end = max(last_end, r.end) + prospective_size = prospective_end - merged_start + if ( + gap <= options["max_gap_bytes"] + and prospective_size <= options["max_coalesced_bytes"] + ): + last.append(pair) + continue + groups.append([pair]) + + # Build a uniform list of work items. Each work item is a list of + # (input_index, ByteRequest | None) pairs. Merged groups have multiple + # members (all RangeByteRequest); uncoalescable items have a single member. + work_items: list[list[tuple[int, ByteRequest | None]]] = [ + [(idx, r) for idx, r in g] for g in groups + ] + work_items.extend([(idx, single)] for idx, single in uncoalescable) + + # Completion queue entries are either ("ok", payload), ("missing", None), + # or ("error", exception). Kept as Any internally to avoid dragging + # Sequence out of TYPE_CHECKING. + completion_queue: asyncio.Queue[ + tuple[str, Sequence[tuple[int, Buffer | None]] | BaseException | None] + ] = asyncio.Queue() + semaphore = asyncio.Semaphore(options["max_concurrency"]) + + async def run_one(members: list[tuple[int, ByteRequest | None]]) -> None: + try: + async with semaphore: + if len(members) == 1 and not isinstance(members[0][1], RangeByteRequest): + # Uncoalescable single fetch. + idx, single = members[0] + buf = await fetch(single) + if buf is None: + await completion_queue.put(("missing", None)) + return + await completion_queue.put(("ok", ((idx, buf),))) + return + # Merged group path: all members are RangeByteRequest. + assert all(isinstance(r, RangeByteRequest) for _, r in members) + starts = [r.start for _, r in members] # type: ignore[union-attr] + ends = [r.end for _, r in members] # type: ignore[union-attr] + group_start = min(starts) + group_end = max(ends) + big = await fetch(RangeByteRequest(group_start, group_end)) + if big is None: + await completion_queue.put(("missing", None)) + return + ordered = sorted(members, key=lambda pair: pair[1].start) # type: ignore[union-attr] + sliced: list[tuple[int, Buffer | None]] = [] + for idx, r in ordered: + sliced.append((idx, big[r.start - group_start : r.end - group_start])) # type: ignore[union-attr] + await completion_queue.put(("ok", tuple(sliced))) + except asyncio.CancelledError: + # Cancellation is expected when we stop scheduling on a missing key. + raise + except BaseException as exc: + await completion_queue.put(("error", exc)) + + # Launch all work items as tasks. The semaphore bounds actual concurrency. + tasks: set[asyncio.Task[None]] = set() + for item in work_items: + tasks.add(asyncio.create_task(run_one(item))) + + try: + pending_error: BaseException | None = None + for _ in range(len(work_items)): + kind, payload = await completion_queue.get() + if kind == "ok": + assert payload is not None + assert not isinstance(payload, BaseException) + yield payload + continue + # "missing" or "error": stop scheduling and cancel pending work. + # Late arrivals that raced to enqueue before cancellation took + # effect sit in the completion queue and are discarded by the + # finally block (the queue is local and will be garbage-collected). + for t in tasks: + if not t.done(): + t.cancel() + if kind == "error": + assert isinstance(payload, BaseException) + pending_error = payload + break + if pending_error is not None: + raise pending_error + finally: + # Best-effort cancellation for in-flight tasks (covers the consumer + # break / early-exit case where we did not proactively cancel). + for t in tasks: + if not t.done(): + t.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 74e5869a66..d967e64ca2 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -3,6 +3,7 @@ import json import warnings from contextlib import suppress +from functools import partial from typing import TYPE_CHECKING, Any from packaging.version import parse as parse_version @@ -14,18 +15,24 @@ Store, SuffixByteRequest, ) +from zarr.core._coalesce import ( + DEFAULT_COALESCE_OPTIONS, + CoalesceOptions, + coalesced_get, +) from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._utils import _join_paths, normalize_path if TYPE_CHECKING: - from collections.abc import AsyncIterator, Iterable + from collections.abc import AsyncIterator, Iterable, Sequence from fsspec import AbstractFileSystem from fsspec.asyn import AsyncFileSystem from fsspec.mapping import FSMap from zarr.core.buffer import BufferPrototype + from zarr.storage._protocols import SupportsGetRanges ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = ( @@ -124,11 +131,14 @@ def __init__( read_only: bool = False, path: str = "/", allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + *, + coalesce_options: CoalesceOptions = DEFAULT_COALESCE_OPTIONS, ) -> None: super().__init__(read_only=read_only) self.fs = fs self.path = normalize_path(path) self.allowed_exceptions = allowed_exceptions + self.coalesce_options = coalesce_options if not self.fs.async_impl: raise TypeError("Filesystem needs to support async operations.") @@ -315,6 +325,22 @@ async def get( else: return value + async def get_ranges( + self, + key: str, + byte_ranges: Iterable[ByteRequest | None], + *, + prototype: BufferPrototype, + ) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]: + """Read many byte ranges from ``key``, coalescing nearby ranges and fetching concurrently. + + See :class:`zarr.storage._protocols.SupportsGetRanges` for the contract and + :func:`zarr.core._coalesce.coalesced_get` for the full semantics. + """ + fetch = partial(self.get, key, prototype) + async for group in coalesced_get(fetch, byte_ranges, options=self.coalesce_options): + yield group + async def set( self, key: str, @@ -440,3 +466,8 @@ async def getsize(self, key: str) -> int: else: # fsspec doesn't have typing. We'll need to assume or verify this is true return int(size) + + +# Module-level type assertion: FsspecStore structurally satisfies SupportsGetRanges. +# This line is a no-op at runtime but causes mypy/pyright to complain if the shape drifts. +_: type[SupportsGetRanges] = FsspecStore diff --git a/src/zarr/storage/_protocols.py b/src/zarr/storage/_protocols.py new file mode 100644 index 0000000000..086a15dd90 --- /dev/null +++ b/src/zarr/storage/_protocols.py @@ -0,0 +1,34 @@ +# src/zarr/storage/_protocols.py +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterable, Sequence + + from zarr.abc.store import ByteRequest + from zarr.core.buffer import Buffer, BufferPrototype + + +@runtime_checkable +class SupportsGetRanges(Protocol): + """Stores that satisfy this protocol can efficiently read many byte ranges + from a single key in a single call, typically via coalescing and concurrent fetch. + + Private / unstable. Shape may change before being made public. + """ + + def get_ranges( + self, + key: str, + byte_ranges: Iterable[ByteRequest | None], + *, + prototype: BufferPrototype, + ) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]: + """Read many byte ranges from ``key``. + + Each yield corresponds to one underlying I/O operation. + + See :func:`zarr.core._coalesce.coalesced_get` for full semantics. + """ + ... diff --git a/tests/test_coalesce.py b/tests/test_coalesce.py new file mode 100644 index 0000000000..68a620ea5d --- /dev/null +++ b/tests/test_coalesce.py @@ -0,0 +1,516 @@ +# tests/test_coalesce.py +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import pytest + +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + SuffixByteRequest, +) +from zarr.core._coalesce import ( + DEFAULT_COALESCE_OPTIONS, + CoalesceOptions, + coalesced_get, +) +from zarr.core.buffer import Buffer, default_buffer_prototype + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Callable, Sequence + + +def _buf(data: bytes) -> Buffer: + return default_buffer_prototype().buffer.from_bytes(data) + + +@dataclass +class FakeFetch: + """Records every call and serves canned bytes from an in-memory blob.""" + + blob: bytes + key_exists: bool = True + raise_on: Callable[[ByteRequest | None], bool] | None = None + calls: list[ByteRequest | None] = field(default_factory=list) + + async def __call__(self, byte_range: ByteRequest | None) -> Buffer | None: + self.calls.append(byte_range) + if not self.key_exists: + return None + if self.raise_on is not None and self.raise_on(byte_range): + raise OSError("injected") + if byte_range is None: + return _buf(self.blob) + if isinstance(byte_range, RangeByteRequest): + return _buf(self.blob[byte_range.start : byte_range.end]) + if isinstance(byte_range, OffsetByteRequest): + return _buf(self.blob[byte_range.offset :]) + if isinstance(byte_range, SuffixByteRequest): + return _buf(self.blob[-byte_range.suffix :]) + raise AssertionError(f"unknown byte_range {byte_range!r}") + + +async def _collect( + agen: AsyncIterator[Sequence[tuple[int, Buffer | None]]], +) -> list[list[tuple[int, Buffer | None]]]: + """Drain an async generator of groups into a list of lists of tuples.""" + return [list(group) async for group in agen] + + +def _contents(groups: list[list[tuple[int, Buffer | None]]]) -> dict[int, bytes]: + """Flatten to {index: bytes}.""" + result: dict[int, bytes] = {} + for group in groups: + for idx, buf in group: + assert buf is not None + result[idx] = buf.to_bytes() + return result + + +# --------------------------------------------------------------------------- +# Shared option values for parametrized structural tests. +# --------------------------------------------------------------------------- + +DEFAULT: CoalesceOptions = DEFAULT_COALESCE_OPTIONS +"""The library default; permissive merging.""" + +MERGE_GAP_50: CoalesceOptions = { + "max_gap_bytes": 50, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 10, +} +"""Merge ranges within 50 bytes of each other.""" + +NO_MERGE: CoalesceOptions = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 10, +} +"""No merging: any positive gap is > -1, so no pair ever coalesces.""" + +CAP_50: CoalesceOptions = { + "max_gap_bytes": 1000, + "max_coalesced_bytes": 50, + "max_concurrency": 10, +} +"""Gap permissive but merged size capped at 50 bytes.""" + + +# A deterministic blob used for content-sensitive cases: byte i == (i % 256). +_INDEXED_BLOB = bytes(i % 256 for i in range(10_000)) + + +# --------------------------------------------------------------------------- +# Parametrized structural/content tests (cases without async timing or errors). +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class StructuralCase: + """One row of the parametrized structure-and-contents table.""" + + id: str + """pytest id for the case.""" + ranges: list[ByteRequest | None] + """Input to coalesced_get.""" + options: CoalesceOptions + """Coalescing knobs.""" + expected_group_sizes: list[int] + """Sorted list of group tuple-counts (order-independent).""" + expected_contents: dict[int, bytes] | None = None + """{input_index: bytes} to verify bytes, or None to skip the content check.""" + expected_n_fetches: int | None = None + """Exact number of calls to the fetch callable, or None to skip the check.""" + + +_STRUCTURAL_CASES: list[StructuralCase] = [ + StructuralCase( + id="empty-input", + ranges=[], + options=DEFAULT, + expected_group_sizes=[], + expected_n_fetches=0, + ), + StructuralCase( + id="single-range", + ranges=[RangeByteRequest(2, 5)], + options=DEFAULT, + expected_group_sizes=[1], + expected_contents={0: _INDEXED_BLOB[2:5]}, + expected_n_fetches=1, + ), + StructuralCase( + id="disjoint-3-no-merge", + ranges=[ + RangeByteRequest(0, 10), + RangeByteRequest(200, 210), + RangeByteRequest(500, 510), + ], + options=MERGE_GAP_50, + expected_group_sizes=[1, 1, 1], + expected_contents={ + 0: _INDEXED_BLOB[0:10], + 1: _INDEXED_BLOB[200:210], + 2: _INDEXED_BLOB[500:510], + }, + expected_n_fetches=3, + ), + StructuralCase( + id="adjacent-3-one-merged-group", + ranges=[ + RangeByteRequest(0, 5), + RangeByteRequest(10, 15), + RangeByteRequest(20, 25), + ], + options=MERGE_GAP_50, + expected_group_sizes=[3], + expected_contents={ + 0: _INDEXED_BLOB[0:5], + 1: _INDEXED_BLOB[10:15], + 2: _INDEXED_BLOB[20:25], + }, + expected_n_fetches=1, + ), + StructuralCase( + id="two-clusters-one-singleton", + ranges=[ + RangeByteRequest(0, 10), + RangeByteRequest(20, 30), + RangeByteRequest(500, 510), + ], + options=MERGE_GAP_50, + expected_group_sizes=[1, 2], + expected_contents={ + 0: _INDEXED_BLOB[0:10], + 1: _INDEXED_BLOB[20:30], + 2: _INDEXED_BLOB[500:510], + }, + expected_n_fetches=2, + ), + StructuralCase( + id="uncoalescable-mixed-with-range", + ranges=[ + RangeByteRequest(0, 3), + OffsetByteRequest(5), + SuffixByteRequest(2), + None, + ], + options=DEFAULT, + expected_group_sizes=[1, 1, 1, 1], + expected_contents={ + 0: _INDEXED_BLOB[0:3], + 1: _INDEXED_BLOB[5:], + 2: _INDEXED_BLOB[-2:], + 3: _INDEXED_BLOB, + }, + expected_n_fetches=4, + ), + StructuralCase( + id="shuffled-input-indices-preserved", + ranges=[ + RangeByteRequest(500, 510), + RangeByteRequest(0, 10), + RangeByteRequest(200, 210), + RangeByteRequest(300, 310), + ], + options=MERGE_GAP_50, + expected_group_sizes=[1, 1, 1, 1], + expected_contents={ + 0: _INDEXED_BLOB[500:510], + 1: _INDEXED_BLOB[0:10], + 2: _INDEXED_BLOB[200:210], + 3: _INDEXED_BLOB[300:310], + }, + expected_n_fetches=4, + ), + StructuralCase( + id="cap-prevents-merge-of-close-ranges", + # 20 + 20 gap + 20 = 60-byte merged span > cap of 50. + ranges=[RangeByteRequest(0, 20), RangeByteRequest(40, 60)], + options=CAP_50, + expected_group_sizes=[1, 1], + expected_n_fetches=2, + ), + StructuralCase( + id="single-range-larger-than-cap-passes-through", + # Cap only applies to MERGE decisions; a lone oversized range still fetches. + ranges=[RangeByteRequest(0, 200)], + options=CAP_50, + expected_group_sizes=[1], + expected_contents={0: _INDEXED_BLOB[0:200]}, + expected_n_fetches=1, + ), +] + + +@pytest.mark.parametrize("case", _STRUCTURAL_CASES, ids=lambda c: c.id) +async def test_coalescing_structure_and_contents(case: StructuralCase) -> None: + """Group structure, byte contents, and fetch-call count for the deterministic cases.""" + fetch = FakeFetch(_INDEXED_BLOB) + groups = await _collect(coalesced_get(fetch, case.ranges, options=case.options)) + + assert sorted(len(g) for g in groups) == sorted(case.expected_group_sizes) + + if case.expected_contents is not None: + assert _contents(groups) == case.expected_contents + + if case.expected_n_fetches is not None: + assert len(fetch.calls) == case.expected_n_fetches + + +# --------------------------------------------------------------------------- +# Focused non-parametrized tests for cases with distinctive assertion shapes. +# --------------------------------------------------------------------------- + + +async def test_within_group_ordering_is_start_offset() -> None: + """Within a merged group, tuples are ordered by start offset, not input order.""" + fetch = FakeFetch(_INDEXED_BLOB) + # Two ranges that merge; one has a later start but is listed first in input. + ranges: list[ByteRequest | None] = [RangeByteRequest(20, 25), RangeByteRequest(0, 5)] + groups = await _collect(coalesced_get(fetch, ranges, options=MERGE_GAP_50)) + assert len(groups) == 1 + # Input index 1 (start=0) comes first, then 0 (start=20). + assert [idx for idx, _ in groups[0]] == [1, 0] + + +async def test_adjacent_ranges_fire_single_fetch_spanning_merged_region() -> None: + """Verify the merged fetch covers exactly the span from min-start to max-end.""" + fetch = FakeFetch(_INDEXED_BLOB) + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 5), + RangeByteRequest(10, 15), + RangeByteRequest(20, 25), + ] + await _collect(coalesced_get(fetch, ranges, options=MERGE_GAP_50)) + assert len(fetch.calls) == 1 + call = fetch.calls[0] + assert isinstance(call, RangeByteRequest) + assert call.start == 0 + assert call.end == 25 + + +# --------------------------------------------------------------------------- +# Concurrency and cancellation. +# --------------------------------------------------------------------------- + + +async def test_max_concurrency_is_honored() -> None: + """With 10 non-mergeable ranges and max_concurrency=3, peak in-flight must not exceed 3.""" + in_flight = 0 + peak = 0 + lock = asyncio.Lock() + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + nonlocal in_flight, peak + async with lock: + in_flight += 1 + peak = max(peak, in_flight) + # give the scheduler a chance to run other tasks + await asyncio.sleep(0.01) + async with lock: + in_flight -= 1 + return _buf(b"x") + + ranges: list[ByteRequest | None] = [RangeByteRequest(i * 1000, i * 1000 + 1) for i in range(10)] + opts: CoalesceOptions = { + "max_gap_bytes": 0, # force no merging + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 3, + } + async for _group in coalesced_get(fetch, ranges, options=opts): + pass + assert peak <= 3 + assert peak >= 2 # must have been some real concurrency + + +async def test_consumer_break_cancels_pending_fetches() -> None: + """Breaking out of the async for should cancel pending fetches rather than let them complete.""" + completed_calls = 0 + cancelled_calls = 0 + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + nonlocal completed_calls, cancelled_calls + assert isinstance(byte_range, RangeByteRequest) + start = byte_range.start + try: + # First fetch returns fast so the async-for body runs and can break. + # Later fetches sleep long enough that cancellation has room to land. + await asyncio.sleep(0.001 if start == 0 else 2.0) + except asyncio.CancelledError: + cancelled_calls += 1 + raise + completed_calls += 1 + return _buf(b"x") + + opts: CoalesceOptions = { + "max_gap_bytes": -1, # no merging + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 3, + } + ranges: list[ByteRequest | None] = [RangeByteRequest(i * 1000, i * 1000 + 1) for i in range(6)] + + agen = coalesced_get(fetch, ranges, options=opts) + async for _group in agen: + break + # Explicitly close the generator so its finally block runs (cancelling + # in-flight tasks) before we make assertions. + await agen.aclose() + + assert cancelled_calls >= 1 + assert completed_calls >= 1 + + +# --------------------------------------------------------------------------- +# Key-missing semantics. +# --------------------------------------------------------------------------- + + +async def test_key_missing_from_first_call_yields_nothing() -> None: + """If the very first fetch returns None, the iterator yields no groups.""" + fetch = FakeFetch(b"x" * 100, key_exists=False) + ranges: list[ByteRequest | None] = [RangeByteRequest(0, 10), RangeByteRequest(20, 30)] + groups = await _collect(coalesced_get(fetch, ranges, options=DEFAULT_COALESCE_OPTIONS)) + assert groups == [] + + +@pytest.mark.parametrize( + "byte_range", + [OffsetByteRequest(5), SuffixByteRequest(5), None], + ids=["offset", "suffix", "none"], +) +async def test_key_missing_on_uncoalescable_input_yields_nothing( + byte_range: ByteRequest | None, +) -> None: + """Uncoalescable inputs take a distinct path; key-missing must still short-circuit.""" + fetch = FakeFetch(b"x" * 100, key_exists=False) + groups = await _collect(coalesced_get(fetch, [byte_range], options=DEFAULT_COALESCE_OPTIONS)) + assert groups == [] + + +async def test_key_missing_mid_stream_yields_earlier_groups_only() -> None: + """If a later fetch returns None, earlier-completed groups remain observable.""" + call_count = 0 + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + nonlocal call_count + call_count += 1 + # Deterministic: first call serves, second returns None. + await asyncio.sleep(0.01 if call_count == 1 else 0.02) + if call_count >= 2: + return None + return _buf(b"ok") + + opts: CoalesceOptions = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 1, # serialize for determinism + } + ranges: list[ByteRequest | None] = [RangeByteRequest(0, 2), RangeByteRequest(100, 102)] + groups = await _collect(coalesced_get(fetch, ranges, options=opts)) + assert len(groups) == 1 + assert len(groups[0]) == 1 + + +async def test_key_missing_mid_stream_with_concurrency_drains_late_arrivals() -> None: + """ + Under max_concurrency > 1, a mid-stream miss should still cause the iterator + to complete cleanly even when unrelated tasks are still in flight and arrive + after the miss has been observed. + """ + late_gate = asyncio.Event() + miss_fired = asyncio.Event() + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + assert isinstance(byte_range, RangeByteRequest) + start = byte_range.start + if start == 0: + # First to complete: arrives before the miss. + await asyncio.sleep(0.01) + return _buf(b"ok") + if start == 1000: + # Miss: a little later than #0 so #0 yields first. + await asyncio.sleep(0.03) + miss_fired.set() + return None + # Late arrivals: wait until the miss has been processed, then return + # a buffer so the drain loop sees them post-stop. + await asyncio.wait_for(miss_fired.wait(), timeout=5.0) + await asyncio.wait_for(late_gate.wait(), timeout=5.0) + return _buf(b"ok") + + opts: CoalesceOptions = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 3, + } + ranges: list[ByteRequest | None] = [RangeByteRequest(i * 1000, i * 1000 + 1) for i in range(7)] + + groups: list[list[tuple[int, Buffer | None]]] = [] + agen = coalesced_get(fetch, ranges, options=opts) + try: + async for group in agen: + groups.append(list(group)) + late_gate.set() + finally: + late_gate.set() + + assert len(groups) == 1 + assert len(groups[0]) == 1 + idx, buf = groups[0][0] + assert idx == 0 + assert buf is not None + assert miss_fired.is_set() + + +# --------------------------------------------------------------------------- +# Error propagation. +# --------------------------------------------------------------------------- + + +async def test_fetch_raises_propagates() -> None: + """An exception raised by fetch propagates on the yield that produced the failing group.""" + fetch = FakeFetch( + _INDEXED_BLOB, + raise_on=lambda r: isinstance(r, RangeByteRequest) and r.start >= 100, + ) + opts: CoalesceOptions = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 1, + } + ranges: list[ByteRequest | None] = [RangeByteRequest(0, 10), RangeByteRequest(200, 210)] + with pytest.raises(OSError, match="injected"): + await _collect(coalesced_get(fetch, ranges, options=opts)) + + +# --------------------------------------------------------------------------- +# Property-style coverage invariant. +# --------------------------------------------------------------------------- + + +async def test_coverage_invariant_random_inputs() -> None: + """For any random RangeByteRequest input, every input index appears exactly once.""" + import random + + rng = random.Random(42) + fetch = FakeFetch(_INDEXED_BLOB) + + ranges: list[ByteRequest | None] = [] + for _ in range(50): + start = rng.randint(0, 9000) + length = rng.randint(1, 500) + ranges.append(RangeByteRequest(start, start + length)) + + groups = await _collect(coalesced_get(fetch, ranges, options=DEFAULT_COALESCE_OPTIONS)) + seen: list[int] = [idx for group in groups for idx, _buf in group] + assert sorted(seen) == list(range(len(ranges))) + + flat = _contents(groups) + for i, r in enumerate(ranges): + assert isinstance(r, RangeByteRequest) + assert flat[i] == _INDEXED_BLOB[r.start : r.end] diff --git a/tests/test_store/test_fsspec_get_ranges.py b/tests/test_store/test_fsspec_get_ranges.py new file mode 100644 index 0000000000..a659540182 --- /dev/null +++ b/tests/test_store/test_fsspec_get_ranges.py @@ -0,0 +1,130 @@ +# tests/test_store/test_fsspec_get_ranges.py +"""Lightweight integration tests for FsspecStore.get_ranges using MemoryFileSystem. + +These don't need moto/s3 — they exercise the new method against an in-process +fsspec MemoryFileSystem wrapped in the async wrapper. +""" + +from __future__ import annotations + +import pytest +from packaging.version import parse as parse_version + +from zarr.abc.store import RangeByteRequest +from zarr.core._coalesce import DEFAULT_COALESCE_OPTIONS, CoalesceOptions +from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.storage import FsspecStore +from zarr.storage._fsspec import _make_async + +fsspec = pytest.importorskip("fsspec") + +# AsyncFileSystemWrapper (needed to wrap a sync MemoryFileSystem) landed in fsspec 2024.12.0. +# Older versions are pinned by the min-deps CI job, so skip the whole file there. +pytestmark = pytest.mark.skipif( + parse_version(fsspec.__version__) < parse_version("2024.12.0"), + reason="No AsyncFileSystemWrapper", +) + + +@pytest.fixture +def memory_store() -> FsspecStore: + """An FsspecStore backed by fsspec MemoryFileSystem (wrapped async).""" + from fsspec.implementations.memory import MemoryFileSystem + + # Each test gets a clean filesystem; MemoryFileSystem is a singleton per target_options, + # so clear state explicitly. + fs: MemoryFileSystem = MemoryFileSystem() + fs.store.clear() + fs.pseudo_dirs.clear() + async_fs = _make_async(fs) + return FsspecStore(fs=async_fs, path="/root") + + +async def _write(store: FsspecStore, key: str, data: bytes) -> None: + buf = default_buffer_prototype().buffer.from_bytes(data) + await store.set(key, buf) + + +async def test_get_ranges_happy_path(memory_store: FsspecStore) -> None: + blob = bytes(i % 256 for i in range(1024)) + await _write(memory_store, "blob", blob) + proto = default_buffer_prototype() + + ranges = [ + RangeByteRequest(0, 10), + RangeByteRequest(100, 110), + RangeByteRequest(500, 520), + ] + groups: list[list[tuple[int, Buffer | None]]] = [ + list(group) async for group in memory_store.get_ranges("blob", ranges, prototype=proto) + ] + + flat: dict[int, bytes] = {} + for group in groups: + for idx, buf in group: + assert buf is not None + flat[idx] = buf.to_bytes() + + assert flat[0] == blob[0:10] + assert flat[1] == blob[100:110] + assert flat[2] == blob[500:520] + + +async def test_get_ranges_missing_key_yields_nothing(memory_store: FsspecStore) -> None: + proto = default_buffer_prototype() + groups: list[list[tuple[int, Buffer | None]]] = [ + list(group) + async for group in memory_store.get_ranges( + "does-not-exist", [RangeByteRequest(0, 10)], prototype=proto + ) + ] + assert groups == [] + + +async def test_default_coalesce_options_on_store_without_arg() -> None: + from fsspec.implementations.memory import MemoryFileSystem + + fs = MemoryFileSystem() + fs.store.clear() + store = FsspecStore(fs=_make_async(fs), path="/x") + assert store.coalesce_options == DEFAULT_COALESCE_OPTIONS + + +async def test_coalesce_options_wired_through() -> None: + from fsspec.implementations.memory import MemoryFileSystem + + fs = MemoryFileSystem() + fs.store.clear() + custom: CoalesceOptions = { + "max_gap_bytes": 0, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 2, + } + store = FsspecStore(fs=_make_async(fs), path="/x", coalesce_options=custom) + assert store.coalesce_options == custom + + +async def test_get_ranges_mixed_range_types(memory_store: FsspecStore) -> None: + """Covers RangeByteRequest, OffsetByteRequest, SuffixByteRequest, and None in one call.""" + from zarr.abc.store import ByteRequest, OffsetByteRequest, SuffixByteRequest + + blob = bytes(i % 256 for i in range(512)) + await _write(memory_store, "mixed", blob) + proto = default_buffer_prototype() + + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 10), + OffsetByteRequest(500), + SuffixByteRequest(12), + None, + ] + flat: dict[int, bytes] = {} + async for group in memory_store.get_ranges("mixed", ranges, prototype=proto): + for idx, buf in group: + assert buf is not None + flat[idx] = buf.to_bytes() + + assert flat[0] == blob[0:10] + assert flat[1] == blob[500:] + assert flat[2] == blob[-12:] + assert flat[3] == blob diff --git a/tests/test_store/test_protocols.py b/tests/test_store/test_protocols.py new file mode 100644 index 0000000000..9dae19ffba --- /dev/null +++ b/tests/test_store/test_protocols.py @@ -0,0 +1,48 @@ +# tests/test_store/test_protocols.py +"""Runtime and static conformance tests for zarr.storage._protocols.SupportsGetRanges.""" + +from __future__ import annotations + +import pytest + +from zarr.storage._protocols import SupportsGetRanges + +fsspec = pytest.importorskip("fsspec") + +from packaging.version import parse as parse_version # noqa: E402 + +# AsyncFileSystemWrapper (needed to wrap a sync MemoryFileSystem) landed in fsspec 2024.12.0. +# Older versions are pinned by the min-deps CI job. +_needs_async_wrapper = pytest.mark.skipif( + parse_version(fsspec.__version__) < parse_version("2024.12.0"), + reason="No AsyncFileSystemWrapper", +) + + +@_needs_async_wrapper +def test_fsspec_store_satisfies_supports_get_ranges() -> None: + from fsspec.implementations.memory import MemoryFileSystem + + from zarr.storage import FsspecStore + from zarr.storage._fsspec import _make_async + + fs = MemoryFileSystem() + fs.store.clear() + store = FsspecStore(fs=_make_async(fs), path="/x") + assert isinstance(store, SupportsGetRanges) + + +def test_memory_store_does_not_satisfy_supports_get_ranges() -> None: + """Sanity check: stores that don't implement get_ranges shouldn't satisfy the protocol.""" + from zarr.storage import MemoryStore + + store = MemoryStore() + assert not isinstance(store, SupportsGetRanges) + + +def test_type_assignment_at_module_level() -> None: + """Smoke-test the module-level `_: type[SupportsGetRanges] = FsspecStore`. + + If this runs without error the module imported cleanly; the static check is in mypy. + """ + from zarr.storage import _fsspec # noqa: F401