Skip to content

Commit 3f4901e

Browse files
fix(auth): avoid SSE OAuth refresh deadlock
1 parent 9773a3f commit 3f4901e

4 files changed

Lines changed: 507 additions & 0 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,31 @@ def _add_auth_header(self, request: httpx.Request) -> None:
483483
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
484484
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
485485

486+
async def prepare_request_with_refresh(self, client: httpx.AsyncClient, request: httpx.Request) -> None:
487+
"""Refresh stored tokens and add an auth header for requests sent outside the auth flow."""
488+
async with self.context.lock:
489+
if not self._initialized:
490+
await self._initialize()
491+
492+
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
493+
if protocol_version is not None:
494+
self.context.protocol_version = protocol_version
495+
496+
if self.context.is_token_valid():
497+
self._add_auth_header(request)
498+
return
499+
500+
if not self.context.can_refresh_token():
501+
return
502+
503+
refresh_request = await self._refresh_token()
504+
refresh_response = await client.send(refresh_request, auth=None)
505+
506+
if not await self._handle_refresh_response(refresh_response):
507+
return
508+
509+
self._add_auth_header(request)
510+
486511
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
487512
content = await response.aread()
488513
metadata = OAuthMetadata.model_validate_json(content)

src/mcp/client/sse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from httpx_sse import SSEError, aconnect_sse
1212

1313
import mcp.types as types
14+
from mcp.client.auth import OAuthClientProvider
1415
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1516
from mcp.shared.message import SessionMessage
1617

