Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 59 additions & 17 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,18 @@ class OAuthContext:
token_expiry_time: float | None = None

# State
#
# `lock` guards short-lived reads/writes of provider state (initialization
# flag, token cache mutation, protocol_version assignment). It is held only
# while mutating state and is released before any HTTP request is yielded
# so a long-running request (e.g. GET SSE long-poll) does not block
# unrelated concurrent requests.
#
# `refresh_lock` provides single-flight semantics for token refresh: only
# one concurrent refresh fires; other waiters block on this lock, then
# re-check the token cache and proceed without re-refreshing.
lock: anyio.Lock = field(default_factory=anyio.Lock)
refresh_lock: anyio.Lock = field(default_factory=anyio.Lock)

def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
Expand Down Expand Up @@ -452,7 +463,7 @@ async def _refresh_token(self) -> httpx.Request:

return httpx.Request("POST", token_url, data=refresh_data, headers=headers)

async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
"""Handle token refresh response. Returns True if successful."""
if response.status_code != 200:
logger.warning(f"Token refresh failed: {response.status_code}")
Expand Down Expand Up @@ -504,7 +515,17 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
"""HTTPX auth flow integration.

Lock scope:
``self.context.lock`` is held only while reading/mutating provider
state. The actual HTTP request yield (which may be a long-poll GET
SSE stream) runs outside any lock so concurrent unrelated requests
are not blocked. ``self.context.refresh_lock`` provides
single-flight semantics for token refresh.
"""
# === Phase 1: state read + refresh decision (brief context.lock) ===
needs_refresh = False
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover
Expand All @@ -513,20 +534,40 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)

if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover

if not await self._handle_refresh_response(refresh_response): # pragma: no cover
# Refresh failed, need full re-authentication
self._initialized = False

if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

if response.status_code == 401:
needs_refresh = True

# === Phase 2: single-flight token refresh (yield outside context.lock) ===
if needs_refresh:
async with self.context.refresh_lock:
# Re-check under context.lock: another coroutine may already have
# refreshed while we were waiting on refresh_lock.
refresh_request: httpx.Request | None = None
async with self.context.lock:
if not self.context.is_token_valid() and self.context.can_refresh_token():
refresh_request = await self._refresh_token()
if refresh_request is not None:
# yield runs outside any lock so a long network round trip
# does not block unrelated concurrent requests.
refresh_response = yield refresh_request
async with self.context.lock:
if not await self._handle_refresh_response(refresh_response):
# Refresh failed; fall through to 401 handling below.
self._initialized = False

# === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

# === Phase 4: 401 / 403 full OAuth flow ===
# NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
# token exchange) under context.lock. This is the existing behavior and
# is acceptable because the 401 path is exceptional and not concurrent
# with steady-state traffic. A future refactor could narrow the lock
# here in the same pattern as Phase 1-2.
if response.status_code == 401:
async with self.context.lock:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
Expand Down Expand Up @@ -619,7 +660,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Retry with new tokens
self._add_auth_header(request)
yield request
elif response.status_code == 403:
elif response.status_code == 403:
async with self.context.lock:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")

Expand Down
200 changes: 200 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for refactored OAuth client authentication implementation."""

import base64
import contextlib
import time
from unittest import mock
from urllib.parse import parse_qs, quote, unquote, urlparse

import anyio
import httpx
import pytest
from inline_snapshot import Is, snapshot
Expand Down Expand Up @@ -2618,3 +2620,201 @@ async def callback_handler() -> tuple[str, str | None]:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass


@pytest.mark.anyio
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""Regression for #1326: a second request reaches its yield while the
first is still suspended (= simulating a server-side long-poll).

Before the lock-scope fix, ``async_auth_flow`` held ``context.lock``
across ``yield request``. A GET SSE long-poll would therefore hold the
lock for the entire SSE lifetime, blocking any concurrent request
waiting on the same provider's lock.
"""
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
# flow (Phase 4) is triggered — we exercise the steady-state Phase 3
# yield path that previously held the lock.
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() + 1800
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

# Flow 1: drive to yield, then leave suspended (simulating long-poll).
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
slow_flow = oauth_provider.async_auth_flow(slow_request)
yielded_slow = await slow_flow.__anext__()
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"

# Flow 2: concurrent request. With the fix this reaches its yield
# immediately; without the fix it would block on context.lock.
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
fast_flow = oauth_provider.async_auth_flow(fast_request)
with anyio.fail_after(5):
yielded_fast = await fast_flow.__anext__()
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"

