From a2b3844dd99a1995d09b4aa1e266e5fba9243811 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Fri, 15 May 2026 08:12:36 -0700 Subject: [PATCH] feat: expose schema generator for tool schemas --- src/mcp/server/mcpserver/server.py | 8 +++- src/mcp/server/mcpserver/tools/base.py | 11 ++++- .../server/mcpserver/tools/tool_manager.py | 12 +++++- .../mcpserver/utilities/func_metadata.py | 11 ++++- tests/server/mcpserver/test_server.py | 40 +++++++++++++++++++ 5 files changed, 77 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b..374a1f72a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -12,6 +12,7 @@ import anyio import pydantic_core +from pydantic.json_schema import GenerateJsonSchema from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -149,6 +150,7 @@ def __init__( dependencies: list[str] | None = None, lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, + schema_generator: type[GenerateJsonSchema] | None = None, ): self.settings = Settings( debug=debug, @@ -162,7 +164,11 @@ def __init__( ) self.dependencies = self.settings.dependencies - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._tool_manager = ToolManager( + tools=tools, + warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools, + schema_generator=schema_generator, + ) self._resource_manager = ResourceManager( resources=resources, warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources ) diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 754313eb8..b572c0790 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field +from pydantic.json_schema import GenerateJsonSchema from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter @@ -52,6 +53,7 @@ def from_function( icons: list[Icon] | None = None, meta: dict[str, Any] | None = None, structured_output: bool | None = None, + schema_generator: type[GenerateJsonSchema] | None = None, ) -> Tool: """Create a Tool from a function.""" func_name = name or fn.__name__ @@ -71,8 +73,15 @@ def from_function( fn, skip_names=[context_kwarg] if context_kwarg is not None else [], structured_output=structured_output, + schema_generator=schema_generator, ) - parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) + if schema_generator is None: + parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) + else: + parameters = func_arg_metadata.arg_model.model_json_schema( + by_alias=True, + schema_generator=schema_generator, + ) return cls( fn=fn, diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index eef4911f9..9dd580212 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -3,6 +3,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from pydantic.json_schema import GenerateJsonSchema + from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools.base import Tool from mcp.server.mcpserver.utilities.logging import get_logger @@ -18,7 +20,13 @@ class ToolManager: """Manages MCPServer tools.""" - def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None): + def __init__( + self, + warn_on_duplicate_tools: bool = True, + *, + tools: list[Tool] | None = None, + schema_generator: type[GenerateJsonSchema] | None = None, + ): self._tools: dict[str, Tool] = {} for tool in tools or (): if warn_on_duplicate_tools and tool.name in self._tools: @@ -26,6 +34,7 @@ def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | self._tools[tool.name] = tool self.warn_on_duplicate_tools = warn_on_duplicate_tools + self.schema_generator = schema_generator def get_tool(self, name: str) -> Tool | None: """Get tool by name.""" @@ -56,6 +65,7 @@ def add_tool( icons=icons, meta=meta, structured_output=structured_output, + schema_generator=self.schema_generator, ) existing = self._tools.get(tool.name) if existing: diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 4a7610637..536af5ac2 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -171,6 +171,7 @@ def func_metadata( func: Callable[..., Any], skip_names: Sequence[str] = (), structured_output: bool | None = None, + schema_generator: type[GenerateJsonSchema] | None = None, ) -> FuncMetadata: """Given a function, return metadata including a Pydantic model representing its signature. @@ -201,6 +202,8 @@ def func_metadata( - TypedDict - converted to a Pydantic model with same fields - Dataclasses and other annotated classes - converted to Pydantic models - Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field + schema_generator: Optional Pydantic JSON schema generator to use when producing + structured output schemas. Defaults to strict schema generation. Returns: A FuncMetadata object containing: @@ -302,7 +305,10 @@ def func_metadata( original_annotation = sig.return_annotation output_model, output_schema, wrap_output = _try_create_model_and_schema( - original_annotation, return_type_expr, func.__name__ + original_annotation, + return_type_expr, + func.__name__, + schema_generator=schema_generator, ) if output_model is None and structured_output is True: @@ -323,6 +329,7 @@ def _try_create_model_and_schema( original_annotation: Any, type_expr: Any, func_name: str, + schema_generator: type[GenerateJsonSchema] | None = None, ) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]: """Try to create a model and schema for the given annotation without warnings. @@ -401,7 +408,7 @@ def _try_create_model_and_schema( # If we successfully created a model, try to get its schema # Use StrictJsonSchema to raise exceptions instead of warnings try: - schema = model.model_json_schema(schema_generator=StrictJsonSchema) + schema = model.model_json_schema(schema_generator=schema_generator or StrictJsonSchema) except ( PydanticUserError, TypeError, diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3457ec944..b4477b2bd 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -6,6 +6,7 @@ import pytest from inline_snapshot import snapshot from pydantic import BaseModel +from pydantic.json_schema import GenerateJsonSchema from starlette.applications import Starlette from starlette.routing import Mount, Route @@ -42,6 +43,13 @@ TextResourceContents, ) + +class PrimitiveUnionSchema(GenerateJsonSchema): + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["union_format"] = "primitive_type_array" + super().__init__(*args, **kwargs) + + pytestmark = pytest.mark.anyio @@ -74,6 +82,38 @@ def test_dependencies(self): mcp_no_deps = MCPServer("test") assert mcp_no_deps.dependencies == [] + def test_schema_generator_applies_to_tool_schemas(self): + def convert(value: int | str) -> int | str: + return value + + mcp = MCPServer(schema_generator=PrimitiveUnionSchema) + mcp.add_tool(convert) + + tool = mcp._tool_manager.get_tool("convert") + assert tool is not None + assert tool.parameters["properties"]["value"] == {"title": "Value", "type": ["integer", "string"]} + assert tool.output_schema is not None + assert tool.output_schema["properties"]["result"] == {"title": "Result", "type": ["integer", "string"]} + + def test_default_tool_schema_generation_is_preserved(self): + def convert(value: int | str) -> int | str: + return value + + mcp = MCPServer() + mcp.add_tool(convert) + + tool = mcp._tool_manager.get_tool("convert") + assert tool is not None + assert tool.parameters["properties"]["value"] == { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Value", + } + assert tool.output_schema is not None + assert tool.output_schema["properties"]["result"] == { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Result", + } + async def test_sse_app_returns_starlette_app(self): """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test")