Skip to content

Commit fdfe5d4

Browse files
fix(auth): avoid SSE OAuth refresh deadlock
1 parent 9773a3f commit fdfe5d4

4 files changed

Lines changed: 224 additions & 0 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,25 @@ def _add_auth_header(self, request: httpx.Request) -> None:
483483
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
484484
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
485485

486+
async def _prepare_request_with_refresh(self, client: httpx.AsyncClient, request: httpx.Request) -> None:
487+
"""Refresh stored tokens and add an auth header for requests sent outside the auth flow."""
488+
async with self.context.lock:
489+
if not self._initialized:
490+
await self._initialize()
491+
492+
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
493+
494+
if self.context.is_token_valid() or not self.context.can_refresh_token():
495+
if self.context.is_token_valid():
496+
self._add_auth_header(request)
497+
return
498+
499+
refresh_request = await self._refresh_token()
500+
refresh_response = await client.send(refresh_request, auth=None)
501+
502+
if await self._handle_refresh_response(refresh_response):
503+
self._add_auth_header(request)
504+
486505
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
487506
content = await response.aread()
488507
metadata = OAuthMetadata.model_validate_json(content)

src/mcp/client/sse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from httpx_sse import SSEError, aconnect_sse
1212

1313
import mcp.types as types
14+
from mcp.client.auth import OAuthClientProvider
1415
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1516
from mcp.shared.message import SessionMessage
1617

@@ -65,10 +66,19 @@ async def sse_client(
6566
async with httpx_client_factory(
6667
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
6768
) as client:
69+
sse_request_kwargs: dict[str, Any] = {}
70+
if isinstance(auth, OAuthClientProvider):
71+
sse_request = httpx.Request("GET", url, headers=headers)
72+
await auth._prepare_request_with_refresh(client, sse_request) # pyright: ignore[reportPrivateUsage]
73+
if "Authorization" in sse_request.headers:
74+
sse_request_kwargs["headers"] = dict(sse_request.headers)
75+
sse_request_kwargs["auth"] = None
76+
6877
async with aconnect_sse(
6978
client,
7079
"GET",
7180
url,
81+
**sse_request_kwargs,
7282
) as event_source:
7383
event_source.response.raise_for_status()
7484
logger.debug("SSE connection established")

tests/client/test_auth.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,91 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
631631
assert "client_id=test_client" in content
632632
assert "client_secret=test_secret" in content
633633

634+
@pytest.mark.anyio
635+
async def test_prepare_request_with_refresh_refreshes_expired_token(
636+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
637+
):
638+
"""Test preflight refresh for streaming requests that cannot drive OAuth inline."""
639+
640+
class FailingAuth(httpx.Auth):
641+
async def async_auth_flow(self, request: httpx.Request):
642+
raise AssertionError("preflight refresh should bypass client auth")
643+
yield request # pragma: no cover
644+
645+
oauth_provider.context.current_tokens = valid_tokens
646+
oauth_provider.context.token_expiry_time = time.time() - 1
647+
oauth_provider.context.client_info = OAuthClientInformationFull(
648+
client_id="test_client",
649+
client_secret="test_secret",
650+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
651+
token_endpoint_auth_method="client_secret_post",
652+
)
653+
oauth_provider._initialized = True
654+
655+
requests: list[httpx.Request] = []
656+
657+
async def handler(request: httpx.Request) -> httpx.Response:
658+
requests.append(request)
659+
return httpx.Response(
660+
200,
661+
json={
662+
"access_token": "refreshed_access_token",
663+
"token_type": "Bearer",
664+
"expires_in": 3600,
665+
"refresh_token": "refreshed_refresh_token",
666+
},
667+
request=request,
668+
)
669+
670+
request = httpx.Request(
671+
"GET",
672+
"https://api.example.com/v1/mcp/sse",
673+
headers={"mcp-protocol-version": "2025-06-18"},
674+
)
675+
676+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=FailingAuth()) as client:
677+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
678+
679+
assert len(requests) == 1
680+
assert requests[0].method == "POST"
681+
assert str(requests[0].url) == "https://api.example.com/token"
682+
assert "grant_type=refresh_token" in requests[0].content.decode()
683+
assert "resource=" in requests[0].content.decode()
684+
assert request.headers["Authorization"] == "Bearer refreshed_access_token"
685+
assert oauth_provider.context.current_tokens is not None
686+
assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token"
687+
assert mock_storage._tokens is not None
688+
assert mock_storage._tokens.access_token == "refreshed_access_token"
689+
690+
@pytest.mark.anyio
691+
async def test_prepare_request_with_refresh_skips_valid_token(
692+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
693+
):
694+
"""Test preflight refresh is a no-op while the current token is still valid."""
695+
oauth_provider.context.current_tokens = valid_tokens
696+
oauth_provider.context.token_expiry_time = time.time() + 1800
697+
oauth_provider.context.client_info = OAuthClientInformationFull(
698+
client_id="test_client",
699+
client_secret="test_secret",
700+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
701+
)
702+
oauth_provider._initialized = True
703+
704+
requests: list[httpx.Request] = []
705+
706+
async def handler(request: httpx.Request) -> httpx.Response:
707+
requests.append(request)
708+
return httpx.Response(500, request=request) # pragma: no cover
709+
710+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
711+
712+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
713+
await oauth_provider._prepare_request_with_refresh(client, request) # type: ignore[reportPrivateUsage]
714+
715+
assert requests == []
716+
assert request.headers["Authorization"] == "Bearer test_access_token"
717+
assert oauth_provider.context.current_tokens is valid_tokens
718+
634719
@pytest.mark.anyio
635720
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
636721
"""Test token exchange with client_secret_basic authentication."""

