Skip to content

Commit a2b3844

Browse files
committed
feat: expose schema generator for tool schemas
1 parent 161834d commit a2b3844

5 files changed

Lines changed: 77 additions & 5 deletions

File tree

src/mcp/server/mcpserver/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import anyio
1414
import pydantic_core
15+
from pydantic.json_schema import GenerateJsonSchema
1516
from pydantic.networks import AnyUrl
1617
from pydantic_settings import BaseSettings, SettingsConfigDict
1718
from starlette.applications import Starlette
@@ -149,6 +150,7 @@ def __init__(
149150
dependencies: list[str] | None = None,
150151
lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None,
151152
auth: AuthSettings | None = None,
153+
schema_generator: type[GenerateJsonSchema] | None = None,
152154
):
153155
self.settings = Settings(
154156
debug=debug,
@@ -162,7 +164,11 @@ def __init__(
162164
)
163165
self.dependencies = self.settings.dependencies
164166

165-
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
167+
self._tool_manager = ToolManager(
168+
tools=tools,
169+
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools,
170+
schema_generator=schema_generator,
171+
)
166172
self._resource_manager = ResourceManager(
167173
resources=resources, warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
168174
)

src/mcp/server/mcpserver/tools/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from pydantic import BaseModel, Field
8+
from pydantic.json_schema import GenerateJsonSchema
89

910
from mcp.server.mcpserver.exceptions import ToolError
1011
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
@@ -52,6 +53,7 @@ def from_function(
5253
icons: list[Icon] | None = None,
5354
meta: dict[str, Any] | None = None,
5455
structured_output: bool | None = None,
56+
schema_generator: type[GenerateJsonSchema] | None = None,
5557
) -> Tool:
5658
"""Create a Tool from a function."""
5759
func_name = name or fn.__name__
@@ -71,8 +73,15 @@ def from_function(
7173
fn,
7274
skip_names=[context_kwarg] if context_kwarg is not None else [],
7375
structured_output=structured_output,
76+
schema_generator=schema_generator,
7477
)
75-
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
78+
if schema_generator is None:
79+
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
80+
else:
81+
parameters = func_arg_metadata.arg_model.model_json_schema(
82+
by_alias=True,
83+
schema_generator=schema_generator,
84+
)
7685

7786
return cls(
7887
fn=fn,

src/mcp/server/mcpserver/tools/tool_manager.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from collections.abc import Callable
44
from typing import TYPE_CHECKING, Any
55

6+
from pydantic.json_schema import GenerateJsonSchema
7+
68
from mcp.server.mcpserver.exceptions import ToolError
79
from mcp.server.mcpserver.tools.base import Tool
810
from mcp.server.mcpserver.utilities.logging import get_logger
@@ -18,14 +20,21 @@
1820
class ToolManager:
1921
"""Manages MCPServer tools."""
2022

21-
def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None):
23+
def __init__(
24+
self,
25+
warn_on_duplicate_tools: bool = True,
26+
*,
27+
tools: list[Tool] | None = None,
28+
schema_generator: type[GenerateJsonSchema] | None = None,
29+
):
2230
self._tools: dict[str, Tool] = {}
2331
for tool in tools or ():
2432
if warn_on_duplicate_tools and tool.name in self._tools:
2533
logger.warning(f"Tool already exists: {tool.name}")
2634
self._tools[tool.name] = tool
2735

2836
self.warn_on_duplicate_tools = warn_on_duplicate_tools
37+
self.schema_generator = schema_generator
2938

3039
def get_tool(self, name: str) -> Tool | None:
3140
"""Get tool by name."""
@@ -56,6 +65,7 @@ def add_tool(
5665
icons=icons,
5766
meta=meta,
5867
structured_output=structured_output,
68+
schema_generator=self.schema_generator,
5969
)
6070
existing = self._tools.get(tool.name)
6171
if existing:

src/mcp/server/mcpserver/utilities/func_metadata.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def func_metadata(
171171
func: Callable[..., Any],
172172
skip_names: Sequence[str] = (),
173173
structured_output: bool | None = None,
174+
schema_generator: type[GenerateJsonSchema] | None = None,
174175
) -> FuncMetadata:
175176
"""Given a function, return metadata including a Pydantic model representing its signature.
176177
@@ -201,6 +202,8 @@ def func_metadata(
201202
- TypedDict - converted to a Pydantic model with same fields
202203
- Dataclasses and other annotated classes - converted to Pydantic models
203204
- Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field
205+
schema_generator: Optional Pydantic JSON schema generator to use when producing
206+
structured output schemas. Defaults to strict schema generation.
204207
205208
Returns:
206209
A FuncMetadata object containing:
@@ -302,7 +305,10 @@ def func_metadata(
302305
original_annotation = sig.return_annotation
303306

304307
output_model, output_schema, wrap_output = _try_create_model_and_schema(
305-
original_annotation, return_type_expr, func.__name__
308+
original_annotation,
309+
return_type_expr,
310+
func.__name__,
311+
schema_generator=schema_generator,
306312
)
307313

308314
if output_model is None and structured_output is True:
@@ -323,6 +329,7 @@ def _try_create_model_and_schema(
323329
original_annotation: Any,
324330
type_expr: Any,
325331
func_name: str,
332+
schema_generator: type[GenerateJsonSchema] | None = None,
326333
) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]:
327334
"""Try to create a model and schema for the given annotation without warnings.
328335
@@ -401,7 +408,7 @@ def _try_create_model_and_schema(
401408
# If we successfully created a model, try to get its schema
402409
# Use StrictJsonSchema to raise exceptions instead of warnings
403410
try:
404-
schema = model.model_json_schema(schema_generator=StrictJsonSchema)
411+
schema = model.model_json_schema(schema_generator=schema_generator or StrictJsonSchema)
405412
except (
406413
PydanticUserError,
407414
TypeError,

tests/server/mcpserver/test_server.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from inline_snapshot import snapshot
88
from pydantic import BaseModel
9+
from pydantic.json_schema import GenerateJsonSchema
910
from starlette.applications import Starlette
1011
from starlette.routing import Mount, Route
1112

@@ -42,6 +43,13 @@
4243
TextResourceContents,
4344
)
4445

46+
47+
class PrimitiveUnionSchema(GenerateJsonSchema):
48+
def __init__(self, *args: Any, **kwargs: Any) -> None:
49+
kwargs["union_format"] = "primitive_type_array"
50+
super().__init__(*args, **kwargs)
51+
52+
4553
pytestmark = pytest.mark.anyio
4654

4755

@@ -74,6 +82,38 @@ def test_dependencies(self):
7482
mcp_no_deps = MCPServer("test")
7583
assert mcp_no_deps.dependencies == []
7684

85+
def test_schema_generator_applies_to_tool_schemas(self):
86+
def convert(value: int | str) -> int | str:
87+
return value
88+
89+
mcp = MCPServer(schema_generator=PrimitiveUnionSchema)
90+
mcp.add_tool(convert)
91+
92+
tool = mcp._tool_manager.get_tool("convert")
93+
assert tool is not None
94+
assert tool.parameters["properties"]["value"] == {"title": "Value", "type": ["integer", "string"]}
95+
assert tool.output_schema is not None
96+
assert tool.output_schema["properties"]["result"] == {"title": "Result", "type": ["integer", "string"]}
97+
98+
def test_default_tool_schema_generation_is_preserved(self):
99+
def convert(value: int | str) -> int | str:
100+
return value
101+
102+
mcp = MCPServer()
103+
mcp.add_tool(convert)
104+
105+
tool = mcp._tool_manager.get_tool("convert")
106+
assert tool is not None
107+
assert tool.parameters["properties"]["value"] == {
108+
"anyOf": [{"type": "integer"}, {"type": "string"}],
109+
"title": "Value",
110+
}
111+
assert tool.output_schema is not None
112+
assert tool.output_schema["properties"]["result"] == {
113+
"anyOf": [{"type": "integer"}, {"type": "string"}],
114+
"title": "Result",
115+
}
116+
77117
async def test_sse_app_returns_starlette_app(self):
78118
"""Test that sse_app returns a Starlette application with correct routes."""
79119
mcp = MCPServer("test")

0 commit comments

Comments
 (0)