Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#### Bugs Fixed
* Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243)
* Fixed bug where a short customer deadline could interrupt the SDK's internal lookup of a container's partition layout and leave the cached layout empty or stale, causing the customer's retries to repeatedly hit the same failure. See [PR 47066](https://github.com/Azure/azure-sdk-for-python/pull/47066)

#### Other Changes
* Reduced per-client memory overhead when partition-level circuit breaker (PPCB) is enabled by sharing the partition key range routing map cache across CosmosClient instances connected to the same endpoint, and stripping unused fields from cached partition key ranges using compact PKRange namedtuples. See [PR 46297](https://github.com/Azure/azure-sdk-for-python/pull/46297)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,35 @@ def prepare_fetch_options_and_headers(
) -> Dict[str, Any]:
"""Prepare sanitised feed options and headers for a PK-range fetch.

This mutates *kwargs* in-place (sets ``headers``).
This mutates *kwargs* in-place:

* sets ``headers`` (with the PK-range page size, the incremental-feed
``A-IM`` value, and the optional ``If-None-Match`` ETag); and
* drops any customer-supplied ``timeout`` / ``read_timeout`` kwargs.

Stripping the customer's deadline at the cache layer is deliberate.
Most cache call sites already drop ``**kwargs`` two layers above the
fetch, but a small set of paths -- ``read_feed_ranges`` (sync and
async) and the circuit-breaker recovery path -- forward ``**kwargs``
all the way down. If the customer passed ``timeout=N`` on one of those
paths, the HTTP pipeline's connection-retry policy would otherwise read
it as a wall-clock budget on the routing-map fetch and raise
``CosmosClientTimeoutError`` mid-fetch, leaving the cache empty
(cold-cache call) or stale (refresh call) and pushing the customer's
retry into a doomed loop. The routing-map fetch is internal metadata
and should be governed by the SDK's own retry behaviour, not by
deadlines the customer intended for their data operation. Stripping
here applies uniformly on sync and async, and is belt-and-braces with
the call-site-level drops that already cover the common paths.

:param previous_routing_map: The base routing map for incremental
updates, or ``None`` for a full load.
:type previous_routing_map:
~azure.cosmos._routing.collection_routing_map.CollectionRoutingMap
or None
:param dict feed_options: Raw feed options from the caller.
:param dict kwargs: Keyword arguments (mutated -- ``headers`` is set).
:param dict kwargs: Keyword arguments (mutated -- ``headers`` is set,
``timeout`` and ``read_timeout`` are removed).
:return: The sanitised ``change_feed_options`` dict.
:rtype: dict
"""
Expand All @@ -119,6 +139,10 @@ def prepare_fetch_options_and_headers(
headers.pop(http_constants.HttpHeaders.IfNoneMatch, None)

kwargs['headers'] = headers
# Strip customer-side deadlines so they do not bound the metadata
# fetch -- see the function docstring for the full rationale.
kwargs.pop('timeout', None)
kwargs.pop('read_timeout', None)
return change_feed_options


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
if TYPE_CHECKING:
from ...aio._cosmos_client_connection_async import CosmosClientConnection

# Module-level shared state, keyed by endpoint URL. All four dicts and the
# refcount are mutated only while holding ``_shared_cache_lock``. Sharing across
# every async CosmosClient that targets the same endpoint is what eliminates
# the per-client duplicate copies of the routing map (the memory win driving
# this change), and what lets concurrent readers single-flight a single
# refresh.
# Module-level shared state, keyed by endpoint URL. All five module globals
# (four state dicts plus the refcount dict) are mutated only while holding
# ``_shared_cache_lock``. Sharing across every async CosmosClient that targets
# the same endpoint is what eliminates the per-client duplicate copies of the
# routing map (the memory win driving this change), and what lets concurrent
# readers single-flight a single refresh.

# endpoint -> { collection_id -> CollectionRoutingMap }. The actual cached
# routing maps. The inner dict is shared by every client for that endpoint, so
Expand All @@ -75,15 +75,24 @@
# and defeat the single-flight invariant.
_shared_locks_locks: Dict[str, threading.Lock] = {}

# endpoint -> { (loop_id, collection_id) -> asyncio.Task }. The single
# in-flight fetch-and-publish task per (loop, collection). Any caller that
# arrives during a cold-cache fetch joins this task via ``asyncio.shield``
# instead of issuing its own network round trip, so concurrent callers
# share a single fetch. The task body owns both the fetch and the cache
# write, so the publish survives any individual caller being cancelled
# (e.g. by ``asyncio.wait_for``) while awaiting it.
_shared_inflight_fetches: Dict[str, Dict[tuple, asyncio.Task]] = {}
Comment thread
dibahlfi marked this conversation as resolved.

# endpoint -> int. Number of live async ``PartitionKeyRangeCache`` instances
# using this endpoint. Incremented on construction and decremented in
# ``release`` (called from ``CosmosClient.__aexit__`` / ``close`` / ``__del__``).
# When the count hits zero we drop the entry from all four dicts so an idle
# When the count hits zero we drop the entry from all five dicts so an idle
# endpoint does not pin memory forever. ``clear_cache`` does NOT touch this
# count — it only wipes routing-map contents.
_shared_cache_refcounts: Dict[str, int] = {}

# Process-wide lock guarding the four dicts above for *this* (async) module.
# Process-wide lock guarding the five dicts above for *this* (async) module.
# Note: the sync module ``_routing/routing_map_provider.py`` defines its own
# independent set of module-level dicts and its own ``_shared_cache_lock`` —
# state is NOT shared between the sync and async modules. A sync and an async
Expand Down Expand Up @@ -123,20 +132,23 @@ def __init__(self, client: Any):
self._endpoint = _resolve_endpoint(client)
self._released = False

# Share routing map cache, per-collection asyncio locks, and the
# per-endpoint meta-lock that guards the per-collection-lock dict
# across all clients with the same endpoint. Refcount lets us evict
# the entry when the last sharing client releases it (see ``release``).
# Share routing map cache, per-collection asyncio locks, the
# per-endpoint meta-lock that guards the per-collection-lock dict,
# and the in-flight fetch-task dict across all clients with the same
# endpoint. Refcount lets us evict the entry when the last sharing
# client releases it (see ``release``).
with _shared_cache_lock:
if self._endpoint not in _shared_routing_map_cache:
_shared_routing_map_cache[self._endpoint] = {}
_shared_collection_locks[self._endpoint] = {}
_shared_locks_locks[self._endpoint] = threading.Lock()
_shared_inflight_fetches[self._endpoint] = {}
_shared_cache_refcounts[self._endpoint] = 0
_shared_cache_refcounts[self._endpoint] += 1
self._collection_routing_map_by_item = _shared_routing_map_cache[self._endpoint]
self._collection_locks: Dict[tuple, asyncio.Lock] = _shared_collection_locks[self._endpoint]
self._locks_lock: threading.Lock = _shared_locks_locks[self._endpoint]
self._inflight_fetches: Dict[tuple, asyncio.Task] = _shared_inflight_fetches[self._endpoint]

def clear_cache(self):
"""Clear the shared routing map cache for this endpoint.
Expand All @@ -145,13 +157,13 @@ def clear_cache(self):
client references to the same dict object, so concurrent clients
sharing the endpoint continue to share a single cache instance.

The per-collection locks dict is intentionally **not** cleared here:
an in-flight ``_fetch_routing_map`` caller holds one of those locks
and will write its result into the (now-empty) shared cache when it
completes. Keeping the lock in place ensures that any concurrent
arrival serialises behind the in-flight refresh (single-flight
invariant) instead of racing it with a fresh lock. The locks dict
is evicted in ``release()`` once the endpoint refcount hits zero.
The per-collection locks dict and the in-flight fetch-task dict are
Comment thread
dibahlfi marked this conversation as resolved.
intentionally **not** cleared here. A fetch task scheduled before
this call keeps a reference to the (now-empty) cache dict and will
publish its result into it when it completes; any concurrent arrival
meanwhile joins that same task instead of racing it. Both auxiliary
dicts are evicted in ``release()`` once the endpoint refcount hits
zero.
"""
with _shared_cache_lock:
if self._endpoint in _shared_routing_map_cache:
Expand Down Expand Up @@ -180,6 +192,7 @@ def release(self) -> None:
_shared_routing_map_cache.pop(endpoint, None)
_shared_collection_locks.pop(endpoint, None)
_shared_locks_locks.pop(endpoint, None)
_shared_inflight_fetches.pop(endpoint, None)
else:
_shared_cache_refcounts[endpoint] = count
except Exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -267,9 +280,13 @@ async def get_routing_map(
) -> Optional[CollectionRoutingMap]:
"""Gets or refreshes the routing map for a collection.

This method handles the logic for fetching, caching, and updating the
collection's routing map. It uses a locking mechanism to prevent race
conditions during concurrent updates.
Concurrent callers that arrive while a fetch is already in flight for
the same collection join that fetch via ``asyncio.shield`` rather than
issuing their own network round trip. The fetch task owns the cache
write, so the publish completes even if every awaiting caller is
cancelled (for example by ``asyncio.wait_for``) before the fetch
returns. The next caller — whether the original caller retrying or a
new one — finds the cache populated.

:param str collection_link: The link to the collection.
:param Optional[Dict[str, Any]] feed_options: Optional query options.
Expand All @@ -281,37 +298,156 @@ async def get_routing_map(
"""
collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link)

