From 904b3e15b626c672d610059d6313e7204ea74bf5 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 21 May 2026 19:48:00 -0500 Subject: [PATCH 1/3] PKRange Cache fix --- .../_routing/aio/routing_map_provider.py | 172 ++++++- .../routing/test_routing_map_provider.py | 128 ++++- .../test_routing_map_provider_async.py | 477 +++++++++++++++++- .../test_shared_pk_range_cache_async.py | 58 ++- 4 files changed, 762 insertions(+), 73 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index 4cfb429ab7e3..f23e20c072f5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -44,7 +44,7 @@ if TYPE_CHECKING: from ...aio._cosmos_client_connection_async import CosmosClientConnection -# Module-level shared state, keyed by endpoint URL. All four dicts and the +# Module-level shared state, keyed by endpoint URL. All five dicts and the # refcount are mutated only while holding ``_shared_cache_lock``. Sharing across # every async CosmosClient that targets the same endpoint is what eliminates # the per-client duplicate copies of the routing map (the memory win driving @@ -75,15 +75,24 @@ # and defeat the single-flight invariant. _shared_locks_locks: Dict[str, threading.Lock] = {} +# endpoint -> { (loop_id, collection_id) -> asyncio.Task }. The single +# in-flight fetch-and-publish task per (loop, collection). Any caller that +# arrives during a cold-cache fetch joins this task via ``asyncio.shield`` +# instead of issuing its own network round trip, so concurrent callers +# share a single fetch. The task body owns both the fetch and the cache +# write, so the publish survives any individual caller being cancelled +# (e.g. by ``asyncio.wait_for``) while awaiting it. +_shared_inflight_fetches: Dict[str, Dict[tuple, asyncio.Task]] = {} + # endpoint -> int. Number of live async ``PartitionKeyRangeCache`` instances # using this endpoint. Incremented on construction and decremented in # ``release`` (called from ``CosmosClient.__aexit__`` / ``close`` / ``__del__``). -# When the count hits zero we drop the entry from all four dicts so an idle +# When the count hits zero we drop the entry from all five dicts so an idle # endpoint does not pin memory forever. ``clear_cache`` does NOT touch this # count — it only wipes routing-map contents. _shared_cache_refcounts: Dict[str, int] = {} -# Process-wide lock guarding the four dicts above for *this* (async) module. +# Process-wide lock guarding the five dicts above for *this* (async) module. # Note: the sync module ``_routing/routing_map_provider.py`` defines its own # independent set of module-level dicts and its own ``_shared_cache_lock`` — # state is NOT shared between the sync and async modules. A sync and an async @@ -123,20 +132,23 @@ def __init__(self, client: Any): self._endpoint = _resolve_endpoint(client) self._released = False - # Share routing map cache, per-collection asyncio locks, and the - # per-endpoint meta-lock that guards the per-collection-lock dict - # across all clients with the same endpoint. Refcount lets us evict - # the entry when the last sharing client releases it (see ``release``). + # Share routing map cache, per-collection asyncio locks, the + # per-endpoint meta-lock that guards the per-collection-lock dict, + # and the in-flight fetch-task dict across all clients with the same + # endpoint. Refcount lets us evict the entry when the last sharing + # client releases it (see ``release``). with _shared_cache_lock: if self._endpoint not in _shared_routing_map_cache: _shared_routing_map_cache[self._endpoint] = {} _shared_collection_locks[self._endpoint] = {} _shared_locks_locks[self._endpoint] = threading.Lock() + _shared_inflight_fetches[self._endpoint] = {} _shared_cache_refcounts[self._endpoint] = 0 _shared_cache_refcounts[self._endpoint] += 1 self._collection_routing_map_by_item = _shared_routing_map_cache[self._endpoint] self._collection_locks: Dict[tuple, asyncio.Lock] = _shared_collection_locks[self._endpoint] self._locks_lock: threading.Lock = _shared_locks_locks[self._endpoint] + self._inflight_fetches: Dict[tuple, asyncio.Task] = _shared_inflight_fetches[self._endpoint] def clear_cache(self): """Clear the shared routing map cache for this endpoint. @@ -145,13 +157,13 @@ def clear_cache(self): client references to the same dict object, so concurrent clients sharing the endpoint continue to share a single cache instance. - The per-collection locks dict is intentionally **not** cleared here: - an in-flight ``_fetch_routing_map`` caller holds one of those locks - and will write its result into the (now-empty) shared cache when it - completes. Keeping the lock in place ensures that any concurrent - arrival serialises behind the in-flight refresh (single-flight - invariant) instead of racing it with a fresh lock. The locks dict - is evicted in ``release()`` once the endpoint refcount hits zero. + The per-collection locks dict and the in-flight fetch-task dict are + intentionally **not** cleared here. A fetch task scheduled before + this call keeps a reference to the (now-empty) cache dict and will + publish its result into it when it completes; any concurrent arrival + meanwhile joins that same task instead of racing it. Both auxiliary + dicts are evicted in ``release()`` once the endpoint refcount hits + zero. """ with _shared_cache_lock: if self._endpoint in _shared_routing_map_cache: @@ -180,6 +192,7 @@ def release(self) -> None: _shared_routing_map_cache.pop(endpoint, None) _shared_collection_locks.pop(endpoint, None) _shared_locks_locks.pop(endpoint, None) + _shared_inflight_fetches.pop(endpoint, None) else: _shared_cache_refcounts[endpoint] = count except Exception: # pylint: disable=broad-except @@ -267,9 +280,13 @@ async def get_routing_map( ) -> Optional[CollectionRoutingMap]: """Gets or refreshes the routing map for a collection. - This method handles the logic for fetching, caching, and updating the - collection's routing map. It uses a locking mechanism to prevent race - conditions during concurrent updates. + Concurrent callers that arrive while a fetch is already in flight for + the same collection join that fetch via ``asyncio.shield`` rather than + issuing their own network round trip. The fetch task owns the cache + write, so the publish completes even if every awaiting caller is + cancelled (for example by ``asyncio.wait_for``) before the fetch + returns. The next caller — whether the original caller retrying or a + new one — finds the cache populated. :param str collection_link: The link to the collection. :param Optional[Dict[str, Any]] feed_options: Optional query options. @@ -281,37 +298,136 @@ async def get_routing_map( """ collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - # First check (no lock) for the fast path. + # Fast path: cache hit without acquiring any lock. if not force_refresh: cached_map = self._collection_routing_map_by_item.get(collection_id) if cached_map: return cached_map - # Acquire lock only when a refresh or initial load is likely needed. + fetch_task = await self._register_or_join_inflight_fetch( + collection_id, + collection_link, + feed_options, + force_refresh, + previous_routing_map, + kwargs, + ) + + if fetch_task is not None: + # ``shield`` ensures our cancellation only unwinds *this* awaiter; + # the underlying task keeps running on the event loop and the + # cache write inside the task body still happens. Other waiters + # (and any subsequent caller hitting the now-populated cache) are + # unaffected by our cancellation. + await asyncio.shield(fetch_task) + + return self._collection_routing_map_by_item.get(collection_id) + + async def _register_or_join_inflight_fetch( + self, + collection_id: str, + collection_link: str, + feed_options: Optional[Dict[str, Any]], + force_refresh: bool, + previous_routing_map: Optional[CollectionRoutingMap], + fetch_kwargs: Dict[str, Any], + ) -> Optional[asyncio.Task]: + """Return the in-flight fetch task for this collection, creating one if needed. + + Holding the per-collection lock for just the check-or-create window + (no network I/O inside the lock) keeps the critical section small. + The returned task may be one this call just scheduled or one a + concurrent caller scheduled moments earlier — either way the caller + should await it through ``asyncio.shield``. + + :param str collection_id: The resolved collection identifier used as the cache key. + :param str collection_link: The link to the collection. + :param Optional[Dict[str, Any]] feed_options: Optional query options. + :param bool force_refresh: Whether the caller asked for a refresh. + :param Optional[CollectionRoutingMap] previous_routing_map: The caller's last + observed routing map, used by the refresh-decision helper. + :param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch. + :return: A running ``asyncio.Task`` to await, or ``None`` if no fetch + is needed (cache was populated by a concurrent caller after the + fast-path check). + :rtype: Optional[asyncio.Task] + """ + inflight_key = (id(asyncio.get_running_loop()), collection_id) collection_lock = await self._get_lock_for_collection(collection_id) async with collection_lock: - # Second check (with lock) — use shared helper for the decision logic. + existing_task = self._inflight_fetches.get(inflight_key) + if existing_task is not None: + return existing_task + should_fetch, base_routing_map = determine_refresh_action( self._collection_routing_map_by_item, collection_id, force_refresh, previous_routing_map, ) + if not should_fetch: + return None - if should_fetch: - new_routing_map = await self._fetch_routing_map( - collection_link, + new_task = asyncio.create_task( + self._fetch_and_publish( collection_id, + collection_link, base_routing_map, feed_options, - **kwargs + inflight_key, + fetch_kwargs, ) + ) + self._inflight_fetches[inflight_key] = new_task + return new_task + + async def _fetch_and_publish( + self, + collection_id: str, + collection_link: str, + base_routing_map: Optional[CollectionRoutingMap], + feed_options: Optional[Dict[str, Any]], + inflight_key: tuple, + fetch_kwargs: Dict[str, Any], + ) -> Optional[CollectionRoutingMap]: + """Run ``_fetch_routing_map`` and publish its result, then free the in-flight slot. + + The cache assignment lives inside this task body so a caller's + cancellation while awaiting the task cannot interrupt the publish. + The ``finally`` block always frees the in-flight slot — on success, + on a fetch error, or on cancellation — so the next caller is free to + schedule a fresh attempt. - # Update the cache. - if new_routing_map: - self._collection_routing_map_by_item[collection_id] = new_routing_map + :param str collection_id: The resolved collection identifier used as the cache key. + :param str collection_link: The link to the collection. + :param Optional[CollectionRoutingMap] base_routing_map: The base routing map + for incremental updates, or ``None`` for a full load. + :param Optional[Dict[str, Any]] feed_options: Optional query options. + :param tuple inflight_key: The ``(loop_id, collection_id)`` key into the in-flight dict. + :param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch. + :return: The new routing map, or ``None`` if the fetch produced nothing. + :rtype: Optional[CollectionRoutingMap] + """ + try: + new_routing_map = await self._fetch_routing_map( + collection_link, + collection_id, + base_routing_map, + feed_options, + **fetch_kwargs, + ) - return self._collection_routing_map_by_item.get(collection_id) + if new_routing_map: + self._collection_routing_map_by_item[collection_id] = new_routing_map + + return new_routing_map + finally: + # Atomic single-key removal; no lock needed. Runs on success, + # on fetch error, and on cancellation alike, so the next caller + # can register a fresh fetch immediately. + inflight_fetches = self._inflight_fetches + if inflight_key in inflight_fetches: + del inflight_fetches[inflight_key] async def _fetch_routing_map( diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index 56f6637ff454..793831c48ef2 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -1,19 +1,19 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import threading +import time import unittest +from typing import Optional, Mapping, Any +from unittest.mock import MagicMock import pytest +from azure.cosmos import _base, http_constants from azure.cosmos._routing import routing_range as routing_range from azure.cosmos._routing.routing_map_provider import CollectionRoutingMap from azure.cosmos._routing.routing_map_provider import SmartRoutingMapProvider from azure.cosmos._routing.routing_map_provider import PartitionKeyRangeCache -from azure.cosmos import http_constants - -from typing import Optional, Mapping, Any -from unittest.mock import MagicMock -import threading @pytest.mark.cosmosEmulator class TestRoutingMapProvider(unittest.TestCase): @@ -214,7 +214,6 @@ def test_get_routing_map_caches_on_first_call(self): self.assertIsNotNone(result) self.assertEqual(len(list(result._orderedPartitionKeyRanges)), 5) # Verify it's cached - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) self.assertIn(collection_id, provider._collection_routing_map_by_item) @@ -246,6 +245,55 @@ def user_hook(headers, _): self.assertEqual(result.change_feed_etag, expected_internal_etag) self.assertEqual(hook_calls, ['"user-hook-etag"']) + def test_get_routing_map_tight_timeout_kwarg_still_populates_cache(self): + """Sync path forwards timeout kwarg and populates cache on successful fetch.""" + # The sync side of the PK-range cache work is much narrower than the + # async side: sync has no asyncio cancellation channel, so the + # "wait_for kills the fetch mid-flight" failure mode doesn't apply. + # What CAN reach the cache on sync is the `timeout=` kwarg + # + # This test covers the happy path of that flow: + # 1. Customer calls `get_routing_map(..., timeout=0.001)`. + # 2. The cache layer forwards the kwarg to the underlying read + # (verified by inspecting what the mock saw). + # 3. The fetch completes successfully (the mock returns instantly + # without honouring the tiny timeout). + # 4. The result lands in the cache as normal. + # 5. A second call hits the cache fast-path with no new fetch. + call_count = {'count': 0} + # We record the timeout the mock saw, to prove the kwargs path is + # intact end-to-end (cache layer didn't silently drop it). + seen_timeout = {'value': None} + original_ranges = self.partition_key_ranges + + class TimeoutAwareClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + seen_timeout['value'] = kwargs.get('timeout') + TestRoutingMapProvider._capture_internal_headers(kwargs, '"timeout-etag"') + return original_ranges + + provider = PartitionKeyRangeCache(TimeoutAwareClient()) + collection_link = "dbs/db/colls/container" + + # === Step 1: first call with a tight timeout kwarg. The mock returns + # instantly so the timeout doesn't actually fire; the fetch succeeds. + result1 = provider.get_routing_map(collection_link, feed_options={}, timeout=0.001) + self.assertIsNotNone(result1) + # === Step 2: verify the cache layer forwarded the timeout kwarg + # down to the mock. If this is None, the kwargs path is broken. + self.assertEqual(seen_timeout['value'], 0.001) + self.assertEqual(call_count['count'], 1) + + # === Step 3: confirm the routing map landed in the cache as normal. + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + self.assertIn(collection_id, provider._collection_routing_map_by_item) + + # === Step 4: second call (no timeout). Cache hit, no extra fetch. + result2 = provider.get_routing_map(collection_link, feed_options={}) + self.assertIs(result2, result1) + self.assertEqual(call_count['count'], 1) + def test_get_routing_map_returns_cached_on_second_call(self): """Second call returns the same cached object without re-fetching.""" call_count = {'count': 0} @@ -313,7 +361,6 @@ def test_is_cache_stale_etag_logic(self): TestRoutingMapProvider.MockedCosmosClientConnection(self.partition_key_ranges) ) collection_link = "dbs/db/colls/container" - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) # Populate cache @@ -348,7 +395,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return incomplete_ranges provider = PartitionKeyRangeCache(IncompleteClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -385,7 +431,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return delta_ranges provider = PartitionKeyRangeCache(DeltaClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -433,7 +478,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return full_ranges provider = PartitionKeyRangeCache(HeaderCapturingClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -491,7 +535,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return delta_ranges provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -538,7 +581,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return delta_ranges provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -610,7 +652,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return full_ranges provider = PartitionKeyRangeCache(RapidSplitClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -666,7 +707,6 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) return delta_ranges provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -724,8 +764,26 @@ def test_concurrent_refresh_serialized_by_lock(self): With `and`, only the first thread that finds the cache stale actually fetches. Subsequent threads see the updated ETag and skip the redundant fetch. """ + + # Sync mirror of the async concurrent-refresh test. On sync, the + # cache uses a `threading.Lock` (not an asyncio Lock) to serialise + # concurrent refreshes — this is the same lock that prevents the + # gateway-side stampede where every cold-cache caller would fire + # its own concurrent fetch. + # + # Even if 5 threads all decide to `force_refresh` at the same moment, + # the lock makes sure they take turns, and the double-checked ETag + # logic short-circuits the second-through-fifth threads once the + # first one has already done the refresh. + # + # The test forces all this contention by gating the mock client with + # a `threading.Event`, then releases the gate and verifies that + # every thread came out the other side with a valid result and no + # exceptions. call_count = {'count': 0} original_ranges = self.partition_key_ranges + # A threading.Event so we can pin the mock client mid-fetch while + # the contention builds up, then release them all at once. fetch_event = threading.Event() class SlowCountingClient: @@ -739,10 +797,14 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) provider = PartitionKeyRangeCache(SlowCountingClient()) collection_link = "dbs/db/colls/container" - # Populate cache with initial map + # === Step 1: populate the cache with an initial map. Open the gate + # so this first load isn't slow. fetch_event.set() # Let the initial load go fast initial_map = provider.get_routing_map(collection_link, feed_options={}) self.assertEqual(call_count['count'], 1) + # === Step 2: close the gate. Subsequent fetches will park inside + # the mock until we open it again — guaranteeing the contention + # window stays open while threads pile up. fetch_event.clear() # Now make subsequent fetches slow results = [None] * 5 @@ -757,21 +819,22 @@ def thread_fn(idx): except Exception as e: errors.append(e) + # === Step 3: launch 5 OS threads all calling force_refresh at once. threads = [threading.Thread(target=thread_fn, args=(i,)) for i in range(5)] for t in threads: t.start() - # Give threads time to all start and contend on the lock - import time + # === Step 4: give them time to all start and stack up on the lock. time.sleep(0.2) - # Release the slow fetch + # Release the slow fetch so the queued threads can drain. fetch_event.set() for t in threads: t.join(timeout=10) + # === Step 5: contract — no thread crashed, all 5 came back with a + # valid map. self.assertEqual(len(errors), 0, f"Threads raised errors: {errors}") - # All threads should get a non-None result for i, r in enumerate(results): self.assertIsNotNone(r, f"Thread {i} got None") @@ -781,23 +844,39 @@ def test_cache_never_none_during_refresh(self): The cache entry is atomically replaced, never deleted. This test verifies that concurrent readers always see either the old valid map or the new valid map. """ + + # Sync mirror of the async "cache never None" test. Same property + # applies: a concurrent reader (running on its own thread, hitting + # the cache fast path) must never observe None while a refresher + # thread is replacing the map. The implementation must use atomic + # dict assignment (`cache[key] = new_map`), never delete-then- + # reinsert. + # + # If a reader ever did see None, it would conclude the cache was + # cold and trigger its own fetch — a needless extra HTTP request, + # multiplied by however many readers happened to look at the wrong + # microsecond. Across many readers under load this would compound + # into the same kind of stampede the lock was added to prevent. original_ranges = self.partition_key_ranges call_count = {'count': 0} class SlowClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - import time + # Artificial 100 ms delay so the refresher is provably mid- + # flight while the reader is polling — without it, the + # refresh might complete before the reader observes even + # one read. time.sleep(0.1) # Simulate network delay TestRoutingMapProvider._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') return original_ranges provider = PartitionKeyRangeCache(SlowClient()) collection_link = "dbs/db/colls/container" - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - # Populate cache + # === Step 1: populate the cache so the reader has something + # non-None to observe before, during, and after the refresh. initial_map = provider.get_routing_map(collection_link, feed_options={}) self.assertIsNotNone(initial_map) @@ -818,6 +897,9 @@ def refresher_fn(): force_refresh=True, previous_routing_map=initial_map ) + # === Step 2: start both threads. Reader spins, refresher does its + # one slow refresh and exits. Once refresher is done we tell reader + # to stop and join it. reader = threading.Thread(target=reader_fn) refresher = threading.Thread(target=refresher_fn) @@ -827,6 +909,8 @@ def refresher_fn(): stop_event.set() reader.join(timeout=5) + # === Step 3: the assertion. Reader saw the slot transition from + # old map -> new map without ever observing a None intermediate. self.assertEqual(none_seen['count'], 0, "Cache entry should never be None during a refresh — it should be atomically replaced") diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py index 5d7408bb6216..f73748963609 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py @@ -1,15 +1,17 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio # pylint: disable=do-not-import-asyncio import unittest import pytest +from azure.cosmos import _base, http_constants from azure.cosmos._routing import routing_range as routing_range from azure.cosmos._routing.aio.routing_map_provider import CollectionRoutingMap from azure.cosmos._routing.aio.routing_map_provider import SmartRoutingMapProvider from azure.cosmos._routing.aio.routing_map_provider import PartitionKeyRangeCache -from azure.cosmos import http_constants +from azure.cosmos.exceptions import CosmosHttpResponseError from typing import Optional, Mapping, Any from unittest.mock import MagicMock @@ -47,9 +49,14 @@ async def _gen(): return _gen() def tearDown(self): - from azure.cosmos._routing.aio.routing_map_provider import _shared_routing_map_cache, _shared_cache_lock + from azure.cosmos._routing.aio.routing_map_provider import ( + _shared_routing_map_cache, + _shared_inflight_fetches, + _shared_cache_lock, + ) with _shared_cache_lock: _shared_routing_map_cache.clear() + _shared_inflight_fetches.clear() def setUp(self): self.partition_key_ranges = [ @@ -180,7 +187,6 @@ async def test_get_routing_map_caches_on_first_call_async(self): self.assertIsNotNone(result) self.assertEqual(len(list(result._orderedPartitionKeyRanges)), 5) - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) self.assertIn(collection_id, provider._collection_routing_map_by_item) @@ -294,7 +300,6 @@ async def test_is_cache_stale_etag_logic_async(self): TestRoutingMapProviderAsync.MockedCosmosClientConnection(self.partition_key_ranges) ) collection_link = "dbs/db/colls/container" - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) cached_map = await provider.get_routing_map(collection_link, feed_options={}) @@ -333,7 +338,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(IncompleteClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -373,7 +377,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(DeltaClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -424,7 +427,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(HeaderCapturingClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -486,7 +488,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -538,7 +539,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -609,7 +609,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(RapidSplitClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -670,7 +669,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -731,9 +729,21 @@ async def test_concurrent_refresh_serialized_by_lock_async(self): Verifies that coroutines don't corrupt the cache and all get a valid result. """ - import asyncio + # The cache uses a small per-collection lock to make sure that even + # under concurrent `force_refresh=True` storms (e.g. lots of 410 + # responses arriving at once), the cache state stays consistent and + # no caller gets back garbage. + # + # We do not assert that exactly one fetch happens — the production + # code allows the first refresh to populate, and subsequent contending + # refreshes may either skip (if they see the new ETag) or proceed. + # The contract this test pins down is the weaker but essential one: + # nothing crashes, nothing corrupts, and every caller gets back a + # valid routing map. call_count = {'count': 0} original_ranges = self.partition_key_ranges + # A gate so we can force contention: refreshes will all pile up + # waiting here, then we open the gate and let them race. fetch_event = asyncio.Event() class SlowCountingClient: @@ -751,10 +761,13 @@ async def _gen(): provider = PartitionKeyRangeCache(SlowCountingClient()) collection_link = "dbs/db/colls/container" - # Populate cache with initial map (let it go fast) + # === Step 1: populate the cache with a known initial map. Let the + # gate be open so this initial load isn't slow. fetch_event.set() initial_map = await provider.get_routing_map(collection_link, feed_options={}) self.assertEqual(call_count['count'], 1) + # === Step 2: close the gate so subsequent fetches will block, + # giving the concurrent callers time to all queue up. fetch_event.clear() async def refresh_fn(): @@ -763,15 +776,19 @@ async def refresh_fn(): force_refresh=True, previous_routing_map=initial_map ) - # Launch 5 concurrent refresh coroutines + # === Step 3: launch 5 concurrent refresh coroutines. With the gate + # closed, they'll all pile up at the lock and/or the fetch event. tasks = [asyncio.create_task(refresh_fn()) for _ in range(5)] + # Yield so they all reach their parked state before we release. await asyncio.sleep(0.1) + # Now open the gate and let them all proceed. fetch_event.set() results = await asyncio.gather(*tasks) - # All coroutines should get a non-None result + # === Step 4: contract — every coroutine got back a non-None result. + # Concurrency didn't corrupt anyone's view of the cache. for i, r in enumerate(results): self.assertIsNotNone(r, f"Coroutine {i} got None") @@ -780,7 +797,19 @@ async def test_cache_never_none_during_refresh_async(self): The cache entry is atomically replaced, never deleted. """ - import asyncio + # Important property for fast-path readers: the cache slot is + # ALWAYS either the old map or the new map — it is never + # transiently set to None while a refresh is in flight. + # + # If the refresh code did `del cache[key]; cache[key] = new_map`, + # there would be a window where a concurrent fast-path reader could + # observe `None` and incorrectly conclude the cache is cold, which + # would trigger a redundant fetch storm. The fix uses atomic + # replacement, so readers always see a valid map. + # + # We verify this by spinning a reader coroutine that polls the + # cache slot continuously while a refresh runs, and asserting it + # never once observed None. original_ranges = self.partition_key_ranges call_count = {'count': 0} @@ -789,6 +818,8 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) call_count['count'] += 1 async def _gen(): + # A small artificial delay so the refresher is provably + # in flight while the reader is polling. await asyncio.sleep(0.05) TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') for r in original_ranges: @@ -798,13 +829,15 @@ async def _gen(): provider = PartitionKeyRangeCache(SlowClient()) collection_link = "dbs/db/colls/container" - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - # Populate cache + # === Step 1: populate the cache so the reader has something + # non-None to observe between refresh windows. initial_map = await provider.get_routing_map(collection_link, feed_options={}) self.assertIsNotNone(initial_map) + # === Step 2: set up the reader. It polls the cache as fast as the + # loop will let it, counting any None observations as failures. none_seen = {'count': 0} stop_event = asyncio.Event() @@ -813,6 +846,7 @@ async def reader_fn(): cached = provider._collection_routing_map_by_item.get(collection_id) if cached is None: none_seen['count'] += 1 + # Yield so this loop doesn't monopolise the event loop. await asyncio.sleep(0) async def refresher_fn(): @@ -821,14 +855,421 @@ async def refresher_fn(): force_refresh=True, previous_routing_map=initial_map ) + # === Step 3: start the reader, then do the refresh, then stop the reader. reader_task = asyncio.create_task(reader_fn()) await refresher_fn() stop_event.set() await reader_task + # === Step 4: contract — the reader saw the slot transition from + # old map to new map without ever observing a None intermediate. self.assertEqual(none_seen['count'], 0, "Cache entry should never be None during a refresh — it should be atomically replaced") + async def test_cache_populated_when_originating_caller_is_cancelled_async(self): + """Cancelling the originating caller mid-fetch must NOT prevent the cache write. + + Reproduces the failure mode where a customer's ``asyncio.wait_for`` + deadline expires while the routing-map fetch is in flight. The fetch + task runs independently of the caller, owns the cache assignment, and + completes successfully so the cache ends up populated. + """ + # The bug this test pins down: in the old code, the routing-map fetch + # ran on the customer's call stack. If the customer wrapped their call + # in `asyncio.wait_for(..., timeout=2)` and the timeout fired mid-fetch, + # the CancelledError tore down the fetch, skipped the cache-write line + # that lived right after the `await`, and the cache stayed empty. Every + # retry repeated the same doomed sequence. + # + # The fix: move the fetch + cache-write into a shared task per + # collection, and have callers wait on it through `asyncio.shield`. Now + # when the caller is cancelled, only the *waiter* unwinds — the task + # itself keeps running on the event loop, finishes the fetch, and + # writes the result into the cache before returning. + # + # This test reproduces the cancellation, then verifies the cache *does* + # get populated once the gated fetch completes — even though nobody is + # awaiting it anymore. + original_ranges = self.partition_key_ranges + # A gate we control: the fetch will block here until we set the event. + # This lets us guarantee the fetch is still in flight at the moment + # the customer's deadline fires. + fetch_gate = asyncio.Event() + call_count = {'count': 0} + + class SlowClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + + async def _gen(): + # Park here — simulates a slow HTTP round trip the customer + # won't wait long enough for. + await fetch_gate.wait() + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"slow-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(SlowClient()) + collection_link = "dbs/db/colls/container" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + + # === Step 1: customer's wait_for fires before the fetch can complete. + # Without the fix this is where the publish would be lost forever. + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + provider.get_routing_map(collection_link, feed_options={}), + timeout=0.05, + ) + + # === Step 2: at this instant the cache MUST still be empty. + # The fetch is gated — it hasn't run to completion yet, so the + # in-flight task hasn't had a chance to write to the cache. If anything + # were in the cache here it would mean the fetch shortcut something, + # and the test wouldn't actually be exercising the fix. + self.assertIsNone(provider._collection_routing_map_by_item.get(collection_id)) + + # === Step 3: let the gated fetch complete. + # The originating caller is long gone (raised TimeoutError above), but + # the shared task is still alive on the event loop. We open the gate + # so it can finish its work, then poll until the cache slot fills in. + fetch_gate.set() + for _ in range(100): + if provider._collection_routing_map_by_item.get(collection_id) is not None: + break + await asyncio.sleep(0.01) + + # === Step 4: cache must now be populated — this is the fix in action. + # Even though the caller that triggered the fetch was cancelled, the + # task survived, the publish ran inside the task, and the routing map + # made it into the cache. + populated = provider._collection_routing_map_by_item.get(collection_id) + self.assertIsNotNone(populated, + "Cache must be populated after the gated fetch completes") + self.assertEqual(len(list(populated._orderedPartitionKeyRanges)), len(original_ranges)) + self.assertEqual(call_count['count'], 1, + "Exactly one fetch should have been issued") + + # === Step 5: the customer's retry now hits a populated cache. + # No new HTTP fetch. The whole point of the fix — the second attempt + # gets the work that the first attempt's fetch finished after timeout. + result = await provider.get_routing_map(collection_link, feed_options={}) + self.assertIs(result, populated) + self.assertEqual(call_count['count'], 1) + + async def test_cache_populated_when_cancelled_with_timeout_kwarg_async(self): + """Caller cancellation + timeout kwarg still results in cache population.""" + # This is the previous test's companion. It pins down the same fix + # behaviour (cache must still populate after the originating caller is + # cancelled) but covers the case where the customer *also* passed a + # `timeout=N` keyword argument — i.e. both timeout mechanisms are in + # play at once: + # + # - the asyncio cancellation (from wait_for), AND + # - the kwargs timeout (a plain Python kwarg the HTTP layer reads). + # + # The kwargs timeout still gets forwarded to the underlying call (we + # verify the mock saw it). The point is that even in this combined + # scenario, the shared-task fix still wins: the task keeps + # running after the caller times out, finishes the fetch, and the + # cache ends up populated. + original_ranges = self.partition_key_ranges + fetch_gate = asyncio.Event() + call_count = {'count': 0} + # We capture the timeout the mock client actually saw, to prove the + # kwargs path is intact end-to-end (not silently dropped before reaching + # the underlying read). + seen_timeout_kwarg = {'value': None} + + class SlowClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + # Record what the cache layer actually forwarded. + seen_timeout_kwarg['value'] = kwargs.get('timeout') + + async def _gen(): + # Gate again — fetch won't complete until we say so. + await fetch_gate.wait() + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"slow-timeout-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(SlowClient()) + collection_link = "dbs/db/colls/container" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + + # === Step 1: customer call dies via wait_for cancellation. + # Note both timeouts are present: the inner kwargs `timeout=0.001` + # (which the cache forwards to the mock) AND the outer + # `wait_for(..., timeout=0.05)` that asynchronously cancels the caller. + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + provider.get_routing_map(collection_link, feed_options={}, timeout=0.001), + timeout=0.05, + ) + + # Sanity-check: the cache layer really did forward the kwargs timeout + # down to the underlying read. If this is None it means the kwargs + # path is broken, regardless of whether the cache populates. + self.assertEqual(seen_timeout_kwarg['value'], 0.001) + # Cache still empty — fetch is gated, hasn't published yet. + self.assertIsNone(provider._collection_routing_map_by_item.get(collection_id)) + + # === Step 2: let the gated fetch finish, then poll. + fetch_gate.set() + for _ in range(100): + if provider._collection_routing_map_by_item.get(collection_id) is not None: + break + await asyncio.sleep(0.01) + + # === Step 3: cache must be populated. Same property as the previous + # test — the orphaned task lived past the caller's cancellation. + populated = provider._collection_routing_map_by_item.get(collection_id) + self.assertIsNotNone(populated) + self.assertEqual(call_count['count'], 1) + + # === Step 4: retry hits the populated cache, no second fetch. + result = await provider.get_routing_map(collection_link, feed_options={}) + self.assertIs(result, populated) + self.assertEqual(call_count['count'], 1) + + async def test_concurrent_cold_cache_callers_share_a_single_fetch_async(self): + """Concurrent cold-cache callers must coalesce onto one fetch task.""" + # This pins down the "one shared task per container, not one per + # caller" property. The bug it guards against: if every + # cold-cache caller spawned its own fetch task, 10 simultaneous + # callers would each fire their own HTTP request at the gateway — a + # gateway-side stampede. + # + # The fix uses an in-flight-fetches dict: the first caller creates the + # task and stores it; later callers find it there and join the same + # task instead of starting a new one. + # + # We verify both halves of the property: + # 1. The mock is called exactly ONCE even though 10 callers arrived + # cold and concurrently. + # 2. All 10 callers receive the SAME routing-map object (proving + # they really joined one task, didn't each get their own copy). + original_ranges = self.partition_key_ranges + # Gate the fetch so all 10 callers have time to arrive and join the + # in-flight task before any of them can succeed. Without the gate, + # the first caller might finish so quickly that the others arrive + # AFTER the task is done — which would be a different (cache-hit) + # code path, not the shared-task path we're testing here. + fetch_gate = asyncio.Event() + call_count = {'count': 0} + + class SlowClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + + async def _gen(): + await fetch_gate.wait() + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"shared-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(SlowClient()) + collection_link = "dbs/db/colls/container" + + async def caller(): + return await provider.get_routing_map(collection_link, feed_options={}) + + # Fire all 10 callers as concurrent tasks. Each one independently + # finds the cache empty and goes down the slow path; the in-flight + # dict is what makes them coalesce. + tasks = [asyncio.create_task(caller()) for _ in range(10)] + # Yield so every caller has a chance to enter the slow path and + # either create the in-flight task (one of them) or find it and + # join (the other nine). Without this yield we'd race the + # gate-set below and might not get the contention we're testing. + await asyncio.sleep(0.05) + # Now let the (single) fetch complete. + fetch_gate.set() + results = await asyncio.gather(*tasks) + + # Critical assertion: the mock was called ONCE, not 10 times. + # This is the whole point of the in-flight dict. + self.assertEqual(call_count['count'], 1, + "All 10 concurrent cold-cache callers should share one fetch") + # And every caller got the same object back — proving they all + # awaited the same shared task, not 10 separately-scheduled fetches + # that happened to return equivalent data. + first = results[0] + self.assertIsNotNone(first) + for r in results[1:]: + self.assertIs(r, first, "All callers should observe the same routing map object") + + async def test_waiter_joining_after_originator_cancelled_gets_result_async(self): + """A waiter that joins after the originating caller is cancelled still receives the fetched map.""" + # The trickiest property of the shared-task fix: the originating caller + # (the one who created the in-flight task) can be cancelled at any + # point, but a *later* caller arriving while the fetch is still + # running must successfully join that same task and receive its + # result. The cancellation of the originator can't take the task + # down with it (that's what `asyncio.shield` guarantees). + # + # Scenario walked through: + # 1. Originator starts → registers the in-flight task → parks. + # 2. Originator is cancelled before the fetch can finish. + # 3. A NEW caller (the "waiter") arrives. The fetch is still + # running on the loop. The waiter finds the task in the + # in-flight dict and awaits it. + # 4. We open the fetch gate. The task completes successfully. + # 5. The waiter wakes up with the routing map. + # + # The mock must show ONE call total — the waiter joined, didn't + # start a fresh fetch. + original_ranges = self.partition_key_ranges + fetch_gate = asyncio.Event() + call_count = {'count': 0} + + class SlowClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + + async def _gen(): + await fetch_gate.wait() + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"join-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(SlowClient()) + collection_link = "dbs/db/colls/container" + + # === Step 1: kick off the originator as its own task so we can + # cancel it explicitly without bringing down the test coroutine. + originator = asyncio.create_task( + provider.get_routing_map(collection_link, feed_options={}) + ) + # Yield twice — once for the originator to be scheduled, once for + # it to enter the slow path and register the in-flight task in the + # dict. If we cancel before that registration happens, the waiter + # below won't find anything to join and will start its own fetch. + await asyncio.sleep(0) + await asyncio.sleep(0) + + # === Step 2: cancel the originator. The shared task it created + # should keep running on the event loop (still parked on the gate). + originator.cancel() + with self.assertRaises(asyncio.CancelledError): + await originator + + # === Step 3: a NEW caller arrives — the waiter. The originator is + # gone, but the in-flight task is still alive. The waiter should + # find that task and join it. + waiter = asyncio.create_task( + provider.get_routing_map(collection_link, feed_options={}) + ) + # Yield so the waiter has time to enter the slow path and find the + # already-registered in-flight task. + await asyncio.sleep(0.01) + # === Step 4: now let the gated fetch complete. The waiter is + # awaiting on the shared task; when the task finishes, the waiter + # wakes up with the result. + fetch_gate.set() + result = await waiter + + # === Step 5: the waiter received a real routing map (not None, + # not an exception inherited from the cancelled originator). + self.assertIsNotNone(result) + # And critically: only ONE underlying fetch happened. The waiter + # joined the originator's task, didn't start a separate one. + self.assertEqual(call_count['count'], 1, + "Waiter should join the in-flight task, not start a new fetch") + self.assertEqual(len(list(result._orderedPartitionKeyRanges)), len(original_ranges)) + + async def test_failed_fetch_clears_inflight_slot_so_next_caller_retries_async(self): + """When a fetch fails, the in-flight slot is freed and the next caller can retry.""" + # The shared-task fix relies on the `finally` block inside the fetch + # task to remove its entry from the in-flight dict — REGARDLESS of + # whether the fetch succeeded or raised. If a failed fetch left a + # dead task in the dict, the next caller would find that dead task + # and await it forever (or get back the same stale exception). + # + # This test simulates the failure case: + # 1. First fetch raises CosmosHttpResponseError (simulated 500). + # 2. The caller sees the exception propagate out — expected. + # 3. The in-flight dict slot must be EMPTY now, so a fresh attempt + # can be registered. + # 4. A second caller arrives, finds an empty slot, registers a + # brand-new fetch, and that one succeeds. + original_ranges = self.partition_key_ranges + call_count = {'count': 0} + + class FlakyClient: + """First call raises, second call returns valid ranges.""" + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + attempt = call_count['count'] + + async def _gen(): + if attempt == 1: + # Simulate the kind of transient backend error that + # would cause the fetch task to raise — and bring the + # whole publish path down with it. + raise CosmosHttpResponseError(status_code=500, message="simulated transient failure") + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"retry-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(FlakyClient()) + collection_link = "dbs/db/colls/container" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + # The in-flight dict is keyed by (loop_id, collection_id) so the same + # cache can be safely reused across different event loops. + inflight_key = (id(asyncio.get_running_loop()), collection_id) + + # === Step 1: first call. Mock raises; expect the exception to + # propagate out to us (the awaiting caller). + with self.assertRaises(CosmosHttpResponseError): + await provider.get_routing_map(collection_link, feed_options={}) + + # === Step 2: the critical assertion. The failed task's `finally` + # block should have removed itself from the in-flight dict. If this + # fails, the next caller would be stuck awaiting a dead task. + self.assertNotIn(inflight_key, provider._inflight_fetches, + "Failed fetch should free the in-flight slot") + + # === Step 3: a fresh caller arrives. Because the slot is empty, + # they register a brand-new fetch — that's `call_count` going to 2. + # And this attempt succeeds. + result = await provider.get_routing_map(collection_link, feed_options={}) + self.assertIsNotNone(result) + self.assertEqual(call_count['count'], 2, + "Second attempt should issue a brand-new fetch") + + async def test_inflight_slot_freed_after_successful_fetch_async(self): + """The in-flight slot must be empty after a successful fetch completes.""" + # The companion to the previous test: cleanup must also happen on + # the SUCCESS path. If it only happened on failure, every successful + # fetch would leave a stale `done` task in the in-flight dict, and + # the dict would grow unbounded over the lifetime of the client. + # + # We do exactly one successful fetch, then check the dict is empty. + provider = PartitionKeyRangeCache( + TestRoutingMapProviderAsync.MockedCosmosClientConnection(self.partition_key_ranges) + ) + collection_link = "dbs/db/colls/container" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + inflight_key = (id(asyncio.get_running_loop()), collection_id) + + # Do a normal, successful fetch. + await provider.get_routing_map(collection_link, feed_options={}) + # Slot must be cleaned up. If this fails it means the `finally` block + # is only running on the failure path, not on success. + self.assertNotIn(inflight_key, provider._inflight_fetches, + "Successful fetch should free the in-flight slot") + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py index bfaa10947a2d..4377d4cc69ae 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py @@ -9,6 +9,7 @@ the same class in both sync and async paths. """ +import asyncio # pylint: disable=do-not-import-asyncio import unittest import pytest @@ -31,18 +32,20 @@ def __init__(self, url_connection): class TestSharedPartitionKeyRangeCacheAsync(unittest.IsolatedAsyncioTestCase): def tearDown(self): - # Wipe ALL four shared-cache globals between unit tests, not just - # the routing-map dict, so refcount and lock state stay consistent - # for tests that exercise lifecycle behavior. + # Wipe ALL shared-cache globals between unit tests, not just + # the routing-map dict, so refcount, lock, and in-flight-task + # state stay consistent for tests that exercise lifecycle behavior. from azure.cosmos._routing.aio.routing_map_provider import ( _shared_collection_locks, _shared_locks_locks, + _shared_inflight_fetches, _shared_cache_refcounts, ) with _shared_cache_lock: _shared_routing_map_cache.clear() _shared_collection_locks.clear() _shared_locks_locks.clear() + _shared_inflight_fetches.clear() _shared_cache_refcounts.clear() async def test_same_endpoint_shares_cache_async(self): @@ -113,12 +116,14 @@ def tearDown(self): from azure.cosmos._routing.aio.routing_map_provider import ( _shared_collection_locks, _shared_locks_locks, + _shared_inflight_fetches, _shared_cache_refcounts, ) with _shared_cache_lock: _shared_routing_map_cache.clear() _shared_collection_locks.clear() _shared_locks_locks.clear() + _shared_inflight_fetches.clear() _shared_cache_refcounts.clear() def _refcount(self, endpoint): @@ -140,16 +145,19 @@ async def test_release_evicts_at_zero_async(self): from azure.cosmos._routing.aio.routing_map_provider import ( _shared_collection_locks, _shared_locks_locks, + _shared_inflight_fetches, _shared_cache_refcounts, ) ep = "https://async-lifecycle2.documents.azure.com:443/" c1 = PartitionKeyRangeCache(MockClient(ep)) for d in (_shared_routing_map_cache, _shared_collection_locks, - _shared_locks_locks, _shared_cache_refcounts): + _shared_locks_locks, _shared_inflight_fetches, + _shared_cache_refcounts): self.assertIn(ep, d) c1.release() for d in (_shared_routing_map_cache, _shared_collection_locks, - _shared_locks_locks, _shared_cache_refcounts): + _shared_locks_locks, _shared_inflight_fetches, + _shared_cache_refcounts): self.assertNotIn(ep, d) async def test_release_is_idempotent_async(self): @@ -172,6 +180,46 @@ async def test_clear_cache_does_not_change_refcount_async(self): self.assertEqual(self._refcount(ep), before) self.assertIn(ep, _shared_routing_map_cache) + async def test_release_while_fetch_inflight_async(self): + from azure.cosmos._routing.aio.routing_map_provider import ( + _shared_inflight_fetches, + _shared_cache_refcounts, + ) + + ep = "https://async-lifecycle5.documents.azure.com:443/" + fetch_gate = asyncio.Event() + partition_key_ranges = [{"id": "0", "minInclusive": "", "maxExclusive": "FF"}] + + class SlowReadClient(MockClient): + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + async def _gen(): + await fetch_gate.wait() + for r in partition_key_ranges: + yield r + + return _gen() + + c1 = PartitionKeyRangeCache(SlowReadClient(ep)) + fetch_task = asyncio.create_task( + c1.get_routing_map("dbs/db/colls/container", feed_options={}) + ) + + for _ in range(100): + if c1._inflight_fetches: # pylint: disable=protected-access + break + await asyncio.sleep(0.01) + + self.assertTrue(c1._inflight_fetches) # pylint: disable=protected-access + c1.release() + self.assertNotIn(ep, _shared_cache_refcounts) + self.assertNotIn(ep, _shared_routing_map_cache) + self.assertNotIn(ep, _shared_inflight_fetches) + + fetch_gate.set() + routing_map = await fetch_task + self.assertIsNotNone(routing_map) + self.assertFalse(c1._inflight_fetches) # pylint: disable=protected-access + if __name__ == "__main__": unittest.main() From 8cd16bb903e99c49dca78708d1cf70db3fe5dda7 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Thu, 21 May 2026 22:02:06 -0500 Subject: [PATCH 2/3] addressing co-pilot comments --- .../_routing/aio/routing_map_provider.py | 36 +++++++-- .../test_routing_map_provider_async.py | 26 +++++-- .../test_shared_pk_range_cache_async.py | 74 ++++++++++++++++++- 3 files changed, 121 insertions(+), 15 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index f23e20c072f5..ac4e6763522c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -319,7 +319,17 @@ async def get_routing_map( # cache write inside the task body still happens. Other waiters # (and any subsequent caller hitting the now-populated cache) are # unaffected by our cancellation. - await asyncio.shield(fetch_task) + fetched_map = await asyncio.shield(fetch_task) + # Return the task's result directly instead of re-reading from + # the cache dict. Between the task completing and this line + # running, any other ready coroutine can execute — including + # ``clear_cache()`` from a concurrent retry path — which would + # empty the dict and leave us returning ``None`` despite the + # fetch having just succeeded. Using the task's return value + # sidesteps that window entirely. Matches the + # ``AsyncCacheNonBlocking`` pattern in the Java/.NET SDKs. + if fetched_map is not None: + return fetched_map return self._collection_routing_map_by_item.get(collection_id) @@ -357,7 +367,16 @@ async def _register_or_join_inflight_fetch( async with collection_lock: existing_task = self._inflight_fetches.get(inflight_key) if existing_task is not None: - return existing_task + if not existing_task.done(): + return existing_task + # Stale completed task. Under normal scheduling this never + # happens because ``_fetch_and_publish``'s ``finally`` pops + # the entry before the task transitions to ``done``. The one + # realistic way an orphan can sit here is a previous event + # loop being closed mid-fetch whose ``id()`` was then reused + # by the current loop . Drop the orphan and fall through to start a + # fresh fetch on the live loop. + self._inflight_fetches.pop(inflight_key, None) should_fetch, base_routing_map = determine_refresh_action( self._collection_routing_map_by_item, @@ -422,12 +441,13 @@ async def _fetch_and_publish( return new_routing_map finally: - # Atomic single-key removal; no lock needed. Runs on success, - # on fetch error, and on cancellation alike, so the next caller - # can register a fresh fetch immediately. - inflight_fetches = self._inflight_fetches - if inflight_key in inflight_fetches: - del inflight_fetches[inflight_key] + # ``dict.pop(key, default)`` is a single C-level operation under + # the GIL, so this cleanup is atomic and needs no explicit lock. + # The ``None`` default makes it tolerant of the key already being + # gone. Runs on success, on fetch error, and on cancellation + # alike, so the next caller can register a fresh fetch + # immediately. + self._inflight_fetches.pop(inflight_key, None) async def _fetch_routing_map( diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py index f73748963609..f7107bd008e6 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py @@ -1149,12 +1149,26 @@ async def _gen(): originator = asyncio.create_task( provider.get_routing_map(collection_link, feed_options={}) ) - # Yield twice — once for the originator to be scheduled, once for - # it to enter the slow path and register the in-flight task in the - # dict. If we cancel before that registration happens, the waiter - # below won't find anything to join and will start its own fetch. - await asyncio.sleep(0) - await asyncio.sleep(0) + # Poll for the originator to enter the slow path and register the + # in-flight task. Polling is more robust than ``sleep(0) × N`` + # because it captures the actual condition we care about, not a + # guess at how many event-loop ticks the registration path happens + # to need today. Without this, a future ``await`` added anywhere + # on the registration path could leave the originator un-registered + # when ``cancel()`` fires below — the waiter would then start its + # own fresh fetch and ``call_count`` would still end at 1, making + # this test silently pass for the wrong reason. + for _ in range(100): + if provider._inflight_fetches: # pylint: disable=protected-access + break + await asyncio.sleep(0.01) + # Loud failure if registration never happened — otherwise the + # test would silently pass without ever exercising the join path + # it is trying to validate. + self.assertTrue( + provider._inflight_fetches, # pylint: disable=protected-access + "Originator should have registered an in-flight task before cancellation", + ) # === Step 2: cancel the originator. The shared task it created # should keep running on the event loop (still parked on the gate). diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py index 4377d4cc69ae..bc12cbf806ec 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache_async.py @@ -216,10 +216,82 @@ async def _gen(): self.assertNotIn(ep, _shared_inflight_fetches) fetch_gate.set() - routing_map = await fetch_task + routing_map = await asyncio.wait_for(fetch_task, timeout=5) self.assertIsNotNone(routing_map) self.assertFalse(c1._inflight_fetches) # pylint: disable=protected-access + async def test_clear_cache_while_fetch_inflight_async(self): + """An in-flight fetch survives clear_cache() and repopulates the dict. + + clear_cache() empties the routing-map dict in place — it does not + drop the in-flight fetch task. So a fetch that was already running + when the cache was cleared keeps going on the event loop, finishes, + and publishes its result into the (now-empty) dict, leaving the + cache populated for the next caller. + + This pins the invariant documented on clear_cache(): a future + refactor that reassigns the dict (``= {}``) instead of clearing it + in place would break dict identity. The in-flight task would then + publish into the now-orphan old dict, leaving the cache empty for + new arrivals — and this test would fail loudly on step 4. + """ + from azure.cosmos._routing.aio.routing_map_provider import ( + _shared_inflight_fetches, + ) + + ep = "https://async-lifecycle6.documents.azure.com:443/" + fetch_gate = asyncio.Event() + partition_key_ranges = [{"id": "0", "minInclusive": "", "maxExclusive": "FF"}] + + class SlowReadClient(MockClient): + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + async def _gen(): + await fetch_gate.wait() + for r in partition_key_ranges: + yield r + + return _gen() + + c1 = PartitionKeyRangeCache(SlowReadClient(ep)) + + # === Step 1: start a gated fetch and wait for it to register. + fetch_task = asyncio.create_task( + c1.get_routing_map("dbs/db/colls/container", feed_options={}) + ) + for _ in range(100): + if c1._inflight_fetches: # pylint: disable=protected-access + break + await asyncio.sleep(0.01) + self.assertTrue( + c1._inflight_fetches, # pylint: disable=protected-access + "Fetch task should be registered before we clear the cache", + ) + + # === Step 2: clear the cache while the fetch is parked on the gate. + c1.clear_cache() + # Cache dict is empty — the fetch hasn't completed yet. + self.assertEqual(len(c1._collection_routing_map_by_item), 0) + # Critically: clear_cache() must not have dropped the in-flight task + # entry — that's what lets the survivor repopulate the cache below. + self.assertIn(ep, _shared_inflight_fetches) + self.assertTrue(_shared_inflight_fetches[ep]) + + # === Step 3: open the gate. The surviving task completes and + # publishes its result. + fetch_gate.set() + routing_map = await asyncio.wait_for(fetch_task, timeout=5) + self.assertIsNotNone(routing_map) + + # === Step 4: the cache is now repopulated by the in-flight task — + # proving clear_cache preserved dict identity (in-place .clear()) + # rather than replacing the dict with a fresh one. A regression + # that swapped ``.clear()`` for ``= {}`` would leave the cache + # empty here because the in-flight task would have written into + # the now-orphan old dict. + self.assertEqual(len(c1._collection_routing_map_by_item), 1) + # And the in-flight slot was freed by the task's ``finally`` block. + self.assertFalse(c1._inflight_fetches) # pylint: disable=protected-access + if __name__ == "__main__": unittest.main() From 6573a5cfae81679e016d66a00e32b9572c2bc33b Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Fri, 22 May 2026 00:18:32 -0500 Subject: [PATCH 3/3] addressing co-pilot comments --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + .../_routing/_routing_map_provider_common.py | 28 ++- .../_routing/aio/routing_map_provider.py | 12 +- sdk/cosmos/azure-cosmos/tests/conftest.py | 14 +- .../routing/test_routing_map_provider.py | 43 ++-- .../test_routing_map_provider_async.py | 220 ++++++++++++------ .../test_shared_cache_integration_async.py | 9 +- 7 files changed, 221 insertions(+), 106 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index ceda0633d67c..d5d0dcaa33d6 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -9,6 +9,7 @@ #### Bugs Fixed * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) +* Fixed bug where a short customer deadline could interrupt the SDK's internal lookup of a container's partition layout and leave the cached layout empty or stale, causing the customer's retries to repeatedly hit the same failure. See [PR 47066](https://github.com/Azure/azure-sdk-for-python/pull/47066) #### Other Changes * Reduced per-client memory overhead when partition-level circuit breaker (PPCB) is enabled by sharing the partition key range routing map cache across CosmosClient instances connected to the same endpoint, and stripping unused fields from cached partition key ranges using compact PKRange namedtuples. See [PR 46297](https://github.com/Azure/azure-sdk-for-python/pull/46297) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py index ce579fdb258a..bb969ff6de01 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py @@ -88,7 +88,26 @@ def prepare_fetch_options_and_headers( ) -> Dict[str, Any]: """Prepare sanitised feed options and headers for a PK-range fetch. - This mutates *kwargs* in-place (sets ``headers``). + This mutates *kwargs* in-place: + + * sets ``headers`` (with the PK-range page size, the incremental-feed + ``A-IM`` value, and the optional ``If-None-Match`` ETag); and + * drops any customer-supplied ``timeout`` / ``read_timeout`` kwargs. + + Stripping the customer's deadline at the cache layer is deliberate. + Most cache call sites already drop ``**kwargs`` two layers above the + fetch, but a small set of paths -- ``read_feed_ranges`` (sync and + async) and the circuit-breaker recovery path -- forward ``**kwargs`` + all the way down. If the customer passed ``timeout=N`` on one of those + paths, the HTTP pipeline's connection-retry policy would otherwise read + it as a wall-clock budget on the routing-map fetch and raise + ``CosmosClientTimeoutError`` mid-fetch, leaving the cache empty + (cold-cache call) or stale (refresh call) and pushing the customer's + retry into a doomed loop. The routing-map fetch is internal metadata + and should be governed by the SDK's own retry behaviour, not by + deadlines the customer intended for their data operation. Stripping + here applies uniformly on sync and async, and is belt-and-braces with + the call-site-level drops that already cover the common paths. :param previous_routing_map: The base routing map for incremental updates, or ``None`` for a full load. @@ -96,7 +115,8 @@ def prepare_fetch_options_and_headers( ~azure.cosmos._routing.collection_routing_map.CollectionRoutingMap or None :param dict feed_options: Raw feed options from the caller. - :param dict kwargs: Keyword arguments (mutated -- ``headers`` is set). + :param dict kwargs: Keyword arguments (mutated -- ``headers`` is set, + ``timeout`` and ``read_timeout`` are removed). :return: The sanitised ``change_feed_options`` dict. :rtype: dict """ @@ -119,6 +139,10 @@ def prepare_fetch_options_and_headers( headers.pop(http_constants.HttpHeaders.IfNoneMatch, None) kwargs['headers'] = headers + # Strip customer-side deadlines so they do not bound the metadata + # fetch -- see the function docstring for the full rationale. + kwargs.pop('timeout', None) + kwargs.pop('read_timeout', None) return change_feed_options diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index ac4e6763522c..b62deb5f3f0f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -44,12 +44,12 @@ if TYPE_CHECKING: from ...aio._cosmos_client_connection_async import CosmosClientConnection -# Module-level shared state, keyed by endpoint URL. All five dicts and the -# refcount are mutated only while holding ``_shared_cache_lock``. Sharing across -# every async CosmosClient that targets the same endpoint is what eliminates -# the per-client duplicate copies of the routing map (the memory win driving -# this change), and what lets concurrent readers single-flight a single -# refresh. +# Module-level shared state, keyed by endpoint URL. All five module globals +# (four state dicts plus the refcount dict) are mutated only while holding +# ``_shared_cache_lock``. Sharing across every async CosmosClient that targets +# the same endpoint is what eliminates the per-client duplicate copies of the +# routing map (the memory win driving this change), and what lets concurrent +# readers single-flight a single refresh. # endpoint -> { collection_id -> CollectionRoutingMap }. The actual cached # routing maps. The inner dict is shared by every client for that endpoint, so diff --git a/sdk/cosmos/azure-cosmos/tests/conftest.py b/sdk/cosmos/azure-cosmos/tests/conftest.py index 9f6d602c6534..0cabf7368889 100644 --- a/sdk/cosmos/azure-cosmos/tests/conftest.py +++ b/sdk/cosmos/azure-cosmos/tests/conftest.py @@ -72,10 +72,22 @@ def _reset_shared_pk_range_cache(): # if we ``.clear()`` the outer registry, a freshly-constructed client for # the same endpoint creates a brand-new inner dict and the dict-identity # invariant that test_shared_cache_integration relies on is broken. - # Same reasoning for ``_shared_collection_locks``. + # Same reasoning for ``_shared_collection_locks`` and + # ``_shared_inflight_fetches``. The in-flight dict is async-only, so we + # tolerate it being absent on the sync module via ``getattr``. for pmp in (_sync_pmp, _async_pmp): with pmp._shared_cache_lock: # pylint: disable=protected-access for cache in pmp._shared_routing_map_cache.values(): # pylint: disable=protected-access cache.clear() for locks in pmp._shared_collection_locks.values(): # pylint: disable=protected-access locks.clear() + # Drop references to any leaked in-flight fetch tasks. By the + # time this fixture runs the per-test event loop has already + # been torn down by IsolatedAsyncioTestCase, so any task still + # in the dict is either cancelled or stranded — we cannot + # await it, only release the reference so the next test starts + # with an empty in-flight slot. + inflight_registry = getattr(pmp, "_shared_inflight_fetches", None) # pylint: disable=protected-access + if inflight_registry is not None: + for inflight in inflight_registry.values(): + inflight.clear() diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index 793831c48ef2..6bee944f7080 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -245,51 +245,38 @@ def user_hook(headers, _): self.assertEqual(result.change_feed_etag, expected_internal_etag) self.assertEqual(hook_calls, ['"user-hook-etag"']) - def test_get_routing_map_tight_timeout_kwarg_still_populates_cache(self): - """Sync path forwards timeout kwarg and populates cache on successful fetch.""" - # The sync side of the PK-range cache work is much narrower than the - # async side: sync has no asyncio cancellation channel, so the - # "wait_for kills the fetch mid-flight" failure mode doesn't apply. - # What CAN reach the cache on sync is the `timeout=` kwarg - # - # This test covers the happy path of that flow: - # 1. Customer calls `get_routing_map(..., timeout=0.001)`. - # 2. The cache layer forwards the kwarg to the underlying read - # (verified by inspecting what the mock saw). - # 3. The fetch completes successfully (the mock returns instantly - # without honouring the tiny timeout). - # 4. The result lands in the cache as normal. - # 5. A second call hits the cache fast-path with no new fetch. + def test_get_routing_map_strips_customer_timeout_kwargs(self): + """Cache layer strips ``timeout=`` / ``read_timeout=`` before the fetch.""" call_count = {'count': 0} - # We record the timeout the mock saw, to prove the kwargs path is - # intact end-to-end (cache layer didn't silently drop it). - seen_timeout = {'value': None} + seen_kwargs = {} original_ranges = self.partition_key_ranges class TimeoutAwareClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - seen_timeout['value'] = kwargs.get('timeout') + seen_kwargs.update(kwargs) TestRoutingMapProvider._capture_internal_headers(kwargs, '"timeout-etag"') return original_ranges provider = PartitionKeyRangeCache(TimeoutAwareClient()) collection_link = "dbs/db/colls/container" - # === Step 1: first call with a tight timeout kwarg. The mock returns - # instantly so the timeout doesn't actually fire; the fetch succeeds. - result1 = provider.get_routing_map(collection_link, feed_options={}, timeout=0.001) + result1 = provider.get_routing_map( + collection_link, feed_options={}, timeout=0.001, read_timeout=0.001, + ) self.assertIsNotNone(result1) - # === Step 2: verify the cache layer forwarded the timeout kwarg - # down to the mock. If this is None, the kwargs path is broken. - self.assertEqual(seen_timeout['value'], 0.001) self.assertEqual(call_count['count'], 1) - # === Step 3: confirm the routing map landed in the cache as normal. + self.assertNotIn('timeout', seen_kwargs, + "Cache layer must strip customer 'timeout' before the fetch") + self.assertNotIn('read_timeout', seen_kwargs, + "Cache layer must strip customer 'read_timeout' before the fetch") + + # Internal header capture still needs to flow through. + self.assertIn('_internal_response_headers_capture', seen_kwargs) + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) self.assertIn(collection_id, provider._collection_routing_map_by_item) - - # === Step 4: second call (no timeout). Cache hit, no extra fetch. result2 = provider.get_routing_map(collection_link, feed_options={}) self.assertIs(result2, result1) self.assertEqual(call_count['count'], 1) diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py index f7107bd008e6..edfbcee7358d 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py @@ -223,6 +223,47 @@ def user_hook(headers, _): self.assertEqual(result.change_feed_etag, expected_internal_etag) self.assertEqual(hook_calls, ['"user-hook-etag"']) + async def test_get_routing_map_strips_customer_timeout_kwargs_async(self): + """Async mirror: cache layer strips ``timeout``/``read_timeout`` before PK-range fetch.""" + call_count = {'count': 0} + seen_kwargs: dict = {} + original_ranges = self.partition_key_ranges + + class TimeoutAwareClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + seen_kwargs.update(kwargs) + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"strip-etag"') + + async def _gen(): + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(TimeoutAwareClient()) + collection_link = "dbs/db/colls/container" + + result1 = await provider.get_routing_map( + collection_link, feed_options={}, timeout=0.001, read_timeout=0.001, + ) + self.assertIsNotNone(result1) + self.assertEqual(call_count['count'], 1) + + self.assertNotIn('timeout', seen_kwargs, + "Cache layer must strip customer 'timeout' before the fetch") + self.assertNotIn('read_timeout', seen_kwargs, + "Cache layer must strip customer 'read_timeout' before the fetch") + + # Internal header capture still needs to flow through. + self.assertIn('_internal_response_headers_capture', seen_kwargs) + + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + self.assertIn(collection_id, provider._collection_routing_map_by_item) + result2 = await provider.get_routing_map(collection_link, feed_options={}) + self.assertIs(result2, result1) + self.assertEqual(call_count['count'], 1) + async def test_get_routing_map_returns_cached_on_second_call_async(self): """Second call returns the same cached object without re-fetching.""" call_count = {'count': 0} @@ -959,34 +1000,16 @@ async def _gen(): self.assertEqual(call_count['count'], 1) async def test_cache_populated_when_cancelled_with_timeout_kwarg_async(self): - """Caller cancellation + timeout kwarg still results in cache population.""" - # This is the previous test's companion. It pins down the same fix - # behaviour (cache must still populate after the originating caller is - # cancelled) but covers the case where the customer *also* passed a - # `timeout=N` keyword argument — i.e. both timeout mechanisms are in - # play at once: - # - # - the asyncio cancellation (from wait_for), AND - # - the kwargs timeout (a plain Python kwarg the HTTP layer reads). - # - # The kwargs timeout still gets forwarded to the underlying call (we - # verify the mock saw it). The point is that even in this combined - # scenario, the shared-task fix still wins: the task keeps - # running after the caller times out, finishes the fetch, and the - # cache ends up populated. + """Cache still populates when caller is cancelled and timeout kwarg is present.""" original_ranges = self.partition_key_ranges fetch_gate = asyncio.Event() call_count = {'count': 0} - # We capture the timeout the mock client actually saw, to prove the - # kwargs path is intact end-to-end (not silently dropped before reaching - # the underlying read). - seen_timeout_kwarg = {'value': None} + seen_kwargs: dict = {} class SlowClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - # Record what the cache layer actually forwarded. - seen_timeout_kwarg['value'] = kwargs.get('timeout') + seen_kwargs.update(kwargs) async def _gen(): # Gate again — fetch won't complete until we say so. @@ -1001,64 +1024,86 @@ async def _gen(): collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) - # === Step 1: customer call dies via wait_for cancellation. - # Note both timeouts are present: the inner kwargs `timeout=0.001` - # (which the cache forwards to the mock) AND the outer - # `wait_for(..., timeout=0.05)` that asynchronously cancels the caller. with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( provider.get_routing_map(collection_link, feed_options={}, timeout=0.001), timeout=0.05, ) - # Sanity-check: the cache layer really did forward the kwargs timeout - # down to the underlying read. If this is None it means the kwargs - # path is broken, regardless of whether the cache populates. - self.assertEqual(seen_timeout_kwarg['value'], 0.001) - # Cache still empty — fetch is gated, hasn't published yet. + self.assertNotIn( + 'timeout', seen_kwargs, + "Cache layer must strip customer 'timeout' before the fetch", + ) self.assertIsNone(provider._collection_routing_map_by_item.get(collection_id)) - # === Step 2: let the gated fetch finish, then poll. fetch_gate.set() for _ in range(100): if provider._collection_routing_map_by_item.get(collection_id) is not None: break await asyncio.sleep(0.01) - # === Step 3: cache must be populated. Same property as the previous - # test — the orphaned task lived past the caller's cancellation. populated = provider._collection_routing_map_by_item.get(collection_id) self.assertIsNotNone(populated) self.assertEqual(call_count['count'], 1) - # === Step 4: retry hits the populated cache, no second fetch. + result = await provider.get_routing_map(collection_link, feed_options={}) + self.assertIs(result, populated) + self.assertEqual(call_count['count'], 1) + + async def test_all_waiters_cancelled_cache_still_populates_async(self): + """Cache populates even when all concurrent waiters time out.""" + original_ranges = self.partition_key_ranges + fetch_gate = asyncio.Event() + call_count = {'count': 0} + + class SlowClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + + async def _gen(): + await fetch_gate.wait() + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"all-cancel-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(SlowClient()) + collection_link = "dbs/db/colls/container" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + + async def cancellable_caller(): + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + provider.get_routing_map(collection_link, feed_options={}), + timeout=0.05, + ) + + callers = [asyncio.create_task(cancellable_caller()) for _ in range(5)] + await asyncio.gather(*callers) + + self.assertIsNone(provider._collection_routing_map_by_item.get(collection_id)) + + fetch_gate.set() + for _ in range(100): + if provider._collection_routing_map_by_item.get(collection_id) is not None: + break + await asyncio.sleep(0.01) + + populated = provider._collection_routing_map_by_item.get(collection_id) + self.assertIsNotNone(populated, + "Cache must populate even after every awaiter is cancelled") + self.assertEqual(call_count['count'], 1, + "All 5 cancelled callers should have coalesced on one fetch") + self.assertEqual(len(list(populated._orderedPartitionKeyRanges)), len(original_ranges)) + result = await provider.get_routing_map(collection_link, feed_options={}) self.assertIs(result, populated) self.assertEqual(call_count['count'], 1) async def test_concurrent_cold_cache_callers_share_a_single_fetch_async(self): """Concurrent cold-cache callers must coalesce onto one fetch task.""" - # This pins down the "one shared task per container, not one per - # caller" property. The bug it guards against: if every - # cold-cache caller spawned its own fetch task, 10 simultaneous - # callers would each fire their own HTTP request at the gateway — a - # gateway-side stampede. - # - # The fix uses an in-flight-fetches dict: the first caller creates the - # task and stores it; later callers find it there and join the same - # task instead of starting a new one. - # - # We verify both halves of the property: - # 1. The mock is called exactly ONCE even though 10 callers arrived - # cold and concurrently. - # 2. All 10 callers receive the SAME routing-map object (proving - # they really joined one task, didn't each get their own copy). original_ranges = self.partition_key_ranges - # Gate the fetch so all 10 callers have time to arrive and join the - # in-flight task before any of them can succeed. Without the gate, - # the first caller might finish so quickly that the others arrive - # AFTER the task is done — which would be a different (cache-hit) - # code path, not the shared-task path we're testing here. fetch_gate = asyncio.Event() call_count = {'count': 0} @@ -1080,31 +1125,72 @@ async def _gen(): async def caller(): return await provider.get_routing_map(collection_link, feed_options={}) - # Fire all 10 callers as concurrent tasks. Each one independently - # finds the cache empty and goes down the slow path; the in-flight - # dict is what makes them coalesce. tasks = [asyncio.create_task(caller()) for _ in range(10)] - # Yield so every caller has a chance to enter the slow path and - # either create the in-flight task (one of them) or find it and - # join (the other nine). Without this yield we'd race the - # gate-set below and might not get the contention we're testing. await asyncio.sleep(0.05) - # Now let the (single) fetch complete. fetch_gate.set() results = await asyncio.gather(*tasks) - # Critical assertion: the mock was called ONCE, not 10 times. - # This is the whole point of the in-flight dict. self.assertEqual(call_count['count'], 1, "All 10 concurrent cold-cache callers should share one fetch") - # And every caller got the same object back — proving they all - # awaited the same shared task, not 10 separately-scheduled fetches - # that happened to return equivalent data. first = results[0] self.assertIsNotNone(first) for r in results[1:]: self.assertIs(r, first, "All callers should observe the same routing map object") + async def test_force_refresh_caller_joins_cold_cache_fetch_async(self): + """A force-refresh caller should join an in-flight cold-cache fetch.""" + original_ranges = self.partition_key_ranges + fetch_gate = asyncio.Event() + call_count = {'count': 0} + + class SlowClient: + def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): + call_count['count'] += 1 + + async def _gen(): + await fetch_gate.wait() + TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"mixed-mode-etag"') + for r in original_ranges: + yield r + + return _gen() + + provider = PartitionKeyRangeCache(SlowClient()) + collection_link = "dbs/db/colls/container" + + caller_a = asyncio.create_task( + provider.get_routing_map(collection_link, feed_options={}) + ) + # Wait until caller A has registered the in-flight task. + for _ in range(100): + if provider._inflight_fetches: # pylint: disable=protected-access + break + await asyncio.sleep(0.01) + self.assertTrue( + provider._inflight_fetches, # pylint: disable=protected-access + "Originating cold-cache caller should register an in-flight task", + ) + + caller_b = asyncio.create_task( + provider.get_routing_map( + collection_link, feed_options={}, force_refresh=True + ) + ) + await asyncio.sleep(0.01) + + fetch_gate.set() + result_a, result_b = await asyncio.gather(caller_a, caller_b) + + self.assertEqual( + call_count['count'], 1, + "force_refresh joiner should share the cold-cache fetch, not duplicate it", + ) + self.assertIsNotNone(result_a) + self.assertIs( + result_a, result_b, + "Both callers should observe the same routing map object", + ) + async def test_waiter_joining_after_originator_cancelled_gets_result_async(self): """A waiter that joins after the originating caller is cancelled still receives the fetched map.""" # The trickiest property of the shared-task fix: the originating caller diff --git a/sdk/cosmos/azure-cosmos/tests/test_shared_cache_integration_async.py b/sdk/cosmos/azure-cosmos/tests/test_shared_cache_integration_async.py index 88e959c71e98..fb2d8f8b64f3 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_shared_cache_integration_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_shared_cache_integration_async.py @@ -50,17 +50,22 @@ async def asyncTearDown(self): await self.client1.close() # Release module-level shared routing-map state between async tests so # the test order cannot affect cache contents observed by a later test. - # Clear ALL four shared-cache globals (not just the routing-map dict) - # to keep refcount/lock state consistent. + # Clear ALL five shared-cache globals (not just the routing-map dict) + # to keep refcount/lock/in-flight state consistent. The in-flight dict + # was added alongside the asyncio.shield-based fetch coalescing; if + # we leave it dangling a later test could join a task bound to this + # test's torn-down event loop. from azure.cosmos._routing.aio.routing_map_provider import ( _shared_collection_locks, _shared_locks_locks, + _shared_inflight_fetches, _shared_cache_refcounts, ) with _shared_cache_lock: _shared_routing_map_cache.pop(self.host, None) _shared_collection_locks.pop(self.host, None) _shared_locks_locks.pop(self.host, None) + _shared_inflight_fetches.pop(self.host, None) _shared_cache_refcounts.pop(self.host, None) def _get_routing_provider(self, client):