Skip to content

Commit b4320fb

Browse files
committed
test: cover refresh failure paths and double-check refresh skip
After dropping the function-level `# pragma: no cover` from `_handle_refresh_response` and removing the per-line pragmas from the refactored Phase 2, the strict-no-cover audit identified covered lines still marked pragma'd and surfaced previously-untested branches. Three new tests close the coverage gaps: * `test_refresh_with_failed_status_clears_tokens` — exercises the ``response.status_code != 200`` branch in `_handle_refresh_response` and the `self._initialized = False` reset on refresh failure. * `test_refresh_with_invalid_json_clears_tokens` — exercises the ValidationError branch when the refresh body is not valid JSON. * `test_double_check_inside_refresh_lock_skips_second_refresh` — uses monkeypatch to flip `is_token_valid` between Phase 1 (False) and the double-check inside `refresh_lock` (True), exercising the branch where a second coroutine's refresh is correctly elided. Also: convert the new tests from the legacy Test* class pattern to plain top-level `test_*` functions per AGENTS.md, and drop unneeded per-line `# pragma: no cover` markers in the refactored auth_flow. Coverage report: 100.00% on `src/mcp/client/auth/oauth2.py`, strict-no-cover clean.
1 parent 54ea6a5 commit b4320fb

2 files changed

Lines changed: 192 additions & 107 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ async def _refresh_token(self) -> httpx.Request:
463463

464464
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
465465

466-
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
466+
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
467467
"""Handle token refresh response. Returns True if successful."""
468468
if response.status_code != 200:
469469
logger.warning(f"Token refresh failed: {response.status_code}")
@@ -544,13 +544,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
544544
refresh_request: httpx.Request | None = None
545545
async with self.context.lock:
546546
if not self.context.is_token_valid() and self.context.can_refresh_token():
547-
refresh_request = await self._refresh_token() # pragma: no cover
547+
refresh_request = await self._refresh_token()
548548
if refresh_request is not None:
549549
# yield runs outside any lock so a long network round trip
550550
# does not block unrelated concurrent requests.
551-
refresh_response = yield refresh_request # pragma: no cover
551+
refresh_response = yield refresh_request
552552
async with self.context.lock:
553-
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
553+
if not await self._handle_refresh_response(refresh_response):
554554
# Refresh failed; fall through to 401 handling below.
555555
self._initialized = False
556556

tests/client/test_auth.py

Lines changed: 188 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,114 +2622,199 @@ async def callback_handler() -> tuple[str, str | None]:
26222622
pass
26232623

26242624

