From ec818afd53169ae2a8480d027fdc1cf8cb01ae65 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Thu, 14 May 2026 19:17:54 -0700 Subject: [PATCH] feat(auth): expose access token subject --- src/mcp/server/auth/provider.py | 4 +++- src/mcp/server/mcpserver/context.py | 7 +++++++ .../auth/middleware/test_auth_context.py | 5 +++++ tests/server/mcpserver/test_server.py | 18 ++++++++++++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a854..f3292d170d 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -40,6 +40,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None + claims: dict[str, Any] | None = None RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index e87388eee9..dee8296965 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -5,6 +5,7 @@ from pydantic import AnyUrl, BaseModel +from mcp.server.auth.middleware.auth_context import get_access_token from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext from mcp.server.elicitation import ( ElicitationResult, @@ -213,6 +214,12 @@ def client_id(self) -> str | None: """Get the client ID if available.""" return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + @property + def subject(self) -> str | None: + """Get the subject from the authenticated access token if available.""" + access_token = get_access_token() + return access_token.subject if access_token else None + @property def request_id(self) -> str: """Get the unique ID for this request.""" diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index 66481bcf79..ffeadd0973 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -41,6 +41,8 @@ def valid_access_token() -> AccessToken: client_id="test_client", scopes=["read", "write"], expires_at=int(time.time()) + 3600, # 1 hour from now + subject="user_123", + claims={"tenant_id": "tenant_456"}, ) @@ -77,6 +79,9 @@ async def send(message: Message) -> None: # pragma: no cover # Verify the access token was available during the call assert app.access_token_during_call == valid_access_token + assert app.access_token_during_call is not None + assert app.access_token_during_call.subject == "user_123" + assert app.access_token_during_call.claims == {"tenant_id": "tenant_456"} # Verify context is reset after middleware assert auth_context_var.get() is None diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3457ec944a..cc690a03af 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -10,6 +10,9 @@ from starlette.routing import Mount, Route from mcp.client import Client +from mcp.server.auth.middleware.auth_context import auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer @@ -1516,3 +1519,18 @@ async def test_report_progress_passes_related_request_id(): message="halfway", related_request_id="req-abc-123", ) + + +def test_context_subject_reads_authenticated_access_token(): + """Test that Context exposes the authenticated token subject.""" + access_token = AccessToken( + token="valid_token", + client_id="test_client", + scopes=["read"], + subject="user_123", + ) + token = auth_context_var.set(AuthenticatedUser(access_token)) + try: + assert Context().subject == "user_123" + finally: + auth_context_var.reset(token)