From d0debb63fb4a630532c3a9b6d1574a64f64990a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20N=2E=20Eng=C3=B8y?= Date: Sun, 8 Feb 2026 20:13:04 +0100 Subject: [PATCH 1/2] auth: restrict CORS to loopback by default --- src/mcp/server/auth/routes.py | 16 ++++++++- src/mcp/server/auth/settings.py | 8 +++++ src/mcp/server/lowlevel/server.py | 1 + src/mcp/server/mcpserver/server.py | 1 + .../mcpserver/auth/test_auth_integration.py | 35 +++++++++++++++++++ 5 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 08f735f36..266831899 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -20,6 +20,8 @@ from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +DEFAULT_AUTH_CORS_ORIGIN_REGEX = r"^https?://(localhost|127\.0\.0\.1|\[::1\])(?::\d+)?$" + def validate_issuer_url(url: AnyHttpUrl): """Validate that the issuer URL meets OAuth 2.0 requirements. @@ -55,10 +57,17 @@ def validate_issuer_url(url: AnyHttpUrl): def cors_middleware( handler: Callable[[Request], Response | Awaitable[Response]], allow_methods: list[str], + *, + allow_origin_regex: str | None = None, ) -> ASGIApp: + # Default: allow loopback browser clients (e.g., MCP Inspector) without allowing arbitrary sites. + if allow_origin_regex is None: + allow_origin_regex = DEFAULT_AUTH_CORS_ORIGIN_REGEX + cors_app = CORSMiddleware( app=request_response(handler), - allow_origins="*", + allow_origins=[], + allow_origin_regex=allow_origin_regex, allow_methods=allow_methods, allow_headers=[MCP_PROTOCOL_VERSION_HEADER], ) @@ -71,6 +80,7 @@ def create_auth_routes( service_documentation_url: AnyHttpUrl | None = None, client_registration_options: ClientRegistrationOptions | None = None, revocation_options: RevocationOptions | None = None, + cors_origin_regex: str | None = None, ) -> list[Route]: validate_issuer_url(issuer_url) @@ -94,6 +104,7 @@ def create_auth_routes( endpoint=cors_middleware( MetadataHandler(metadata).handle, ["GET", "OPTIONS"], + allow_origin_regex=cors_origin_regex, ), methods=["GET", "OPTIONS"], ), @@ -109,6 +120,7 @@ def create_auth_routes( endpoint=cors_middleware( TokenHandler(provider, client_authenticator).handle, ["POST", "OPTIONS"], + allow_origin_regex=cors_origin_regex, ), methods=["POST", "OPTIONS"], ), @@ -125,6 +137,7 @@ def create_auth_routes( endpoint=cors_middleware( registration_handler.handle, ["POST", "OPTIONS"], + allow_origin_regex=cors_origin_regex, ), methods=["POST", "OPTIONS"], ) @@ -138,6 +151,7 @@ def create_auth_routes( endpoint=cors_middleware( revocation_handler.handle, ["POST", "OPTIONS"], + allow_origin_regex=cors_origin_regex, ), methods=["POST", "OPTIONS"], ) diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index 1649826db..29926dc98 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -21,6 +21,14 @@ class AuthSettings(BaseModel): client_registration_options: ClientRegistrationOptions | None = None revocation_options: RevocationOptions | None = None required_scopes: list[str] | None = None + cors_origin_regex: str | None = Field( + default=None, + description=( + "Regex for allowed browser Origin values on the authorization server endpoints " + "(/token, /register, /.well-known/oauth-authorization-server, etc). " + "If unset, a safe default allows only loopback origins (localhost/127.0.0.1/[::1])." + ), + ) # Resource Server settings (when operating as RS only) resource_server_url: AnyHttpUrl | None = Field( diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 96dcaf1c7..7f9129aad 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -867,6 +867,7 @@ def streamable_http_app( service_documentation_url=auth.service_documentation_url, client_registration_options=auth.client_registration_options, revocation_options=auth.revocation_options, + cors_origin_regex=auth.cors_origin_regex, ) ) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..7c3a86918 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -840,6 +840,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no service_documentation_url=self.settings.auth.service_documentation_url, client_registration_options=self.settings.auth.client_registration_options, revocation_options=self.settings.auth.revocation_options, + cors_origin_regex=self.settings.auth.cors_origin_regex, ) ) diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index a78a86cf0..9eac44994 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -22,6 +22,7 @@ construct_redirect_uri, ) from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes +from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.auth import OAuthClientInformationFull, OAuthToken @@ -325,6 +326,40 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): ] assert metadata["service_documentation"] == "https://docs.example.com/" + @pytest.mark.anyio + async def test_cors_allows_loopback_origin_by_default(self, test_client: httpx.AsyncClient): + origin = "http://localhost:5173" + response = await test_client.get( + "/.well-known/oauth-authorization-server", + headers={"Origin": origin}, + ) + assert response.status_code == 200 + assert response.headers.get("access-control-allow-origin") == origin + + @pytest.mark.anyio + async def test_cors_blocks_non_loopback_origin_by_default(self, test_client: httpx.AsyncClient): + origin = "https://evil.example" + response = await test_client.get( + "/.well-known/oauth-authorization-server", + headers={"Origin": origin}, + ) + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + + @pytest.mark.anyio + async def test_cors_preflight_allows_loopback_origin_by_default(self, test_client: httpx.AsyncClient): + origin = "http://127.0.0.1:3000" + response = await test_client.options( + "/token", + headers={ + "Origin": origin, + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": MCP_PROTOCOL_VERSION_HEADER, + }, + ) + assert response.status_code == 200 + assert response.headers.get("access-control-allow-origin") == origin + @pytest.mark.anyio async def test_token_validation_error(self, test_client: httpx.AsyncClient): """Test token endpoint error - validation error.""" From f7bf8b6feabd6a78279fabc92679af45b26da2d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20N=2E=20Eng=C3=B8y?= Date: Sun, 8 Feb 2026 23:00:22 +0100 Subject: [PATCH 2/2] auth: cover custom CORS origin regex --- .../mcpserver/auth/test_auth_integration.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 9eac44994..f907e3f9a 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -360,6 +360,40 @@ async def test_cors_preflight_allows_loopback_origin_by_default(self, test_clien assert response.status_code == 200 assert response.headers.get("access-control-allow-origin") == origin + @pytest.mark.anyio + async def test_cors_origin_regex_override(self, mock_oauth_provider: MockOAuthProvider): + auth_routes = create_auth_routes( + mock_oauth_provider, + AnyHttpUrl("https://auth.example.com"), + AnyHttpUrl("https://docs.example.com"), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["read", "write", "profile"], + default_scopes=["read", "write"], + ), + revocation_options=RevocationOptions(enabled=True), + cors_origin_regex=r"^https://allowed\.example$", + ) + app = Starlette(routes=auth_routes) + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="https://mcptest.com") as client: + allowed = "https://allowed.example" + blocked = "http://localhost:5173" + + response = await client.get( + "/.well-known/oauth-authorization-server", + headers={"Origin": allowed}, + ) + assert response.status_code == 200 + assert response.headers.get("access-control-allow-origin") == allowed + + response = await client.get( + "/.well-known/oauth-authorization-server", + headers={"Origin": blocked}, + ) + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + @pytest.mark.anyio async def test_token_validation_error(self, test_client: httpx.AsyncClient): """Test token endpoint error - validation error."""