2625-
class TestConcurrentRequestsDoNotDeadlock:
2626-
"""Regression tests for #1326.
2625+
@pytest.mark.anyio
2626+
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
2627+
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2628+
):
2629+
"""Regression for #1326: a second request reaches its yield while the
2630+
first is still suspended (= simulating a server-side long-poll).
26272631
2628-
Ensures that ``OAuthClientProvider.async_auth_flow`` does not serialize
2629-
concurrent unrelated requests behind a long-running one (e.g. GET SSE
2630-
long-poll). The fix narrows ``context.lock`` to state mutation only; the
2631-
actual ``yield request`` runs outside any lock.
2632+
Before the lock-scope fix, ``async_auth_flow`` held ``context.lock``
2633+
across ``yield request``. A GET SSE long-poll would therefore hold the
2634+
lock for the entire SSE lifetime, blocking any concurrent request
2635+
waiting on the same provider's lock.
26322636
"""
2637+
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2638+
# flow (Phase 4) is triggered — we exercise the steady-state Phase 3
2639+
# yield path that previously held the lock.
2640+
oauth_provider.context.current_tokens = valid_tokens
2641+
oauth_provider.context.token_expiry_time = time.time() + 1800
2642+
oauth_provider.context.client_info = OAuthClientInformationFull(
2643+
client_id="test_client_id",
2644+
client_secret="test_client_secret",
2645+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2646+
)
2647+
oauth_provider._initialized = True
26332648

2634-
@pytest.mark.anyio
2635-
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
2636-
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2637-
):
2638-
"""A second request must reach its yield while the first is still
2639-
suspended at its yield (= simulating a server-side long-poll).
2640-
2641-
Before this fix, ``async_auth_flow`` held ``context.lock`` across
2642-
``yield request``. A GET SSE long-poll would therefore hold the lock
2643-
for the entire SSE lifetime, blocking any concurrent request waiting
2644-
on the same provider's lock and producing a multi-second stall.
2645-
"""
2646-
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2647-
# flow (Phase 4) is triggered — we want to exercise the steady-state
2648-
# Phase 3 yield path that previously held the lock.
2649-
oauth_provider.context.current_tokens = valid_tokens
2650-
oauth_provider.context.token_expiry_time = time.time() + 1800
2651-
oauth_provider.context.client_info = OAuthClientInformationFull(
2652-
client_id="test_client_id",
2653-
client_secret="test_client_secret",
2654-
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2655-
)
2656-
oauth_provider._initialized = True
2649+
# Flow 1: drive to yield, then leave suspended (simulating long-poll).
2650+
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2651+
slow_flow = oauth_provider.async_auth_flow(slow_request)
2652+
yielded_slow = await slow_flow.__anext__()
2653+
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"
26572654

2658-
# Flow 1: simulate a slow request. Drive it to its yield, then
2659-
# deliberately do not send a response — it stays suspended at the
2660-
# yield, just like a GET SSE long-poll waiting for the next event.
2661-
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2662-
slow_flow = oauth_provider.async_auth_flow(slow_request)
2663-
yielded_slow = await slow_flow.__anext__()
2664-
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"
2665-
2666-
# Flow 2: a concurrent request on the same provider. With the fix,
2667-
# context.lock is not held during Flow 1's yield, so Flow 2 reaches
2668-
# its yield almost immediately. Without the fix, this would block
2669-
# until Flow 1 receives a response — i.e., it would hit the timeout.
2670-
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2671-
fast_flow = oauth_provider.async_auth_flow(fast_request)
2672-
with anyio.fail_after(1.0):
2673-
yielded_fast = await fast_flow.__anext__()
2674-
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"
2675-
2676-
# Clean up both generators in deterministic order.
2677-
with contextlib.suppress(StopAsyncIteration):
2678-
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
2679-
with contextlib.suppress(StopAsyncIteration):
2680-
await slow_flow.asend(httpx.Response(200, request=yielded_slow))
2655+
# Flow 2: concurrent request. With the fix this reaches its yield
2656+
# immediately; without the fix it would block on context.lock.
2657+
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2658+
fast_flow = oauth_provider.async_auth_flow(fast_request)
2659+
with anyio.fail_after(5):
2660+
yielded_fast = await fast_flow.__anext__()
2661+
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"
26812662

2682-
@pytest.mark.anyio
2683-
async def test_concurrent_token_refresh_is_single_flight(
2684-
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2685-
):
2686-
"""When concurrent requests both observe an expired token, only one
2687-
refresh request is sent: ``refresh_lock`` provides single-flight
2688-
semantics so the second waiter re-checks state and proceeds without
2689-
re-triggering refresh.
2690-
"""
2691-
# Mark the token as expired so the next auth_flow run enters Phase 2.
2692-
oauth_provider.context.current_tokens = valid_tokens
2693-
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2694-
oauth_provider.context.client_info = OAuthClientInformationFull(
2695-
client_id="test_client_id",
2696-
client_secret="test_client_secret",
2697-
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2698-
)
2699-
oauth_provider._initialized = True
2663+
with contextlib.suppress(StopAsyncIteration):
2664+
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
2665+
with contextlib.suppress(StopAsyncIteration):
2666+
await slow_flow.asend(httpx.Response(200, request=yielded_slow))
27002667

2701-
# Flow A: drive it to the refresh yield and suspend there.
2702-
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
2703-
flow_a = oauth_provider.async_auth_flow(request_a)
2704-
refresh_a = await flow_a.__anext__()
2705-
assert "grant_type=refresh_token" in refresh_a.read().decode()
27062668

2707-
# Complete Flow A's refresh with a fresh token.
2708-
refresh_response = httpx.Response(
2709-
200,
2710-
content=(
2711-
b'{"access_token": "new_access_token", "token_type": "Bearer", '
2712-
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2713-
),
2714-
request=refresh_a,
2715-
)
2716-
request_a_post = await flow_a.asend(refresh_response)
2717-
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"
2718-
2719-
# Flow B starts after Flow A's refresh has completed. Because token
2720-
# state was updated under context.lock, Flow B observes the fresh
2721-
# token in Phase 1, skips Phase 2 entirely, and reaches its yield
2722-
# directly. No second refresh request is sent.
2723-
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
2724-
flow_b = oauth_provider.async_auth_flow(request_b)
2725-
with anyio.fail_after(1.0):
2726-
request_b_yielded = await flow_b.__anext__()
2727-
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"
2728-
# Confirm Flow B yielded the original POST, not a refresh request.
2729-
assert request_b_yielded.method == "POST"
2730-
2731-
# Clean up.
2732-
with contextlib.suppress(StopAsyncIteration):
2733-
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
2669+
@pytest.mark.anyio
2670+
async def test_refresh_lock_double_check_skips_redundant_refresh(
2671+
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2672+
):
2673+
"""Two flows enter Phase 2 with an expired token. After the first
2674+
completes a refresh, the second observes the fresh token via the
2675+
Phase 2 double-check inside ``refresh_lock`` (or directly in Phase 1
2676+
if it arrives late) and skips its own refresh.
2677+
"""
2678+
oauth_provider.context.current_tokens = valid_tokens
2679+
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2680+
oauth_provider.context.client_info = OAuthClientInformationFull(
2681+
client_id="test_client_id",
2682+
client_secret="test_client_secret",
2683+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2684+
)
2685+
oauth_provider._initialized = True
2686+
2687+
# Flow A: drive to refresh yield, then complete refresh.
2688+
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
2689+
flow_a = oauth_provider.async_auth_flow(request_a)
2690+
refresh_a = await flow_a.__anext__()
2691+
assert "grant_type=refresh_token" in refresh_a.read().decode()
2692+
2693+
refresh_response = httpx.Response(
2694+
200,
2695+
content=(
2696+
b'{"access_token": "new_access_token", "token_type": "Bearer", '
2697+
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2698+
),
2699+
request=refresh_a,
2700+
)
2701+
request_a_post = await flow_a.asend(refresh_response)
2702+
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"
2703+
2704+
# Flow B: state already refreshed; Phase 1 sees valid token, skips Phase 2.
2705+
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
2706+
flow_b = oauth_provider.async_auth_flow(request_b)
2707+
with anyio.fail_after(5):
2708+
request_b_yielded = await flow_b.__anext__()
2709+
assert request_b_yielded.method == "POST"
2710+
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"
2711+
2712+
with contextlib.suppress(StopAsyncIteration):
2713+
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
2714+
with contextlib.suppress(StopAsyncIteration):
2715+
await flow_a.asend(httpx.Response(200, request=request_a_post))
2716+
2717+
2718+
@pytest.mark.anyio
2719+
async def test_refresh_with_failed_status_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
2720+
"""A non-2xx refresh response clears stored tokens and marks the provider
2721+
uninitialized so the next request triggers a full OAuth flow.
2722+
"""
2723+
oauth_provider.context.current_tokens = valid_tokens
2724+
oauth_provider.context.token_expiry_time = time.time() - 100
2725+
oauth_provider.context.client_info = OAuthClientInformationFull(
2726+
client_id="test_client_id",
2727+
client_secret="test_client_secret",
2728+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2729+
)
2730+
oauth_provider._initialized = True
2731+
2732+
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2733+
flow = oauth_provider.async_auth_flow(request)
2734+
refresh_request = await flow.__anext__()
2735+
assert "grant_type=refresh_token" in refresh_request.read().decode()
2736+
2737+
# Refresh server returns 401.
2738+
refresh_response = httpx.Response(401, content=b'{"error": "invalid_grant"}', request=refresh_request)
2739+
with contextlib.suppress(StopAsyncIteration):
2740+
# After failed refresh, the flow proceeds to Phase 3 yielding the
2741+
# original request without a fresh Authorization header. We don't
2742+
# exercise the subsequent 401/full OAuth path here.
2743+
await flow.asend(refresh_response)
2744+
2745+
assert oauth_provider.context.current_tokens is None
2746+
2747+
2748+
@pytest.mark.anyio
2749+
async def test_refresh_with_invalid_json_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
2750+
"""A 200 refresh response with a malformed body clears stored tokens —
2751+
the pydantic ValidationError branch is taken.
2752+
"""
2753+
oauth_provider.context.current_tokens = valid_tokens
2754+
oauth_provider.context.token_expiry_time = time.time() - 100
2755+
oauth_provider.context.client_info = OAuthClientInformationFull(
2756+
client_id="test_client_id",
2757+
client_secret="test_client_secret",
2758+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2759+
)
2760+
oauth_provider._initialized = True
2761+
2762+
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2763+
flow = oauth_provider.async_auth_flow(request)
2764+
refresh_request = await flow.__anext__()
2765+
2766+
# Body does not parse as OAuthToken.
2767+
refresh_response = httpx.Response(200, content=b"not json", request=refresh_request)
2768+
with contextlib.suppress(StopAsyncIteration):
2769+
await flow.asend(refresh_response)
2770+
2771+
assert oauth_provider.context.current_tokens is None
2772+
2773+
2774+
@pytest.mark.anyio
2775+
async def test_double_check_inside_refresh_lock_skips_second_refresh(
2776+
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken, monkeypatch: pytest.MonkeyPatch
2777+
):
2778+
"""Exercise the double-check branch inside ``refresh_lock``: ``is_token_valid``
2779+
returns False in Phase 1 (= the flow decides to refresh) but True inside
2780+
the inner ``context.lock`` block (= another coroutine refreshed while we
2781+
were waiting on ``refresh_lock``). The flow must skip ``_refresh_token``
2782+
and proceed straight to Phase 3.
2783+
"""
2784+
oauth_provider.context.current_tokens = valid_tokens
2785+
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2786+
oauth_provider.context.client_info = OAuthClientInformationFull(
2787+
client_id="test_client_id",
2788+
client_secret="test_client_secret",
2789+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2790+
)
2791+
oauth_provider._initialized = True
2792+
2793+
# Toggle is_token_valid: False on the first call (Phase 1 decision),
2794+
# True on the second (double-check inside refresh_lock).
2795+
call_count = {"n": 0}
2796+
original_is_valid = oauth_provider.context.__class__.is_token_valid
2797+
2798+
def fake_is_token_valid(self: object) -> bool:
2799+
call_count["n"] += 1
2800+
if call_count["n"] == 1:
2801+
return False
2802+
# By the second call, "another coroutine" refreshed; reset token expiry
2803+
# so callers downstream see a valid token.
2804+
oauth_provider.context.token_expiry_time = time.time() + 1800
2805+
return True
2806+
2807+
monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", fake_is_token_valid)
2808+
try:
2809+
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2810+
flow = oauth_provider.async_auth_flow(request)
2811+
# No refresh yield is expected — the flow goes directly to its own
2812+
# request yield with the (now-valid) token header attached.
2813+
with anyio.fail_after(5):
2814+
yielded = await flow.__anext__()
2815+
assert yielded.method == "POST"
2816+
assert yielded.headers.get("Authorization") == "Bearer test_access_token"
27342817
with contextlib.suppress(StopAsyncIteration):
2735-
await flow_a.asend(httpx.Response(200, request=request_a_post))
2818+
await flow.asend(httpx.Response(200, request=yielded))
2819+
finally:
2820+
monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", original_is_valid)

0 commit comments

Comments
 (0)