diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f5775..c5bb838a9e 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -114,7 +114,18 @@ class OAuthContext: token_expiry_time: float | None = None # State + # + # `lock` guards short-lived reads/writes of provider state (initialization + # flag, token cache mutation, protocol_version assignment). It is held only + # while mutating state and is released before any HTTP request is yielded + # so a long-running request (e.g. GET SSE long-poll) does not block + # unrelated concurrent requests. + # + # `refresh_lock` provides single-flight semantics for token refresh: only + # one concurrent refresh fires; other waiters block on this lock, then + # re-check the token cache and proceed without re-refreshing. lock: anyio.Lock = field(default_factory=anyio.Lock) + refresh_lock: anyio.Lock = field(default_factory=anyio.Lock) def get_authorization_base_url(self, server_url: str) -> str: """Extract base URL by removing path component.""" @@ -452,7 +463,7 @@ async def _refresh_token(self) -> httpx.Request: return httpx.Request("POST", token_url, data=refresh_data, headers=headers) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -504,7 +515,17 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: - """HTTPX auth flow integration.""" + """HTTPX auth flow integration. + + Lock scope: + ``self.context.lock`` is held only while reading/mutating provider + state. The actual HTTP request yield (which may be a long-poll GET + SSE stream) runs outside any lock so concurrent unrelated requests + are not blocked. ``self.context.refresh_lock`` provides + single-flight semantics for token refresh. + """ + # === Phase 1: state read + refresh decision (brief context.lock) === + needs_refresh = False async with self.context.lock: if not self._initialized: await self._initialize() # pragma: no cover @@ -513,20 +534,40 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) if not self.context.is_token_valid() and self.context.can_refresh_token(): - # Try to refresh token - refresh_request = await self._refresh_token() # pragma: no cover - refresh_response = yield refresh_request # pragma: no cover - - if not await self._handle_refresh_response(refresh_response): # pragma: no cover - # Refresh failed, need full re-authentication - self._initialized = False - - if self.context.is_token_valid(): - self._add_auth_header(request) - - response = yield request - - if response.status_code == 401: + needs_refresh = True + + # === Phase 2: single-flight token refresh (yield outside context.lock) === + if needs_refresh: + async with self.context.refresh_lock: + # Re-check under context.lock: another coroutine may already have + # refreshed while we were waiting on refresh_lock. + refresh_request: httpx.Request | None = None + async with self.context.lock: + if not self.context.is_token_valid() and self.context.can_refresh_token(): + refresh_request = await self._refresh_token() + if refresh_request is not None: + # yield runs outside any lock so a long network round trip + # does not block unrelated concurrent requests. + refresh_response = yield refresh_request + async with self.context.lock: + if not await self._handle_refresh_response(refresh_response): + # Refresh failed; fall through to 401 handling below. + self._initialized = False + + # === Phase 3: send request (no lock; safe for long-poll GET SSE) === + if self.context.is_token_valid(): + self._add_auth_header(request) + + response = yield request + + # === Phase 4: 401 / 403 full OAuth flow === + # NOTE: Phase 4 yields multiple sub-requests (discovery, registration, + # token exchange) under context.lock. This is the existing behavior and + # is acceptable because the 401 path is exceptional and not concurrent + # with steady-state traffic. A future refactor could narrow the lock + # here in the same pattern as Phase 1-2. + if response.status_code == 401: + async with self.context.lock: # Perform full OAuth flow try: # OAuth flow must be inline due to generator constraints @@ -619,7 +660,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Retry with new tokens self._add_auth_header(request) yield request - elif response.status_code == 403: + elif response.status_code == 403: + async with self.context.lock: # Step 1: Extract error field from WWW-Authenticate header error = extract_field_from_www_auth(response, "error") diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c92..8b0f71b22d 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,10 +1,12 @@ """Tests for refactored OAuth client authentication implementation.""" import base64 +import contextlib import time from unittest import mock from urllib.parse import parse_qs, quote, unquote, urlparse +import anyio import httpx import pytest from inline_snapshot import Is, snapshot @@ -2618,3 +2620,201 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +@pytest.mark.anyio +async def test_concurrent_request_not_blocked_by_pending_long_running_request( + oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken +): + """Regression for #1326: a second request reaches its yield while the + first is still suspended (= simulating a server-side long-poll). + + Before the lock-scope fix, ``async_auth_flow`` held ``context.lock`` + across ``yield request``. A GET SSE long-poll would therefore hold the + lock for the entire SSE lifetime, blocking any concurrent request + waiting on the same provider's lock. + """ + # Set up valid tokens so neither refresh (Phase 2) nor full OAuth + # flow (Phase 4) is triggered — we exercise the steady-state Phase 3 + # yield path that previously held the lock. + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() + 1800 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + # Flow 1: drive to yield, then leave suspended (simulating long-poll). + slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + slow_flow = oauth_provider.async_auth_flow(slow_request) + yielded_slow = await slow_flow.__anext__() + assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token" + + # Flow 2: concurrent request. With the fix this reaches its yield + # immediately; without the fix it would block on context.lock. + fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp") + fast_flow = oauth_provider.async_auth_flow(fast_request) + with anyio.fail_after(5): + yielded_fast = await fast_flow.__anext__() + assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token" + + with contextlib.suppress(StopAsyncIteration): + await fast_flow.asend(httpx.Response(200, request=yielded_fast)) + with contextlib.suppress(StopAsyncIteration): + await slow_flow.asend(httpx.Response(200, request=yielded_slow)) + + +@pytest.mark.anyio +async def test_refresh_lock_double_check_skips_redundant_refresh( + oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken +): + """Two flows enter Phase 2 with an expired token. After the first + completes a refresh, the second observes the fresh token via the + Phase 2 double-check inside ``refresh_lock`` (or directly in Phase 1 + if it arrives late) and skips its own refresh. + """ + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 100 # expired + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + # Flow A: drive to refresh yield, then complete refresh. + request_a = httpx.Request("GET", "https://api.example.com/v1/mcp") + flow_a = oauth_provider.async_auth_flow(request_a) + refresh_a = await flow_a.__anext__() + assert "grant_type=refresh_token" in refresh_a.read().decode() + + refresh_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", ' + b'"expires_in": 3600, "refresh_token": "new_refresh_token"}' + ), + request=refresh_a, + ) + request_a_post = await flow_a.asend(refresh_response) + assert request_a_post.headers.get("Authorization") == "Bearer new_access_token" + + # Flow B: state already refreshed; Phase 1 sees valid token, skips Phase 2. + request_b = httpx.Request("POST", "https://api.example.com/v1/mcp") + flow_b = oauth_provider.async_auth_flow(request_b) + with anyio.fail_after(5): + request_b_yielded = await flow_b.__anext__() + assert request_b_yielded.method == "POST" + assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token" + + with contextlib.suppress(StopAsyncIteration): + await flow_b.asend(httpx.Response(200, request=request_b_yielded)) + with contextlib.suppress(StopAsyncIteration): + await flow_a.asend(httpx.Response(200, request=request_a_post)) + + +@pytest.mark.anyio +async def test_refresh_with_failed_status_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + """A non-2xx refresh response clears stored tokens and marks the provider + uninitialized so the next request triggers a full OAuth flow. + """ + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 100 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + request = httpx.Request("POST", "https://api.example.com/v1/mcp") + flow = oauth_provider.async_auth_flow(request) + refresh_request = await flow.__anext__() + assert "grant_type=refresh_token" in refresh_request.read().decode() + + # Refresh server returns 401. + refresh_response = httpx.Response(401, content=b'{"error": "invalid_grant"}', request=refresh_request) + with contextlib.suppress(StopAsyncIteration): + # After failed refresh, the flow proceeds to Phase 3 yielding the + # original request without a fresh Authorization header. We don't + # exercise the subsequent 401/full OAuth path here. + await flow.asend(refresh_response) + + assert oauth_provider.context.current_tokens is None + + +@pytest.mark.anyio +async def test_refresh_with_invalid_json_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + """A 200 refresh response with a malformed body clears stored tokens — + the pydantic ValidationError branch is taken. + """ + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 100 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + request = httpx.Request("POST", "https://api.example.com/v1/mcp") + flow = oauth_provider.async_auth_flow(request) + refresh_request = await flow.__anext__() + + # Body does not parse as OAuthToken. + refresh_response = httpx.Response(200, content=b"not json", request=refresh_request) + with contextlib.suppress(StopAsyncIteration): + await flow.asend(refresh_response) + + assert oauth_provider.context.current_tokens is None + + +@pytest.mark.anyio +async def test_double_check_inside_refresh_lock_skips_second_refresh( + oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken, monkeypatch: pytest.MonkeyPatch +): + """Exercise the double-check branch inside ``refresh_lock``: ``is_token_valid`` + returns False in Phase 1 (= the flow decides to refresh) but True inside + the inner ``context.lock`` block (= another coroutine refreshed while we + were waiting on ``refresh_lock``). The flow must skip ``_refresh_token`` + and proceed straight to Phase 3. + """ + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 100 # expired + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + # Toggle is_token_valid: False on the first call (Phase 1 decision), + # True on the second (double-check inside refresh_lock). + call_count = {"n": 0} + original_is_valid = oauth_provider.context.__class__.is_token_valid + + def fake_is_token_valid(self: object) -> bool: + call_count["n"] += 1 + if call_count["n"] == 1: + return False + # By the second call, "another coroutine" refreshed; reset token expiry + # so callers downstream see a valid token. + oauth_provider.context.token_expiry_time = time.time() + 1800 + return True + + monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", fake_is_token_valid) + try: + request = httpx.Request("POST", "https://api.example.com/v1/mcp") + flow = oauth_provider.async_auth_flow(request) + # No refresh yield is expected — the flow goes directly to its own + # request yield with the (now-valid) token header attached. + with anyio.fail_after(5): + yielded = await flow.__anext__() + assert yielded.method == "POST" + assert yielded.headers.get("Authorization") == "Bearer test_access_token" + with contextlib.suppress(StopAsyncIteration): + await flow.asend(httpx.Response(200, request=yielded)) + finally: + monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", original_is_valid)