Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/mcp/server/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata

DEFAULT_AUTH_CORS_ORIGIN_REGEX = r"^https?://(localhost|127\.0\.0\.1|\[::1\])(?::\d+)?$"


def validate_issuer_url(url: AnyHttpUrl):
"""Validate that the issuer URL meets OAuth 2.0 requirements.
Expand Down Expand Up @@ -55,10 +57,17 @@ def validate_issuer_url(url: AnyHttpUrl):
def cors_middleware(
handler: Callable[[Request], Response | Awaitable[Response]],
allow_methods: list[str],
*,
allow_origin_regex: str | None = None,
) -> ASGIApp:
# Default: allow loopback browser clients (e.g., MCP Inspector) without allowing arbitrary sites.
if allow_origin_regex is None:
allow_origin_regex = DEFAULT_AUTH_CORS_ORIGIN_REGEX

cors_app = CORSMiddleware(
app=request_response(handler),
allow_origins="*",
allow_origins=[],
allow_origin_regex=allow_origin_regex,
allow_methods=allow_methods,
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
)
Expand All @@ -71,6 +80,7 @@ def create_auth_routes(
service_documentation_url: AnyHttpUrl | None = None,
client_registration_options: ClientRegistrationOptions | None = None,
revocation_options: RevocationOptions | None = None,
cors_origin_regex: str | None = None,
) -> list[Route]:
validate_issuer_url(issuer_url)

Expand All @@ -94,6 +104,7 @@ def create_auth_routes(
endpoint=cors_middleware(
MetadataHandler(metadata).handle,
["GET", "OPTIONS"],
allow_origin_regex=cors_origin_regex,
),
methods=["GET", "OPTIONS"],
),
Expand All @@ -109,6 +120,7 @@ def create_auth_routes(
endpoint=cors_middleware(
TokenHandler(provider, client_authenticator).handle,
["POST", "OPTIONS"],
allow_origin_regex=cors_origin_regex,
),
methods=["POST", "OPTIONS"],
),
Expand All @@ -125,6 +137,7 @@ def create_auth_routes(
endpoint=cors_middleware(
registration_handler.handle,
["POST", "OPTIONS"],
allow_origin_regex=cors_origin_regex,
),
methods=["POST", "OPTIONS"],
)
Expand All @@ -138,6 +151,7 @@ def create_auth_routes(
endpoint=cors_middleware(
revocation_handler.handle,
["POST", "OPTIONS"],
allow_origin_regex=cors_origin_regex,
),
methods=["POST", "OPTIONS"],
)
Expand Down
8 changes: 8 additions & 0 deletions src/mcp/server/auth/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ class AuthSettings(BaseModel):
client_registration_options: ClientRegistrationOptions | None = None
revocation_options: RevocationOptions | None = None
required_scopes: list[str] | None = None
cors_origin_regex: str | None = Field(
default=None,
description=(
"Regex for allowed browser Origin values on the authorization server endpoints "
"(/token, /register, /.well-known/oauth-authorization-server, etc). "
"If unset, a safe default allows only loopback origins (localhost/127.0.0.1/[::1])."
),
)

# Resource Server settings (when operating as RS only)
resource_server_url: AnyHttpUrl | None = Field(
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ def streamable_http_app(
service_documentation_url=auth.service_documentation_url,
client_registration_options=auth.client_registration_options,
revocation_options=auth.revocation_options,
cors_origin_regex=auth.cors_origin_regex,
)
)

Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no
service_documentation_url=self.settings.auth.service_documentation_url,
client_registration_options=self.settings.auth.client_registration_options,
revocation_options=self.settings.auth.revocation_options,
cors_origin_regex=self.settings.auth.cors_origin_regex,
)
)

Expand Down
69 changes: 69 additions & 0 deletions tests/server/mcpserver/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
construct_redirect_uri,
)
from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken


Expand Down Expand Up @@ -325,6 +326,74 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
]
assert metadata["service_documentation"] == "https://docs.example.com/"

@pytest.mark.anyio
async def test_cors_allows_loopback_origin_by_default(self, test_client: httpx.AsyncClient):
origin = "http://localhost:5173"
response = await test_client.get(
"/.well-known/oauth-authorization-server",
headers={"Origin": origin},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == origin

@pytest.mark.anyio
async def test_cors_blocks_non_loopback_origin_by_default(self, test_client: httpx.AsyncClient):
origin = "https://evil.example"
response = await test_client.get(
"/.well-known/oauth-authorization-server",
headers={"Origin": origin},
)
assert response.status_code == 200
assert "access-control-allow-origin" not in response.headers

@pytest.mark.anyio
async def test_cors_preflight_allows_loopback_origin_by_default(self, test_client: httpx.AsyncClient):
origin = "http://127.0.0.1:3000"
response = await test_client.options(
"/token",
headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": MCP_PROTOCOL_VERSION_HEADER,
},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == origin

@pytest.mark.anyio
async def test_cors_origin_regex_override(self, mock_oauth_provider: MockOAuthProvider):
auth_routes = create_auth_routes(
mock_oauth_provider,
AnyHttpUrl("https://auth.example.com"),
AnyHttpUrl("https://docs.example.com"),
client_registration_options=ClientRegistrationOptions(
enabled=True,
valid_scopes=["read", "write", "profile"],
default_scopes=["read", "write"],
),
revocation_options=RevocationOptions(enabled=True),
cors_origin_regex=r"^https://allowed\.example$",
)
app = Starlette(routes=auth_routes)

async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="https://mcptest.com") as client:
allowed = "https://allowed.example"
blocked = "http://localhost:5173"

response = await client.get(
"/.well-known/oauth-authorization-server",
headers={"Origin": allowed},
)
assert response.status_code == 200
assert response.headers.get("access-control-allow-origin") == allowed

response = await client.get(
"/.well-known/oauth-authorization-server",
headers={"Origin": blocked},
)
assert response.status_code == 200
assert "access-control-allow-origin" not in response.headers

@pytest.mark.anyio
async def test_token_validation_error(self, test_client: httpx.AsyncClient):
"""Test token endpoint error - validation error."""
Expand Down