# First check (no lock) for the fast path.
# Fast path: cache hit without acquiring any lock.
if not force_refresh:
cached_map = self._collection_routing_map_by_item.get(collection_id)
if cached_map:
return cached_map

# Acquire lock only when a refresh or initial load is likely needed.
fetch_task = await self._register_or_join_inflight_fetch(
collection_id,
collection_link,
feed_options,
force_refresh,
previous_routing_map,
kwargs,
)

if fetch_task is not None:
# ``shield`` ensures our cancellation only unwinds *this* awaiter;
# the underlying task keeps running on the event loop and the
# cache write inside the task body still happens. Other waiters
# (and any subsequent caller hitting the now-populated cache) are
# unaffected by our cancellation.
fetched_map = await asyncio.shield(fetch_task)
# Return the task's result directly instead of re-reading from
# the cache dict. Between the task completing and this line
# running, any other ready coroutine can execute — including
# ``clear_cache()`` from a concurrent retry path — which would
# empty the dict and leave us returning ``None`` despite the
# fetch having just succeeded. Using the task's return value
# sidesteps that window entirely. Matches the
# ``AsyncCacheNonBlocking`` pattern in the Java/.NET SDKs.
if fetched_map is not None:
return fetched_map

return self._collection_routing_map_by_item.get(collection_id)
Comment thread
dibahlfi marked this conversation as resolved.

