Skip to content

Commit 239d682

Browse files
fix: pass conformance auth scenarios, add RFC 8707 resource validation (#2010)
1 parent 7c7c13b commit 239d682

File tree

4 files changed

+187
-8
lines changed

4 files changed

+187
-8
lines changed

.github/actions/conformance/client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,27 @@ async def run_client_credentials_basic(server_url: str) -> None:
275275
async def run_auth_code_client(server_url: str) -> None:
276276
"""Authorization code flow (default for auth/* scenarios)."""
277277
callback_handler = ConformanceOAuthCallbackHandler()
278+
storage = InMemoryTokenStorage()
279+
280+
# Check for pre-registered client credentials from context
281+
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
282+
if context_json:
283+
try:
284+
context = json.loads(context_json)
285+
client_id = context.get("client_id")
286+
client_secret = context.get("client_secret")
287+
if client_id:
288+
await storage.set_client_info(
289+
OAuthClientInformationFull(
290+
client_id=client_id,
291+
client_secret=client_secret,
292+
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
293+
token_endpoint_auth_method="client_secret_basic" if client_secret else "none",
294+
)
295+
)
296+
logger.debug(f"Pre-loaded client credentials: client_id={client_id}")
297+
except json.JSONDecodeError:
298+
logger.exception("Failed to parse MCP_CONFORMANCE_CONTEXT")
278299

279300
oauth_auth = OAuthClientProvider(
280301
server_url=server_url,
@@ -284,7 +305,7 @@ async def run_auth_code_client(server_url: str) -> None:
284305
grant_types=["authorization_code", "refresh_token"],
285306
response_types=["code"],
286307
),
287-
storage=InMemoryTokenStorage(),
308+
storage=storage,
288309
redirect_handler=callback_handler.handle_redirect,
289310
callback_handler=callback_handler.handle_callback,
290311
client_metadata_url="https://conformance-test.local/client-metadata.json",

.github/workflows/conformance.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
name: Conformance Tests
22

33
on:
4-
# Disabled: conformance tests are currently broken in CI
5-
# push:
6-
# branches: [main]
7-
# pull_request:
4+
push:
5+
branches: [main]
6+
pull_request:
87
workflow_dispatch:
98

109
concurrency:
@@ -43,4 +42,4 @@ jobs:
4342
with:
4443
node-version: 24
4544
- run: uv sync --frozen --all-extras --package mcp
46-
- run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
45+
- run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all

src/mcp/client/auth/oauth2.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
230230
timeout: float = 300.0,
231231
client_metadata_url: str | None = None,
232+
validate_resource_url: Callable[[str, str | None], Awaitable[None]] | None = None,
232233
):
233234
"""Initialize OAuth2 authentication.
234235
@@ -243,6 +244,10 @@ def __init__(
243244
advertises client_id_metadata_document_supported=true, this URL will be
244245
used as the client_id instead of performing dynamic client registration.
245246
Must be a valid HTTPS URL with a non-root pathname.
247+
validate_resource_url: Optional callback to override resource URL validation.
248+
Called with (server_url, prm_resource) where prm_resource is the resource
249+
from Protected Resource Metadata (or None if not present). If not provided,
250+
default validation rejects mismatched resources per RFC 8707.
246251
247252
Raises:
248253
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
@@ -263,6 +268,7 @@ def __init__(
263268
timeout=timeout,
264269
client_metadata_url=client_metadata_url,
265270
)
271+
self._validate_resource_url_callback = validate_resource_url
266272
self._initialized = False
267273

268274
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
@@ -476,6 +482,26 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
476482
metadata = OAuthMetadata.model_validate_json(content)
477483
self.context.oauth_metadata = metadata
478484

485+
async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None:
486+
"""Validate that PRM resource matches the server URL per RFC 8707."""
487+
prm_resource = str(prm.resource) if prm.resource else None
488+
489+
if self._validate_resource_url_callback is not None:
490+
await self._validate_resource_url_callback(self.context.server_url, prm_resource)
491+
return
492+
493+
if not prm_resource:
494+
return # pragma: no cover
495+
default_resource = resource_url_from_server_url(self.context.server_url)
496+
# Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs
497+
# (e.g. "https://example.com/") while resource_url_from_server_url may not.
498+
if not default_resource.endswith("/"):
499+
default_resource += "/"
500+
if not prm_resource.endswith("/"):
501+
prm_resource += "/"
502+
if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource):
503+
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
504+
479505
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
480506
"""HTTPX auth flow integration."""
481507
async with self.context.lock:
@@ -517,6 +543,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
517543

518544
prm = await handle_protected_resource_response(discovery_response)
519545
if prm:
546+
# Validate PRM resource matches server URL (RFC 8707)
547+
await self._validate_resource_match(prm)
520548
self.context.protected_resource_metadata = prm
521549

522550
# todo: try all authorization_servers to find the OASM

tests/client/test_auth.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic import AnyHttpUrl, AnyUrl
1212

1313
from mcp.client.auth import OAuthClientProvider, PKCEParameters
14+
from mcp.client.auth.exceptions import OAuthFlowError
1415
from mcp.client.auth.utils import (
1516
build_oauth_authorization_server_metadata_discovery_urls,
1617
build_protected_resource_metadata_discovery_urls,
@@ -818,6 +819,136 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
818819
assert "resource=" in content
819820

820821

822+
@pytest.mark.anyio
823+
async def test_validate_resource_rejects_mismatched_resource(
824+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
825+
) -> None:
826+
"""Client must reject PRM resource that doesn't match server URL."""
827+
provider = OAuthClientProvider(
828+
server_url="https://api.example.com/v1/mcp",
829+
client_metadata=client_metadata,
830+
storage=mock_storage,
831+
)
832+
provider._initialized = True
833+
834+
prm = ProtectedResourceMetadata(
835+
resource=AnyHttpUrl("https://evil.example.com/mcp"),
836+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
837+
)
838+
with pytest.raises(OAuthFlowError, match="does not match expected"):
839+
await provider._validate_resource_match(prm)
840+
841+
842+
@pytest.mark.anyio
843+
async def test_validate_resource_accepts_matching_resource(
844+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
845+
) -> None:
846+
"""Client must accept PRM resource that matches server URL."""
847+
provider = OAuthClientProvider(
848+
server_url="https://api.example.com/v1/mcp",
849+
client_metadata=client_metadata,
850+
storage=mock_storage,
851+
)
852+
provider._initialized = True
853+
854+
prm = ProtectedResourceMetadata(
855+
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
856+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
857+
)
858+
# Should not raise
859+
await provider._validate_resource_match(prm)
860+
861+
862+
@pytest.mark.anyio
863+
async def test_validate_resource_custom_callback(
864+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
865+
) -> None:
866+
"""Custom callback overrides default validation."""
867+
callback_called_with: list[tuple[str, str | None]] = []
868+
869+
async def custom_validate(server_url: str, prm_resource: str | None) -> None:
870+
callback_called_with.append((server_url, prm_resource))
871+
872+
provider = OAuthClientProvider(
873+
server_url="https://api.example.com/v1/mcp",
874+
client_metadata=client_metadata,
875+
storage=mock_storage,
876+
validate_resource_url=custom_validate,
877+
)
878+
provider._initialized = True
879+
880+
# This would normally fail default validation (different origin),
881+
# but custom callback accepts it
882+
prm = ProtectedResourceMetadata(
883+
resource=AnyHttpUrl("https://evil.example.com/mcp"),
884+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
885+
)
886+
await provider._validate_resource_match(prm)
887+
assert callback_called_with == snapshot([("https://api.example.com/v1/mcp", "https://evil.example.com/mcp")])
888+
889+
890+
@pytest.mark.anyio
891+
async def test_validate_resource_accepts_root_url_with_trailing_slash(
892+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
893+
) -> None:
894+
"""Root URLs with trailing slash normalization should match."""
895+
provider = OAuthClientProvider(
896+
server_url="https://api.example.com",
897+
client_metadata=client_metadata,
898+
storage=mock_storage,
899+
)
900+
provider._initialized = True
901+
902+
prm = ProtectedResourceMetadata(
903+
resource=AnyHttpUrl("https://api.example.com/"),
904+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
905+
)
906+
# Should not raise despite trailing slash difference
907+
await provider._validate_resource_match(prm)
908+
909+
910+
@pytest.mark.anyio
911+
async def test_validate_resource_accepts_server_url_with_trailing_slash(
912+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
913+
) -> None:
914+
"""Server URL with trailing slash should match PRM resource."""
915+
provider = OAuthClientProvider(
916+
server_url="https://api.example.com/v1/mcp/",
917+
client_metadata=client_metadata,
918+
storage=mock_storage,
919+
)
920+
provider._initialized = True
921+
922+
prm = ProtectedResourceMetadata(
923+
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
924+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
925+
)
926+
# Should not raise - both normalize to the same URL with trailing slash
927+
await provider._validate_resource_match(prm)
928+
929+
930+
@pytest.mark.anyio
931+
async def test_get_resource_url_uses_canonical_when_prm_mismatches(
932+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
933+
) -> None:
934+
"""get_resource_url falls back to canonical URL when PRM resource doesn't match."""
935+
provider = OAuthClientProvider(
936+
server_url="https://api.example.com/v1/mcp",
937+
client_metadata=client_metadata,
938+
storage=mock_storage,
939+
)
940+
provider._initialized = True
941+
942+
# Set PRM with a resource that is NOT a parent of the server URL
943+
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
944+
resource=AnyHttpUrl("https://other.example.com/mcp"),
945+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
946+
)
947+
948+
# get_resource_url should return the canonical server URL, not the PRM resource
949+
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")
950+
951+
821952
class TestRegistrationResponse:
822953
"""Test client registration response handling."""
823954

@@ -963,7 +1094,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
9631094
# Send a successful discovery response with minimal protected resource metadata
9641095
discovery_response = httpx.Response(
9651096
200,
966-
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
1097+
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
9671098
request=discovery_request,
9681099
)
9691100

@@ -1116,7 +1247,7 @@ async def test_token_exchange_accepts_201_status(
11161247
# Send a successful discovery response with minimal protected resource metadata
11171248
discovery_response = httpx.Response(
11181249
200,
1119-
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
1250+
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
11201251
request=discovery_request,
11211252
)
11221253

0 commit comments

Comments
 (0)