with contextlib.suppress(StopAsyncIteration):
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
with contextlib.suppress(StopAsyncIteration):
await slow_flow.asend(httpx.Response(200, request=yielded_slow))


@pytest.mark.anyio
async def test_refresh_lock_double_check_skips_redundant_refresh(
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""Two flows enter Phase 2 with an expired token. After the first
completes a refresh, the second observes the fresh token via the
Phase 2 double-check inside ``refresh_lock`` (or directly in Phase 1
if it arrives late) and skips its own refresh.
"""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

# Flow A: drive to refresh yield, then complete refresh.
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
flow_a = oauth_provider.async_auth_flow(request_a)
refresh_a = await flow_a.__anext__()
assert "grant_type=refresh_token" in refresh_a.read().decode()

refresh_response = httpx.Response(
200,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", '
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
),
request=refresh_a,
)
request_a_post = await flow_a.asend(refresh_response)
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"

# Flow B: state already refreshed; Phase 1 sees valid token, skips Phase 2.
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
flow_b = oauth_provider.async_auth_flow(request_b)
with anyio.fail_after(5):
request_b_yielded = await flow_b.__anext__()
assert request_b_yielded.method == "POST"
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"

with contextlib.suppress(StopAsyncIteration):
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
with contextlib.suppress(StopAsyncIteration):
await flow_a.asend(httpx.Response(200, request=request_a_post))


@pytest.mark.anyio
async def test_refresh_with_failed_status_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
"""A non-2xx refresh response clears stored tokens and marks the provider
uninitialized so the next request triggers a full OAuth flow.
"""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 100
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

request = httpx.Request("POST", "https://api.example.com/v1/mcp")
flow = oauth_provider.async_auth_flow(request)
refresh_request = await flow.__anext__()
assert "grant_type=refresh_token" in refresh_request.read().decode()

# Refresh server returns 401.
refresh_response = httpx.Response(401, content=b'{"error": "invalid_grant"}', request=refresh_request)
with contextlib.suppress(StopAsyncIteration):
# After failed refresh, the flow proceeds to Phase 3 yielding the
# original request without a fresh Authorization header. We don't
# exercise the subsequent 401/full OAuth path here.
await flow.asend(refresh_response)

assert oauth_provider.context.current_tokens is None


@pytest.mark.anyio
async def test_refresh_with_invalid_json_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
"""A 200 refresh response with a malformed body clears stored tokens —
the pydantic ValidationError branch is taken.
"""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 100
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

request = httpx.Request("POST", "https://api.example.com/v1/mcp")
flow = oauth_provider.async_auth_flow(request)
refresh_request = await flow.__anext__()

# Body does not parse as OAuthToken.
refresh_response = httpx.Response(200, content=b"not json", request=refresh_request)
with contextlib.suppress(StopAsyncIteration):
await flow.asend(refresh_response)

assert oauth_provider.context.current_tokens is None


@pytest.mark.anyio
async def test_double_check_inside_refresh_lock_skips_second_refresh(
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken, monkeypatch: pytest.MonkeyPatch
):
"""Exercise the double-check branch inside ``refresh_lock``: ``is_token_valid``
returns False in Phase 1 (= the flow decides to refresh) but True inside
the inner ``context.lock`` block (= another coroutine refreshed while we
were waiting on ``refresh_lock``). The flow must skip ``_refresh_token``
and proceed straight to Phase 3.
"""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

# Toggle is_token_valid: False on the first call (Phase 1 decision),
# True on the second (double-check inside refresh_lock).
call_count = {"n": 0}
original_is_valid = oauth_provider.context.__class__.is_token_valid

def fake_is_token_valid(self: object) -> bool:
call_count["n"] += 1
if call_count["n"] == 1:
return False
# By the second call, "another coroutine" refreshed; reset token expiry
# so callers downstream see a valid token.
oauth_provider.context.token_expiry_time = time.time() + 1800
return True

monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", fake_is_token_valid)
try:
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
flow = oauth_provider.async_auth_flow(request)
# No refresh yield is expected — the flow goes directly to its own
# request yield with the (now-valid) token header attached.
with anyio.fail_after(5):
yielded = await flow.__anext__()
assert yielded.method == "POST"
assert yielded.headers.get("Authorization") == "Bearer test_access_token"
with contextlib.suppress(StopAsyncIteration):
await flow.asend(httpx.Response(200, request=yielded))
finally:
monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", original_is_valid)
Loading