async def _register_or_join_inflight_fetch(
self,
collection_id: str,
collection_link: str,
feed_options: Optional[Dict[str, Any]],
force_refresh: bool,
previous_routing_map: Optional[CollectionRoutingMap],
fetch_kwargs: Dict[str, Any],
) -> Optional[asyncio.Task]:
"""Return the in-flight fetch task for this collection, creating one if needed.

Holding the per-collection lock for just the check-or-create window
(no network I/O inside the lock) keeps the critical section small.
The returned task may be one this call just scheduled or one a
concurrent caller scheduled moments earlier — either way the caller
should await it through ``asyncio.shield``.

:param str collection_id: The resolved collection identifier used as the cache key.
:param str collection_link: The link to the collection.
:param Optional[Dict[str, Any]] feed_options: Optional query options.
:param bool force_refresh: Whether the caller asked for a refresh.
:param Optional[CollectionRoutingMap] previous_routing_map: The caller's last
observed routing map, used by the refresh-decision helper.
:param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch.
:return: A running ``asyncio.Task`` to await, or ``None`` if no fetch
is needed (cache was populated by a concurrent caller after the
fast-path check).
:rtype: Optional[asyncio.Task]
"""
inflight_key = (id(asyncio.get_running_loop()), collection_id)
collection_lock = await self._get_lock_for_collection(collection_id)
async with collection_lock:
# Second check (with lock) — use shared helper for the decision logic.
existing_task = self._inflight_fetches.get(inflight_key)
Comment thread
dibahlfi marked this conversation as resolved.
Comment thread
dibahlfi marked this conversation as resolved.
if existing_task is not None:
if not existing_task.done():
return existing_task
# Stale completed task. Under normal scheduling this never
# happens because ``_fetch_and_publish``'s ``finally`` pops
# the entry before the task transitions to ``done``. The one
# realistic way an orphan can sit here is a previous event
# loop being closed mid-fetch whose ``id()`` was then reused
# by the current loop . Drop the orphan and fall through to start a
# fresh fetch on the live loop.
self._inflight_fetches.pop(inflight_key, None)