tests/shared/test_sse.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020

2121
import mcp.client.sse
2222
import mcp.types as types
23+
from mcp.client.auth import OAuthClientProvider
2324
from mcp.client.session import ClientSession
2425
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
26+
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
2527
from mcp.server import Server
2628
from mcp.server.sse import SseServerTransport
2729
from mcp.server.transport_security import TransportSecuritySettings
30+
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2831
from mcp.shared.exceptions import McpError
2932
from mcp.types import (
3033
EmptyResult,
@@ -602,3 +605,110 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
602605
assert not isinstance(msg, Exception)
603606
assert isinstance(msg.message.root, types.JSONRPCResponse)
604607
assert msg.message.root.id == 1
608+
609+
610+
@pytest.mark.filterwarnings("ignore::ResourceWarning")
611+
@pytest.mark.anyio
612+
async def test_sse_client_preflights_oauth_refresh_before_streaming() -> None:
613+
"""Regression test for OAuth refresh deadlocks while opening SSE streams."""
614+
615+
class MemoryTokenStorage:
616+
def __init__(self) -> None:
617+
self.tokens: OAuthToken | None = None
618+
self.client_info: OAuthClientInformationFull | None = None
619+
620+
async def get_tokens(self) -> OAuthToken | None:
621+
return self.tokens
622+
623+
async def set_tokens(self, tokens: OAuthToken) -> None:
624+
self.tokens = tokens
625+
626+
async def get_client_info(self) -> OAuthClientInformationFull | None:
627+
return self.client_info
628+
629+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
630+
self.client_info = client_info
631+
632+
class NoStreamAuthProvider(OAuthClientProvider):
633+
async def async_auth_flow(self, request: httpx.Request):
634+
if request.url.path.endswith("/sse"):
635+
raise AssertionError("SSE stream should use the preflight bearer header")
636+
async for auth_request in super().async_auth_flow(request):
637+
yield auth_request
638+
639+
storage = MemoryTokenStorage()
640+
oauth_provider = NoStreamAuthProvider(
641+
server_url="https://api.example.com/v1/mcp",
642+
client_metadata=OAuthClientMetadata(
643+
client_name="Test Client",
644+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
645+
),
646+
storage=storage,
647+
)
648+
oauth_provider.context.current_tokens = OAuthToken(
649+
access_token="expired_access_token",
650+
refresh_token="refresh_token",
651+
expires_in=1,
652+
)
653+
oauth_provider.context.token_expiry_time = time.time() - 1
654+
oauth_provider.context.client_info = OAuthClientInformationFull(
655+
client_id="test_client",
656+
client_secret="test_secret",
657+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
658+
token_endpoint_auth_method="client_secret_post",
659+
)
660+
oauth_provider._initialized = True
661+
662+
events: list[str] = []
663+
664+
async def handler(request: httpx.Request) -> httpx.Response:
665+
if request.url.path == "/token":
666+
events.append("refresh")
667+
assert request.method == "POST"
668+
assert "resource=" in request.content.decode()
669+
return httpx.Response(
670+
200,
671+
json={
672+
"access_token": "refreshed_access_token",
673+
"token_type": "Bearer",
674+
"expires_in": 3600,
675+
"refresh_token": "refreshed_refresh_token",
676+
},
677+
request=request,
678+
)
679+
680+
events.append("sse")
681+
assert request.url.path == "/v1/mcp/sse"
682+
assert request.headers["Authorization"] == "Bearer refreshed_access_token"
683+
return httpx.Response(
684+
200,
685+
headers={"Content-Type": "text/event-stream"},
686+
content=b"event: endpoint\ndata: /messages/?session_id=abc123\n\n",
687+
request=request,
688+
)
689+
690+
def client_factory(
691+
headers: dict[str, str] | None = None,
692+
timeout: httpx.Timeout | None = None,
693+
auth: httpx.Auth | None = None,
694+
) -> httpx.AsyncClient:
695+
assert auth is oauth_provider
696+
return httpx.AsyncClient(
697+
headers=headers,
698+
timeout=timeout,
699+
auth=auth,
700+
transport=httpx.MockTransport(handler),
701+
)
702+
703+
with anyio.fail_after(1):
704+
async with sse_client(
705+
"https://api.example.com/v1/mcp/sse",
706+
headers={MCP_PROTOCOL_VERSION: "2025-06-18"},
707+
auth=oauth_provider,
708+
httpx_client_factory=client_factory,
709+
):
710+
pass
711+
712+
assert events == ["refresh", "sse"]
713+
assert storage.tokens is not None
714+
assert storage.tokens.access_token == "refreshed_access_token"

0 commit comments

Comments
 (0)