Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_serializer, field_validator


class OAuthToken(BaseModel):
Expand Down Expand Up @@ -129,6 +129,18 @@ class OAuthMetadata(BaseModel):
code_challenge_methods_supported: list[str] | None = None
client_id_metadata_document_supported: bool | None = None

@field_serializer("issuer")
def serialize_issuer_without_trailing_slash(self, v: AnyHttpUrl) -> str:
"""Strip trailing slash from issuer URL during serialization.

RFC 8414 examples show issuer URLs without trailing slashes, and some
OAuth clients (Google ADK, IBM MCP Context Forge) require exact match
between discovery URL and returned issuer per RFC 8414 Section 3.3.
Pydantic's AnyHttpUrl automatically adds a trailing slash, which breaks
these clients. See: https://github.com/modelcontextprotocol/python-sdk/issues/1919
"""
return str(v).rstrip("/")


class ProtectedResourceMetadata(BaseModel):
"""RFC 9728 OAuth 2.0 Protected Resource Metadata.
Expand All @@ -151,3 +163,17 @@ class ProtectedResourceMetadata(BaseModel):
dpop_signing_alg_values_supported: list[str] | None = None
# dpop_bound_access_tokens_required default is False, but ommited here for clarity
dpop_bound_access_tokens_required: bool | None = None

@field_serializer("resource")
def serialize_resource_without_trailing_slash(self, v: AnyHttpUrl) -> str:
"""Strip trailing slash from resource URL during serialization.

Same rationale as OAuthMetadata.issuer - RFC specs show URLs without
trailing slashes, and clients may require exact URL matching.
"""
return str(v).rstrip("/")

@field_serializer("authorization_servers")
def serialize_auth_servers_without_trailing_slash(self, v: list[AnyHttpUrl]) -> list[str]:
"""Strip trailing slashes from authorization server URLs during serialization."""
return [str(url).rstrip("/") for url in v]
5 changes: 3 additions & 2 deletions tests/server/auth/test_protected_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC
# For root resource, metadata should be at standard location
response = await root_resource_client.get("/.well-known/oauth-protected-resource")
assert response.status_code == 200
# Note: URLs should NOT have trailing slashes per RFC 8414/9728 (see issue #1919)
assert response.json() == snapshot(
{
"resource": "https://example.com/",
"authorization_servers": ["https://auth.example.com/"],
"resource": "https://example.com",
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read"],
"resource_name": "Root Resource",
"bearer_methods_supported": ["header"],
Expand Down
73 changes: 72 additions & 1 deletion tests/shared/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests for OAuth 2.0 shared code."""

from mcp.shared.auth import OAuthMetadata
import json

from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata


def test_oauth():
Expand Down Expand Up @@ -58,3 +60,72 @@ def test_oauth_with_jarm():
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
}
)


class TestIssuerTrailingSlash:
"""Tests for issue #1919: trailing slash in issuer URL.

RFC 8414 examples show issuer URLs without trailing slashes, and some
OAuth clients require exact match between discovery URL and returned issuer.
Pydantic's AnyHttpUrl automatically adds a trailing slash, so we strip it
during serialization.
"""

def test_oauth_metadata_issuer_no_trailing_slash_in_json(self):
"""Serialized issuer should not have trailing slash."""
metadata = OAuthMetadata(
issuer="https://example.com",
authorization_endpoint="https://example.com/oauth2/authorize",
token_endpoint="https://example.com/oauth2/token",
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["issuer"] == "https://example.com"
assert not serialized["issuer"].endswith("/")

def test_oauth_metadata_issuer_with_path_preserves_path(self):
"""Issuer with path should preserve the path, only strip trailing slash."""
metadata = OAuthMetadata(
issuer="https://example.com/auth",
authorization_endpoint="https://example.com/oauth2/authorize",
token_endpoint="https://example.com/oauth2/token",
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["issuer"] == "https://example.com/auth"
assert not serialized["issuer"].endswith("/")

def test_oauth_metadata_issuer_with_path_and_trailing_slash(self):
"""Issuer with path and trailing slash should only strip the trailing slash."""
metadata = OAuthMetadata(
issuer="https://example.com/auth/",
authorization_endpoint="https://example.com/oauth2/authorize",
token_endpoint="https://example.com/oauth2/token",
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["issuer"] == "https://example.com/auth"

def test_protected_resource_metadata_no_trailing_slash(self):
"""ProtectedResourceMetadata.resource should not have trailing slash."""
metadata = ProtectedResourceMetadata(
resource="https://example.com",
authorization_servers=["https://auth.example.com"],
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["resource"] == "https://example.com"
assert not serialized["resource"].endswith("/")

def test_protected_resource_metadata_auth_servers_no_trailing_slash(self):
"""ProtectedResourceMetadata.authorization_servers should not have trailing slashes."""
metadata = ProtectedResourceMetadata(
resource="https://example.com",
authorization_servers=[
"https://auth1.example.com",
"https://auth2.example.com/path",
],
)
serialized = json.loads(metadata.model_dump_json())
assert serialized["authorization_servers"] == [
"https://auth1.example.com",
"https://auth2.example.com/path",
]
for url in serialized["authorization_servers"]:
assert not url.endswith("/")
Loading