should_fetch, base_routing_map = determine_refresh_action(
self._collection_routing_map_by_item,
collection_id,
force_refresh,
previous_routing_map,
)
if not should_fetch:
return None

if should_fetch:
new_routing_map = await self._fetch_routing_map(
collection_link,
new_task = asyncio.create_task(
self._fetch_and_publish(
collection_id,
collection_link,
base_routing_map,
feed_options,
**kwargs
inflight_key,
fetch_kwargs,
)
)
self._inflight_fetches[inflight_key] = new_task
return new_task

async def _fetch_and_publish(
self,
collection_id: str,
collection_link: str,
base_routing_map: Optional[CollectionRoutingMap],
feed_options: Optional[Dict[str, Any]],
inflight_key: tuple,
fetch_kwargs: Dict[str, Any],
) -> Optional[CollectionRoutingMap]:
"""Run ``_fetch_routing_map`` and publish its result, then free the in-flight slot.

# Update the cache.
if new_routing_map:
self._collection_routing_map_by_item[collection_id] = new_routing_map
The cache assignment lives inside this task body so a caller's
cancellation while awaiting the task cannot interrupt the publish.
The ``finally`` block always frees the in-flight slot — on success,
on a fetch error, or on cancellation — so the next caller is free to
schedule a fresh attempt.

:param str collection_id: The resolved collection identifier used as the cache key.
:param str collection_link: The link to the collection.
:param Optional[CollectionRoutingMap] base_routing_map: The base routing map
for incremental updates, or ``None`` for a full load.
:param Optional[Dict[str, Any]] feed_options: Optional query options.
:param tuple inflight_key: The ``(loop_id, collection_id)`` key into the in-flight dict.
:param Dict[str, Any] fetch_kwargs: Pipeline kwargs forwarded to the fetch.
:return: The new routing map, or ``None`` if the fetch produced nothing.
:rtype: Optional[CollectionRoutingMap]
"""
try:
new_routing_map = await self._fetch_routing_map(
collection_link,
collection_id,
base_routing_map,
feed_options,
**fetch_kwargs,
)

return self._collection_routing_map_by_item.get(collection_id)
if new_routing_map:
self._collection_routing_map_by_item[collection_id] = new_routing_map

return new_routing_map
finally:
# ``dict.pop(key, default)`` is a single C-level operation under
# the GIL, so this cleanup is atomic and needs no explicit lock.
# The ``None`` default makes it tolerant of the key already being
# gone. Runs on success, on fetch error, and on cancellation
# alike, so the next caller can register a fresh fetch
# immediately.
self._inflight_fetches.pop(inflight_key, None)


async def _fetch_routing_map(
Expand Down
Loading
Loading