@@ -65,10 +66,19 @@ async def sse_client(
6566
async with httpx_client_factory(
6667
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
6768
) as client:
69+
sse_request_kwargs: dict[str, Any] = {}
70+
if isinstance(auth, OAuthClientProvider):
71+
sse_request = httpx.Request("GET", url, headers=headers)
72+
await auth.prepare_request_with_refresh(client, sse_request)
73+
if "Authorization" in sse_request.headers:
74+
sse_request_kwargs["headers"] = dict(sse_request.headers)
75+
sse_request_kwargs["auth"] = None
76+
6877
async with aconnect_sse(
6978
client,
7079
"GET",
7180
url,
81+
**sse_request_kwargs,
7282
) as event_source:
7383
event_source.response.raise_for_status()
7484
logger.debug("SSE connection established")

tests/client/test_auth.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_valid_client_metadata_url,
2929
should_use_client_metadata_url,
3030
)
31+
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3132
from mcp.shared.auth import (
3233
OAuthClientInformationFull,
3334
OAuthClientMetadata,
@@ -631,6 +632,209 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
631632
assert "client_id=test_client" in content
632633
assert "client_secret=test_secret" in content
633634

635+
@pytest.mark.anyio
636+
async def test_prepare_request_with_refresh_refreshes_expired_token(
637+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
638+
):
639+
"""Test preflight refresh for streaming requests that cannot drive OAuth inline."""
640+
641+
class FailingAuth(httpx.Auth):
642+
async def async_auth_flow(self, request: httpx.Request): # pragma: no cover
643+
raise AssertionError("preflight refresh should bypass client auth")
644+
yield request
645+
646+
oauth_provider.context.current_tokens = valid_tokens
647+
oauth_provider.context.token_expiry_time = time.time() - 1
648+
oauth_provider.context.client_info = OAuthClientInformationFull(
649+
client_id="test_client",
650+
client_secret="test_secret",
651+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
652+
token_endpoint_auth_method="client_secret_post",
653+
)
654+
oauth_provider._initialized = True
655+
656+
requests: list[httpx.Request] = []
657+
658+
async def handler(request: httpx.Request) -> httpx.Response:
659+
requests.append(request)
660+
return httpx.Response(
661+
200,
662+
json={
663+
"access_token": "refreshed_access_token",
664+
"token_type": "Bearer",
665+
"expires_in": 3600,
666+
"refresh_token": "refreshed_refresh_token",
667+
},
668+
request=request,
669+
)
670+
671+
request = httpx.Request(
672+
"GET",
673+
"https://api.example.com/v1/mcp/sse",
674+
headers={MCP_PROTOCOL_VERSION: "2025-06-18"},
675+
)
676+
677+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=FailingAuth()) as client:
678+
await oauth_provider.prepare_request_with_refresh(client, request)
679+
680+
assert len(requests) == 1
681+
assert requests[0].method == "POST"
682+
assert str(requests[0].url) == "https://api.example.com/token"
683+
assert "grant_type=refresh_token" in requests[0].content.decode()
684+
assert "resource=" in requests[0].content.decode()
685+
assert request.headers["Authorization"] == "Bearer refreshed_access_token"
686+
assert oauth_provider.context.current_tokens is not None
687+
assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token"
688+
assert mock_storage._tokens is not None
689+
assert mock_storage._tokens.access_token == "refreshed_access_token"
690+
691+
@pytest.mark.anyio
692+
async def test_prepare_request_with_refresh_skips_valid_token(
693+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
694+
):
695+
"""Test preflight refresh is a no-op while the current token is still valid."""
696+
oauth_provider.context.current_tokens = valid_tokens
697+
oauth_provider.context.token_expiry_time = time.time() + 1800
698+
oauth_provider.context.client_info = OAuthClientInformationFull(
699+
client_id="test_client",
700+
client_secret="test_secret",
701+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
702+
)
703+
oauth_provider._initialized = True
704+
705+
requests: list[httpx.Request] = []
706+
707+
async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover
708+
requests.append(request)
709+
return httpx.Response(500, request=request)
710+
711+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
712+
713+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
714+
await oauth_provider.prepare_request_with_refresh(client, request)
715+
716+
assert requests == []
717+
assert request.headers["Authorization"] == "Bearer test_access_token"
718+
assert oauth_provider.context.current_tokens is valid_tokens
719+
720+
@pytest.mark.anyio
721+
async def test_prepare_request_with_refresh_preserves_protocol_version_without_header(
722+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
723+
):
724+
"""Test preflight refresh preserves an existing protocol version when the request has no header."""
725+
oauth_provider.context.current_tokens = valid_tokens
726+
oauth_provider.context.token_expiry_time = time.time() - 1
727+
oauth_provider.context.client_info = OAuthClientInformationFull(
728+
client_id="test_client",
729+
client_secret="test_secret",
730+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
731+
token_endpoint_auth_method="client_secret_post",
732+
)
733+
oauth_provider.context.protocol_version = "2025-06-18"
734+
oauth_provider._initialized = True
735+
736+
requests: list[httpx.Request] = []
737+
738+
async def handler(request: httpx.Request) -> httpx.Response:
739+
requests.append(request)
740+
return httpx.Response(
741+
200,
742+
json={
743+
"access_token": "refreshed_access_token",
744+
"token_type": "Bearer",
745+
"expires_in": 3600,
746+
"refresh_token": "refreshed_refresh_token",
747+
},
748+
request=request,
749+
)
750+
751+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
752+
753+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
754+
await oauth_provider.prepare_request_with_refresh(client, request)
755+
756+
assert len(requests) == 1
757+
assert "resource=" in requests[0].content.decode()
758+
assert oauth_provider.context.protocol_version == "2025-06-18"
759+
assert request.headers["Authorization"] == "Bearer refreshed_access_token"
760+
761+
@pytest.mark.anyio
762+
async def test_prepare_request_with_refresh_initializes_storage(
763+
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
764+
):
765+
"""Test preflight refresh loads persisted OAuth state before preparing the request."""
766+
client_info = OAuthClientInformationFull(
767+
client_id="test_client",
768+
client_secret="test_secret",
769+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
770+
)
771+
await mock_storage.set_tokens(valid_tokens)
772+
await mock_storage.set_client_info(client_info)
773+
774+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
775+
776+
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda request: httpx.Response(500))) as client:
777+
await oauth_provider.prepare_request_with_refresh(client, request)
778+
779+
assert request.headers["Authorization"] == "Bearer test_access_token"
780+
assert oauth_provider.context.current_tokens is valid_tokens
781+
assert oauth_provider.context.client_info is client_info
782+
783+
@pytest.mark.anyio
784+
async def test_prepare_request_with_refresh_skips_without_refresh_token(self, oauth_provider: OAuthClientProvider):
785+
"""Test preflight refresh leaves the request alone when refresh is not possible."""
786+
oauth_provider.context.current_tokens = OAuthToken(
787+
access_token="expired_access_token",
788+
refresh_token=None,
789+
expires_in=1,
790+
)
791+
oauth_provider.context.token_expiry_time = time.time() - 1
792+
oauth_provider.context.client_info = OAuthClientInformationFull(
793+
client_id="test_client",
794+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
795+
)
796+
oauth_provider._initialized = True
797+
798+
requests: list[httpx.Request] = []
799+
800+
async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover
801+
requests.append(request)
802+
return httpx.Response(500, request=request)
803+
804+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
805+
806+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
807+
await oauth_provider.prepare_request_with_refresh(client, request)
808+
809+
assert requests == []
810+
assert "Authorization" not in request.headers
811+
812+
@pytest.mark.anyio
813+
async def test_prepare_request_with_refresh_keeps_request_unauthenticated_after_refresh_failure(
814+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
815+
):
816+
"""Test failed preflight refresh does not add a stale bearer header."""
817+
oauth_provider.context.current_tokens = valid_tokens
818+
oauth_provider.context.token_expiry_time = time.time() - 1
819+
oauth_provider.context.client_info = OAuthClientInformationFull(
820+
client_id="test_client",
821+
client_secret="test_secret",
822+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
823+
token_endpoint_auth_method="client_secret_post",
824+
)
825+
oauth_provider._initialized = True
826+
827+
async def handler(request: httpx.Request) -> httpx.Response:
828+
return httpx.Response(400, request=request)
829+
830+
request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")
831+
832+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
833+
await oauth_provider.prepare_request_with_refresh(client, request)
834+
835+
assert "Authorization" not in request.headers
836+
assert oauth_provider.context.current_tokens is None
837+
634838
@pytest.mark.anyio
635839
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
636840
"""Test token exchange with client_secret_basic authentication."""

0 commit comments

Comments
 (0)