From 0926169f72fd8b67e8e804e70a1f9ea9958f0671 Mon Sep 17 00:00:00 2001 From: Tonye Jack Date: Sat, 18 Apr 2026 23:42:02 -0600 Subject: [PATCH 1/2] fix(api-client): share a long-lived httpx client, configure timeouts, and single-flight JWKS / OIDC refetch --- src/auth0_api_python/api_client.py | 77 ++++++++---- src/auth0_api_python/utils.py | 77 ++++++++++-- tests/test_concurrent_fetch.py | 189 +++++++++++++++++++++++++++++ 3 files changed, 311 insertions(+), 32 deletions(-) create mode 100644 tests/test_concurrent_fetch.py diff --git a/src/auth0_api_python/api_client.py b/src/auth0_api_python/api_client.py index 36f23cf..6e1dad0 100644 --- a/src/auth0_api_python/api_client.py +++ b/src/auth0_api_python/api_client.py @@ -1,5 +1,6 @@ import asyncio import time +from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Optional, Union @@ -22,6 +23,7 @@ VerifyAccessTokenError, ) from .utils import ( + aclose_default_httpx_client, calculate_jwk_thumbprint, fetch_jwks, fetch_oidc_metadata, @@ -111,11 +113,28 @@ def __init__(self, options: ApiClientOptions): self._cache_ttl = options.cache_ttl_seconds + # Per-cache-key single-flight locks for OIDC discovery and JWKS + # refetches. Without these, every concurrent request that misses the + # cache at the moment of TTL expiry fires its own outbound HTTP call + # — a thundering herd that Auth0 rate-limits and we time out on. + # The lock guarantees only ONE coroutine per cache key refetches; + # the rest await the result and read from the now-warm cache. + self._discovery_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._jwks_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._jwt = JsonWebToken(["RS256"]) self._dpop_algorithms = ["ES256"] self._dpop_jwt = JsonWebToken(self._dpop_algorithms) + async def aclose(self) -> None: + """Release the shared default httpx client used for JWKS / OIDC fetches. + + Only meaningful when no `custom_fetch` was supplied. Safe to call + multiple times; safe to call when the client was never created. + """ + await aclose_default_httpx_client() + def is_dpop_required(self) -> bool: """Check if DPoP authentication is required.""" return getattr(self.options, "dpop_required", False) @@ -1029,19 +1048,26 @@ async def _discover(self, issuer: Optional[str] = None) -> dict[str, Any]: if cached: return cached - metadata, max_age = await fetch_oidc_metadata( - domain=domain, - custom_fetch=self.options.custom_fetch - ) + # Single-flight: only one coroutine per cache_key refetches; the + # rest await it and re-check the cache after acquiring the lock. + async with self._discovery_locks[cache_key]: + cached = self._discovery_cache.get(cache_key) + if cached: + return cached - effective_ttl = self._cache_ttl - if max_age is not None and self._cache_ttl is not None: - effective_ttl = min(max_age, self._cache_ttl) - elif max_age is not None: - effective_ttl = max_age + metadata, max_age = await fetch_oidc_metadata( + domain=domain, + custom_fetch=self.options.custom_fetch + ) + + effective_ttl = self._cache_ttl + if max_age is not None and self._cache_ttl is not None: + effective_ttl = min(max_age, self._cache_ttl) + elif max_age is not None: + effective_ttl = max_age - self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl) - return metadata + self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl) + return metadata async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]: """ @@ -1060,19 +1086,26 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]: if cached: return cached - jwks_data, max_age = await fetch_jwks( - jwks_uri=jwks_uri, - custom_fetch=self.options.custom_fetch - ) + # Single-flight: only one coroutine per cache_key refetches; the + # rest await it and re-check the cache after acquiring the lock. + async with self._jwks_locks[cache_key]: + cached = self._jwks_cache.get(cache_key) + if cached: + return cached + + jwks_data, max_age = await fetch_jwks( + jwks_uri=jwks_uri, + custom_fetch=self.options.custom_fetch + ) - effective_ttl = self._cache_ttl - if max_age is not None and self._cache_ttl is not None: - effective_ttl = min(max_age, self._cache_ttl) - elif max_age is not None: - effective_ttl = max_age + effective_ttl = self._cache_ttl + if max_age is not None and self._cache_ttl is not None: + effective_ttl = min(max_age, self._cache_ttl) + elif max_age is not None: + effective_ttl = max_age - self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl) - return jwks_data + self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl) + return jwks_data def _validate_claims_presence( self, diff --git a/src/auth0_api_python/utils.py b/src/auth0_api_python/utils.py index 72fbef8..012d969 100644 --- a/src/auth0_api_python/utils.py +++ b/src/auth0_api_python/utils.py @@ -3,6 +3,7 @@ using httpx or a custom fetch approach. """ +import asyncio import base64 import hashlib import json @@ -13,6 +14,62 @@ import httpx from ada_url import URL +# --------------------------------------------------------------------------- +# Shared, lazily-initialized httpx.AsyncClient used by `fetch_jwks` and +# `fetch_oidc_metadata` whenever the caller has NOT supplied a `custom_fetch`. +# +# Why this exists: +# The previous implementation constructed `httpx.AsyncClient()` per call. +# That meant: +# 1. A fresh TCP + TLS handshake to Auth0 on every cache miss (no +# connection pooling / keep-alive across calls). +# 2. httpx's default 5-second connect/read/write/pool timeouts applied. +# Under any non-trivial concurrency — for example, when the in-memory +# JWKS cache expires while N requests are in flight — both effects cause +# `httpx.ConnectTimeout` errors that propagate out of `verify_request` +# and surface to callers as opaque "Unknown auth error" failures. +# +# A single long-lived client gives us connection pooling, explicit timeouts, +# and bounded transport-level retries. +# --------------------------------------------------------------------------- +_DEFAULT_HTTPX_CLIENT: Optional[httpx.AsyncClient] = None +_DEFAULT_HTTPX_CLIENT_LOCK = asyncio.Lock() + + +def _build_default_httpx_client() -> httpx.AsyncClient: + """Construct the default shared client used when no `custom_fetch` is set.""" + return httpx.AsyncClient( + timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0), + limits=httpx.Limits( + max_connections=200, + max_keepalive_connections=50, + ), + transport=httpx.AsyncHTTPTransport(retries=2), + ) + + +async def _get_default_httpx_client() -> httpx.AsyncClient: + """Return the shared default client, creating it on first use.""" + global _DEFAULT_HTTPX_CLIENT + if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed: + return _DEFAULT_HTTPX_CLIENT + async with _DEFAULT_HTTPX_CLIENT_LOCK: + if _DEFAULT_HTTPX_CLIENT is None or _DEFAULT_HTTPX_CLIENT.is_closed: + _DEFAULT_HTTPX_CLIENT = _build_default_httpx_client() + return _DEFAULT_HTTPX_CLIENT + + +async def aclose_default_httpx_client() -> None: + """Close the module-level shared httpx client. + + Useful for tests and for callers that want a clean shutdown. Safe to call + multiple times; safe to call when the client was never created. + """ + global _DEFAULT_HTTPX_CLIENT + if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed: + await _DEFAULT_HTTPX_CLIENT.aclose() + _DEFAULT_HTTPX_CLIENT = None + def parse_cache_control_max_age(headers: Mapping[str, str]) -> Optional[int]: """ @@ -102,11 +159,11 @@ async def fetch_oidc_metadata( return data, max_age return response, None else: - async with httpx.AsyncClient() as client: - resp = await client.get(url) - resp.raise_for_status() - max_age = parse_cache_control_max_age(resp.headers) - return resp.json(), max_age + client = await _get_default_httpx_client() + resp = await client.get(url) + resp.raise_for_status() + max_age = parse_cache_control_max_age(resp.headers) + return resp.json(), max_age async def fetch_jwks( @@ -128,11 +185,11 @@ async def fetch_jwks( return data, max_age return response, None else: - async with httpx.AsyncClient() as client: - resp = await client.get(jwks_uri) - resp.raise_for_status() - max_age = parse_cache_control_max_age(resp.headers) - return resp.json(), max_age + client = await _get_default_httpx_client() + resp = await client.get(jwks_uri) + resp.raise_for_status() + max_age = parse_cache_control_max_age(resp.headers) + return resp.json(), max_age def _decode_jwt_segment(token: Union[str, bytes], segment_index: int) -> dict: diff --git a/tests/test_concurrent_fetch.py b/tests/test_concurrent_fetch.py new file mode 100644 index 0000000..a29d675 --- /dev/null +++ b/tests/test_concurrent_fetch.py @@ -0,0 +1,189 @@ +""" +Tests for the shared default httpx client and per-cache-key +single-flight refetch behaviour in `ApiClient`. + +Guards two fixes: + +1. The default httpx client used by `utils.fetch_jwks` / + `utils.fetch_oidc_metadata` (when no `custom_fetch` is supplied) is a + single shared instance with explicit timeouts — not a fresh + `httpx.AsyncClient()` per call relying on httpx's 5-second defaults. + +2. Per-cache-key single-flight on `ApiClient._fetch_jwks` and + `ApiClient._discover`: N concurrent cache misses for the same key + produce exactly ONE upstream HTTP fetch. +""" + +import asyncio + +import pytest +import pytest_asyncio +from conftest import DISCOVERY_URL, JWKS_URL +from pytest_httpx import HTTPXMock + +from auth0_api_python import ApiClient, ApiClientOptions +from auth0_api_python import utils as auth0_utils + +# ===== Fixtures ===== + +@pytest_asyncio.fixture(autouse=True) +async def _reset_default_httpx_client(): + """Ensure each test starts and ends with no shared httpx client. + + The shared client is built lazily on first use; resetting it before + each test guarantees it is (re)created inside the active + `httpx_mock` patch scope, so pytest-httpx can intercept its + requests. + """ + await auth0_utils.aclose_default_httpx_client() + yield + await auth0_utils.aclose_default_httpx_client() + + +# ===== Single-flight: JWKS ===== + +@pytest.mark.asyncio +async def test_concurrent_jwks_misses_trigger_single_fetch(httpx_mock: HTTPXMock): + """ + Test that 50 concurrent callers missing the JWKS cache for the same + URI cause exactly one outbound HTTP fetch, not 50. + """ + httpx_mock.add_response( + method="GET", + url=JWKS_URL, + json={"keys": []}, + is_reusable=True, + ) + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + + results = await asyncio.gather( + *(api_client._fetch_jwks(JWKS_URL) for _ in range(50)) + ) + + assert all(r == {"keys": []} for r in results) + requests = [r for r in httpx_mock.get_requests() if str(r.url) == JWKS_URL] + assert len(requests) == 1, ( + f"Expected exactly 1 outbound JWKS fetch under concurrent miss, " + f"got {len(requests)}" + ) + + +# ===== Single-flight: OIDC discovery ===== + +@pytest.mark.asyncio +async def test_concurrent_oidc_misses_trigger_single_fetch(httpx_mock: HTTPXMock): + """ + Test that 50 concurrent callers missing the OIDC discovery cache + cause exactly one outbound HTTP fetch, not 50. + """ + httpx_mock.add_response( + method="GET", + url=DISCOVERY_URL, + json={"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL}, + is_reusable=True, + ) + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + + results = await asyncio.gather( + *(api_client._discover() for _ in range(50)) + ) + + expected = {"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL} + assert all(r == expected for r in results) + requests = [r for r in httpx_mock.get_requests() if str(r.url) == DISCOVERY_URL] + assert len(requests) == 1, ( + f"Expected exactly 1 outbound OIDC discovery fetch under " + f"concurrent miss, got {len(requests)}" + ) + + +# ===== Per-key locking ===== + +@pytest.mark.asyncio +async def test_jwks_locks_are_per_cache_key(httpx_mock: HTTPXMock): + """ + Test that concurrent misses for DIFFERENT JWKS URIs are not + serialized behind a single global lock — each URI gets its own. + """ + uri_a = "https://a.auth0.local/.well-known/jwks.json" + uri_b = "https://b.auth0.local/.well-known/jwks.json" + httpx_mock.add_response(method="GET", url=uri_a, json={"keys": ["a"]}) + httpx_mock.add_response(method="GET", url=uri_b, json={"keys": ["b"]}) + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + + a, b = await asyncio.gather( + api_client._fetch_jwks(uri_a), + api_client._fetch_jwks(uri_b), + ) + + assert a == {"keys": ["a"]} + assert b == {"keys": ["b"]} + + requests_a = [r for r in httpx_mock.get_requests() if str(r.url) == uri_a] + requests_b = [r for r in httpx_mock.get_requests() if str(r.url) == uri_b] + assert len(requests_a) == 1 + assert len(requests_b) == 1 + + +# ===== Default httpx client ===== + +@pytest.mark.asyncio +async def test_default_httpx_client_is_shared(): + """ + Test that the default httpx client is a singleton across calls when + no `custom_fetch` is supplied — not a fresh client per call. + """ + first = await auth0_utils._get_default_httpx_client() + second = await auth0_utils._get_default_httpx_client() + + assert first is second, "default httpx client should be a singleton" + + +@pytest.mark.asyncio +async def test_default_httpx_client_has_explicit_timeouts(): + """ + Test that the default httpx client has explicit, non-default + timeouts. Regression guard: httpx's default 5-second timeouts + fall over under concurrent load. + """ + client = await auth0_utils._get_default_httpx_client() + + assert client.timeout.connect is not None + assert client.timeout.read is not None + assert client.timeout.write is not None + assert client.timeout.pool is not None + # Read timeout in particular should be generous (>= 5s). + assert client.timeout.read >= 5.0 + + +# ===== Shutdown ===== + +@pytest.mark.asyncio +async def test_aclose_is_idempotent(): + """ + Test that `ApiClient.aclose()` and `aclose_default_httpx_client()` + are safe to call multiple times — including before the client was + ever created — and that the client can be re-created after close. + """ + # Safe to call before the client is ever built. + await auth0_utils.aclose_default_httpx_client() + await auth0_utils.aclose_default_httpx_client() + + api_client = ApiClient( + ApiClientOptions(domain="auth0.local", audience="my-audience") + ) + await api_client.aclose() + await api_client.aclose() # idempotent + + # Closed client must be re-creatable on next use. + new_client = await auth0_utils._get_default_httpx_client() + assert not new_client.is_closed From 1e613f52491b1a6b332641123c2cec3f4b4a296f Mon Sep 17 00:00:00 2001 From: Tonye Jack Date: Sat, 18 Apr 2026 23:53:35 -0600 Subject: [PATCH 2/2] Updated comments --- src/auth0_api_python/api_client.py | 6 +-- src/auth0_api_python/utils.py | 24 +--------- tests/test_concurrent_fetch.py | 70 ++++-------------------------- 3 files changed, 11 insertions(+), 89 deletions(-) diff --git a/src/auth0_api_python/api_client.py b/src/auth0_api_python/api_client.py index 6e1dad0..5c3726b 100644 --- a/src/auth0_api_python/api_client.py +++ b/src/auth0_api_python/api_client.py @@ -128,11 +128,7 @@ def __init__(self, options: ApiClientOptions): self._dpop_jwt = JsonWebToken(self._dpop_algorithms) async def aclose(self) -> None: - """Release the shared default httpx client used for JWKS / OIDC fetches. - - Only meaningful when no `custom_fetch` was supplied. Safe to call - multiple times; safe to call when the client was never created. - """ + """Release the shared default httpx client. Idempotent; no-op when a `custom_fetch` is in use.""" await aclose_default_httpx_client() def is_dpop_required(self) -> bool: diff --git a/src/auth0_api_python/utils.py b/src/auth0_api_python/utils.py index 012d969..7349782 100644 --- a/src/auth0_api_python/utils.py +++ b/src/auth0_api_python/utils.py @@ -14,24 +14,6 @@ import httpx from ada_url import URL -# --------------------------------------------------------------------------- -# Shared, lazily-initialized httpx.AsyncClient used by `fetch_jwks` and -# `fetch_oidc_metadata` whenever the caller has NOT supplied a `custom_fetch`. -# -# Why this exists: -# The previous implementation constructed `httpx.AsyncClient()` per call. -# That meant: -# 1. A fresh TCP + TLS handshake to Auth0 on every cache miss (no -# connection pooling / keep-alive across calls). -# 2. httpx's default 5-second connect/read/write/pool timeouts applied. -# Under any non-trivial concurrency — for example, when the in-memory -# JWKS cache expires while N requests are in flight — both effects cause -# `httpx.ConnectTimeout` errors that propagate out of `verify_request` -# and surface to callers as opaque "Unknown auth error" failures. -# -# A single long-lived client gives us connection pooling, explicit timeouts, -# and bounded transport-level retries. -# --------------------------------------------------------------------------- _DEFAULT_HTTPX_CLIENT: Optional[httpx.AsyncClient] = None _DEFAULT_HTTPX_CLIENT_LOCK = asyncio.Lock() @@ -60,11 +42,7 @@ async def _get_default_httpx_client() -> httpx.AsyncClient: async def aclose_default_httpx_client() -> None: - """Close the module-level shared httpx client. - - Useful for tests and for callers that want a clean shutdown. Safe to call - multiple times; safe to call when the client was never created. - """ + """Close the module-level shared httpx client. Idempotent.""" global _DEFAULT_HTTPX_CLIENT if _DEFAULT_HTTPX_CLIENT is not None and not _DEFAULT_HTTPX_CLIENT.is_closed: await _DEFAULT_HTTPX_CLIENT.aclose() diff --git a/tests/test_concurrent_fetch.py b/tests/test_concurrent_fetch.py index a29d675..551ef9a 100644 --- a/tests/test_concurrent_fetch.py +++ b/tests/test_concurrent_fetch.py @@ -1,19 +1,3 @@ -""" -Tests for the shared default httpx client and per-cache-key -single-flight refetch behaviour in `ApiClient`. - -Guards two fixes: - -1. The default httpx client used by `utils.fetch_jwks` / - `utils.fetch_oidc_metadata` (when no `custom_fetch` is supplied) is a - single shared instance with explicit timeouts — not a fresh - `httpx.AsyncClient()` per call relying on httpx's 5-second defaults. - -2. Per-cache-key single-flight on `ApiClient._fetch_jwks` and - `ApiClient._discover`: N concurrent cache misses for the same key - produce exactly ONE upstream HTTP fetch. -""" - import asyncio import pytest @@ -28,13 +12,6 @@ @pytest_asyncio.fixture(autouse=True) async def _reset_default_httpx_client(): - """Ensure each test starts and ends with no shared httpx client. - - The shared client is built lazily on first use; resetting it before - each test guarantees it is (re)created inside the active - `httpx_mock` patch scope, so pytest-httpx can intercept its - requests. - """ await auth0_utils.aclose_default_httpx_client() yield await auth0_utils.aclose_default_httpx_client() @@ -44,10 +21,7 @@ async def _reset_default_httpx_client(): @pytest.mark.asyncio async def test_concurrent_jwks_misses_trigger_single_fetch(httpx_mock: HTTPXMock): - """ - Test that 50 concurrent callers missing the JWKS cache for the same - URI cause exactly one outbound HTTP fetch, not 50. - """ + """N concurrent JWKS cache misses for the same URI cause exactly one upstream fetch.""" httpx_mock.add_response( method="GET", url=JWKS_URL, @@ -65,20 +39,14 @@ async def test_concurrent_jwks_misses_trigger_single_fetch(httpx_mock: HTTPXMock assert all(r == {"keys": []} for r in results) requests = [r for r in httpx_mock.get_requests() if str(r.url) == JWKS_URL] - assert len(requests) == 1, ( - f"Expected exactly 1 outbound JWKS fetch under concurrent miss, " - f"got {len(requests)}" - ) + assert len(requests) == 1 # ===== Single-flight: OIDC discovery ===== @pytest.mark.asyncio async def test_concurrent_oidc_misses_trigger_single_fetch(httpx_mock: HTTPXMock): - """ - Test that 50 concurrent callers missing the OIDC discovery cache - cause exactly one outbound HTTP fetch, not 50. - """ + """N concurrent OIDC discovery cache misses cause exactly one upstream fetch.""" httpx_mock.add_response( method="GET", url=DISCOVERY_URL, @@ -97,20 +65,14 @@ async def test_concurrent_oidc_misses_trigger_single_fetch(httpx_mock: HTTPXMock expected = {"issuer": "https://auth0.local/", "jwks_uri": JWKS_URL} assert all(r == expected for r in results) requests = [r for r in httpx_mock.get_requests() if str(r.url) == DISCOVERY_URL] - assert len(requests) == 1, ( - f"Expected exactly 1 outbound OIDC discovery fetch under " - f"concurrent miss, got {len(requests)}" - ) + assert len(requests) == 1 # ===== Per-key locking ===== @pytest.mark.asyncio async def test_jwks_locks_are_per_cache_key(httpx_mock: HTTPXMock): - """ - Test that concurrent misses for DIFFERENT JWKS URIs are not - serialized behind a single global lock — each URI gets its own. - """ + """Concurrent misses for different JWKS URIs are not serialized behind one global lock.""" uri_a = "https://a.auth0.local/.well-known/jwks.json" uri_b = "https://b.auth0.local/.well-known/jwks.json" httpx_mock.add_response(method="GET", url=uri_a, json={"keys": ["a"]}) @@ -138,30 +100,22 @@ async def test_jwks_locks_are_per_cache_key(httpx_mock: HTTPXMock): @pytest.mark.asyncio async def test_default_httpx_client_is_shared(): - """ - Test that the default httpx client is a singleton across calls when - no `custom_fetch` is supplied — not a fresh client per call. - """ + """The default httpx client is a singleton across calls.""" first = await auth0_utils._get_default_httpx_client() second = await auth0_utils._get_default_httpx_client() - assert first is second, "default httpx client should be a singleton" + assert first is second @pytest.mark.asyncio async def test_default_httpx_client_has_explicit_timeouts(): - """ - Test that the default httpx client has explicit, non-default - timeouts. Regression guard: httpx's default 5-second timeouts - fall over under concurrent load. - """ + """The default httpx client sets explicit, non-default timeouts.""" client = await auth0_utils._get_default_httpx_client() assert client.timeout.connect is not None assert client.timeout.read is not None assert client.timeout.write is not None assert client.timeout.pool is not None - # Read timeout in particular should be generous (>= 5s). assert client.timeout.read >= 5.0 @@ -169,12 +123,7 @@ async def test_default_httpx_client_has_explicit_timeouts(): @pytest.mark.asyncio async def test_aclose_is_idempotent(): - """ - Test that `ApiClient.aclose()` and `aclose_default_httpx_client()` - are safe to call multiple times — including before the client was - ever created — and that the client can be re-created after close. - """ - # Safe to call before the client is ever built. + """`aclose()` is safe to call repeatedly and the client can be re-created afterward.""" await auth0_utils.aclose_default_httpx_client() await auth0_utils.aclose_default_httpx_client() @@ -184,6 +133,5 @@ async def test_aclose_is_idempotent(): await api_client.aclose() await api_client.aclose() # idempotent - # Closed client must be re-creatable on next use. new_client = await auth0_utils._get_default_httpx_client() assert not new_client.is_closed