|
20 | 20 |
|
21 | 21 | import mcp.client.sse |
22 | 22 | import mcp.types as types |
| 23 | +from mcp.client.auth import OAuthClientProvider |
23 | 24 | from mcp.client.session import ClientSession |
24 | 25 | from mcp.client.sse import _extract_session_id_from_endpoint, sse_client |
| 26 | +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION |
25 | 27 | from mcp.server import Server |
26 | 28 | from mcp.server.sse import SseServerTransport |
27 | 29 | from mcp.server.transport_security import TransportSecuritySettings |
| 30 | +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken |
28 | 31 | from mcp.shared.exceptions import McpError |
29 | 32 | from mcp.types import ( |
30 | 33 | EmptyResult, |
@@ -602,3 +605,110 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: |
602 | 605 | assert not isinstance(msg, Exception) |
603 | 606 | assert isinstance(msg.message.root, types.JSONRPCResponse) |
604 | 607 | assert msg.message.root.id == 1 |
| 608 | + |
| 609 | + |
| 610 | +@pytest.mark.filterwarnings("ignore::ResourceWarning") |
| 611 | +@pytest.mark.anyio |
| 612 | +async def test_sse_client_preflights_oauth_refresh_before_streaming() -> None: |
| 613 | + """Regression test for OAuth refresh deadlocks while opening SSE streams.""" |
| 614 | + |
| 615 | + class MemoryTokenStorage: |
| 616 | + def __init__(self) -> None: |
| 617 | + self.tokens: OAuthToken | None = None |
| 618 | + self.client_info: OAuthClientInformationFull | None = None |
| 619 | + |
| 620 | + async def get_tokens(self) -> OAuthToken | None: |
| 621 | + return self.tokens |
| 622 | + |
| 623 | + async def set_tokens(self, tokens: OAuthToken) -> None: |
| 624 | + self.tokens = tokens |
| 625 | + |
| 626 | + async def get_client_info(self) -> OAuthClientInformationFull | None: |
| 627 | + return self.client_info |
| 628 | + |
| 629 | + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: |
| 630 | + self.client_info = client_info |
| 631 | + |
| 632 | + class NoStreamAuthProvider(OAuthClientProvider): |
| 633 | + async def async_auth_flow(self, request: httpx.Request): |
| 634 | + if request.url.path.endswith("/sse"): |
| 635 | + raise AssertionError("SSE stream should use the preflight bearer header") |
| 636 | + async for auth_request in super().async_auth_flow(request): |
| 637 | + yield auth_request |
| 638 | + |
| 639 | + storage = MemoryTokenStorage() |
| 640 | + oauth_provider = NoStreamAuthProvider( |
| 641 | + server_url="https://api.example.com/v1/mcp", |
| 642 | + client_metadata=OAuthClientMetadata( |
| 643 | + client_name="Test Client", |
| 644 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 645 | + ), |
| 646 | + storage=storage, |
| 647 | + ) |
| 648 | + oauth_provider.context.current_tokens = OAuthToken( |
| 649 | + access_token="expired_access_token", |
| 650 | + refresh_token="refresh_token", |
| 651 | + expires_in=1, |
| 652 | + ) |
| 653 | + oauth_provider.context.token_expiry_time = time.time() - 1 |
| 654 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 655 | + client_id="test_client", |
| 656 | + client_secret="test_secret", |
| 657 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 658 | + token_endpoint_auth_method="client_secret_post", |
| 659 | + ) |
| 660 | + oauth_provider._initialized = True |
| 661 | + |
| 662 | + events: list[str] = [] |
| 663 | + |
| 664 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 665 | + if request.url.path == "/token": |
| 666 | + events.append("refresh") |
| 667 | + assert request.method == "POST" |
| 668 | + assert "resource=" in request.content.decode() |
| 669 | + return httpx.Response( |
| 670 | + 200, |
| 671 | + json={ |
| 672 | + "access_token": "refreshed_access_token", |
| 673 | + "token_type": "Bearer", |
| 674 | + "expires_in": 3600, |
| 675 | + "refresh_token": "refreshed_refresh_token", |
| 676 | + }, |
| 677 | + request=request, |
| 678 | + ) |
| 679 | + |
| 680 | + events.append("sse") |
| 681 | + assert request.url.path == "/v1/mcp/sse" |
| 682 | + assert request.headers["Authorization"] == "Bearer refreshed_access_token" |
| 683 | + return httpx.Response( |
| 684 | + 200, |
| 685 | + headers={"Content-Type": "text/event-stream"}, |
| 686 | + content=b"event: endpoint\ndata: /messages/?session_id=abc123\n\n", |
| 687 | + request=request, |
| 688 | + ) |
| 689 | + |
| 690 | + def client_factory( |
| 691 | + headers: dict[str, str] | None = None, |
| 692 | + timeout: httpx.Timeout | None = None, |
| 693 | + auth: httpx.Auth | None = None, |
| 694 | + ) -> httpx.AsyncClient: |
| 695 | + assert auth is oauth_provider |
| 696 | + return httpx.AsyncClient( |
| 697 | + headers=headers, |
| 698 | + timeout=timeout, |
| 699 | + auth=auth, |
| 700 | + transport=httpx.MockTransport(handler), |
| 701 | + ) |
| 702 | + |
| 703 | + with anyio.fail_after(1): |
| 704 | + async with sse_client( |
| 705 | + "https://api.example.com/v1/mcp/sse", |
| 706 | + headers={MCP_PROTOCOL_VERSION: "2025-06-18"}, |
| 707 | + auth=oauth_provider, |
| 708 | + httpx_client_factory=client_factory, |
| 709 | + ): |
| 710 | + pass |
| 711 | + |
| 712 | + assert events == ["refresh", "sse"] |
| 713 | + assert storage.tokens is not None |
| 714 | + assert storage.tokens.access_token == "refreshed_access_token" |
0 commit comments