|
28 | 28 | is_valid_client_metadata_url, |
29 | 29 | should_use_client_metadata_url, |
30 | 30 | ) |
| 31 | +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION |
31 | 32 | from mcp.shared.auth import ( |
32 | 33 | OAuthClientInformationFull, |
33 | 34 | OAuthClientMetadata, |
@@ -631,6 +632,209 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, |
631 | 632 | assert "client_id=test_client" in content |
632 | 633 | assert "client_secret=test_secret" in content |
633 | 634 |
|
| 635 | + @pytest.mark.anyio |
| 636 | + async def test_prepare_request_with_refresh_refreshes_expired_token( |
| 637 | + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken |
| 638 | + ): |
| 639 | + """Test preflight refresh for streaming requests that cannot drive OAuth inline.""" |
| 640 | + |
| 641 | + class FailingAuth(httpx.Auth): |
| 642 | + async def async_auth_flow(self, request: httpx.Request): # pragma: no cover |
| 643 | + raise AssertionError("preflight refresh should bypass client auth") |
| 644 | + yield request |
| 645 | + |
| 646 | + oauth_provider.context.current_tokens = valid_tokens |
| 647 | + oauth_provider.context.token_expiry_time = time.time() - 1 |
| 648 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 649 | + client_id="test_client", |
| 650 | + client_secret="test_secret", |
| 651 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 652 | + token_endpoint_auth_method="client_secret_post", |
| 653 | + ) |
| 654 | + oauth_provider._initialized = True |
| 655 | + |
| 656 | + requests: list[httpx.Request] = [] |
| 657 | + |
| 658 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 659 | + requests.append(request) |
| 660 | + return httpx.Response( |
| 661 | + 200, |
| 662 | + json={ |
| 663 | + "access_token": "refreshed_access_token", |
| 664 | + "token_type": "Bearer", |
| 665 | + "expires_in": 3600, |
| 666 | + "refresh_token": "refreshed_refresh_token", |
| 667 | + }, |
| 668 | + request=request, |
| 669 | + ) |
| 670 | + |
| 671 | + request = httpx.Request( |
| 672 | + "GET", |
| 673 | + "https://api.example.com/v1/mcp/sse", |
| 674 | + headers={MCP_PROTOCOL_VERSION: "2025-06-18"}, |
| 675 | + ) |
| 676 | + |
| 677 | + async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=FailingAuth()) as client: |
| 678 | + await oauth_provider.prepare_request_with_refresh(client, request) |
| 679 | + |
| 680 | + assert len(requests) == 1 |
| 681 | + assert requests[0].method == "POST" |
| 682 | + assert str(requests[0].url) == "https://api.example.com/token" |
| 683 | + assert "grant_type=refresh_token" in requests[0].content.decode() |
| 684 | + assert "resource=" in requests[0].content.decode() |
| 685 | + assert request.headers["Authorization"] == "Bearer refreshed_access_token" |
| 686 | + assert oauth_provider.context.current_tokens is not None |
| 687 | + assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token" |
| 688 | + assert mock_storage._tokens is not None |
| 689 | + assert mock_storage._tokens.access_token == "refreshed_access_token" |
| 690 | + |
| 691 | + @pytest.mark.anyio |
| 692 | + async def test_prepare_request_with_refresh_skips_valid_token( |
| 693 | + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken |
| 694 | + ): |
| 695 | + """Test preflight refresh is a no-op while the current token is still valid.""" |
| 696 | + oauth_provider.context.current_tokens = valid_tokens |
| 697 | + oauth_provider.context.token_expiry_time = time.time() + 1800 |
| 698 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 699 | + client_id="test_client", |
| 700 | + client_secret="test_secret", |
| 701 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 702 | + ) |
| 703 | + oauth_provider._initialized = True |
| 704 | + |
| 705 | + requests: list[httpx.Request] = [] |
| 706 | + |
| 707 | + async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover |
| 708 | + requests.append(request) |
| 709 | + return httpx.Response(500, request=request) |
| 710 | + |
| 711 | + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") |
| 712 | + |
| 713 | + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: |
| 714 | + await oauth_provider.prepare_request_with_refresh(client, request) |
| 715 | + |
| 716 | + assert requests == [] |
| 717 | + assert request.headers["Authorization"] == "Bearer test_access_token" |
| 718 | + assert oauth_provider.context.current_tokens is valid_tokens |
| 719 | + |
| 720 | + @pytest.mark.anyio |
| 721 | + async def test_prepare_request_with_refresh_preserves_protocol_version_without_header( |
| 722 | + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken |
| 723 | + ): |
| 724 | + """Test preflight refresh preserves an existing protocol version when the request has no header.""" |
| 725 | + oauth_provider.context.current_tokens = valid_tokens |
| 726 | + oauth_provider.context.token_expiry_time = time.time() - 1 |
| 727 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 728 | + client_id="test_client", |
| 729 | + client_secret="test_secret", |
| 730 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 731 | + token_endpoint_auth_method="client_secret_post", |
| 732 | + ) |
| 733 | + oauth_provider.context.protocol_version = "2025-06-18" |
| 734 | + oauth_provider._initialized = True |
| 735 | + |
| 736 | + requests: list[httpx.Request] = [] |
| 737 | + |
| 738 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 739 | + requests.append(request) |
| 740 | + return httpx.Response( |
| 741 | + 200, |
| 742 | + json={ |
| 743 | + "access_token": "refreshed_access_token", |
| 744 | + "token_type": "Bearer", |
| 745 | + "expires_in": 3600, |
| 746 | + "refresh_token": "refreshed_refresh_token", |
| 747 | + }, |
| 748 | + request=request, |
| 749 | + ) |
| 750 | + |
| 751 | + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") |
| 752 | + |
| 753 | + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: |
| 754 | + await oauth_provider.prepare_request_with_refresh(client, request) |
| 755 | + |
| 756 | + assert len(requests) == 1 |
| 757 | + assert "resource=" in requests[0].content.decode() |
| 758 | + assert oauth_provider.context.protocol_version == "2025-06-18" |
| 759 | + assert request.headers["Authorization"] == "Bearer refreshed_access_token" |
| 760 | + |
| 761 | + @pytest.mark.anyio |
| 762 | + async def test_prepare_request_with_refresh_initializes_storage( |
| 763 | + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken |
| 764 | + ): |
| 765 | + """Test preflight refresh loads persisted OAuth state before preparing the request.""" |
| 766 | + client_info = OAuthClientInformationFull( |
| 767 | + client_id="test_client", |
| 768 | + client_secret="test_secret", |
| 769 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 770 | + ) |
| 771 | + await mock_storage.set_tokens(valid_tokens) |
| 772 | + await mock_storage.set_client_info(client_info) |
| 773 | + |
| 774 | + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") |
| 775 | + |
| 776 | + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda request: httpx.Response(500))) as client: |
| 777 | + await oauth_provider.prepare_request_with_refresh(client, request) |
| 778 | + |
| 779 | + assert request.headers["Authorization"] == "Bearer test_access_token" |
| 780 | + assert oauth_provider.context.current_tokens is valid_tokens |
| 781 | + assert oauth_provider.context.client_info is client_info |
| 782 | + |
| 783 | + @pytest.mark.anyio |
| 784 | + async def test_prepare_request_with_refresh_skips_without_refresh_token(self, oauth_provider: OAuthClientProvider): |
| 785 | + """Test preflight refresh leaves the request alone when refresh is not possible.""" |
| 786 | + oauth_provider.context.current_tokens = OAuthToken( |
| 787 | + access_token="expired_access_token", |
| 788 | + refresh_token=None, |
| 789 | + expires_in=1, |
| 790 | + ) |
| 791 | + oauth_provider.context.token_expiry_time = time.time() - 1 |
| 792 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 793 | + client_id="test_client", |
| 794 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 795 | + ) |
| 796 | + oauth_provider._initialized = True |
| 797 | + |
| 798 | + requests: list[httpx.Request] = [] |
| 799 | + |
| 800 | + async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover |
| 801 | + requests.append(request) |
| 802 | + return httpx.Response(500, request=request) |
| 803 | + |
| 804 | + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") |
| 805 | + |
| 806 | + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: |
| 807 | + await oauth_provider.prepare_request_with_refresh(client, request) |
| 808 | + |
| 809 | + assert requests == [] |
| 810 | + assert "Authorization" not in request.headers |
| 811 | + |
| 812 | + @pytest.mark.anyio |
| 813 | + async def test_prepare_request_with_refresh_keeps_request_unauthenticated_after_refresh_failure( |
| 814 | + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken |
| 815 | + ): |
| 816 | + """Test failed preflight refresh does not add a stale bearer header.""" |
| 817 | + oauth_provider.context.current_tokens = valid_tokens |
| 818 | + oauth_provider.context.token_expiry_time = time.time() - 1 |
| 819 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 820 | + client_id="test_client", |
| 821 | + client_secret="test_secret", |
| 822 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 823 | + token_endpoint_auth_method="client_secret_post", |
| 824 | + ) |
| 825 | + oauth_provider._initialized = True |
| 826 | + |
| 827 | + async def handler(request: httpx.Request) -> httpx.Response: |
| 828 | + return httpx.Response(400, request=request) |
| 829 | + |
| 830 | + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") |
| 831 | + |
| 832 | + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: |
| 833 | + await oauth_provider.prepare_request_with_refresh(client, request) |
| 834 | + |
| 835 | + assert "Authorization" not in request.headers |
| 836 | + assert oauth_provider.context.current_tokens is None |
| 837 | + |
634 | 838 | @pytest.mark.anyio |
635 | 839 | async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider): |
636 | 840 | """Test token exchange with client_secret_basic authentication.""" |
|
0 commit comments