Skip to content

Commit e4b4db5

Browse files
committed
fix(oauth): omit resource on refresh token
1 parent f475344 commit e4b4db5

2 files changed

Lines changed: 51 additions & 5 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ def get_resource_url(self) -> str:
151151

152152
# If PRM provides a resource that's a valid parent, use it
153153
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
154-
prm_resource = str(self.protected_resource_metadata.resource)
154+
# Pydantic v2 AnyHttpUrl normalizes bare-domain URLs by appending a trailing
155+
# slash (e.g. "https://example.com" -> "https://example.com/"). OAuth
156+
# providers may treat that as a distinct audience, so strip it.
157+
prm_resource = str(self.protected_resource_metadata.resource).rstrip("/")
155158
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
156159
resource = prm_resource
157160

@@ -442,10 +445,6 @@ async def _refresh_token(self) -> httpx.Request:
442445
"client_id": self.context.client_info.client_id,
443446
}
444447

445-
# Only include resource param if conditions are met
446-
if self.context.should_include_resource_param(self.context.protocol_version):
447-
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
448-
449448
# Prepare authentication based on preferred method
450449
headers = {"Content-Type": "application/x-www-form-urlencoded"}
451450
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)

tests/client/test_auth.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,24 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
259259
assert context.current_tokens is None
260260
assert context.token_expiry_time is None
261261

262+
def test_get_resource_url_strips_trailing_slash_from_bare_domain_prm(
263+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
264+
) -> None:
265+
"""get_resource_url strips Pydantic AnyHttpUrl trailing slash for bare-domain PRM."""
266+
provider = OAuthClientProvider(
267+
server_url="https://api.example.com/v1/mcp",
268+
client_metadata=client_metadata,
269+
storage=mock_storage,
270+
)
271+
provider._initialized = True
272+
273+
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
274+
resource=AnyHttpUrl("https://api.example.com"),
275+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
276+
)
277+
278+
assert provider.context.get_resource_url() == snapshot("https://api.example.com")
279+
262280

263281
class TestOAuthFlow:
264282
"""Test OAuth flow methods."""
@@ -631,6 +649,35 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
631649
assert "client_id=test_client" in content
632650
assert "client_secret=test_secret" in content
633651

652+
@pytest.mark.anyio
653+
async def test_refresh_token_request_omits_resource_even_when_required_by_protocol(
654+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
655+
) -> None:
656+
"""refresh_token request should not include RFC8707 resource parameter."""
657+
provider = OAuthClientProvider(
658+
server_url="https://api.example.com/v1/mcp",
659+
client_metadata=client_metadata,
660+
storage=mock_storage,
661+
)
662+
provider._initialized = True
663+
provider.context.protocol_version = "2025-06-18"
664+
provider.context.current_tokens = OAuthToken(
665+
access_token="test_access_token",
666+
token_type="bearer",
667+
expires_in=3600,
668+
refresh_token="test_refresh_token",
669+
)
670+
provider.context.client_info = OAuthClientInformationFull(
671+
client_id="test_client",
672+
client_secret="test_secret",
673+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
674+
token_endpoint_auth_method="client_secret_post",
675+
)
676+
677+
request = await provider._refresh_token()
678+
content = request.content.decode()
679+
assert "resource=" not in content
680+
634681
@pytest.mark.anyio
635682
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
636683
"""Test token exchange with client_secret_basic authentication."""

0 commit comments

Comments
 (0)