diff --git a/docs/experimental/index.md b/docs/experimental/index.md index 1d496b3f1..c97fe2a3d 100644 --- a/docs/experimental/index.md +++ b/docs/experimental/index.md @@ -27,10 +27,9 @@ Tasks are useful for: Experimental features are accessed via the `.experimental` property: ```python -# Server-side -@server.experimental.get_task() -async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - ... +# Server-side: enable task support (auto-registers default handlers) +server = Server(name="my-server") +server.experimental.enable_tasks() # Client-side result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) diff --git a/docs/migration.md b/docs/migration.md index 7d30f0ac9..8bd95ff0b 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -351,7 +351,6 @@ The nested `RequestParams.Meta` Pydantic model class has been replaced with a to - `RequestParams.Meta` (Pydantic model) → `RequestParamsMeta` (TypedDict) - Attribute access (`meta.progress_token`) → Dictionary access (`meta.get("progress_token")`) - `progress_token` field changed from `ProgressToken | None = None` to `NotRequired[ProgressToken]` -` **In request context handlers:** @@ -364,11 +363,12 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: await ctx.session.send_progress_notification(ctx.meta.progress_token, 0.5, 100) # After (v2) -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> list[TextContent]: - ctx = server.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: if ctx.meta and "progress_token" in ctx.meta: await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) + ... + +server = Server("my-server", on_call_tool=handle_call_tool) ``` ### `RequestContext` and `ProgressContext` type parameters simplified @@ -471,6 +471,157 @@ await client.read_resource("test://resource") await client.read_resource(str(my_any_url)) ``` +### Lowlevel `Server`: constructor parameters are now keyword-only + +All parameters after `name` are now keyword-only. If you were passing `version` or other parameters positionally, use keyword arguments instead: + +```python +# Before (v1) +server = Server("my-server", "1.0") + +# After (v2) +server = Server("my-server", version="1.0") +``` + +### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params + +The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import Server + +server = Server("my-server") + +@server.list_tools() +async def handle_list_tools(): + return [types.Tool(name="my_tool", description="A tool", inputSchema={})] + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + return [types.TextContent(type="text", text=f"Called {name}")] +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", description="A tool", inputSchema={})]) + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=f"Called {params.name}")], + is_error=False, + ) + +server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) +``` + +**Key differences:** + +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `RequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). +- The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. + +**Notification handlers:** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ProgressNotificationParams + + +async def handle_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + print(f"Progress: {params.progress}/{params.total}") + +server = Server("my-server", on_progress=handle_progress) +``` + +### Lowlevel `Server`: `request_context` property removed + +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar still exists but should not be needed — use `ctx` directly instead. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import request_ctx + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + ctx = server.request_context # or request_ctx.get() + await ctx.session.send_log_message(level="info", data="Processing...") + return [types.TextContent(type="text", text="Done")] +``` + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import CallToolRequestParams, CallToolResult, TextContent + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + await ctx.session.send_log_message(level="info", data="Processing...") + return CallToolResult( + content=[TextContent(type="text", text="Done")], + is_error=False, + ) +``` + +### `RequestContext`: request-specific fields are now optional + +The `RequestContext` class now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. + +```python +from mcp.server import ServerRequestContext + +# request_id, meta, etc. are available in request handlers +# but None in notification handlers +``` + +### Experimental: task handler decorators removed + +The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. + +Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. + +**Before (v1):** + +```python +server = Server("my-server") +server.experimental.enable_tasks(task_store) + +@server.experimental.get_task() +async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: + ... +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import GetTaskRequestParams, GetTaskResult + + +async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + ... + + +server = Server("my-server") +server.experimental.enable_tasks(on_get_task=custom_get_task) +``` + ## Deprecations @@ -506,16 +657,16 @@ params = CallToolRequestParams( The `streamable_http_app()` method is now available directly on the lowlevel `Server` class, not just `MCPServer`. This allows using the streamable HTTP transport without the MCPServer wrapper. ```python -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams -server = Server("my-server") -# Register handlers... -@server.list_tools() -async def list_tools(): - return [...] +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[...]) + + +server = Server("my-server", on_list_tools=handle_list_tools) -# Create a Starlette app for streamable HTTP app = server.streamable_http_app( streamable_http_path="/mcp", json_response=False, diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 4fb7d9a1d..b838e2dd5 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -4,30 +4,19 @@ Server implementing all MCP features for conformance testing based on Conformance Server Specification. """ +from __future__ import annotations + import asyncio -import base64 import json import logging +from typing import Any import click -from mcp.server.mcpserver import Context, MCPServer -from mcp.server.mcpserver.prompts.base import UserMessage -from mcp.server.session import ServerSession +from mcp import types +from mcp.server.context import ServerRequestContext +from mcp.server.elicitation import ElicitationResult, elicit_with_validation +from mcp.server.lowlevel import Server from mcp.server.streamable_http import EventCallback, EventMessage, EventStore -from mcp.types import ( - AudioContent, - Completion, - CompletionArgument, - CompletionContext, - EmbeddedResource, - ImageContent, - JSONRPCMessage, - PromptReference, - ResourceTemplateReference, - SamplingMessage, - TextContent, - TextResourceContents, -) from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -41,10 +30,10 @@ class InMemoryEventStore(EventStore): """Simple in-memory event store for SSE resumability testing.""" def __init__(self) -> None: - self._events: list[tuple[StreamId, EventId, JSONRPCMessage | None]] = [] + self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) @@ -80,136 +69,14 @@ async def replay_events_after(self, last_event_id: EventId, send_callback: Event # Create event store for SSE resumability (SEP-1699) event_store = InMemoryEventStore() -mcp = MCPServer( - name="mcp-conformance-test-server", -) - - -# Tools -@mcp.tool() -def test_simple_text() -> str: - """Tests simple text content response""" - return "This is a simple text response for testing." - - -@mcp.tool() -def test_image_content() -> list[ImageContent]: - """Tests image content response""" - return [ImageContent(type="image", data=TEST_IMAGE_BASE64, mime_type="image/png")] - - -@mcp.tool() -def test_audio_content() -> list[AudioContent]: - """Tests audio content response""" - return [AudioContent(type="audio", data=TEST_AUDIO_BASE64, mime_type="audio/wav")] - - -@mcp.tool() -def test_embedded_resource() -> list[EmbeddedResource]: - """Tests embedded resource content response""" - return [ - EmbeddedResource( - type="resource", - resource=TextResourceContents( - uri="test://embedded-resource", - mime_type="text/plain", - text="This is an embedded resource content.", - ), - ) - ] - - -@mcp.tool() -def test_multiple_content_types() -> list[TextContent | ImageContent | EmbeddedResource]: - """Tests response with multiple content types (text, image, resource)""" - return [ - TextContent(type="text", text="Multiple content types test:"), - ImageContent(type="image", data=TEST_IMAGE_BASE64, mime_type="image/png"), - EmbeddedResource( - type="resource", - resource=TextResourceContents( - uri="test://mixed-content-resource", - mime_type="application/json", - text='{"test": "data", "value": 123}', - ), - ), - ] - - -@mcp.tool() -async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: - """Tests tool that emits log messages during execution""" - await ctx.info("Tool execution started") - await asyncio.sleep(0.05) - - await ctx.info("Tool processing data") - await asyncio.sleep(0.05) - - await ctx.info("Tool execution completed") - return "Tool with logging executed successfully" - - -@mcp.tool() -async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: - """Tests tool that reports progress notifications""" - await ctx.report_progress(progress=0, total=100, message="Completed step 0 of 100") - await asyncio.sleep(0.05) - - await ctx.report_progress(progress=50, total=100, message="Completed step 50 of 100") - await asyncio.sleep(0.05) - - await ctx.report_progress(progress=100, total=100, message="Completed step 100 of 100") - - # Return progress token as string - progress_token = ( - ctx.request_context.meta.get("progress_token") if ctx.request_context and ctx.request_context.meta else 0 - ) - return str(progress_token) - - -@mcp.tool() -async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str: - """Tests server-initiated sampling (LLM completion request)""" - try: - # Request sampling from client - result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], - max_tokens=100, - ) - - # Since we're not passing tools param, result.content is single content - if result.content.type == "text": - model_response = result.content.text - else: - model_response = "No response" - return f"LLM response: {model_response}" - except Exception as e: - return f"Sampling not supported or error: {str(e)}" +# --- Pydantic models for elicitation --- class UserResponse(BaseModel): response: str = Field(description="User's response") -@mcp.tool() -async def test_elicitation(message: str, ctx: Context[ServerSession, None]) -> str: - """Tests server-initiated elicitation (user input request)""" - try: - # Request user input from client - result = await ctx.elicit(message=message, schema=UserResponse) - - # Type-safe discriminated union narrowing using action field - if result.action == "accept": - content = result.data.model_dump_json() - else: # decline or cancel - content = "{}" - - return f"User response: action={result.action}, content={content}" - except Exception as e: - return f"Elicitation not supported or error: {str(e)}" - - class SEP1034DefaultsSchema(BaseModel): """Schema for testing SEP-1034 elicitation with default values for all primitive types""" @@ -224,24 +91,6 @@ class SEP1034DefaultsSchema(BaseModel): verified: bool = Field(default=True, description="Verification status") -@mcp.tool() -async def test_elicitation_sep1034_defaults(ctx: Context[ServerSession, None]) -> str: - """Tests elicitation with default values for all primitive types (SEP-1034)""" - try: - # Request user input with defaults for all primitive types - result = await ctx.elicit(message="Please provide user information", schema=SEP1034DefaultsSchema) - - # Type-safe discriminated union narrowing using action field - if result.action == "accept": - content = result.data.model_dump_json() - else: # decline or cancel - content = "{}" - - return f"Elicitation result: action={result.action}, content={content}" - except Exception as e: - return f"Elicitation not supported or error: {str(e)}" - - class EnumSchemasTestSchema(BaseModel): """Schema for testing enum schema variations (SEP-1330)""" @@ -283,150 +132,566 @@ class EnumSchemasTestSchema(BaseModel): ) -@mcp.tool() -async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> str: - """Tests elicitation with enum schema variations per SEP-1330""" - try: - result = await ctx.elicit( - message="Please select values using different enum schema types", schema=EnumSchemasTestSchema +# --- Helper to perform elicitation through the low-level API --- + + +async def _elicit( + ctx: ServerRequestContext[Any], + message: str, + schema: type[BaseModel], +) -> ElicitationResult[Any]: + """Elicit information from the client using the low-level ServerRequestContext.""" + return await elicit_with_validation( + session=ctx.session, + message=message, + schema=schema, + related_request_id=ctx.request_id, + ) + + +# --- Tool definitions --- + +TOOLS: list[types.Tool] = [ + types.Tool( + name="test_simple_text", + description="Tests simple text content response", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_image_content", + description="Tests image content response", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_audio_content", + description="Tests audio content response", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_embedded_resource", + description="Tests embedded resource content response", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_multiple_content_types", + description="Tests response with multiple content types (text, image, resource)", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_tool_with_logging", + description="Tests tool that emits log messages during execution", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_tool_with_progress", + description="Tests tool that reports progress notifications", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_sampling", + description="Tests server-initiated sampling (LLM completion request)", + input_schema={ + "type": "object", + "properties": {"prompt": {"type": "string", "description": "Prompt for sampling"}}, + "required": ["prompt"], + }, + ), + types.Tool( + name="test_elicitation", + description="Tests server-initiated elicitation (user input request)", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string", "description": "Message for elicitation"}}, + "required": ["message"], + }, + ), + types.Tool( + name="test_elicitation_sep1034_defaults", + description="Tests elicitation with default values for all primitive types (SEP-1034)", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_elicitation_sep1330_enums", + description="Tests elicitation with enum schema variations per SEP-1330", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_error_handling", + description="Tests error response handling", + input_schema={"type": "object", "properties": {}}, + ), + types.Tool( + name="test_reconnection", + description="Tests SSE polling by closing stream mid-call (SEP-1699)", + input_schema={"type": "object", "properties": {}}, + ), +] + + +# --- Resource definitions --- + +RESOURCES: list[types.Resource] = [ + types.Resource( + uri="test://static-text", + name="Static Text Resource", + description="A static text resource for testing", + mime_type="text/plain", + ), + types.Resource( + uri="test://static-binary", + name="Static Binary Resource", + description="A static binary resource (image) for testing", + mime_type="image/png", + ), + types.Resource( + uri="test://watched-resource", + name="Watched Resource", + description="A resource that can be subscribed to for updates", + mime_type="text/plain", + ), +] + +RESOURCE_TEMPLATES: list[types.ResourceTemplate] = [ + types.ResourceTemplate( + uriTemplate="test://template/{id}/data", + name="Template Resource", + description="A resource template with parameter substitution", + mime_type="application/json", + ), +] + +# --- Prompt definitions --- + +PROMPTS: list[types.Prompt] = [ + types.Prompt( + name="test_simple_prompt", + description="A simple prompt without arguments", + ), + types.Prompt( + name="test_prompt_with_arguments", + description="A prompt with required arguments", + arguments=[ + types.PromptArgument(name="arg1", description="First argument", required=True), + types.PromptArgument(name="arg2", description="Second argument", required=True), + ], + ), + types.Prompt( + name="test_prompt_with_embedded_resource", + description="A prompt that includes an embedded resource", + arguments=[ + types.PromptArgument(name="resourceUri", description="URI of the resource to embed", required=True), + ], + ), + types.Prompt( + name="test_prompt_with_image", + description="A prompt that includes image content", + ), +] + + +# --- Handler implementations --- + + +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult(tools=TOOLS) + + +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + """Handle tool calls.""" + name = params.name + arguments = params.arguments or {} + + if name == "test_simple_text": + return types.CallToolResult( + content=[types.TextContent(type="text", text="This is a simple text response for testing.")] ) - if result.action == "accept": - content = result.data.model_dump_json() - else: - content = "{}" + elif name == "test_image_content": + return types.CallToolResult( + content=[types.ImageContent(type="image", data=TEST_IMAGE_BASE64, mime_type="image/png")] + ) - return f"Elicitation completed: action={result.action}, content={content}" - except Exception as e: - return f"Elicitation not supported or error: {str(e)}" + elif name == "test_audio_content": + return types.CallToolResult( + content=[types.AudioContent(type="audio", data=TEST_AUDIO_BASE64, mime_type="audio/wav")] + ) + elif name == "test_embedded_resource": + return types.CallToolResult( + content=[ + types.EmbeddedResource( + type="resource", + resource=types.TextResourceContents( + uri="test://embedded-resource", + mime_type="text/plain", + text="This is an embedded resource content.", + ), + ) + ] + ) -@mcp.tool() -def test_error_handling() -> str: - """Tests error response handling""" - raise RuntimeError("This tool intentionally returns an error for testing") + elif name == "test_multiple_content_types": + return types.CallToolResult( + content=[ + types.TextContent(type="text", text="Multiple content types test:"), + types.ImageContent(type="image", data=TEST_IMAGE_BASE64, mime_type="image/png"), + types.EmbeddedResource( + type="resource", + resource=types.TextResourceContents( + uri="test://mixed-content-resource", + mime_type="application/json", + text='{"test": "data", "value": 123}', + ), + ), + ] + ) + + elif name == "test_tool_with_logging": + await ctx.session.send_log_message( + level="info", data="Tool execution started", related_request_id=ctx.request_id + ) + await asyncio.sleep(0.05) + + await ctx.session.send_log_message(level="info", data="Tool processing data", related_request_id=ctx.request_id) + await asyncio.sleep(0.05) + + await ctx.session.send_log_message( + level="info", data="Tool execution completed", related_request_id=ctx.request_id + ) + return types.CallToolResult( + content=[types.TextContent(type="text", text="Tool with logging executed successfully")] + ) + + elif name == "test_tool_with_progress": + progress_token = ctx.meta.get("progress_token") if ctx.meta else None + if progress_token is not None: + await ctx.session.send_progress_notification( + progress_token=progress_token, progress=0, total=100, message="Completed step 0 of 100" + ) + await asyncio.sleep(0.05) -@mcp.tool() -async def test_reconnection(ctx: Context[ServerSession, None]) -> str: - """Tests SSE polling by closing stream mid-call (SEP-1699)""" - await ctx.info("Before disconnect") + await ctx.session.send_progress_notification( + progress_token=progress_token, progress=50, total=100, message="Completed step 50 of 100" + ) + await asyncio.sleep(0.05) - await ctx.close_sse_stream() + await ctx.session.send_progress_notification( + progress_token=progress_token, progress=100, total=100, message="Completed step 100 of 100" + ) - await asyncio.sleep(0.2) # Wait for client to reconnect + return types.CallToolResult( + content=[types.TextContent(type="text", text=str(progress_token if progress_token is not None else 0))] + ) + + elif name == "test_sampling": + prompt = str(arguments.get("prompt", "")) + try: + result = await ctx.session.create_message( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text=prompt))], + max_tokens=100, + related_request_id=ctx.request_id, + ) + + if result.content.type == "text": + model_response = result.content.text + else: + model_response = "No response" + + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"LLM response: {model_response}")] + ) + except Exception as e: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Sampling not supported or error: {str(e)}")] + ) + + elif name == "test_elicitation": + message = str(arguments.get("message", "")) + try: + result = await _elicit(ctx, message=message, schema=UserResponse) + + if result.action == "accept": + content = result.data.model_dump_json() + else: + content = "{}" + + return types.CallToolResult( + content=[ + types.TextContent(type="text", text=f"User response: action={result.action}, content={content}") + ] + ) + except Exception as e: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Elicitation not supported or error: {str(e)}")] + ) + + elif name == "test_elicitation_sep1034_defaults": + try: + result = await _elicit(ctx, message="Please provide user information", schema=SEP1034DefaultsSchema) + + if result.action == "accept": + content = result.data.model_dump_json() + else: + content = "{}" + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", text=f"Elicitation result: action={result.action}, content={content}" + ) + ] + ) + except Exception as e: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Elicitation not supported or error: {str(e)}")] + ) + + elif name == "test_elicitation_sep1330_enums": + try: + result = await _elicit( + ctx, message="Please select values using different enum schema types", schema=EnumSchemasTestSchema + ) + + if result.action == "accept": + content = result.data.model_dump_json() + else: + content = "{}" + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", text=f"Elicitation completed: action={result.action}, content={content}" + ) + ] + ) + except Exception as e: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Elicitation not supported or error: {str(e)}")] + ) + + elif name == "test_error_handling": + return types.CallToolResult( + isError=True, + content=[types.TextContent(type="text", text="This tool intentionally returns an error for testing")], + ) - await ctx.info("After reconnect") - return "Reconnection test completed" + elif name == "test_reconnection": + await ctx.session.send_log_message(level="info", data="Before disconnect", related_request_id=ctx.request_id) + if ctx.close_sse_stream: + await ctx.close_sse_stream() -# Resources -@mcp.resource("test://static-text") -def static_text_resource() -> str: - """A static text resource for testing""" - return "This is the content of the static text resource." + await asyncio.sleep(0.2) # Wait for client to reconnect + await ctx.session.send_log_message(level="info", data="After reconnect", related_request_id=ctx.request_id) + return types.CallToolResult(content=[types.TextContent(type="text", text="Reconnection test completed")]) -@mcp.resource("test://static-binary") -def static_binary_resource() -> bytes: - """A static binary resource (image) for testing""" - return base64.b64decode(TEST_IMAGE_BASE64) + raise ValueError(f"Unknown tool: {name}") -@mcp.resource("test://template/{id}/data") -def template_resource(id: str) -> str: - """A resource template with parameter substitution""" - return json.dumps({"id": id, "templateTest": True, "data": f"Data for ID: {id}"}) +async def handle_list_resources( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + """List available resources.""" + return types.ListResourcesResult(resources=RESOURCES) -@mcp.resource("test://watched-resource") -def watched_resource() -> str: - """A resource that can be subscribed to for updates""" - return watched_resource_content +async def handle_list_resource_templates( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListResourceTemplatesResult: + """List available resource templates.""" + return types.ListResourceTemplatesResult(resource_templates=RESOURCE_TEMPLATES) -# Prompts -@mcp.prompt() -def test_simple_prompt() -> list[UserMessage]: - """A simple prompt without arguments""" - return [UserMessage(role="user", content=TextContent(type="text", text="This is a simple prompt for testing."))] +async def handle_read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + """Read a specific resource.""" + uri = str(params.uri) + if uri == "test://static-text": + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=uri, + mime_type="text/plain", + text="This is the content of the static text resource.", + ) + ] + ) -@mcp.prompt() -def test_prompt_with_arguments(arg1: str, arg2: str) -> list[UserMessage]: - """A prompt with required arguments""" - return [ - UserMessage( - role="user", content=TextContent(type="text", text=f"Prompt with arguments: arg1='{arg1}', arg2='{arg2}'") + elif uri == "test://static-binary": + return types.ReadResourceResult( + contents=[ + types.BlobResourceContents( + uri=uri, + mime_type="image/png", + blob=TEST_IMAGE_BASE64, + ) + ] ) - ] - - -@mcp.prompt() -def test_prompt_with_embedded_resource(resourceUri: str) -> list[UserMessage]: - """A prompt that includes an embedded resource""" - return [ - UserMessage( - role="user", - content=EmbeddedResource( - type="resource", - resource=TextResourceContents( - uri=resourceUri, + + elif uri == "test://watched-resource": + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=uri, mime_type="text/plain", - text="Embedded resource content for testing.", - ), - ), - ), - UserMessage(role="user", content=TextContent(type="text", text="Please process the embedded resource above.")), - ] + text=watched_resource_content, + ) + ] + ) + + # Check for template match: test://template/{id}/data + elif uri.startswith("test://template/") and uri.endswith("/data"): + resource_id = uri[len("test://template/") : -len("/data")] + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=uri, + mime_type="application/json", + text=json.dumps({"id": resource_id, "templateTest": True, "data": f"Data for ID: {resource_id}"}), + ) + ] + ) + + raise ValueError(f"Unknown resource: {uri}") -@mcp.prompt() -def test_prompt_with_image() -> list[UserMessage]: - """A prompt that includes image content""" - return [ - UserMessage(role="user", content=ImageContent(type="image", data=TEST_IMAGE_BASE64, mime_type="image/png")), - UserMessage(role="user", content=TextContent(type="text", text="Please analyze the image above.")), - ] +async def handle_list_prompts( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + """List available prompts.""" + return types.ListPromptsResult(prompts=PROMPTS) -# Custom request handlers -# TODO(felix): Add public APIs to MCPServer for subscribe_resource, unsubscribe_resource, -# and set_logging_level to avoid accessing protected _lowlevel_server attribute. -@mcp._lowlevel_server.set_logging_level() # pyright: ignore[reportPrivateUsage] -async def handle_set_logging_level(level: str) -> None: - """Handle logging level changes""" - logger.info(f"Log level set to: {level}") - # In a real implementation, you would adjust the logging level here - # For conformance testing, we just acknowledge the request +async def handle_get_prompt( + ctx: ServerRequestContext[Any], params: types.GetPromptRequestParams +) -> types.GetPromptResult: + """Get a specific prompt by name.""" + name = params.name + arguments = params.arguments or {} + if name == "test_simple_prompt": + return types.GetPromptResult( + description="A simple prompt without arguments", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="This is a simple prompt for testing."), + ) + ], + ) + + elif name == "test_prompt_with_arguments": + arg1 = arguments.get("arg1", "") + arg2 = arguments.get("arg2", "") + return types.GetPromptResult( + description="A prompt with required arguments", + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=f"Prompt with arguments: arg1='{arg1}', arg2='{arg2}'"), + ) + ], + ) -async def handle_subscribe(uri: str) -> None: - """Handle resource subscription""" - resource_subscriptions.add(str(uri)) - logger.info(f"Subscribed to resource: {uri}") + elif name == "test_prompt_with_embedded_resource": + resource_uri = arguments.get("resourceUri", "") + return types.GetPromptResult( + description="A prompt that includes an embedded resource", + messages=[ + types.PromptMessage( + role="user", + content=types.EmbeddedResource( + type="resource", + resource=types.TextResourceContents( + uri=resource_uri, + mime_type="text/plain", + text="Embedded resource content for testing.", + ), + ), + ), + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Please process the embedded resource above."), + ), + ], + ) + elif name == "test_prompt_with_image": + return types.GetPromptResult( + description="A prompt that includes image content", + messages=[ + types.PromptMessage( + role="user", + content=types.ImageContent(type="image", data=TEST_IMAGE_BASE64, mime_type="image/png"), + ), + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Please analyze the image above."), + ), + ], + ) -async def handle_unsubscribe(uri: str) -> None: - """Handle resource unsubscription""" - resource_subscriptions.discard(str(uri)) - logger.info(f"Unsubscribed from resource: {uri}") + raise ValueError(f"Unknown prompt: {name}") -mcp._lowlevel_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server.unsubscribe_resource()(handle_unsubscribe) # pyright: ignore[reportPrivateUsage] +async def handle_set_logging_level( + ctx: ServerRequestContext[Any], params: types.SetLevelRequestParams +) -> types.EmptyResult: + """Handle logging level changes.""" + logger.info(f"Log level set to: {params.level}") + return types.EmptyResult() -@mcp.completion() -async def _handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, -) -> Completion: - """Handle completion requests""" +async def handle_subscribe_resource( + ctx: ServerRequestContext[Any], params: types.SubscribeRequestParams +) -> types.EmptyResult: + """Handle resource subscription.""" + resource_subscriptions.add(str(params.uri)) + logger.info(f"Subscribed to resource: {params.uri}") + return types.EmptyResult() + + +async def handle_unsubscribe_resource( + ctx: ServerRequestContext[Any], params: types.UnsubscribeRequestParams +) -> types.EmptyResult: + """Handle resource unsubscription.""" + resource_subscriptions.discard(str(params.uri)) + logger.info(f"Unsubscribed from resource: {params.uri}") + return types.EmptyResult() + + +async def handle_completion( + ctx: ServerRequestContext[Any], params: types.CompleteRequestParams +) -> types.CompleteResult: + """Handle completion requests.""" # Basic completion support - returns empty array for conformance # Real implementations would provide contextual suggestions - return Completion(values=[], total=0, has_more=False) + return types.CompleteResult(completion=types.Completion(values=[], total=0, has_more=False)) + + +# --- Server instance --- + +server = Server( + "mcp-conformance-test-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_resources=handle_list_resources, + on_list_resource_templates=handle_list_resource_templates, + on_read_resource=handle_read_resource, + on_subscribe_resource=handle_subscribe_resource, + on_unsubscribe_resource=handle_unsubscribe_resource, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + on_set_logging_level=handle_set_logging_level, + on_completion=handle_completion, +) # CLI @@ -439,6 +704,8 @@ async def _handle_completion( ) def main(port: int, log_level: str) -> int: """Run the MCP Everything Server.""" + import uvicorn + logging.basicConfig( level=getattr(logging, log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -447,13 +714,22 @@ def main(port: int, log_level: str) -> int: logger.info(f"Starting MCP Everything Server on port {port}") logger.info(f"Endpoint will be: http://localhost:{port}/mcp") - mcp.run( - transport="streamable-http", - port=port, + starlette_app = server.streamable_http_app( event_store=event_store, retry_interval=100, # 100ms retry interval for SSE polling ) + config = uvicorn.Config( + starlette_app, + host="127.0.0.1", + port=port, + log_level=log_level.lower(), + ) + uvicorn_server = uvicorn.Server(config) + import anyio + + anyio.run(uvicorn_server.serve) + return 0 diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index ff45ae224..a63d0baf9 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -1,6 +1,6 @@ """Simple MCP server demonstrating pagination for tools, resources, and prompts. -This example shows how to use the paginated decorators to handle large lists +This example shows how to use the on_* handler pattern to handle large lists of items that need to be split across multiple pages. """ @@ -9,6 +9,7 @@ import anyio import click from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from starlette.requests import Request @@ -44,6 +45,136 @@ ] +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """Paginated list_tools - returns 5 tools per page.""" + page_size = 5 + + cursor = params.cursor if params is not None else None + if cursor is None: + start_idx = 0 + else: + try: + start_idx = int(cursor) + except (ValueError, TypeError): + return types.ListToolsResult(tools=[], next_cursor=None) + + page_tools = SAMPLE_TOOLS[start_idx : start_idx + page_size] + + next_cursor = None + if start_idx + page_size < len(SAMPLE_TOOLS): + next_cursor = str(start_idx + page_size) + + return types.ListToolsResult(tools=page_tools, next_cursor=next_cursor) + + +async def handle_list_resources( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + """Paginated list_resources - returns 10 resources per page.""" + page_size = 10 + + cursor = params.cursor if params is not None else None + if cursor is None: + start_idx = 0 + else: + try: + start_idx = int(cursor) + except (ValueError, TypeError): + return types.ListResourcesResult(resources=[], next_cursor=None) + + page_resources = SAMPLE_RESOURCES[start_idx : start_idx + page_size] + + next_cursor = None + if start_idx + page_size < len(SAMPLE_RESOURCES): + next_cursor = str(start_idx + page_size) + + return types.ListResourcesResult(resources=page_resources, next_cursor=next_cursor) + + +async def handle_list_prompts( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + """Paginated list_prompts - returns 7 prompts per page.""" + page_size = 7 + + cursor = params.cursor if params is not None else None + if cursor is None: + start_idx = 0 + else: + try: + start_idx = int(cursor) + except (ValueError, TypeError): + return types.ListPromptsResult(prompts=[], next_cursor=None) + + page_prompts = SAMPLE_PROMPTS[start_idx : start_idx + page_size] + + next_cursor = None + if start_idx + page_size < len(SAMPLE_PROMPTS): + next_cursor = str(start_idx + page_size) + + return types.ListPromptsResult(prompts=page_prompts, next_cursor=next_cursor) + + +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + """Handle tool calls.""" + tool = next((t for t in SAMPLE_TOOLS if t.name == params.name), None) + if not tool: + raise ValueError(f"Unknown tool: {params.name}") + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=f"Called tool '{params.name}' with arguments: {params.arguments}", + ) + ] + ) + + +async def handle_read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + """Handle read_resource requests.""" + resource = next((r for r in SAMPLE_RESOURCES if r.uri == params.uri), None) + if not resource: + raise ValueError(f"Unknown resource: {params.uri}") + + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=params.uri, + text=f"Content of {resource.name}: This is sample content for the resource.", + mime_type="text/plain", + ) + ] + ) + + +async def handle_get_prompt( + ctx: ServerRequestContext[Any], params: types.GetPromptRequestParams +) -> types.GetPromptResult: + """Handle get_prompt requests.""" + prompt = next((p for p in SAMPLE_PROMPTS if p.name == params.name), None) + if not prompt: + raise ValueError(f"Unknown prompt: {params.name}") + + message_text = f"This is the prompt '{params.name}'" + if params.arguments: + message_text += f" with arguments: {params.arguments}" + + return types.GetPromptResult( + description=prompt.description, + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=message_text), + ) + ], + ) + + @click.command() @click.option("--port", default=8000, help="Port to listen on for SSE") @click.option( @@ -53,142 +184,15 @@ help="Transport type", ) def main(port: int, transport: str) -> int: - app = Server("mcp-simple-pagination") - - # Paginated list_tools - returns 5 tools per page - @app.list_tools() - async def list_tools_paginated(request: types.ListToolsRequest) -> types.ListToolsResult: - page_size = 5 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListToolsResult(tools=[], next_cursor=None) - - # Get the page of tools - page_tools = SAMPLE_TOOLS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_TOOLS): - next_cursor = str(start_idx + page_size) - - return types.ListToolsResult(tools=page_tools, next_cursor=next_cursor) - - # Paginated list_resources - returns 10 resources per page - @app.list_resources() - async def list_resources_paginated( - request: types.ListResourcesRequest, - ) -> types.ListResourcesResult: - page_size = 10 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListResourcesResult(resources=[], next_cursor=None) - - # Get the page of resources - page_resources = SAMPLE_RESOURCES[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_RESOURCES): - next_cursor = str(start_idx + page_size) - - return types.ListResourcesResult(resources=page_resources, next_cursor=next_cursor) - - # Paginated list_prompts - returns 7 prompts per page - @app.list_prompts() - async def list_prompts_paginated( - request: types.ListPromptsRequest, - ) -> types.ListPromptsResult: - page_size = 7 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListPromptsResult(prompts=[], next_cursor=None) - - # Get the page of prompts - page_prompts = SAMPLE_PROMPTS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_PROMPTS): - next_cursor = str(start_idx + page_size) - - return types.ListPromptsResult(prompts=page_prompts, next_cursor=next_cursor) - - # Implement call_tool handler - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - # Find the tool in our sample data - tool = next((t for t in SAMPLE_TOOLS if t.name == name), None) - if not tool: - raise ValueError(f"Unknown tool: {name}") - - # Simple mock response - return [ - types.TextContent( - type="text", - text=f"Called tool '{name}' with arguments: {arguments}", - ) - ] - - # Implement read_resource handler - @app.read_resource() - async def read_resource(uri: str) -> str: - # Find the resource in our sample data - resource = next((r for r in SAMPLE_RESOURCES if r.uri == uri), None) - if not resource: - raise ValueError(f"Unknown resource: {uri}") - - # Return a simple string - the decorator will convert it to TextResourceContents - return f"Content of {resource.name}: This is sample content for the resource." - - # Implement get_prompt handler - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: - # Find the prompt in our sample data - prompt = next((p for p in SAMPLE_PROMPTS if p.name == name), None) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - # Simple mock response - message_text = f"This is the prompt '{name}'" - if arguments: - message_text += f" with arguments: {arguments}" - - return types.GetPromptResult( - description=prompt.description, - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent(type="text", text=message_text), - ) - ], - ) + app = Server( + "mcp-simple-pagination", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index cbc5a9d68..05271653b 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -1,6 +1,9 @@ +from typing import Any + import anyio import click from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from starlette.requests import Request @@ -30,20 +33,11 @@ def create_messages(context: str | None = None, topic: str | None = None) -> lis return messages -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-prompt") - - @app.list_prompts() - async def list_prompts() -> list[types.Prompt]: - return [ +async def handle_list_prompts( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ types.Prompt( name="simple", title="Simple Assistant Prompt", @@ -62,19 +56,37 @@ async def list_prompts() -> list[types.Prompt]: ], ) ] + ) - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: - if name != "simple": - raise ValueError(f"Unknown prompt: {name}") - if arguments is None: - arguments = {} +async def handle_get_prompt( + ctx: ServerRequestContext[Any], params: types.GetPromptRequestParams +) -> types.GetPromptResult: + if params.name != "simple": + raise ValueError(f"Unknown prompt: {params.name}") - return types.GetPromptResult( - messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), - description="A simple prompt with optional context and topic arguments", - ) + arguments = params.arguments or {} + + return types.GetPromptResult( + messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), + description="A simple prompt with optional context and topic arguments", + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-prompt", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 588d1044a..9dc32e36c 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -1,8 +1,11 @@ +from typing import Any +from urllib.parse import urlparse + import anyio import click from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents from starlette.requests import Request SAMPLE_RESOURCES = { @@ -21,20 +24,11 @@ } -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-resource") - - @app.list_resources() - async def list_resources() -> list[types.Resource]: - return [ +async def handle_list_resources( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ types.Resource( uri=f"file:///{name}.txt", name=name, @@ -44,20 +38,45 @@ async def list_resources() -> list[types.Resource]: ) for name in SAMPLE_RESOURCES.keys() ] + ) - @app.read_resource() - async def read_resource(uri: str): - from urllib.parse import urlparse - parsed = urlparse(uri) - if not parsed.path: - raise ValueError(f"Invalid resource path: {uri}") - name = parsed.path.replace(".txt", "").lstrip("/") +async def handle_read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + parsed = urlparse(params.uri) + if not parsed.path: + raise ValueError(f"Invalid resource path: {params.uri}") + name = parsed.path.replace(".txt", "").lstrip("/") - if name not in SAMPLE_RESOURCES: - raise ValueError(f"Unknown resource: {uri}") + if name not in SAMPLE_RESOURCES: + raise ValueError(f"Unknown resource: {params.uri}") - return [ReadResourceContents(content=SAMPLE_RESOURCES[name]["content"], mime_type="text/plain")] + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=params.uri, + text=SAMPLE_RESOURCES[name]["content"], + mime_type="text/plain", + ) + ] + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-resource", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 9fed2f0aa..fc23fe40a 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -7,6 +7,7 @@ import click import uvicorn from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette @@ -17,60 +18,37 @@ logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on for HTTP") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", -) -@click.option( - "--json-response", - is_flag=True, - default=False, - help="Enable JSON responses instead of SSE streams", -) -def main( - port: int, - log_level: str, - json_response: bool, -) -> None: - # Configure logging - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + interval = params.arguments.get("interval", 1.0) if params.arguments else 1.0 + count = params.arguments.get("count", 5) if params.arguments else 5 + caller = params.arguments.get("caller", "unknown") if params.arguments else "unknown" - app = Server("mcp-streamable-http-stateless-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - await ctx.session.send_log_message( - level="info", - data=f"Notification {i + 1}/{count} from caller: {caller}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i + 1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) - return [ + return types.CallToolResult( + content=[ types.TextContent( type="text", text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), ) ] + ) + - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="start-notification-stream", description=("Sends a stream of notifications with configurable count and interval"), @@ -94,6 +72,38 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> None: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server( + "mcp-streamable-http-stateless-demo", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Create the session manager with true stateless mode session_manager = StreamableHTTPSessionManager( diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index ef03d9b08..7a2164fb3 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -6,6 +6,7 @@ import anyio import click from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette @@ -19,71 +20,48 @@ logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on for HTTP") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", -) -@click.option( - "--json-response", - is_flag=True, - default=False, - help="Enable JSON responses instead of SSE streams", -) -def main( - port: int, - log_level: str, - json_response: bool, -) -> int: - # Configure logging - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - app = Server("mcp-streamable-http-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - # Include more detailed message for resumability demonstration - notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" - await ctx.session.send_log_message( - level="info", - data=notification_msg, - logger="notification_stream", - # Associates this notification with the original request - # Ensures notifications are sent to the correct response stream - # Without this, notifications will either go to: - # - a standalone SSE stream (if GET request is supported) - # - nowhere (if GET request isn't supported) - related_request_id=ctx.request_id, - ) - logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - # This will send a resource notificaiton though standalone SSE - # established by GET request - await ctx.session.send_resource_updated(uri="http:///test_resource") - return [ +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + interval = params.arguments.get("interval", 1.0) if params.arguments else 1.0 + count = params.arguments.get("count", 5) if params.arguments else 5 + caller = params.arguments.get("caller", "unknown") if params.arguments else "unknown" + + # Send the specified number of notifications with the given interval + for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" + await ctx.session.send_log_message( + level="info", + data=notification_msg, + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + # This will send a resource notificaiton though standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri="http:///test_resource") + return types.CallToolResult( + content=[ types.TextContent( type="text", text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), ) ] + ) + - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="start-notification-stream", description=("Sends a stream of notifications with configurable count and interval"), @@ -107,6 +85,38 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server( + "mcp-streamable-http-demo", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Create event store for resumability # The InMemoryEventStore enables resumability support for StreamableHTTP transport. diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py index dc689ed94..6af6498a1 100644 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -13,42 +13,16 @@ import click import uvicorn from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount -server = Server("simple-task-interactive") -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - - -async def handle_confirm_delete(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) filename = arguments.get("filename", "unknown.txt") @@ -80,9 +54,8 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -async def handle_write_haiku(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the write_haiku tool - demonstrates sampling.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) topic = arguments.get("topic", "nature") @@ -111,20 +84,56 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult( + tools=[ + types.Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + input_schema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + types.Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + ] + ) + + +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: """Dispatch tool calls to their handlers.""" - if name == "confirm_delete": - return await handle_confirm_delete(arguments) - elif name == "write_haiku": - return await handle_write_haiku(arguments) + if params.name == "confirm_delete": + return await handle_confirm_delete(ctx, params.arguments or {}) + elif params.name == "write_haiku": + return await handle_write_haiku(ctx, params.arguments or {}) else: return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], is_error=True, ) +server = Server( + "simple-task-interactive", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) + +# Enable task support - this auto-registers all handlers +server.experimental.enable_tasks() + + def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: @asynccontextmanager async def app_lifespan(app: Starlette) -> AsyncIterator[None]: diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py index ec16b15ae..c651860b2 100644 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -8,33 +8,16 @@ import click import uvicorn from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount -server = Server("simple-task-server") -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] - - -async def handle_long_running_task(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_long_running_task(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the long_running_task tool - demonstrates status updates.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) async def work(task: ServerTaskContext) -> types.CallToolResult: @@ -52,18 +35,45 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult( + tools=[ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + input_schema={"type": "object", "properties": {}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ) + ] + ) + + +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: """Dispatch tool calls to their handlers.""" - if name == "long_running_task": - return await handle_long_running_task(arguments) + if params.name == "long_running_task": + return await handle_long_running_task(ctx, params.arguments or {}) else: return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], is_error=True, ) +server = Server( + "simple-task-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) + +# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task +server.experimental.enable_tasks() + + @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 1c253a22e..28cffaa10 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -3,6 +3,7 @@ import anyio import click from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from mcp.shared._httpx_utils import create_mcp_http_client from starlette.requests import Request @@ -18,28 +19,11 @@ async def fetch_website( return [types.TextContent(type="text", text=response.text)] -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-website-fetcher") - - @app.call_tool() - async def fetch_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - if name != "fetch": - raise ValueError(f"Unknown tool: {name}") - if "url" not in arguments: - raise ValueError("Missing required argument 'url'") - return await fetch_website(arguments["url"]) - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="fetch", title="Website Fetcher", @@ -56,6 +40,31 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name != "fetch": + raise ValueError(f"Unknown tool: {params.name}") + if not params.arguments or "url" not in params.arguments: + raise ValueError("Missing required argument 'url'") + return types.CallToolResult(content=await fetch_website(params.arguments["url"])) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-website-fetcher", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index 9d7071ca7..30ede39ec 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -20,6 +20,7 @@ import anyio import click from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette @@ -31,88 +32,73 @@ logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR)", -) -@click.option( - "--retry-interval", - default=100, - help="SSE retry interval in milliseconds (sent to client)", -) -def main(port: int, log_level: str, retry_interval: int) -> int: - """Run the SSE Polling Demo server.""" - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - # Create the lowlevel server - app = Server("sse-polling-demo") +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + """Handle tool calls.""" + if params.name == "process_batch": + arguments = params.arguments or {} + items = arguments.get("items", 10) + checkpoint_every = arguments.get("checkpoint_every", 3) - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - """Handle tool calls.""" - ctx = app.request_context + if items < 1 or items > 100: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: items must be between 1 and 100")] + ) + if checkpoint_every < 1 or checkpoint_every > 20: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + ) - if name == "process_batch": - items = arguments.get("items", 10) - checkpoint_every = arguments.get("checkpoint_every", 3) + await ctx.session.send_log_message( + level="info", + data=f"Starting batch processing of {items} items...", + logger="process_batch", + related_request_id=ctx.request_id, + ) - if items < 1 or items > 100: - return [types.TextContent(type="text", text="Error: items must be between 1 and 100")] - if checkpoint_every < 1 or checkpoint_every > 20: - return [types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + for i in range(1, items + 1): + # Simulate work + await anyio.sleep(0.5) + # Report progress await ctx.session.send_log_message( level="info", - data=f"Starting batch processing of {items} items...", + data=f"[{i}/{items}] Processing item {i}", logger="process_batch", related_request_id=ctx.request_id, ) - for i in range(1, items + 1): - # Simulate work - await anyio.sleep(0.5) - - # Report progress + # Checkpoint: close stream to trigger client reconnect + if i % checkpoint_every == 0 and i < items: await ctx.session.send_log_message( level="info", - data=f"[{i}/{items}] Processing item {i}", + data=f"Checkpoint at item {i} - closing SSE stream for polling", logger="process_batch", related_request_id=ctx.request_id, ) - - # Checkpoint: close stream to trigger client reconnect - if i % checkpoint_every == 0 and i < items: - await ctx.session.send_log_message( - level="info", - data=f"Checkpoint at item {i} - closing SSE stream for polling", - logger="process_batch", - related_request_id=ctx.request_id, - ) - if ctx.close_sse_stream: - logger.info(f"Closing SSE stream at checkpoint {i}") - await ctx.close_sse_stream() - # Wait for client to reconnect (must be > retry_interval of 100ms) - await anyio.sleep(0.2) - - return [ + if ctx.close_sse_stream: + logger.info(f"Closing SSE stream at checkpoint {i}") + await ctx.close_sse_stream() + # Wait for client to reconnect (must be > retry_interval of 100ms) + await anyio.sleep(0.2) + + return types.CallToolResult( + content=[ types.TextContent( type="text", text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", ) ] + ) + + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")]) - return [types.TextContent(type="text", text=f"Unknown tool: {name}")] - @app.list_tools() - async def list_tools() -> list[types.Tool]: - """List available tools.""" - return [ +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult( + tools=[ types.Tool( name="process_batch", description=( @@ -136,6 +122,34 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR)", +) +@click.option( + "--retry-interval", + default=100, + help="SSE retry interval in milliseconds (sent to client)", +) +def main(port: int, log_level: str, retry_interval: int) -> int: + """Run the SSE Polling Demo server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Create the lowlevel server + app = Server( + "sse-polling-demo", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Create event store for resumability event_store = InMemoryEventStore() diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index fd73a54cd..14d4f6ee0 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -7,55 +7,53 @@ """ import asyncio +import random from datetime import datetime -from typing import Any import mcp.server.stdio from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions -# Create low-level server instance -server = Server("structured-output-lowlevel-example") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with their schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get weather information (simulated)", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number"}, - "conditions": {"type": "string"}, - "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, - "wind_speed": {"type": "number"}, - "timestamp": {"type": "string", "format": "date-time"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get weather information (simulated)", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], + }, + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number"}, + "conditions": {"type": "string"}, + "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, + "wind_speed": {"type": "number"}, + "timestamp": {"type": "string", "format": "date-time"}, + }, + "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], }, - "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], - }, - ), - ] + ), + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> Any: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool call with structured output.""" - if name == "get_weather": - # city = arguments["city"] # Would be used with real weather API + if params.name == "get_weather": + # city = (params.arguments or {})["city"] # Would be used with real weather API # Simulate weather data (in production, call a real weather API) - import random - weather_conditions = ["sunny", "cloudy", "rainy", "partly cloudy", "foggy"] weather_data = { @@ -66,12 +64,23 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "timestamp": datetime.now().isoformat(), } - # Return structured data only + # Return structured data as CallToolResult # The low-level server will serialize this to JSON content automatically - return weather_data + return types.CallToolResult( + content=[types.TextContent(type="text", text=str(weather_data))], + structured_content=weather_data, + ) else: - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +# Create low-level server instance +server = Server( + "structured-output-lowlevel-example", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index 0d4432504..4f92e9fe8 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -3,35 +3,38 @@ """ import asyncio +from typing import Any import mcp.server.stdio from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions -# Create a server instance -server = Server("example-server") - -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt( + ctx: ServerRequestContext[Any], params: types.GetPromptRequestParams +) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -44,6 +47,14 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +# Create a server instance +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 725f5711a..f3edf20a2 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -7,40 +7,50 @@ import mcp.server.stdio from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: +async def handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = str((params.arguments or {}).get("message", "")) return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index da8ff7bdf..8628909b7 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -2,12 +2,14 @@ uv run examples/snippets/servers/lowlevel/lifespan.py """ +import asyncio from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any import mcp.server.stdio from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -44,40 +46,46 @@ async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] - - -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context + # Access lifespan context from the ctx parameter db = ctx.lifespan_context["db"] # Execute query - results = await db.query(arguments["query"]) + results = await db.query((params.arguments or {})["query"]) + + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - return [types.TextContent(type="text", text=f"Query results: {results}")] + +# Pass lifespan to server +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -98,6 +106,4 @@ async def run(): if __name__ == "__main__": - import asyncio - asyncio.run(run()) diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index cad8f67da..4d0fea827 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -7,43 +7,44 @@ import mcp.server.stdio from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] # Simulated weather data - in production, call a weather API weather_data = { @@ -53,12 +54,23 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: "city": city, # Include the requested city } - # low-level server will validate structured output against the tool's + # Return as CallToolResult with structured_content for structured output. + # The low-level server will validate structured output against the tool's # output schema, and additionally serialize it into a TextContent block # for backwards compatibility with pre-2025-06-18 clients. - return weather_data + return types.CallToolResult( + content=[types.TextContent(type="text", text=str(weather_data))], + structured_content=weather_data, + ) else: - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index bb406653e..fa9aa42fd 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -1,22 +1,21 @@ -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with MCP server constructor kwargs.""" from mcp import types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server -# Initialize the server -server = Server("paginated-server") - # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 - # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + # Extract cursor from params + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -32,3 +31,7 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +# Initialize the server +server = Server("paginated-server", on_list_resources=handle_list_resources) diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a2dada3af..aab5c33f7 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -1,5 +1,6 @@ +from .context import ServerRequestContext from .lowlevel import NotificationOptions, Server from .mcpserver import MCPServer from .models import InitializationOptions -__all__ = ["Server", "MCPServer", "NotificationOptions", "InitializationOptions"] +__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 80ae5912b..91aa9a645 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -160,10 +160,7 @@ async def run_task( RuntimeError: If task support is not enabled or task_metadata is missing Example: - @server.call_tool() - async def handle_tool(name: str, args: dict): - ctx = server.request_context - + async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( message="Are you sure?", diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 991221bd0..b2268bc1c 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -44,17 +44,14 @@ class TaskResultHandler: 5. Returns the final result Usage: - # Create handler with store and queue - handler = TaskResultHandler(task_store, message_queue) - - # Register it with the server - @server.experimental.get_task_result() - async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = server.request_context - return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) + async def handle_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + ... + + server.experimental.enable_tasks( + on_task_result=handle_task_result, + ) """ def __init__( diff --git a/src/mcp/server/lowlevel/__init__.py b/src/mcp/server/lowlevel/__init__.py index 66df38991..37191ba1a 100644 --- a/src/mcp/server/lowlevel/__init__.py +++ b/src/mcp/server/lowlevel/__init__.py @@ -1,3 +1,3 @@ from .server import NotificationOptions, Server -__all__ = ["Server", "NotificationOptions"] +__all__ = ["NotificationOptions", "Server"] diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 9b472c023..8ac268728 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -7,10 +7,12 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import Any, Generic +from typing_extensions import TypeVar + +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import MCPError from mcp.shared.experimental.tasks.helpers import cancel_task from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore @@ -18,16 +20,16 @@ from mcp.shared.experimental.tasks.store import TaskStore from mcp.types import ( INVALID_PARAMS, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, - ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, TasksCallCapability, @@ -36,13 +38,12 @@ TasksToolsCapability, ) -if TYPE_CHECKING: - from mcp.server.lowlevel.server import Server - logger = logging.getLogger(__name__) +LifespanResultT = TypeVar("LifespanResultT", default=Any) + -class ExperimentalHandlers: +class ExperimentalHandlers(Generic[LifespanResultT]): """Experimental request/notification handlers. WARNING: These APIs are experimental and may change without notice. @@ -50,13 +51,13 @@ class ExperimentalHandlers: def __init__( self, - server: Server, - request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], - notification_handlers: dict[type, Callable[..., Awaitable[None]]], - ): - self._server = server - self._request_handlers = request_handlers - self._notification_handlers = notification_handlers + add_request_handler: Callable[ + [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None + ], + has_handler: Callable[[str], bool], + ) -> None: + self._add_request_handler = add_request_handler + self._has_handler = has_handler self._task_support: TaskSupport | None = None @property @@ -66,16 +67,13 @@ def task_support(self) -> TaskSupport | None: def update_capabilities(self, capabilities: ServerCapabilities) -> None: # Only add tasks capability if handlers are registered - if not any( - req_type in self._request_handlers - for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest] - ): + if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): return capabilities.tasks = ServerTasksCapability() - if ListTasksRequest in self._request_handlers: + if self._has_handler("tasks/list"): capabilities.tasks.list = TasksListCapability() - if CancelTaskRequest in self._request_handlers: + if self._has_handler("tasks/cancel"): capabilities.tasks.cancel = TasksCancelCapability() capabilities.tasks.requests = ServerTasksRequestsCapability( @@ -86,15 +84,35 @@ def enable_tasks( self, store: TaskStore | None = None, queue: TaskMessageQueue | None = None, + *, + on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] + | None = None, + on_task_result: Callable[ + [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] + ] + | None = None, + on_list_tasks: Callable[ + [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] + ] + | None = None, + on_cancel_task: Callable[ + [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] + ] + | None = None, ) -> TaskSupport: """Enable experimental task support. - This sets up the task infrastructure and auto-registers default handlers - for tasks/get, tasks/result, tasks/list, and tasks/cancel. + This sets up the task infrastructure and registers handlers for + tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers + can be provided via the on_* kwargs; any not provided will use defaults. Args: store: Custom TaskStore implementation (defaults to InMemoryTaskStore) queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) + on_get_task: Custom handler for tasks/get + on_task_result: Custom handler for tasks/result + on_list_tasks: Custom handler for tasks/list + on_cancel_task: Custom handler for tasks/cancel Returns: The TaskSupport configuration object @@ -117,24 +135,27 @@ def enable_tasks( queue = InMemoryTaskMessageQueue() self._task_support = TaskSupport(store=store, queue=queue) - - # Auto-register default handlers - self._register_default_task_handlers() - - return self._task_support - - def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" - assert self._task_support is not None - support = self._task_support - - # Register get_task handler if not already registered - if GetTaskRequest not in self._request_handlers: - - async def _default_get_task(req: GetTaskRequest) -> ServerResult: - task = await support.store.get_task(req.params.task_id) + task_support = self._task_support + + # Register user-provided handlers + if on_get_task is not None: + self._add_request_handler("tasks/get", on_get_task) + if on_task_result is not None: + self._add_request_handler("tasks/result", on_task_result) + if on_list_tasks is not None: + self._add_request_handler("tasks/list", on_list_tasks) + if on_cancel_task is not None: + self._add_request_handler("tasks/cancel", on_cancel_task) + + # Fill in defaults for any not provided + if not self._has_handler("tasks/get"): + + async def _default_get_task( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams + ) -> GetTaskResult: + task = await task_support.store.get_task(params.task_id) if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {req.params.task_id}") + raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( task_id=task.task_id, status=task.status, @@ -145,136 +166,39 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: poll_interval=task.poll_interval, ) - self._request_handlers[GetTaskRequest] = _default_get_task + self._add_request_handler("tasks/get", _default_get_task) - # Register get_task_result handler if not already registered - if GetTaskPayloadRequest not in self._request_handlers: + if not self._has_handler("tasks/result"): - async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = self._server.request_context - result = await support.handler.handle(req, ctx.session, ctx.request_id) + async def _default_get_task_result( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + assert ctx.request_id is not None + req = GetTaskPayloadRequest(params=params) + result = await task_support.handler.handle(req, ctx.session, ctx.request_id) return result - self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result + self._add_request_handler("tasks/result", _default_get_task_result) - # Register list_tasks handler if not already registered - if ListTasksRequest not in self._request_handlers: + if not self._has_handler("tasks/list"): - async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) + async def _default_list_tasks( + ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListTasksResult: + cursor = params.cursor if params else None + tasks, next_cursor = await task_support.store.list_tasks(cursor) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - self._request_handlers[ListTasksRequest] = _default_list_tasks - - # Register cancel_task handler if not already registered - if CancelTaskRequest not in self._request_handlers: - - async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - result = await cancel_task(support.store, req.params.task_id) - return result - - self._request_handlers[CancelTaskRequest] = _default_cancel_task - - def list_tasks( - self, - ) -> Callable[ - [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], - Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ]: - """Register a handler for listing tasks. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: - logger.debug("Registering handler for ListTasksRequest") - wrapper = create_call_wrapper(func, ListTasksRequest) - - async def handler(req: ListTasksRequest) -> ListTasksResult: - result = await wrapper(req) - return result - - self._request_handlers[ListTasksRequest] = handler - return func - - return decorator - - def get_task( - self, - ) -> Callable[ - [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] - ]: - """Register a handler for getting task status. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], - ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: - logger.debug("Registering handler for GetTaskRequest") - wrapper = create_call_wrapper(func, GetTaskRequest) - - async def handler(req: GetTaskRequest) -> GetTaskResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskRequest] = handler - return func - - return decorator - - def get_task_result( - self, - ) -> Callable[ - [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], - Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ]: - """Register a handler for getting task results/payload. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: - logger.debug("Registering handler for GetTaskPayloadRequest") - wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - - async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskPayloadRequest] = handler - return func - - return decorator - - def cancel_task( - self, - ) -> Callable[ - [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], - Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ]: - """Register a handler for cancelling tasks. - - WARNING: This API is experimental and may change without notice. - """ + self._add_request_handler("tasks/list", _default_list_tasks) - def decorator( - func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: - logger.debug("Registering handler for CancelTaskRequest") - wrapper = create_call_wrapper(func, CancelTaskRequest) + if not self._has_handler("tasks/cancel"): - async def handler(req: CancelTaskRequest) -> CancelTaskResult: - result = await wrapper(req) + async def _default_cancel_task( + ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams + ) -> CancelTaskResult: + result = await cancel_task(task_support.store, params.task_id) return result - self._request_handlers[CancelTaskRequest] = handler - return func + self._add_request_handler("tasks/cancel", _default_cancel_task) - return decorator + return task_support diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py deleted file mode 100644 index d17697090..000000000 --- a/src/mcp/server/lowlevel/func_inspection.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from collections.abc import Callable -from typing import Any, TypeVar, get_type_hints - -T = TypeVar("T") -R = TypeVar("R") - - -def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: - """Create a wrapper function that knows how to call func with the request object. - - Returns a wrapper function that takes the request and calls func appropriately. - - The wrapper handles three calling patterns: - 1. Positional-only parameter typed as request_type (no default): func(req) - 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default: func() - """ - try: - sig = inspect.signature(func) - type_hints = get_type_hints(func) - except (ValueError, TypeError, NameError): # pragma: no cover - return lambda _: func() - - # Check for positional-only parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind == inspect.Parameter.POSITIONAL_ONLY: - param_type = type_hints.get(param_name) - if param_type == request_type: # pragma: no branch - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - # Found positional-only parameter with correct type and no default - return lambda req: func(req) - - # Check for any positional/keyword parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch - param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - - # Found keyword parameter with correct type and no default - # Need to capture param_name in closure properly - def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: - return lambda req: func(**{name: req}) - - return make_keyword_wrapper(param_name) - - # No request parameter found - use old style - return lambda _: func() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 96dcaf1c7..3d09e9dc7 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -2,82 +2,49 @@ This module provides a framework for creating an MCP (Model Context Protocol) server. It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. +using constructor-based handler registration. Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None, - message: str | None - ) -> None: - # Implementation - -4. Run the server: +1. Define handler functions: + async def my_list_tools(ctx, params): + return types.ListToolsResult(tools=[...]) + + async def my_call_tool(ctx, params): + return types.CallToolResult(content=[...]) + +2. Create a Server instance with on_* handlers: + server = Server( + "your_server_name", + on_list_tools=my_list_tools, + on_call_tool=my_call_tool, + ) + +3. Run the server: async def main(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) asyncio.run(main()) -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. +The Server class dispatches incoming requests and notifications to registered +handler callables by method string. """ from __future__ import annotations -import base64 import contextvars -import json import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic, TypeAlias, cast +from typing import Any, Generic import anyio -import jsonschema from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware @@ -94,30 +61,20 @@ async def main(): from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError +from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) -RequestT = TypeVar("RequestT", default=Any) - -# type aliases for tool call results -StructuredContent: TypeAlias = dict[str, Any] -UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] -CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] -# This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[ServerRequestContext[Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[ServerRequestContext[Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -128,7 +85,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Args: @@ -140,10 +97,15 @@ async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[s yield {} -class Server(Generic[LifespanResultT, RequestT]): +async def _ping_handler(ctx: ServerRequestContext[Any], params: types.RequestParams | None) -> types.EmptyResult: + return types.EmptyResult() + + +class Server(Generic[LifespanResultT]): def __init__( self, name: str, + *, version: str | None = None, title: str | None = None, description: str | None = None, @@ -151,9 +113,80 @@ def __init__( website_url: str | None = None, icons: list[types.Icon] | None = None, lifespan: Callable[ - [Server[LifespanResultT, RequestT]], + [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + # Request handlers + on_list_tools: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListToolsResult], + ] + | None = None, + on_call_tool: Callable[ + [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], + Awaitable[types.CallToolResult], + ] + | None = None, + on_list_resources: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourcesResult], + ] + | None = None, + on_list_resource_templates: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourceTemplatesResult], + ] + | None = None, + on_read_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams], + Awaitable[types.ReadResourceResult], + ] + | None = None, + on_subscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.SubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_unsubscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.UnsubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_list_prompts: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListPromptsResult], + ] + | None = None, + on_get_prompt: Callable[ + [ServerRequestContext[LifespanResultT], types.GetPromptRequestParams], + Awaitable[types.GetPromptResult], + ] + | None = None, + on_completion: Callable[ + [ServerRequestContext[LifespanResultT], types.CompleteRequestParams], + Awaitable[types.CompleteResult], + ] + | None = None, + on_set_logging_level: Callable[ + [ServerRequestContext[LifespanResultT], types.SetLevelRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_ping: Callable[ + [ServerRequestContext[LifespanResultT], types.RequestParams | None], + Awaitable[types.EmptyResult], + ] = _ping_handler, + # Notification handlers + on_roots_list_changed: Callable[ + [ServerRequestContext[LifespanResultT], types.NotificationParams | None], + Awaitable[None], + ] + | None = None, + on_progress: Callable[ + [ServerRequestContext[LifespanResultT], types.ProgressNotificationParams], + Awaitable[None], + ] + | None = None, ): self.name = name self.version = version @@ -163,15 +196,72 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self._tool_cache: dict[str, types.Tool] = {} - self._experimental_handlers: ExperimentalHandlers | None = None + self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} + self._notification_handlers: dict[ + str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] + ] = {} + self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) + # Populate internal handler dicts from on_* kwargs + self._request_handlers.update( + { + method: handler + for method, handler in { + "ping": on_ping, + "prompts/list": on_list_prompts, + "prompts/get": on_get_prompt, + "resources/list": on_list_resources, + "resources/templates/list": on_list_resource_templates, + "resources/read": on_read_resource, + "resources/subscribe": on_subscribe_resource, + "resources/unsubscribe": on_unsubscribe_resource, + "tools/list": on_list_tools, + "tools/call": on_call_tool, + "logging/setLevel": on_set_logging_level, + "completion/complete": on_completion, + }.items() + if handler is not None + } + ) + + self._notification_handlers.update( + { + method: handler + for method, handler in { + "notifications/roots/list_changed": on_roots_list_changed, + "notifications/progress": on_progress, + }.items() + if handler is not None + } + ) + + def _add_request_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + ) -> None: + """Add a request handler, silently replacing any existing handler for the same method.""" + self._request_handlers[method] = handler + + def _add_notification_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]], + ) -> None: + """Add a notification handler, silently replacing any existing handler for the same method.""" + self._notification_handlers[method] = handler + + def _has_handler(self, method: str) -> bool: + """Check if a handler is registered for the given method.""" + return method in self._request_handlers or method in self._notification_handlers + + # TODO: Rethink capabilities API. Currently capabilities are derived from registered + # handlers but require NotificationOptions to be passed externally for list_changed + # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities + # entirely from server state (e.g. constructor params for list_changed) instead of + # requiring callers to assemble them at create_initialization_options() time. def create_initialization_options( self, notification_options: NotificationOptions | None = None, @@ -214,25 +304,26 @@ def get_capabilities( completions_capability = None # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: + if "prompts/list" in self._request_handlers: prompts_capability = types.PromptsCapability(list_changed=notification_options.prompts_changed) # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: + if "resources/list" in self._request_handlers: resources_capability = types.ResourcesCapability( - subscribe=False, list_changed=notification_options.resources_changed + subscribe="resources/subscribe" in self._request_handlers, + list_changed=notification_options.resources_changed, ) # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: + if "tools/list" in self._request_handlers: tools_capability = types.ToolsCapability(list_changed=notification_options.tools_changed) # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: + if "logging/setLevel" in self._request_handlers: logging_capability = types.LoggingCapability() # Set completions capabilities if handler exists - if types.CompleteRequest in self.request_handlers: + if "completion/complete" in self._request_handlers: completions_capability = types.CompletionsCapability() capabilities = types.ServerCapabilities( @@ -248,12 +339,7 @@ def get_capabilities( return capabilities @property - def request_context(self) -> ServerRequestContext[LifespanResultT, RequestT]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - @property - def experimental(self) -> ExperimentalHandlers: + def experimental(self) -> ExperimentalHandlers[LifespanResultT]: """Experimental APIs for tasks and other features. WARNING: These APIs are experimental and may change without notice. @@ -261,7 +347,10 @@ def experimental(self) -> ExperimentalHandlers: # We create this inline so we only add these capabilities _if_ they're actually used if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers) + self._experimental_handlers = ExperimentalHandlers( + add_request_handler=self._add_request_handler, + has_handler=self._has_handler, + ) return self._experimental_handlers @property @@ -278,374 +367,6 @@ def session_manager(self) -> StreamableHTTPSessionManager: ) return self._session_manager # pragma: no cover - def list_prompts(self): - def decorator( - func: Callable[[], Awaitable[list[types.Prompt]]] - | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], - ): - logger.debug("Registering handler for PromptListRequest") - - wrapper = create_call_wrapper(func, types.ListPromptsRequest) - - async def handler(req: types.ListPromptsRequest): - result = await wrapper(req) - # Handle both old style (list[Prompt]) and new style (ListPromptsResult) - if isinstance(result, types.ListPromptsResult): - return result - else: - # Old style returns list[Prompt] - return types.ListPromptsResult(prompts=result) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return prompt_get - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator( - func: Callable[[], Awaitable[list[types.Resource]]] - | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], - ): - logger.debug("Registering handler for ListResourcesRequest") - - wrapper = create_call_wrapper(func, types.ListResourcesRequest) - - async def handler(req: types.ListResourcesRequest): - result = await wrapper(req) - # Handle both old style (list[Resource]) and new style (ListResourcesResult) - if isinstance(result, types.ListResourcesResult): - return result - else: - # Old style returns list[Resource] - return types.ListResourcesResult(resources=result) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ListResourceTemplatesResult(resource_templates=templates) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator( - func: Callable[[str], Awaitable[str | bytes | Iterable[ReadResourceContents]]], - ): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - - def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any] | None = None): - # Note: ResourceContents uses Field(alias="_meta"), so we must use the alias key - meta_kwargs: dict[str, Any] = {"_meta": meta} if meta is not None else {} - match data: - case str() as data: - return types.TextResourceContents( - uri=req.params.uri, - text=data, - mime_type=mime_type or "text/plain", - **meta_kwargs, - ) - case bytes() as data: # pragma: no branch - return types.BlobResourceContents( - uri=req.params.uri, - blob=base64.b64encode(data).decode(), - mime_type=mime_type or "application/octet-stream", - **meta_kwargs, - ) - - match result: - case str() | bytes() as data: # pragma: lax no cover - warnings.warn( - "Returning str or bytes from read_resource is deprecated. " - "Use Iterable[ReadResourceContents] instead.", - DeprecationWarning, - stacklevel=2, - ) - content = create_content(data, None) - case Iterable() as contents: - contents_list = [ - create_content( - content_item.content, content_item.mime_type, getattr(content_item, "meta", None) - ) - for content_item in contents - ] - return types.ReadResourceResult(contents=contents_list) - case _: # pragma: no cover - raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - - return types.ReadResourceResult(contents=[content]) # pragma: no cover - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.EmptyResult() - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator( - func: Callable[[], Awaitable[list[types.Tool]]] - | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], - ): - logger.debug("Registering handler for ListToolsRequest") - - wrapper = create_call_wrapper(func, types.ListToolsRequest) - - async def handler(req: types.ListToolsRequest): - result = await wrapper(req) - - # Handle both old style (list[Tool]) and new style (ListToolsResult) - if isinstance(result, types.ListToolsResult): - # Refresh the tool cache with returned tools - for tool in result.tools: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return result - else: - # Old style returns list[Tool] - # Clear and refresh the entire tool cache - self._tool_cache.clear() - for tool in result: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return types.ListToolsResult(tools=result) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def _make_error_result(self, error_message: str) -> types.CallToolResult: - """Create a CallToolResult with an error.""" - return types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) - - async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: - """Get tool definition from cache, refreshing if necessary. - - Returns the Tool object if found, None otherwise. - """ - if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: - logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) - - tool = self._tool_cache.get(tool_name) - if tool is None: - logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) - - return tool - - def call_tool(self, *, validate_input: bool = True): - """Register a tool call handler. - - Args: - validate_input: If True, validates input against inputSchema. Default is True. - - The handler validates input against inputSchema (if validate_input=True), calls the tool function, - and builds a CallToolResult with the results: - - Unstructured content (iterable of ContentBlock): returned in content - - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - - Both: returned in content and structuredContent - - If outputSchema is defined, validates structuredContent or errors if missing. - """ - - def decorator( - func: Callable[ - [str, dict[str, Any]], - Awaitable[ - UnstructuredContent - | StructuredContent - | CombinationContent - | types.CallToolResult - | types.CreateTaskResult - ], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - tool_name = req.params.name - arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) - - # input validation - if validate_input and tool: - try: - jsonschema.validate(instance=arguments, schema=tool.input_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Input validation error: {e.message}") - - # tool call - results = await func(tool_name, arguments) - - # output normalization - unstructured_content: UnstructuredContent - maybe_structured_content: StructuredContent | None - if isinstance(results, types.CallToolResult): - return results - elif isinstance(results, types.CreateTaskResult): - # Task-augmented execution returns task info instead of result - return results - elif isinstance(results, tuple) and len(results) == 2: - # tool returned both structured and unstructured content - unstructured_content, maybe_structured_content = cast(CombinationContent, results) - elif isinstance(results, dict): - # tool returned structured content only - maybe_structured_content = cast(StructuredContent, results) - unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): - # tool returned unstructured content only - unstructured_content = cast(UnstructuredContent, results) - maybe_structured_content = None - else: # pragma: no cover - return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") - - # output validation - if tool and tool.output_schema is not None: - if maybe_structured_content is None: - return self._make_error_result( - "Output validation error: outputSchema defined but no structured output returned" - ) - else: - try: - jsonschema.validate(instance=maybe_structured_content, schema=tool.output_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Output validation error: {e.message}") - - # result - return types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # by _handle_request, which converts it to an error response with code -32042 - raise - except Exception as e: - return self._make_error_result(str(e)) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progress_token, - req.params.progress, - req.params.total, - req.params.message, - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceTemplateReference, - types.CompletionArgument, - types.CompletionContext | None, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator - async def run( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -715,7 +436,7 @@ async def _handle_message( if raise_exceptions: raise message case _: - await self._handle_notification(message) + await self._handle_notification(message, session, lifespan_context) for warning in w: # pragma: lax no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -730,10 +451,9 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): + if handler := self._request_handlers.get(req.method): logger.debug("Dispatching request of type %s", type(req).__name__) - token = None try: # Extract request context and close_sse_stream from message metadata request_data = None @@ -746,32 +466,32 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - # Set our global state that can be retrieved via - # app.get_request_context() client_capabilities = session.client_params.capabilities if session.client_params else None task_support = self._experimental_handlers.task_support if self._experimental_handlers else None # Get task metadata from request params if present task_metadata = None if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) - token = request_ctx.set( - ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - response = await handler(req) + token = request_ctx.set(ctx) + try: + response = await handler(ctx, req.params) + finally: + request_ctx.reset(token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): @@ -781,10 +501,6 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err), data=None) - finally: - # Reset the global state after we are done - if token is not None: # pragma: no branch - request_ctx.reset(token) await message.respond(response) else: # pragma: no cover @@ -792,12 +508,29 @@ async def _handle_request( logger.debug("Response sent") - async def _handle_notification(self, notify: Any): - if handler := self.notification_handlers.get(type(notify)): # type: ignore + async def _handle_notification( + self, + notify: types.ClientNotification, + session: ServerSession, + lifespan_context: LifespanResultT, + ) -> None: + if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") @@ -914,5 +647,3 @@ def streamable_http_app( ) -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.EmptyResult() diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..a26254632 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -2,7 +2,9 @@ from __future__ import annotations +import base64 import inspect +import json import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager @@ -29,7 +31,7 @@ from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation from mcp.server.elicitation import elicit_url as _elicit_url from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT, Server +from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager @@ -42,7 +44,30 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.shared.exceptions import MCPError +from mcp.types import ( + Annotations, + BlobResourceContents, + CallToolRequestParams, + CallToolResult, + CompleteRequestParams, + CompleteResult, + Completion, + ContentBlock, + GetPromptRequestParams, + GetPromptResult, + Icon, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + ToolAnnotations, +) from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -91,9 +116,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: MCPServer[LifespanResultT], lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[Server[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: +) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: Server[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -132,6 +157,9 @@ def __init__( auth=auth, ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", title=title, @@ -140,13 +168,17 @@ def __init__( website_url=website_url, icons=icons, version=version, + on_list_tools=self._handle_list_tools, + on_call_tool=self._handle_call_tool, + on_list_resources=self._handle_list_resources, + on_read_resource=self._handle_read_resource, + on_list_resource_templates=self._handle_list_resource_templates, + on_list_prompts=self._handle_list_prompts, + on_get_prompt=self._handle_get_prompt, # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover @@ -164,9 +196,6 @@ def __init__( self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._custom_starlette_routes: list[Route] = [] - # Set up MCP protocol handlers - self._setup_handlers() - # Configure logging configure_logging(self.settings.log_level) @@ -263,18 +292,80 @@ def run( case "streamable-http": # pragma: no cover anyio.run(lambda: self.run_streamable_http_async(**kwargs)) - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._lowlevel_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # MCPServer does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._lowlevel_server.call_tool(validate_input=False)(self.call_tool) - self._lowlevel_server.list_resources()(self.list_resources) - self._lowlevel_server.read_resource()(self.read_resource) - self._lowlevel_server.list_prompts()(self.list_prompts) - self._lowlevel_server.get_prompt()(self.get_prompt) - self._lowlevel_server.list_resource_templates()(self.list_resource_templates) + async def _handle_list_tools( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=await self.list_tools()) + + async def _handle_call_tool( + self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams + ) -> CallToolResult: + try: + result = await self.call_tool(params.name, params.arguments or {}) + except MCPError: + raise + except Exception as e: + return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True) + if isinstance(result, CallToolResult): + return result + if isinstance(result, tuple) and len(result) == 2: + unstructured_content, structured_content = result + return CallToolResult( + content=list(unstructured_content), # type: ignore[arg-type] + structured_content=structured_content, # type: ignore[arg-type] + ) + if isinstance(result, dict): + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(result, indent=2))], + structured_content=result, + ) + return CallToolResult(content=list(result)) + + async def _handle_list_resources( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=await self.list_resources()) + + async def _handle_read_resource( + self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams + ) -> ReadResourceResult: + results = await self.read_resource(params.uri) + contents: list[TextResourceContents | BlobResourceContents] = [] + for item in results: + if isinstance(item.content, bytes): + contents.append( + BlobResourceContents( + uri=params.uri, + blob=base64.b64encode(item.content).decode(), + mime_type=item.mime_type or "application/octet-stream", + _meta=item.meta, + ) + ) + else: + contents.append( + TextResourceContents( + uri=params.uri, + text=item.content, + mime_type=item.mime_type or "text/plain", + _meta=item.meta, + ) + ) + return ReadResourceResult(contents=contents) + + async def _handle_list_resource_templates( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates()) + + async def _handle_list_prompts( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=await self.list_prompts()) + + async def _handle_get_prompt( + self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams + ) -> GetPromptResult: + return await self.get_prompt(params.name, params.arguments) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -298,7 +389,7 @@ def get_context(self) -> Context[LifespanResultT, Request]: during a request; outside a request, most methods will error. """ try: - request_context = self._lowlevel_server.request_context + request_context = request_ctx.get() except LookupError: request_context = None return Context(request_context=request_context, mcp_server=self) @@ -486,7 +577,24 @@ async def handle_completion(ref, argument, context): return Completion(values=["option1", "option2"]) return None """ - return self._lowlevel_server.completion() + + def decorator(func: _CallableT) -> _CallableT: + async def handler( + ctx: ServerRequestContext[LifespanResultT], params: CompleteRequestParams + ) -> CompleteResult: + result = await func(params.ref, params.argument, params.context) + return CompleteResult( + completion=result if result is not None else Completion(values=[], total=None, has_more=None), + ) + + # TODO(maxisbey): remove private access — completion needs post-construction + # handler registration, find a better pattern for this + self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] + "completion/complete", handler + ) + return func + + return decorator def add_resource(self, resource: Resource) -> None: """Add a resource to the server. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a3..6925aa556 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,30 +6,22 @@ Common usage pattern: ``` - server = Server(name) - - @server.call_tool() - async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: + async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: # Check client capabilities before proceeding if ctx.session.check_client_capability( types.ClientCapabilities(experimental={"advanced_tools": dict()}) ): - # Perform advanced tool operations - result = await perform_advanced_tool_operation(arguments) + result = await perform_advanced_tool_operation(params.arguments) else: - # Fall back to basic tool operations - result = await perform_basic_tool_operation(arguments) - + result = await perform_basic_tool_operation(params.arguments) return result - @server.list_prompts() - async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: - # Access session for any necessary checks or operations + async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: if ctx.session.client_params: - # Customize prompts based on client initialization parameters - return generate_custom_prompts(ctx.session.client_params) - else: - return default_prompts + return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) + return ListPromptsResult(prompts=default_prompts) + + server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) ``` The ServerSession class is typically used internally by the Server class and should not diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index a954b24a4..0af16f770 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -60,7 +60,7 @@ class StreamableHTTPSessionManager: def __init__( self, - app: Server[Any, Any], + app: Server[Any], event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index 2facc2a49..bbcee2d02 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -13,8 +13,12 @@ @dataclass(kw_only=True) class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests.""" + """Common context for handling incoming requests. + + For request handlers, request_id is always populated. + For notification handlers, request_id is None. + """ - request_id: RequestId - meta: RequestParamsMeta | None session: SessionT + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index 38ca802da..bd1781cb5 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -72,9 +72,8 @@ async def cancel_task( - Task is already in a terminal state (completed, failed, cancelled) Example: - @server.experimental.cancel_task() - async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: - return await cancel_task(store, request.params.taskId) + async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: + return await cancel_task(store, params.task_id) """ task = await store.get_task(task_id) if task is None: diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d..1858eeac3 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,6 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Any from mcp.types import JSONRPCMessage, RequestId @@ -30,8 +31,10 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None - # Request-specific context (e.g., headers, auth info) - request_context: object | None = None + # Transport-specific request context (e.g. starlette Request for HTTP + # transports, None for stdio). Typed as Any because the server layer is + # transport-agnostic. + request_context: Any = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index d483ae54b..50a99893a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -12,6 +12,7 @@ from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( CallToolResult, @@ -38,36 +39,45 @@ pytestmark = pytest.mark.anyio -@pytest.fixture -def simple_server() -> Server: - """Create a simple MCP server for testing.""" - server = Server(name="test_server") +async def _handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> ListResourcesResult: + return ListResourcesResult( + resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")] + ) - @server.list_resources() - async def handle_list_resources(): - return [Resource(uri="memory://test", name="Test Resource", description="A test resource")] - @server.subscribe_resource() - async def handle_subscribe_resource(uri: str): - pass +async def _handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + return EmptyResult() - @server.unsubscribe_resource() - async def handle_unsubscribe_resource(uri: str): - pass - @server.set_logging_level() - async def handle_set_logging_level(level: str): - pass +async def _handle_unsubscribe_resource( + ctx: ServerRequestContext, params: types.UnsubscribeRequestParams +) -> EmptyResult: + return EmptyResult() - @server.completion() - async def handle_completion( - ref: types.PromptReference | types.ResourceTemplateReference, - argument: types.CompletionArgument, - context: types.CompletionContext | None, - ) -> types.Completion | None: - return types.Completion(values=[]) - return server +async def _handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + return EmptyResult() + + +async def _handle_completion( + ctx: ServerRequestContext, params: types.CompleteRequestParams +) -> types.CompleteResult: + return types.CompleteResult(completion=types.Completion(values=[])) + + +@pytest.fixture +def simple_server() -> Server: + """Create a simple MCP server for testing.""" + return Server( + name="test_server", + on_list_resources=_handle_list_resources, + on_subscribe_resource=_handle_subscribe_resource, + on_unsubscribe_resource=_handle_unsubscribe_resource, + on_set_logging_level=_handle_set_logging_level, + on_completion=_handle_completion, + ) @pytest.fixture @@ -202,19 +212,16 @@ async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() - server = Server(name="test_server") - @server.progress_notification() async def handle_progress_notification( - progress_token: str | int, - progress: float = 0.0, - total: float | None = None, - message: str | None = None, + ctx: ServerRequestContext, params: types.ProgressNotificationParams ) -> None: nonlocal received_from_client - received_from_client = {"progress_token": progress_token, "progress": progress} + received_from_client = {"progress_token": params.progress_token, "progress": params.progress} event.set() + server = Server(name="test_server", on_progress=handle_progress_notification) + async with Client(server) as client: await client.send_progress_notification(progress_token="token123", progress=50.0) await event.wait() diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 5cca8c194..7e2a521d7 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -1,14 +1,13 @@ """Tests for Unicode handling in streamable HTTP transport. Verifies that Unicode text is correctly transmitted and received in both directions -(server→client and client→server) using the streamable HTTP transport. +(server->client and client->server) using the streamable HTTP transport. """ import multiprocessing import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager -from typing import Any import pytest from starlette.applications import Starlette @@ -18,6 +17,7 @@ from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool from tests.test_helpers import wait_for_server @@ -47,54 +47,65 @@ def run_unicode_server(port: int) -> None: # pragma: no cover import uvicorn # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") - - @server.list_tools() - async def list_tools() -> list[Tool]: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: """List tools with Unicode descriptions.""" - return [ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description=( + "🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨" + ), + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, + }, + "required": ["text"], }, - "required": ["text"], - }, - ), - ] + ), + ] + ) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult: """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] + if params.name == "echo_unicode": + arguments = params.arguments or {} + text = arguments.get("text", "") + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=f"Echo: {text}", + ) + ] + ) else: - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") - @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: """List prompts with Unicode names and descriptions.""" - return [ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], + ) + ] + ) - @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: + async def handle_get_prompt( + ctx: ServerRequestContext, params: types.GetPromptRequestParams + ) -> types.GetPromptResult: """Get a prompt with Unicode content.""" - if name == "unicode_prompt": + if params.name == "unicode_prompt": return types.GetPromptResult( messages=[ types.PromptMessage( @@ -106,7 +117,15 @@ async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPr ) ] ) - raise ValueError(f"Unknown prompt: {name}") + raise ValueError(f"Unknown prompt: {params.name}") + + server = Server( + name="unicode_test_server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) # Create the session manager session_manager = StreamableHTTPSessionManager( @@ -177,7 +196,7 @@ async def test_streamable_http_client_unicode_tool_call(running_unicode_server: async with ClientSession(read_stream, write_stream) as session: await session.initialize() - # Test 1: List tools (server→client Unicode in descriptions) + # Test 1: List tools (server->client Unicode in descriptions) tools = await session.list_tools() assert len(tools.tools) == 1 @@ -188,7 +207,7 @@ async def test_streamable_http_client_unicode_tool_call(running_unicode_server: assert "🔤" in echo_tool.description assert "👋" in echo_tool.description - # Test 2: Send Unicode text in tool call (client→server→client) + # Test 2: Send Unicode text in tool call (client->server->client) for test_name, test_string in UNICODE_TEST_STRINGS.items(): result = await session.call_tool("echo_unicode", arguments={"text": test_string}) @@ -209,7 +228,7 @@ async def test_streamable_http_client_unicode_prompts(running_unicode_server: st async with ClientSession(read_stream, write_stream) as session: await session.initialize() - # Test 1: List prompts (server→client Unicode in descriptions) + # Test 1: List prompts (server->client Unicode in descriptions) prompts = await session.list_prompts() assert len(prompts.prompts) == 1 @@ -218,7 +237,7 @@ async def test_streamable_http_client_unicode_prompts(running_unicode_server: st assert prompt.description is not None assert "Слой хранилища, где располагаются" in prompt.description - # Test 2: Get prompt with Unicode content (server→client) + # Test 2: Get prompt with Unicode content (server->client) result = await session.get_prompt("unicode_prompt", arguments={}) assert len(result.messages) == 1 diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 4d7c53db2..9c53376e9 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -4,8 +4,9 @@ from mcp import Client, types from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import ListToolsRequest, ListToolsResult +from mcp.types import ListToolsResult from .conftest import StreamSpyCollection @@ -105,13 +106,17 @@ async def test_list_tools_with_strict_server_validation( async def test_list_tools_with_lowlevel_server(): """Test that list_tools works with a lowlevel Server using params.""" - server = Server("test-lowlevel") - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: # Echo back what cursor we received in the tool description - cursor = request.params.cursor if request.params else None - return ListToolsResult(tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})]) + cursor = params.cursor if params else None + return ListToolsResult( + tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})] + ) + + server = Server("test-lowlevel", on_list_tools=handle_list_tools) async with Client(server) as client: result = await client.list_tools() diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index cc93d303b..08ce46447 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -7,9 +7,10 @@ import jsonschema import pytest -from mcp import Client +from mcp import Client, types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server -from mcp.types import Tool +from mcp.types import CallToolResult, ListToolsResult, Tool @contextmanager @@ -41,8 +42,6 @@ def selective_mock(instance: Any = None, schema: Any = None, *args: Any, **kwarg @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_basemodel(): """Test that client validates structured content against schema for BaseModel outputs""" - # Create a malicious low-level server that returns invalid structured content - server = Server("test-server") # Define the expected schema for our tool output_schema = { @@ -52,22 +51,34 @@ async def test_tool_structured_output_client_side_validation_basemodel(): "title": "UserOutput", } - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="get_user", - description="Get user data", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="get_user", + description="Get user data", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + ) + + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> CallToolResult: # Return invalid structured content - age is string instead of integer - # The low-level server will wrap this in CallToolResult - return {"name": "John", "age": "invalid"} # Invalid: age should be int + return CallToolResult( + content=[], + structured_content={"name": "John", "age": "invalid"}, # Invalid: age should be int + ) + + server = Server( + "test-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Test that client validates the structured content with bypass_server_output_validation(): @@ -82,7 +93,6 @@ async def call_tool(name: str, arguments: dict[str, Any]): @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_primitive(): """Test that client validates structured content for primitive outputs""" - server = Server("test-server") # Primitive types are wrapped in {"result": value} output_schema = { @@ -92,21 +102,34 @@ async def test_tool_structured_output_client_side_validation_primitive(): "title": "calculate_Output", } - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="calculate", - description="Calculate something", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="calculate", + description="Calculate something", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + ) + + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> CallToolResult: # Return invalid structured content - result is string instead of integer - return {"result": "not_a_number"} # Invalid: should be int + return CallToolResult( + content=[], + structured_content={"result": "not_a_number"}, # Invalid: should be int + ) + + server = Server( + "test-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) with bypass_server_output_validation(): async with Client(server) as client: @@ -119,26 +142,38 @@ async def call_tool(name: str, arguments: dict[str, Any]): @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_dict_typed(): """Test that client validates dict[str, T] structured content""" - server = Server("test-server") # dict[str, int] schema output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="get_scores", - description="Get scores", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="get_scores", + description="Get scores", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + ) + + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> CallToolResult: # Return invalid structured content - values should be integers - return {"alice": "100", "bob": "85"} # Invalid: values should be int + return CallToolResult( + content=[], + structured_content={"alice": "100", "bob": "85"}, # Invalid: values should be int + ) + + server = Server( + "test-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) with bypass_server_output_validation(): async with Client(server) as client: @@ -151,7 +186,6 @@ async def call_tool(name: str, arguments: dict[str, Any]): @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_missing_required(): """Test that client validates missing required fields""" - server = Server("test-server") output_schema = { "type": "object", @@ -160,21 +194,34 @@ async def test_tool_structured_output_client_side_validation_missing_required(): "title": "PersonOutput", } - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="get_person", - description="Get person data", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="get_person", + description="Get person data", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + ) + + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> CallToolResult: # Return structured content missing required field 'email' - return {"name": "John", "age": 30} # Missing required 'email' + return CallToolResult( + content=[], + structured_content={"name": "John", "age": 30}, # Missing required 'email' + ) + + server = Server( + "test-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) with bypass_server_output_validation(): async with Client(server) as client: @@ -187,17 +234,27 @@ async def call_tool(name: str, arguments: dict[str, Any]): @pytest.mark.anyio async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): """Test that client logs warning when tool is not in list_tools but has output_schema""" - server = Server("test-server") - @server.list_tools() - async def list_tools() -> list[Tool]: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: # Return empty list - tool is not listed - return [] + return ListToolsResult(tools=[]) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> CallToolResult: # Server still responds to the tool call with structured content - return {"result": 42} + return CallToolResult( + content=[], + structured_content={"result": 42}, + ) + + server = Server( + "test-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Set logging level to capture warnings caplog.set_level(logging.WARNING) diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 30ecb0ac3..c39464fb4 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -5,28 +5,32 @@ from mcp import Client from mcp.client._memory import InMemoryTransport from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import Resource +from mcp.types import ListResourcesResult, PaginatedRequestParams, Resource -@pytest.fixture -def simple_server() -> Server: - """Create a simple MCP server for testing.""" - server = Server(name="test_server") - - # pragma: no cover - handler exists only to register a resource capability. - # Transport tests verify stream creation, not handler invocation. - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ +async def _handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListResourcesResult: # pragma: no cover + return ListResourcesResult( + resources=[ Resource( uri="memory://test", name="Test Resource", description="A test resource", ) ] + ) - return server + +@pytest.fixture +def simple_server() -> Server: + """Create a simple MCP server for testing.""" + return Server( + name="test_server", + on_list_resources=_handle_list_resources, + ) @pytest.fixture diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index f21abf4d0..000a9ce90 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -10,6 +10,7 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -21,16 +22,17 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, ClientResult, CreateTaskResult, - GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + ListToolsResult, + PaginatedRequestParams, ServerNotification, ServerRequest, TaskMetadata, @@ -52,17 +54,15 @@ class AppContext: async def test_session_experimental_get_task() -> None: """Test session.experimental.get_task() method.""" # Note: We bypass the normal lifespan mechanism - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="test_tool", description="Test", input_schema={"type": "object"})]) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + app: AppContext = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None @@ -81,11 +81,10 @@ async def do_work(): raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def on_get_task(ctx: ServerRequestContext[Any], params: GetTaskRequestParams) -> GetTaskResult: + app: AppContext = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -96,6 +95,13 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) + server: Server[AppContext] = Server( # type: ignore[assignment] + "test-server", + on_list_tools=on_list_tools, + on_call_tool=on_call_tool, # type: ignore[arg-type] + ) + server.experimental.enable_tasks(on_get_task=on_get_task) + # Set up streams server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -159,17 +165,15 @@ async def run_server(app_context: AppContext): @pytest.mark.anyio async def test_session_experimental_get_task_result() -> None: """Test session.experimental.get_task_result() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="test_tool", description="Test", input_schema={"type": "object"})]) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + app: AppContext = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None @@ -190,16 +194,22 @@ async def do_work(): raise NotImplementedError - @server.experimental.get_task_result() - async def handle_get_task_result( - request: GetTaskPayloadRequest, + async def on_task_result( + ctx: ServerRequestContext[Any], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app: AppContext = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) + server: Server[AppContext] = Server( # type: ignore[assignment] + "test-server", + on_list_tools=on_list_tools, + on_call_tool=on_call_tool, # type: ignore[arg-type] + ) + server.experimental.enable_tasks(on_task_result=on_task_result) + # Set up streams server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -265,17 +275,15 @@ async def run_server(app_context: AppContext): @pytest.mark.anyio async def test_session_experimental_list_tasks() -> None: """Test TaskClient.list_tasks() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="test_tool", description="Test", input_schema={"type": "object"})]) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + app: AppContext = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None @@ -294,12 +302,18 @@ async def do_work(): raise NotImplementedError - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: - app = server.request_context.lifespan_context - tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + async def on_list_tasks(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListTasksResult: + app: AppContext = ctx.lifespan_context + tasks_list, next_cursor = await app.store.list_tasks(cursor=params.cursor if params else None) return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) + server: Server[AppContext] = Server( # type: ignore[assignment] + "test-server", + on_list_tools=on_list_tools, + on_call_tool=on_call_tool, # type: ignore[arg-type] + ) + server.experimental.enable_tasks(on_list_tasks=on_list_tasks) + # Set up streams server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -360,17 +374,15 @@ async def run_server(app_context: AppContext): @pytest.mark.anyio async def test_session_experimental_cancel_task() -> None: """Test TaskClient.cancel_task() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="test_tool", description="Test", input_schema={"type": "object"})]) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + app: AppContext = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None @@ -380,11 +392,10 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextCon raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def on_get_task(ctx: ServerRequestContext[Any], params: GetTaskRequestParams) -> GetTaskResult: + app: AppContext = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -395,14 +406,13 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" - await app.store.update_task(request.params.task_id, status="cancelled") + async def on_cancel_task(ctx: ServerRequestContext[Any], params: CancelTaskRequestParams) -> CancelTaskResult: + app: AppContext = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" + await app.store.update_task(params.task_id, status="cancelled") # CancelTaskResult extends Task, so we need to return the updated task info - updated_task = await app.store.get_task(request.params.task_id) + updated_task = await app.store.get_task(params.task_id) assert updated_task is not None return CancelTaskResult( task_id=updated_task.task_id, @@ -412,6 +422,13 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: ttl=updated_task.ttl, ) + server: Server[AppContext] = Server( # type: ignore[assignment] + "test-server", + on_list_tools=on_list_tools, + on_call_tool=on_call_tool, # type: ignore[arg-type] + ) + server.experimental.enable_tasks(on_get_task=on_get_task, on_cancel_task=on_cancel_task) + # Set up streams server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index 41cecc129..86e0c707f 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -18,6 +18,7 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -38,8 +39,8 @@ GetTaskRequest, GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerNotification, ServerRequest, TaskMetadata, @@ -69,29 +70,30 @@ async def test_task_lifecycle_with_task_execution() -> None: 3. Return CreateTaskResult immediately 4. Work executes in background, auto-fails on exception """ - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] store = InMemoryTaskStore() - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - input_schema={ - "type": "object", - "properties": {"input": {"type": "string"}}, - }, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: + return ListTasksResult( + tools=[ + Tool( + name="process_data", + description="Process data asynchronously", + input_schema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "process_data" and ctx.experimental.is_task: + if params.name == "process_data" and ctx.experimental.is_task: # 1. Create task in store task_metadata = ctx.experimental.task_metadata assert task_metadata is not None @@ -106,7 +108,7 @@ async def do_work(): async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("Processing input...") # Simulate work - input_value = arguments.get("input", "") + input_value = params.arguments.get("input", "") if params.arguments else "" result_text = f"Processed: {input_value.upper()}" await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) # Signal completion @@ -120,12 +122,12 @@ async def do_work(): raise NotImplementedError - # Register task query handlers (delegate to store) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task( + ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams + ) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -136,21 +138,34 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) # Return as GetTaskPayloadResult (which accepts extra fields) return GetTaskPayloadResult(**result.model_dump()) - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: raise NotImplementedError + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext] = Server( + "test-tasks", + on_list_tools=handle_list_tools, # type: ignore[arg-type] + on_call_tool=handle_call_tool, # type: ignore[arg-type] + ) + + server.experimental.enable_tasks( + on_get_task=handle_get_task, # type: ignore[arg-type] + on_task_result=handle_get_task_result, # type: ignore[arg-type] + on_list_tasks=handle_list_tasks, # type: ignore[arg-type] + ) + # Set up client-server communication server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -231,25 +246,26 @@ async def run_server(app_context: AppContext): @pytest.mark.anyio async def test_task_auto_fails_on_exception() -> None: """Test that task_execution automatically fails the task on unhandled exception.""" - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] store = InMemoryTaskStore() - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object", "properties": {}}, - ) - ] + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: + return ListTasksResult( + tools=[ + Tool( + name="failing_task", + description="A task that fails", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "failing_task" and ctx.experimental.is_task: + if params.name == "failing_task" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) @@ -272,11 +288,12 @@ async def do_failing_work(): raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task( + ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams + ) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -287,6 +304,17 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext] = Server( + "test-tasks-failure", + on_list_tools=handle_list_tools, # type: ignore[arg-type] + on_call_tool=handle_call_tool, # type: ignore[arg-type] + ) + + server.experimental.enable_tasks( + on_get_task=handle_get_task, # type: ignore[arg-type] + ) + # Set up streams server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py index 0d5d1df77..b8cf8cbc9 100644 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -17,6 +17,7 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.experimental.task_support import TaskSupport @@ -26,16 +27,15 @@ from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, - CancelTaskRequest, CancelTaskResult, CreateTaskResult, - GetTaskPayloadRequest, GetTaskPayloadResult, - GetTaskRequest, GetTaskResult, - ListTasksRequest, ListTasksResult, + ListToolsResult, + PaginatedRequestParams, TextContent, Tool, ToolExecution, @@ -52,29 +52,25 @@ async def test_run_task_basic_flow() -> None: 4. Work completes in background 5. Client polls and sees completed status """ - server = Server("test-run-task") - - # One-line setup - server.experimental.enable_tasks() - # Track when work completes and capture received meta work_completed = Event() received_meta: list[str | None] = [None] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="simple_task", - description="A simple task", - input_schema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="simple_task", + description="A simple task", + input_schema={"type": "object", "properties": {"input": {"type": "string"}}}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) # Capture the meta from the request (if present) @@ -83,13 +79,18 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu async def work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Working...") - input_val = arguments.get("input", "default") + input_val = (params.arguments or {}).get("input", "default") result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) work_completed.set() return result return await ctx.experimental.run_task(work) + server = Server("test-run-task", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + + # One-line setup + server.experimental.enable_tasks() + # Set up streams server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -142,25 +143,23 @@ async def run_client() -> None: @pytest.mark.anyio async def test_run_task_auto_fails_on_exception() -> None: """Test that run_task automatically fails the task when work raises.""" - server = Server("test-run-task-fail") - server.experimental.enable_tasks() - work_failed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="failing_task", + description="A task that fails", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -169,6 +168,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-run-task-fail", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -249,32 +251,32 @@ async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> N """Test that enable_tasks() doesn't override already-registered handlers.""" server = Server("test-custom-handlers") - # Register custom handlers BEFORE enable_tasks (never called, just for registration) - @server.experimental.get_task() - async def custom_get_task(req: GetTaskRequest) -> GetTaskResult: + # Register custom handlers via enable_tasks with on_* kwargs + async def custom_get_task(ctx: ServerRequestContext[Any], params: Any) -> GetTaskResult: raise NotImplementedError - @server.experimental.get_task_result() - async def custom_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + async def custom_get_task_result(ctx: ServerRequestContext[Any], params: Any) -> GetTaskPayloadResult: raise NotImplementedError - @server.experimental.list_tasks() - async def custom_list_tasks(req: ListTasksRequest) -> ListTasksResult: + async def custom_list_tasks(ctx: ServerRequestContext[Any], params: Any) -> ListTasksResult: raise NotImplementedError - @server.experimental.cancel_task() - async def custom_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: + async def custom_cancel_task(ctx: ServerRequestContext[Any], params: Any) -> CancelTaskResult: raise NotImplementedError - # Now enable tasks - should NOT override our custom handlers - server.experimental.enable_tasks() + # Enable tasks with custom handlers + server.experimental.enable_tasks( + on_get_task=custom_get_task, + on_task_result=custom_get_task_result, + on_list_tasks=custom_list_tasks, + on_cancel_task=custom_cancel_task, + ) # Verify our custom handlers are still registered (not replaced by defaults) - # The handlers dict should contain our custom handlers - assert GetTaskRequest in server.request_handlers - assert GetTaskPayloadRequest in server.request_handlers - assert ListTasksRequest in server.request_handlers - assert CancelTaskRequest in server.request_handlers + assert "tasks/get" in server._request_handlers + assert "tasks/result" in server._request_handlers + assert "tasks/list" in server._request_handlers + assert "tasks/cancel" in server._request_handlers @pytest.mark.anyio @@ -345,26 +347,24 @@ async def work(task: ServerTaskContext) -> CallToolResult: @pytest.mark.anyio async def test_run_task_with_model_immediate_response() -> None: """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - server = Server("test-run-task-immediate") - server.experimental.enable_tasks() - work_completed = Event() immediate_response_text = "Processing your request..." - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="task_with_immediate", - description="A task with immediate response", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="task_with_immediate", + description="A task with immediate response", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -373,6 +373,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) + server = Server("test-run-task-immediate", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -405,25 +408,23 @@ async def run_client() -> None: @pytest.mark.anyio async def test_run_task_doesnt_complete_if_already_terminal() -> None: """Test that run_task doesn't auto-complete if work manually completed the task.""" - server = Server("test-already-complete") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_complete_task", - description="A task that manually completes", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="manual_complete_task", + description="A task that manually completes", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -436,6 +437,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-already-complete", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -471,25 +475,23 @@ async def run_client() -> None: @pytest.mark.anyio async def test_run_task_doesnt_fail_if_already_terminal() -> None: """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - server = Server("test-already-failed") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_cancel_task", - description="A task that manually cancels then raises", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="manual_cancel_task", + description="A task that manually cancels then raises", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def on_call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -501,6 +503,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-already-failed", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 8005380d2..eebb25f98 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -8,6 +8,7 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -23,7 +24,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, CancelTaskRequestParams, CancelTaskResult, ClientResult, @@ -31,7 +31,6 @@ GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, JSONRPCError, @@ -39,8 +38,8 @@ JSONRPCResponse, ListTasksRequest, ListTasksResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, SamplingMessage, ServerCapabilities, ServerNotification, @@ -79,13 +78,17 @@ async def test_list_tasks_handler() -> None: ), ] - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None + ) -> ListTasksResult: return ListTasksResult(tasks=test_tasks) - handler = server.request_handlers[ListTasksRequest] - request = ListTasksRequest(method="tasks/list") - result = await handler(request) + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) + + handler = server._request_handlers["tasks/list"] + # Create a minimal ctx for direct handler testing + dummy_ctx: Any = None + result = await handler(dummy_ctx, None) assert isinstance(result, ServerResult) assert isinstance(result, ListTasksResult) @@ -99,11 +102,10 @@ async def test_get_task_handler() -> None: """Test that experimental get_task handler works.""" server = Server("test") - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + async def handle_get_task(ctx: ServerRequestContext[Any], params: GetTaskRequestParams) -> GetTaskResult: now = datetime.now(timezone.utc) return GetTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="working", created_at=now, last_updated_at=now, @@ -111,12 +113,12 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=1000, ) - handler = server.request_handlers[GetTaskRequest] - request = GetTaskRequest( - method="tasks/get", - params=GetTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + handler = server._request_handlers["tasks/get"] + params = GetTaskRequestParams(task_id="test-task-123") + dummy_ctx: Any = None + result = await handler(dummy_ctx, params) assert isinstance(result, ServerResult) assert isinstance(result, GetTaskResult) @@ -129,16 +131,17 @@ async def test_get_task_result_handler() -> None: """Test that experimental get_task_result handler works.""" server = Server("test") - @server.experimental.get_task_result() - async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + async def handle_get_task_result( + ctx: ServerRequestContext[Any], params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: return GetTaskPayloadResult() - handler = server.request_handlers[GetTaskPayloadRequest] - request = GetTaskPayloadRequest( - method="tasks/result", - params=GetTaskPayloadRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server.experimental.enable_tasks(on_task_result=handle_get_task_result) + + handler = server._request_handlers["tasks/result"] + params = GetTaskPayloadRequestParams(task_id="test-task-123") + dummy_ctx: Any = None + result = await handler(dummy_ctx, params) assert isinstance(result, ServerResult) assert isinstance(result, GetTaskPayloadResult) @@ -149,23 +152,22 @@ async def test_cancel_task_handler() -> None: """Test that experimental cancel_task handler works.""" server = Server("test") - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel_task(ctx: ServerRequestContext[Any], params: CancelTaskRequestParams) -> CancelTaskResult: now = datetime.now(timezone.utc) return CancelTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="cancelled", created_at=now, last_updated_at=now, ttl=60000, ) - handler = server.request_handlers[CancelTaskRequest] - request = CancelTaskRequest( - method="tasks/cancel", - params=CancelTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) + + handler = server._request_handlers["tasks/cancel"] + params = CancelTaskRequestParams(task_id="test-task-123") + dummy_ctx: Any = None + result = await handler(dummy_ctx, params) assert isinstance(result, ServerResult) assert isinstance(result, CancelTaskResult) @@ -178,14 +180,19 @@ async def test_server_capabilities_include_tasks() -> None: """Test that server capabilities include tasks when handlers are registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None + ) -> ListTasksResult: raise NotImplementedError - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel_task(ctx: ServerRequestContext[Any], params: CancelTaskRequestParams) -> CancelTaskResult: raise NotImplementedError + server.experimental.enable_tasks( + on_list_tasks=handle_list_tasks, + on_cancel_task=handle_cancel_task, + ) + capabilities = server.get_capabilities( notification_options=NotificationOptions(), experimental_capabilities={}, @@ -203,11 +210,13 @@ async def test_server_capabilities_partial_tasks() -> None: """Test capabilities with only some task handlers registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None + ) -> ListTasksResult: raise NotImplementedError # Only list_tasks registered, not cancel_task + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) capabilities = server.get_capabilities( notification_options=NotificationOptions(), @@ -216,40 +225,44 @@ async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: assert capabilities.tasks is not None assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is None # Not registered + assert capabilities.tasks.cancel is not None # enable_tasks registers default cancel_task @pytest.mark.anyio async def test_tool_with_task_execution_metadata() -> None: """Test that tools can declare task execution mode.""" - server = Server("test") - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] + async def handle_list_tools( + ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="quick_tool", + description="Fast tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_FORBIDDEN), + ), + Tool( + name="long_tool", + description="Long running tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ), + Tool( + name="flexible_tool", + description="Can be either", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_OPTIONAL), + ), + ] + ) - tools_handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list") - result = await tools_handler(request) + server = Server("test", on_list_tools=handle_list_tools) + + tools_handler = server._request_handlers["tools/list"] + dummy_ctx: Any = None + result = await tools_handler(dummy_ctx, None) assert isinstance(result, ServerResult) assert isinstance(result, ListToolsResult) @@ -266,26 +279,28 @@ async def list_tools(): @pytest.mark.anyio async def test_task_metadata_in_call_tool_request() -> None: """Test that task metadata is accessible via RequestContext when calling a tool.""" - server = Server("test") captured_task_metadata: TaskMetadata | None = None - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="long_task", - description="A long running task", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support="optional"), - ) - ] + async def handle_list_tools( + ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="long_task", + description="A long running task", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support="optional"), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: nonlocal captured_task_metadata - ctx = server.request_context captured_task_metadata = ctx.experimental.task_metadata - return [TextContent(type="text", text="done")] + return CallToolResult(content=[TextContent(type="text", text="done")]) + + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -347,24 +362,26 @@ async def handle_messages(): @pytest.mark.anyio async def test_task_metadata_is_task_property() -> None: """Test that RequestContext.experimental.is_task works correctly.""" - server = Server("test") is_task_values: list[bool] = [] - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="test_tool", - description="Test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] + async def handle_list_tools( + ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="Test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: is_task_values.append(ctx.experimental.is_task) - return [TextContent(type="text", text="done")] + return CallToolResult(content=[TextContent(type="text", text="done")]) + + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -499,7 +516,7 @@ async def run_server() -> None: # Test get_task (default handler - found) get_result = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id)), + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)), GetTaskResult, ) assert get_result.task_id == task.task_id @@ -508,7 +525,7 @@ async def run_server() -> None: # Test get_task (default handler - not found path) with pytest.raises(MCPError, match="not found"): await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task")), + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent-task")), GetTaskResult, ) @@ -530,7 +547,7 @@ async def run_server() -> None: # Test cancel_task (default handler) cancel_result = await client_session.send_request( - CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id)), CancelTaskResult + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)), CancelTaskResult ) assert cancel_result.task_id == task.task_id assert cancel_result.status == "cancelled" diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 57122da7b..dd990e5dc 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -18,6 +18,7 @@ from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions from mcp.shared._context import RequestContext @@ -26,6 +27,7 @@ from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, CreateMessageRequestParams, CreateMessageResult, @@ -35,6 +37,8 @@ ErrorData, GetTaskPayloadResult, GetTaskResult, + ListToolsResult, + PaginatedRequestParams, SamplingMessage, TaskMetadata, TextContent, @@ -181,24 +185,21 @@ async def test_scenario1_normal_tool_normal_elicitation() -> None: Server calls session.elicit() directly, client responds immediately. """ - server = Server("test-scenario1") elicit_received = Event() tool_result: list[str] = [] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def on_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: # Normal elicitation - expects immediate response result = await ctx.session.elicit( message="Please confirm the action", @@ -209,6 +210,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario1", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -262,27 +265,24 @@ async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: Server calls session.experimental.elicit_as_task(), client creates a task for the elicitation and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2") elicit_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def on_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: # Task-augmented elicitation - server polls client result = await ctx.session.experimental.elicit_as_task( message="Please confirm the action", @@ -294,6 +294,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario2", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -342,26 +344,22 @@ async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: Client calls tool as task. Inside the task, server uses task.elicit() which queues the request and delivers via tasks/result. """ - server = Server("test-scenario3") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def on_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -377,6 +375,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario3", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -452,29 +453,25 @@ async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> Non 5. Server gets the ElicitResult and completes the tool task 6. Client's tasks/result returns with the CallToolResult """ - server = Server("test-scenario4") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def on_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -491,6 +488,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -553,27 +553,24 @@ async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: Server calls session.experimental.create_message_as_task(), client creates a task for the sampling and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2-sampling") sampling_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + ) + ] + ) + async def on_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: # Task-augmented sampling - server polls client result = await ctx.session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], @@ -587,6 +584,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append(response_text) return CallToolResult(content=[TextContent(type="text", text=response_text)]) + server = Server("test-scenario2-sampling", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams @@ -636,29 +635,25 @@ async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() which sends task-augmented sampling. Client creates its own task for the sampling, and server polls the client. """ - server = Server("test-scenario4-sampling") - server.experimental.enable_tasks() - sampling_received = Event() work_completed = Event() # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def on_list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def on_call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -677,6 +672,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4-sampling", on_list_tools=on_list_tools, on_call_tool=on_call_tool) # type: ignore[arg-type] + server.experimental.enable_tasks() + task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py index d00ce40a4..344789b36 100644 --- a/tests/experimental/tasks/test_spec_compliance.py +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -11,16 +11,17 @@ import pytest from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY from mcp.types import ( - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, CreateTaskResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, Task, ) @@ -48,10 +49,11 @@ def test_server_with_list_tasks_handler_declares_list_capability() -> None: """Server with list_tasks handler declares tasks.list capability.""" server: Server = Server("test") - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: + async def handle_list(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError + server.experimental.enable_tasks(on_list_tasks=handle_list) + caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.list is not None @@ -61,10 +63,11 @@ def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: """Server with cancel_task handler declares tasks.cancel capability.""" server: Server = Server("test") - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: raise NotImplementedError + server.experimental.enable_tasks(on_cancel_task=handle_cancel) + caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.cancel is not None @@ -76,10 +79,11 @@ def test_server_with_get_task_handler_declares_requests_tools_call_capability() """ server: Server = Server("test") - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: + async def handle_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError + server.experimental.enable_tasks(on_get_task=handle_get) + caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.requests is not None @@ -91,10 +95,11 @@ def test_server_without_list_handler_has_no_list_capability() -> None: server: Server = Server("test") # Register only get_task (not list_tasks) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: + async def handle_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError + server.experimental.enable_tasks(on_get_task=handle_get) + caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.list is None @@ -105,10 +110,11 @@ def test_server_without_cancel_handler_has_no_cancel_capability() -> None: server: Server = Server("test") # Register only get_task (not cancel_task) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: + async def handle_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError + server.experimental.enable_tasks(on_get_task=handle_get) + caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.cancel is None @@ -118,18 +124,21 @@ def test_server_with_all_task_handlers_has_full_capability() -> None: """Server with all task handlers declares complete tasks capability.""" server: Server = Server("test") - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: + async def handle_list(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: raise NotImplementedError - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: + async def handle_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError + server.experimental.enable_tasks( + on_list_tasks=handle_list, + on_cancel_task=handle_cancel, + on_get_task=handle_get, + ) + caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.list is not None diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 39e2c6f2a..e78d4cce2 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -21,11 +21,8 @@ def get_user_profile(user_id: str) -> str: # pragma: no cover return f"Profile data for user {user_id}" # Get the list of resource templates using the underlying server - # Note: list_resource_templates() returns a decorator that wraps the handler - # The handler returns a ServerResult with a ListResourceTemplatesResult inside - result = await mcp._lowlevel_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest(params=None) - ) + # The handler receives (ctx, params) and returns ListResourceTemplatesResult + result = await mcp._lowlevel_server._request_handlers["resources/templates/list"](None, None) # type: ignore[arg-type] assert isinstance(result, types.ListResourceTemplatesResult) templates = result.resource_templates diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index e738017f8..6a06631bb 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -3,8 +3,8 @@ import pytest from mcp import Client, types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.mcpserver import MCPServer pytestmark = pytest.mark.anyio @@ -58,7 +58,6 @@ def get_image_as_bytes() -> bytes: async def test_lowlevel_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" - server = Server("test") # Create a small test image as bytes image_bytes = b"fake_image_data" @@ -74,18 +73,43 @@ async def test_lowlevel_resource_mime_type(): ), ] - @server.list_resources() - async def handle_list_resources(): - return test_resources - - @server.read_resource() - async def handle_read_resource(uri: str): - if str(uri) == "test://image": - return [ReadResourceContents(content=base64_string, mime_type="image/png")] - elif str(uri) == "test://image_bytes": - return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + return types.ListResourcesResult(resources=test_resources) + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + uri = str(params.uri) + if uri == "test://image": + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=uri, + text=base64_string, + mime_type="image/png", + ) + ] + ) + elif uri == "test://image_bytes": + return types.ReadResourceResult( + contents=[ + types.BlobResourceContents( + uri=uri, + blob=base64.b64encode(image_bytes).decode("utf-8"), + mime_type="image/png", + ) + ] + ) raise Exception(f"Resource not found: {uri}") # pragma: no cover + server = Server( + "test", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) + # Test that resources are listed with correct mime type async with Client(server) as client: # List resources and verify mime types diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index e6ff56877..55f533e27 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -13,8 +13,8 @@ import pytest from mcp import Client, types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents pytestmark = pytest.mark.anyio @@ -26,24 +26,36 @@ async def test_relative_uri_roundtrip(): the server would fail to serialize resources with relative URIs, or the URI would be transformed during the roundtrip. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="user", uri="users/me"), - types.Resource(name="config", uri="./config"), - types.Resource(name="parent", uri="../parent/resource"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ - ReadResourceContents( - content=f"data for {uri}", - mime_type="text/plain", - ) - ] + + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ + types.Resource(name="user", uri="users/me"), + types.Resource(name="config", uri="./config"), + types.Resource(name="parent", uri="../parent/resource"), + ] + ) + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=f"data for {params.uri}", + mime_type="text/plain", + ) + ] + ) + + server = Server( + "test", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) async with Client(server) as client: # List should return the exact URIs we specified @@ -67,18 +79,35 @@ async def test_custom_scheme_uri_roundtrip(): Some MCP servers use custom schemes like "custom://resource". These should work end-to-end. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="custom", uri="custom://my-resource"), - types.Resource(name="file", uri="file:///path/to/file"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ReadResourceContents(content="data", mime_type="text/plain")] + + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ + types.Resource(name="custom", uri="custom://my-resource"), + types.Resource(name="file", uri="file:///path/to/file"), + ] + ) + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text="data", + mime_type="text/plain", + ) + ] + ) + + server = Server( + "test", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) async with Client(server) as client: resources = await client.list_resources() diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 44b17d337..331ecc897 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -10,19 +10,13 @@ """ import base64 -from typing import cast import pytest -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp import Client, types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import Server -from mcp.types import ( - BlobResourceContents, - ReadResourceRequest, - ReadResourceRequestParams, - ReadResourceResult, - ServerResult, -) +from mcp.types import BlobResourceContents @pytest.mark.anyio @@ -31,53 +25,59 @@ async def test_server_base64_encoding_issue(): This test will: 1. Set up a server that returns binary data - 2. Extract the base64-encoded blob from the server's response + 2. Read the resource through the client 3. Verify the encoded data can be properly validated by BlobResourceContents BEFORE FIX: The test will fail because server uses urlsafe_b64encode AFTER FIX: The test will pass because server uses standard b64encode """ - server = Server("test") - # Create binary data that will definitely result in + and / characters # when encoded with standard base64 binary_data = bytes(list(range(255)) * 4) - # Register a resource handler that returns our test data - @server.read_resource() - async def read_resource(uri: str) -> list[ReadResourceContents]: - return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[ReadResourceRequest] - - # Create a request - request = ReadResourceRequest( - params=ReadResourceRequestParams(uri="test://resource"), + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ + types.Resource(uri="test://resource", name="test resource"), + ] + ) + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + return types.ReadResourceResult( + contents=[ + types.BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(binary_data).decode("utf-8"), + mime_type="application/octet-stream", + ) + ] + ) + + server = Server( + "test", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, ) - # Call the handler to get the response - result: ServerResult = await handler(request) - - # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result) - blob_content = read_result.contents[0] + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - # First verify our test data actually produces different encodings - urlsafe_b64 = base64.urlsafe_b64encode(binary_data).decode() - standard_b64 = base64.b64encode(binary_data).decode() - assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate" - " encoding difference" + blob_content = result.contents[0] - # Now validate the server's output with BlobResourceContents.model_validate - # Before the fix: This should fail with "Invalid base64" because server - # uses urlsafe_b64encode - # After the fix: This should pass because server will use standard b64encode - model_dict = blob_content.model_dump() + # First verify our test data actually produces different encodings + urlsafe_b64 = base64.urlsafe_b64encode(binary_data).decode() + standard_b64 = base64.b64encode(binary_data).decode() + assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate encoding difference" - # Direct validation - this will fail before fix, pass after fix - blob_model = BlobResourceContents.model_validate(model_dict) + # Validate the response with BlobResourceContents.model_validate + model_dict = blob_content.model_dump() + blob_model = BlobResourceContents.model_validate(model_dict) - # Verify we can decode the data back correctly - decoded = base64.b64decode(blob_model.blob) - assert decoded == binary_data + # Verify we can decode the data back correctly + decoded = base64.b64decode(blob_model.blob) + assert decoded == binary_data diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index cd27698e6..3514d26ef 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,8 +1,6 @@ """Test to reproduce issue #88: Random error thrown on response.""" -from collections.abc import Sequence from pathlib import Path -from typing import Any import anyio import pytest @@ -11,10 +9,11 @@ from mcp import types from mcp.client.session import ClientSession +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import ContentBlock, TextContent +from mcp.types import TextContent @pytest.mark.anyio @@ -32,36 +31,47 @@ async def test_notification_validation_error(tmp_path: Path): - Slow operations use minimal timeout (10ms) for quick test execution """ - server = Server(name="test") request_count = 0 slow_request_lock = anyio.Event() - @server.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow", - description="A slow tool", - input_schema={"type": "object"}, - ), - types.Tool( - name="fast", - description="A fast tool", - input_schema={"type": "object"}, - ), - ] - - @server.call_tool() - async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="slow", + description="A slow tool", + input_schema={"type": "object"}, + ), + types.Tool( + name="fast", + description="A fast tool", + input_schema={"type": "object"}, + ), + ] + ) + + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult: nonlocal request_count request_count += 1 - if name == "slow": + if params.name == "slow": await slow_request_lock.wait() # it should timeout here - return [TextContent(type="text", text=f"slow {request_count}")] - elif name == "fast": - return [TextContent(type="text", text=f"fast {request_count}")] - return [TextContent(type="text", text=f"unknown {request_count}")] # pragma: no cover + return types.CallToolResult(content=[TextContent(type="text", text=f"slow {request_count}")]) + elif params.name == "fast": + return types.CallToolResult(content=[TextContent(type="text", text=f"fast {request_count}")]) + return types.CallToolResult( + content=[TextContent(type="text", text=f"unknown {request_count}")] + ) # pragma: no cover + + server = Server( + name="test", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py deleted file mode 100644 index 9cb2b561a..000000000 --- a/tests/server/lowlevel/test_func_inspection.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Unit tests for func_inspection module. - -Tests the create_call_wrapper function which determines how to call handler functions -with different parameter signatures and type hints. -""" - -from typing import Any, Generic, TypeVar - -import pytest - -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest, PaginatedRequestParams - -T = TypeVar("T") - - -@pytest.mark.anyio -async def test_no_params_returns_deprecated_wrapper() -> None: - """Test: def foo() - should call without request.""" - called_without_request = False - - async def handler() -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_param_with_default_returns_deprecated_wrapper() -> None: - """Test: def foo(thing: int = 1) - should call without request.""" - called_without_request = False - - async def handler(thing: int = 1) -> list[str]: - nonlocal called_without_request - called_without_request = True - return [f"test-{thing}"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request (uses default value) - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test-1"] - - -@pytest.mark.anyio -async def test_typed_request_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor="test-cursor")) - await wrapper(request) - - assert received_request is not None - assert received_request is request - params = getattr(received_request, "params", None) - assert params is not None - assert params.cursor == "test-cursor" - - -@pytest.mark.anyio -async def test_typed_request_with_default_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest, thing: int = 1) - should pass request through.""" - received_request = None - received_thing = None - - async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: - nonlocal received_request, received_thing - received_request = req - received_thing = thing - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - assert received_thing == 1 # default value - - -@pytest.mark.anyio -async def test_optional_typed_request_with_default_none_is_deprecated() -> None: - """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - old style.""" - called_without_request = False - - async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_untyped_request_param_is_deprecated() -> None: - """Test: def foo(req) - should call without request.""" - called = False - - async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] # pragma: no cover - nonlocal called - called = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_any_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Any) - should call without request.""" - - async def handler(req: Any) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_generic_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Generic[T]) - should call without request.""" - - async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_wrong_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: str) - should call without request.""" - - async def handler(req: str) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_required_param_before_typed_request_attempts_to_pass() -> None: - """Test: def foo(thing: int, req: ListPromptsRequest) - attempts to pass request (will fail at runtime).""" - received_request = None - - async def handler(thing: int, req: ListPromptsRequest) -> list[str]: # pragma: no cover - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper will attempt to pass request, but it will fail at runtime - # because 'thing' is required and has no default - request = ListPromptsRequest(method="prompts/list", params=None) - - # This will raise TypeError because 'thing' is missing - with pytest.raises(TypeError, match="missing 1 required positional argument: 'thing'"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_positional_only_param_with_correct_type() -> None: - """Test: def foo(req: ListPromptsRequest, /) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest, /) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_keyword_only_param_with_correct_type() -> None: - """Test: def foo(*, req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(*, req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler with keyword argument - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_different_request_types() -> None: - """Test that wrapper works with different request types.""" - # Test with ListResourcesRequest - received_request = None - - async def handler(req: ListResourcesRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListResourcesRequest) - - request = ListResourcesRequest(method="resources/list", params=None) - await wrapper(request) - - assert received_request is request - - # Test with ListToolsRequest - received_request = None - - async def handler2(req: ListToolsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper2 = create_call_wrapper(handler2, ListToolsRequest) - - request2 = ListToolsRequest(method="tools/list", params=None) - await wrapper2(request2) - - assert received_request is request2 - - -@pytest.mark.anyio -async def test_mixed_params_with_typed_request() -> None: - """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" - - async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Will fail at runtime due to missing 'a' - request = ListPromptsRequest(method="prompts/list", params=None) - - with pytest.raises(TypeError, match="missing 1 required positional argument: 'a'"): - await wrapper(request) diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 6bf4cddb3..00cad01b5 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -1,81 +1,61 @@ -"""Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" - -import warnings +"""Basic tests for on_list_prompts, on_list_resources, and on_list_tools handlers without pagination.""" import pytest from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, Prompt, Resource, - ServerResult, Tool, ) +pytestmark = pytest.mark.anyio + -@pytest.mark.anyio async def test_list_prompts_basic() -> None: """Test basic prompt listing without pagination.""" - server = Server("test") - test_prompts = [ Prompt(name="prompt1", description="First prompt"), Prompt(name="prompt2", description="Second prompt"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_prompts(ctx: ServerRequestContext, params: None) -> ListPromptsResult: + return ListPromptsResult(prompts=test_prompts) - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return test_prompts + server = Server("test", on_list_prompts=handle_list_prompts) - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + assert "prompts/list" in server._request_handlers + result = await server._request_handlers["prompts/list"](None, None) # type: ignore[arg-type] - assert isinstance(result, ServerResult) assert isinstance(result, ListPromptsResult) assert result.prompts == test_prompts -@pytest.mark.anyio async def test_list_resources_basic() -> None: """Test basic resource listing without pagination.""" - server = Server("test") - test_resources = [ Resource(uri="file:///test1.txt", name="Test 1"), Resource(uri="file:///test2.txt", name="Test 2"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_resources(ctx: ServerRequestContext, params: None) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return test_resources + server = Server("test", on_list_resources=handle_list_resources) - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) + assert "resources/list" in server._request_handlers + result = await server._request_handlers["resources/list"](None, None) # type: ignore[arg-type] - assert isinstance(result, ServerResult) assert isinstance(result, ListResourcesResult) assert result.resources == test_resources -@pytest.mark.anyio async def test_list_tools_basic() -> None: """Test basic tool listing without pagination.""" - server = Server("test") - test_tools = [ Tool( name="tool1", @@ -102,80 +82,52 @@ async def test_list_tools_basic() -> None: ), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_tools(ctx: ServerRequestContext, params: None) -> ListToolsResult: + return ListToolsResult(tools=test_tools) - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return test_tools + server = Server("test", on_list_tools=handle_list_tools) - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) + assert "tools/list" in server._request_handlers + result = await server._request_handlers["tools/list"](None, None) # type: ignore[arg-type] - assert isinstance(result, ServerResult) assert isinstance(result, ListToolsResult) assert result.tools == test_tools -@pytest.mark.anyio async def test_list_prompts_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] + async def handle_list_prompts(ctx: ServerRequestContext, params: None) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + server = Server("test", on_list_prompts=handle_list_prompts) + result = await server._request_handlers["prompts/list"](None, None) # type: ignore[arg-type] - assert isinstance(result, ServerResult) assert isinstance(result, ListPromptsResult) assert result.prompts == [] -@pytest.mark.anyio async def test_list_resources_empty() -> None: """Test listing with empty results.""" - server = Server("test") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_resources(ctx: ServerRequestContext, params: None) -> ListResourcesResult: + return ListResourcesResult(resources=[]) - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] + server = Server("test", on_list_resources=handle_list_resources) + result = await server._request_handlers["resources/list"](None, None) # type: ignore[arg-type] - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) assert isinstance(result, ListResourcesResult) assert result.resources == [] -@pytest.mark.anyio async def test_list_tools_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] + async def handle_list_tools(ctx: ServerRequestContext, params: None) -> ListToolsResult: + return ListToolsResult(tools=[]) - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) + server = Server("test", on_list_tools=handle_list_tools) + result = await server._request_handlers["tools/list"](None, None) # type: ignore[arg-type] - assert isinstance(result, ServerResult) assert isinstance(result, ListToolsResult) assert result.tools == [] diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index 081fb262a..731264eb5 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -1,111 +1,92 @@ +"""Tests for pagination support in on_list_prompts, on_list_resources, and on_list_tools handlers.""" + import pytest from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, PaginatedRequestParams, - ServerResult, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_list_prompts_pagination() -> None: - server = Server("test") - test_cursor = "test-cursor-123" - # Track what request was received - received_request: ListPromptsRequest | None = None +async def test_list_prompts_pagination() -> None: + received_params: PaginatedRequestParams | None = "NOT_SET" # type: ignore[assignment] - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - nonlocal received_request - received_request = request + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + nonlocal received_params + received_params = params return ListPromptsResult(prompts=[], next_cursor="next") - handler = server.request_handlers[ListPromptsRequest] + server = Server("test", on_list_prompts=handle_list_prompts) - # Test: No cursor provided -> handler receives request with None params - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) + # Test: No cursor provided -> handler receives None params + result = await server._request_handlers["prompts/list"](None, None) # type: ignore[arg-type] + assert received_params is None + assert isinstance(result, ListPromptsResult) - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Test: Cursor provided -> handler receives params with cursor + test_cursor = "test-cursor-123" + params = PaginatedRequestParams(cursor=test_cursor) + result2 = await server._request_handlers["prompts/list"](None, params) # type: ignore[arg-type] + assert received_params is not None + assert received_params.cursor == test_cursor + assert isinstance(result2, ListPromptsResult) -@pytest.mark.anyio async def test_list_resources_pagination() -> None: - server = Server("test") - test_cursor = "resource-cursor-456" - - # Track what request was received - received_request: ListResourcesRequest | None = None + received_params: PaginatedRequestParams | None = "NOT_SET" # type: ignore[assignment] - @server.list_resources() - async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: - nonlocal received_request - received_request = request + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + nonlocal received_params + received_params = params return ListResourcesResult(resources=[], next_cursor="next") - handler = server.request_handlers[ListResourcesRequest] + server = Server("test", on_list_resources=handle_list_resources) - # Test: No cursor provided -> handler receives request with None params - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) + # Test: No cursor provided + result = await server._request_handlers["resources/list"](None, None) # type: ignore[arg-type] + assert received_params is None + assert isinstance(result, ListResourcesResult) - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListResourcesRequest( - method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) - ) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Test: Cursor provided + test_cursor = "resource-cursor-456" + params = PaginatedRequestParams(cursor=test_cursor) + result2 = await server._request_handlers["resources/list"](None, params) # type: ignore[arg-type] + assert received_params is not None + assert received_params.cursor == test_cursor + assert isinstance(result2, ListResourcesResult) -@pytest.mark.anyio async def test_list_tools_pagination() -> None: - server = Server("test") - test_cursor = "tools-cursor-789" - - # Track what request was received - received_request: ListToolsRequest | None = None + received_params: PaginatedRequestParams | None = "NOT_SET" # type: ignore[assignment] - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: - nonlocal received_request - received_request = request + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: + nonlocal received_params + received_params = params return ListToolsResult(tools=[], next_cursor="next") - handler = server.request_handlers[ListToolsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + server = Server("test", on_list_tools=handle_list_tools) + + # Test: No cursor provided + result = await server._request_handlers["tools/list"](None, None) # type: ignore[arg-type] + assert received_params is None + assert isinstance(result, ListToolsResult) + + # Test: Cursor provided + test_cursor = "tools-cursor-789" + params = PaginatedRequestParams(cursor=test_cursor) + result2 = await server._request_handlers["tools/list"](None, params) # type: ignore[arg-type] + assert received_params is not None + assert received_params.cursor == test_cursor + assert isinstance(result2, ListToolsResult) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 6d1634f2e..c1e6e605b 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -1,11 +1,10 @@ """Test that cancelled requests don't cause double responses.""" -from typing import Any - import anyio import pytest from mcp import Client, types +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import MCPError from mcp.types import ( @@ -14,6 +13,9 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, + ListToolsResult, + PaginatedRequestParams, + TextContent, Tool, ) @@ -22,34 +24,36 @@ async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" - server = Server("test-server") - # Track tool calls call_count = 0 ev_first_call = anyio.Event() first_request_id = None - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="Tool for testing", - input_schema={}, - ) - ] + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="Tool for testing", + input_schema={}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal call_count, first_request_id - if name == "test_tool": + if params.name == "test_tool": call_count += 1 if call_count == 1: - first_request_id = server.request_context.request_id + first_request_id = ctx.request_id ev_first_call.set() await anyio.sleep(5) # First call is slow - return [types.TextContent(type="text", text=f"Call number: {call_count}")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async with Client(server) as client: # First request (will be cancelled) diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 5a8d67f09..9abcbde2f 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -1,15 +1,14 @@ """Tests for completion handler with context functionality.""" -from typing import Any - import pytest from mcp import Client +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server from mcp.types import ( + CompleteRequestParams, + CompleteResult, Completion, - CompletionArgument, - CompletionContext, PromptReference, ResourceTemplateReference, ) @@ -18,23 +17,22 @@ @pytest.mark.anyio async def test_completion_handler_receives_context(): """Test that the completion handler receives context correctly.""" - server = Server("test-server") # Track what the handler receives - received_args: dict[str, Any] = {} + received_params: CompleteRequestParams | None = None - @server.completion() async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - received_args["ref"] = ref - received_args["argument"] = argument - received_args["context"] = context + ctx: ServerRequestContext, params: CompleteRequestParams + ) -> CompleteResult: + nonlocal received_params + received_params = params # Return test completion - return Completion(values=["test-completion"], total=1, has_more=False) + return CompleteResult( + completion=Completion(values=["test-completion"], total=1, has_more=False), + ) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test with context @@ -45,28 +43,29 @@ async def handle_completion( ) # Verify handler received the context - assert received_args["context"] is not None - assert received_args["context"].arguments == {"previous": "value"} + assert received_params is not None + assert received_params.context is not None + assert received_params.context.arguments == {"previous": "value"} assert result.completion.values == ["test-completion"] @pytest.mark.anyio async def test_completion_backward_compatibility(): """Test that completion works without context (backward compatibility).""" - server = Server("test-server") context_was_none = False - @server.completion() async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + ctx: ServerRequestContext, params: CompleteRequestParams + ) -> CompleteResult: nonlocal context_was_none - context_was_none = context is None + context_was_none = params.context is None + + return CompleteResult( + completion=Completion(values=["no-context-completion"], total=1, has_more=False), + ) - return Completion(values=["no-context-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test without context @@ -82,30 +81,46 @@ async def handle_completion( @pytest.mark.anyio async def test_dependent_completion_scenario(): """Test a real-world scenario with dependent completions.""" - server = Server("test-server") - @server.completion() async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + ctx: ServerRequestContext, params: CompleteRequestParams + ) -> CompleteResult: + ref = params.ref + argument = params.argument + context = params.context + # Simulate database/table completion scenario if isinstance(ref, ResourceTemplateReference): if ref.uri == "db://{database}/{table}": if argument.name == "database": # Complete database names - return Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) + return CompleteResult( + completion=Completion( + values=["users_db", "products_db", "analytics_db"], total=3, has_more=False + ), + ) elif argument.name == "table": # Complete table names based on selected database if context and context.arguments: db = context.arguments.get("database") if db == "users_db": - return Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) + return CompleteResult( + completion=Completion( + values=["users", "sessions", "permissions"], total=3, has_more=False + ), + ) elif db == "products_db": - return Completion(values=["products", "categories", "inventory"], total=3, has_more=False) + return CompleteResult( + completion=Completion( + values=["products", "categories", "inventory"], total=3, has_more=False + ), + ) + + return CompleteResult( # pragma: no cover + completion=Completion(values=[], total=0, has_more=False), + ) - return Completion(values=[], total=0, has_more=False) # pragma: no cover + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # First, complete database @@ -136,14 +151,14 @@ async def handle_completion( @pytest.mark.anyio async def test_completion_error_on_missing_context(): """Test that server can raise error when required context is missing.""" - server = Server("test-server") - @server.completion() async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + ctx: ServerRequestContext, params: CompleteRequestParams + ) -> CompleteResult: + ref = params.ref + argument = params.argument + context = params.context + if isinstance(ref, ResourceTemplateReference): if ref.uri == "db://{database}/{table}": if argument.name == "table": @@ -154,9 +169,15 @@ async def handle_completion( # Normal completion if context is provided db = context.arguments.get("database") if db == "test_db": - return Completion(values=["users", "orders", "products"], total=3, has_more=False) + return CompleteResult( + completion=Completion(values=["users", "orders", "products"], total=3, has_more=False), + ) + + return CompleteResult( # pragma: no cover + completion=Completion(values=[], total=0, has_more=False), + ) - return Completion(values=[], total=0, has_more=False) # pragma: no cover + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Try to complete table without database context - should raise error diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a303664a5..e714bf5df 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -2,18 +2,20 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import pytest from pydantic import TypeAdapter +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( + CallToolRequestParams, + CallToolResult, ClientCapabilities, Implementation, InitializeRequestParams, @@ -39,20 +41,18 @@ async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: finally: context["shutdown"] = True - server = Server[dict[str, bool]]("test", lifespan=test_lifespan) - - # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) - # Create a tool that accesses lifespan context - @server.call_tool() - async def check_lifespan(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def check_lifespan(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: assert isinstance(ctx.lifespan_context, dict) assert ctx.lifespan_context["started"] assert not ctx.lifespan_context["shutdown"] - return [TextContent(type="text", text="true")] + return CallToolResult(content=[TextContent(type="text", text="true")]) + + server = Server[dict[str, bool]]("test", lifespan=test_lifespan, on_call_tool=check_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py index 3f977bcc1..3f6df0982 100644 --- a/tests/server/test_lowlevel_input_validation.py +++ b/tests/server/test_lowlevel_input_validation.py @@ -9,17 +9,28 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ClientResult, + ListToolsResult, + PaginatedRequestParams, + ServerNotification, + ServerRequest, + TextContent, + Tool, +) async def run_tool_test( tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[list[TextContent]]], + call_tool_handler: Callable[[ServerRequestContext, CallToolRequestParams], Awaitable[CallToolResult]], test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], ) -> CallToolResult | None: """Helper to run a tool test with minimal boilerplate. @@ -32,16 +43,14 @@ async def run_tool_test( Returns: The result of the tool call """ - server = Server("test") - result = None - @server.list_tools() - async def list_tools(): - return tools + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=tools) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - return await call_tool_handler(name, arguments) + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=call_tool_handler) + result = None server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -118,12 +127,13 @@ def create_add_tool() -> Tool: async def test_valid_tool_call(): """Test that valid arguments pass validation.""" - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "add": - result = arguments["a"] + arguments["b"] - return [TextContent(type="text", text=f"Result: {result}")] + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "add": + assert params.arguments is not None + result = params.arguments["a"] + params.arguments["b"] + return CallToolResult(content=[TextContent(type="text", text=f"Result: {result}")]) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("add", {"a": 5, "b": 3}) @@ -141,11 +151,28 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_invalid_tool_call_missing_required(): - """Test that missing required arguments fail validation.""" + """Test that missing required arguments fail validation. + + Note: With the new low-level server API, input validation is the handler's + responsibility. The handler returns an error CallToolResult for invalid input. + """ - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + # Handler performs its own validation + arguments = params.arguments or {} + if "a" not in arguments or "b" not in arguments: + missing = [k for k in ["a", "b"] if k not in arguments] + return CallToolResult( + content=[ + TextContent( + type="text", + text=f"Input validation error: '{missing[0]}' is a required property", + ) + ], + is_error=True, + ) + result = arguments["a"] + arguments["b"] # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Result: {result}")]) # pragma: no cover async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("add", {"a": 5}) # missing 'b' @@ -164,11 +191,28 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_invalid_tool_call_wrong_type(): - """Test that wrong argument types fail validation.""" + """Test that wrong argument types fail validation. + + Note: With the new low-level server API, input validation is the handler's + responsibility. + """ - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + # Handler performs its own validation + arguments = params.arguments or {} + for key in ["a", "b"]: + if key in arguments and not isinstance(arguments[key], (int, float)): + return CallToolResult( + content=[ + TextContent( + type="text", + text=f"Input validation error: '{arguments[key]}' is not of type 'number'", + ) + ], + is_error=True, + ) + result = arguments["a"] + arguments["b"] # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Result: {result}")]) # pragma: no cover async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("add", {"a": "five", "b": 3}) # 'a' should be number @@ -187,7 +231,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_cache_refresh_on_missing_tool(): - """Test that tool cache is refreshed when tool is not found.""" + """Test that tool call works even without listing tools first.""" tools = [ Tool( name="multiply", @@ -203,12 +247,13 @@ async def test_cache_refresh_on_missing_tool(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "multiply": - result = arguments["x"] * arguments["y"] - return [TextContent(type="text", text=f"Result: {result}")] + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "multiply": + assert params.arguments is not None + result = params.arguments["x"] * params.arguments["y"] + return CallToolResult(content=[TextContent(type="text", text=f"Result: {result}")]) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: # Call tool without first listing tools (cache should be empty) @@ -228,7 +273,11 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_enum_constraint_validation(): - """Test that enum constraints are validated.""" + """Test that enum constraints are validated. + + Note: With the new low-level server API, input validation is the handler's + responsibility. + """ tools = [ Tool( name="greet", @@ -244,9 +293,23 @@ async def test_enum_constraint_validation(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation failure - raise RuntimeError("Should not reach here") + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + # Handler performs its own validation + arguments = params.arguments or {} + valid_titles = ["Mr", "Ms", "Dr"] + if "title" in arguments and arguments["title"] not in valid_titles: + return CallToolResult( + content=[ + TextContent( + type="text", + text=f"Input validation error: '{arguments['title']}' is not one of {valid_titles}", + ) + ], + is_error=True, + ) + return CallToolResult( # pragma: no cover + content=[TextContent(type="text", text=f"Hello {arguments.get('title', '')} {arguments['name']}")] + ) async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("greet", {"name": "Smith", "title": "Prof"}) # Invalid title @@ -265,7 +328,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_tool_not_in_list_logs_warning(caplog: pytest.LogCaptureFixture): - """Test that calling a tool not in list_tools logs a warning and skips validation.""" + """Test that calling a tool not in list_tools still works.""" tools = [ Tool( name="add", @@ -281,31 +344,24 @@ async def test_tool_not_in_list_logs_warning(caplog: pytest.LogCaptureFixture): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This should be reached since validation is skipped for unknown tools - if name == "unknown_tool": - # Even with invalid arguments, this should execute since validation is skipped - return [TextContent(type="text", text="Unknown tool executed without validation")] + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + # This should be reached since the handler handles all tool calls + if params.name == "unknown_tool": + return CallToolResult(content=[TextContent(type="text", text="Unknown tool executed without validation")]) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: # Call a tool that's not in the list with invalid arguments - # This should trigger the warning about validation not being performed return await client_session.call_tool("unknown_tool", {"invalid": "args"}) with caplog.at_level(logging.WARNING): result = await run_tool_test(tools, call_tool_handler, test_callback) - # Verify results - should succeed because validation is skipped for unknown tools + # Verify results - should succeed because handler handles all calls assert result is not None assert not result.is_error assert len(result.content) == 1 assert result.content[0].type == "text" assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Unknown tool executed without validation" - - # Verify warning was logged - assert any( - "Tool 'unknown_tool' not listed, no validation will be performed" in record.message for record in caplog.records - ) diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py index 92d9c047c..0916b9c28 100644 --- a/tests/server/test_lowlevel_output_validation.py +++ b/tests/server/test_lowlevel_output_validation.py @@ -9,17 +9,28 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ClientResult, + ListToolsResult, + PaginatedRequestParams, + ServerNotification, + ServerRequest, + TextContent, + Tool, +) async def run_tool_test( tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[Any]], + call_tool_handler: Callable[[ServerRequestContext, CallToolRequestParams], Awaitable[CallToolResult]], test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], ) -> CallToolResult | None: """Helper to run a tool test with minimal boilerplate. @@ -32,17 +43,15 @@ async def run_tool_test( Returns: The result of the tool call """ - server = Server("test") - result = None + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=tools) - @server.list_tools() - async def list_tools(): - return tools + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=call_tool_handler) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - return await call_tool_handler(name, arguments) + result = None server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) @@ -116,11 +125,12 @@ async def test_content_only_without_output_schema(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "echo": - return [TextContent(type="text", text=f"Echo: {arguments['message']}")] + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + assert params.arguments is not None + return CallToolResult(content=[TextContent(type="text", text=f"Echo: {params.arguments['message']}")]) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("echo", {"message": "Hello"}) @@ -139,7 +149,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_dict_only_without_output_schema(): - """Test returning dict only when no outputSchema is defined.""" + """Test returning dict as structured_content when no outputSchema is defined.""" tools = [ Tool( name="get_info", @@ -152,11 +162,15 @@ async def test_dict_only_without_output_schema(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "get_info": - return {"status": "ok", "data": {"value": 42}} + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "get_info": + data: dict[str, Any] = {"status": "ok", "data": {"value": 42}} + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(data))], + structured_content=data, + ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("get_info", {}) @@ -176,7 +190,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_both_content_and_dict_without_output_schema(): - """Test returning both content and dict when no outputSchema is defined.""" + """Test returning both content and structured_content when no outputSchema is defined.""" tools = [ Tool( name="process", @@ -189,13 +203,14 @@ async def test_both_content_and_dict_without_output_schema(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "process": - content = [TextContent(type="text", text="Processing complete")] - data = {"result": "success", "count": 10} - return (content, data) + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "process": + return CallToolResult( + content=[TextContent(type="text", text="Processing complete")], + structured_content={"result": "success", "count": 10}, + ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("process", {}) @@ -214,7 +229,12 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_content_only_with_output_schema_error(): - """Test error when outputSchema is defined but only content is returned.""" + """Test that returning content without structured_content when outputSchema is defined results in error. + + Note: With the new low-level server API, handlers return CallToolResult directly. + The handler is responsible for returning the appropriate error when outputSchema + requirements are not met. + """ tools = [ Tool( name="structured_tool", @@ -233,9 +253,18 @@ async def test_content_only_with_output_schema_error(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # This returns only content, but outputSchema expects structured data - return [TextContent(type="text", text="This is not structured")] + # With the new API, the handler is responsible for validation + return CallToolResult( + content=[ + TextContent( + type="text", + text="Output validation error: outputSchema defined but no structured output returned", + ) + ], + is_error=True, + ) async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("structured_tool", {}) @@ -277,13 +306,18 @@ async def test_valid_dict_with_output_schema(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "calc": - x = arguments["x"] - y = arguments["y"] - return {"sum": x + y, "product": x * y} + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "calc": + assert params.arguments is not None + x = params.arguments["x"] + y = params.arguments["y"] + data: dict[str, Any] = {"sum": x + y, "product": x * y} + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(data))], + structured_content=data, + ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("calc", {"x": 3, "y": 4}) @@ -302,7 +336,12 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_invalid_dict_with_output_schema(): - """Test dict output that doesn't match outputSchema.""" + """Test dict output that doesn't match outputSchema. + + Note: With the new low-level server API, handlers return CallToolResult directly. + The handler is responsible for returning the appropriate error when outputSchema + validation fails. + """ tools = [ Tool( name="user_info", @@ -322,12 +361,20 @@ async def test_invalid_dict_with_output_schema(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "user_info": - # Missing required 'age' field - return {"name": "Alice"} + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "user_info": + # Missing required 'age' field - handler reports the error + return CallToolResult( + content=[ + TextContent( + type="text", + text="Output validation error: 'age' is a required property", + ) + ], + is_error=True, + ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("user_info", {}) @@ -346,7 +393,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_both_content_and_valid_dict_with_output_schema(): - """Test returning both content and valid dict with outputSchema.""" + """Test returning both content and valid structured_content with outputSchema.""" tools = [ Tool( name="analyze", @@ -369,13 +416,15 @@ async def test_both_content_and_valid_dict_with_output_schema(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "analyze": - content = [TextContent(type="text", text=f"Analysis of: {arguments['text']}")] - data = {"sentiment": "positive", "confidence": 0.95} - return (content, data) + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "analyze": + assert params.arguments is not None + return CallToolResult( + content=[TextContent(type="text", text=f"Analysis of: {params.arguments['text']}")], + structured_content={"sentiment": "positive", "confidence": 0.95}, + ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("analyze", {"text": "Great job!"}) @@ -393,7 +442,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_tool_call_result(): - """Test returning ToolCallResult when no outputSchema is defined.""" + """Test returning CallToolResult directly.""" tools = [ Tool( name="get_info", @@ -406,15 +455,15 @@ async def test_tool_call_result(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> CallToolResult: - if name == "get_info": + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "get_info": return CallToolResult( content=[TextContent(type="text", text="Results calculated")], structured_content={"status": "ok", "data": {"value": 42}}, _meta={"some": "metadata"}, ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("get_info", {}) @@ -434,7 +483,12 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio async def test_output_schema_type_validation(): - """Test outputSchema validates types correctly.""" + """Test outputSchema validates types correctly. + + Note: With the new low-level server API, handlers return CallToolResult directly. + The handler is responsible for returning the appropriate error when outputSchema + validation fails. + """ tools = [ Tool( name="stats", @@ -455,12 +509,20 @@ async def test_output_schema_type_validation(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "stats": - # Wrong type for 'count' - should be integer - return {"count": "five", "average": 2.5, "items": ["a", "b"]} + async def call_tool_handler(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + if params.name == "stats": + # Wrong type for 'count' - should be integer, handler reports the error + return CallToolResult( + content=[ + TextContent( + type="text", + text="Output validation error: 'five' is not of type 'integer'", + ) + ], + is_error=True, + ) else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") async def test_callback(client_session: ClientSession) -> CallToolResult: return await client_session.call_tool("stats", {}) diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 68543136e..5b27413f9 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -5,39 +5,52 @@ from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations +from mcp.types import ( + ClientResult, + ListToolsResult, + PaginatedRequestParams, + ServerNotification, + ServerRequest, + Tool, + ToolAnnotations, +) @pytest.mark.anyio async def test_lowlevel_server_tool_annotations(): """Test that tool annotations work in low-level server.""" - server = Server("test") # Create a tool with annotations - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="echo", - description="Echo a message back", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + description="Echo a message back", + input_schema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + "required": ["message"], }, - "required": ["message"], - }, - annotations=ToolAnnotations( - title="Echo Tool", - read_only_hint=True, - ), - ) - ] + annotations=ToolAnnotations( + title="Echo Tool", + read_only_hint=True, + ), + ) + ] + ) + + server = Server("test", on_list_tools=handle_list_tools) tools_result = None server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 88fd1e38f..9f9c6bbfb 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -1,11 +1,12 @@ -from collections.abc import Iterable +import base64 from pathlib import Path from tempfile import NamedTemporaryFile import pytest from mcp import types -from mcp.server.lowlevel.server import ReadResourceContents, Server +from mcp.server.context import ServerRequestContext +from mcp.server.lowlevel.server import Server @pytest.fixture @@ -23,22 +24,34 @@ def temp_file(): @pytest.mark.anyio async def test_read_resource_text(temp_file: Path): - server = Server("test") - - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content="Hello World", mime_type="text/plain")] + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text="Hello World", + mime_type="text/plain", + ) + ] + ) + + server = Server("test", on_read_resource=handle_read_resource) # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] + handler = server._request_handlers["resources/read"] + + # Create a mock context + from unittest.mock import MagicMock + + mock_ctx = MagicMock(spec=ServerRequestContext) - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) + # Create params + params = types.ReadResourceRequestParams(uri=temp_file.as_uri()) # Call the handler - result = await handler(request) + result = await handler(mock_ctx, params) assert isinstance(result, types.ReadResourceResult) assert len(result.contents) == 1 @@ -50,22 +63,34 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: @pytest.mark.anyio async def test_read_resource_binary(temp_file: Path): - server = Server("test") - - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")] + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + return types.ReadResourceResult( + contents=[ + types.BlobResourceContents( + uri=str(params.uri), + blob=base64.standard_b64encode(b"Hello World").decode(), + mime_type="application/octet-stream", + ) + ] + ) + + server = Server("test", on_read_resource=handle_read_resource) # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] + handler = server._request_handlers["resources/read"] + + # Create a mock context + from unittest.mock import MagicMock - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) + mock_ctx = MagicMock(spec=ServerRequestContext) + + # Create params + params = types.ReadResourceRequestParams(uri=temp_file.as_uri()) # Call the handler - result = await handler(request) + result = await handler(mock_ctx, params) assert isinstance(result, types.ReadResourceResult) assert len(result.contents) == 1 @@ -76,27 +101,34 @@ async def read_resource(uri: str) -> Iterable[ReadResourceContents]: @pytest.mark.anyio async def test_read_resource_default_mime(temp_file: Path): - server = Server("test") - - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ - ReadResourceContents( - content="Hello World", - # No mime_type specified, should default to text/plain - ) - ] + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text="Hello World", + mime_type="text/plain", + ) + ] + ) + + server = Server("test", on_read_resource=handle_read_resource) # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] + handler = server._request_handlers["resources/read"] + + # Create a mock context + from unittest.mock import MagicMock + + mock_ctx = MagicMock(spec=ServerRequestContext) - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) + # Create params + params = types.ReadResourceRequestParams(uri=temp_file.as_uri()) # Call the handler - result = await handler(request) + result = await handler(mock_ctx, params) assert isinstance(result, types.ReadResourceResult) assert len(result.contents) == 1 diff --git a/tests/server/test_session.py b/tests/server/test_session.py index d353e46e4..1fa542719 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -6,6 +6,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -14,11 +15,15 @@ from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, + CompleteRequestParams, + CompleteResult, Completion, CompletionArgument, - CompletionContext, CompletionsCapability, InitializedNotification, + ListPromptsResult, + ListResourcesResult, + PaginatedRequestParams, Prompt, PromptReference, PromptsCapability, @@ -85,47 +90,56 @@ async def run_server(): @pytest.mark.anyio async def test_server_capabilities(): - server = Server("test") notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} - # Initially no capabilities + # Initially no capabilities (no handlers registered) + server = Server("test") caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts is None assert caps.resources is None assert caps.completions is None - # Add a prompts handler - @server.list_prompts() - async def list_prompts() -> list[Prompt]: # pragma: no cover - return [] + # Create server with a prompts handler + async def list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: # pragma: no cover + return ListPromptsResult(prompts=[]) + server = Server("test", on_list_prompts=list_prompts) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources is None assert caps.completions is None - # Add a resources handler - @server.list_resources() - async def list_resources() -> list[Resource]: # pragma: no cover - return [] + # Create server with both prompts and resources handlers + async def list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: # pragma: no cover + return ListResourcesResult(resources=[]) + server = Server("test", on_list_prompts=list_prompts, on_list_resources=list_resources) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) assert caps.completions is None - # Add a complete handler - @server.completion() - async def complete( # pragma: no cover - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - return Completion( - values=["completion1", "completion2"], + # Create server with prompts, resources, and completion handlers + async def complete( + ctx: ServerRequestContext, params: CompleteRequestParams + ) -> CompleteResult: # pragma: no cover + return CompleteResult( + completion=Completion( + values=["completion1", "completion2"], + ), ) + server = Server( + "test", + on_list_prompts=list_prompts, + on_list_resources=list_resources, + on_completion=complete, + ) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 31238b9ff..e82f5bb9d 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -1,25 +1,30 @@ +from typing import Any + import pytest -from mcp import Client +from mcp import Client, types from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.types import EmptyResult, Resource -@pytest.fixture -def mcp_server() -> Server: - server = Server(name="test_server") - - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ +async def handle_list_resources( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: # pragma: no cover + return types.ListResourcesResult( + resources=[ Resource( uri="memory://test", name="Test Resource", description="A test resource", ) ] + ) - return server + +@pytest.fixture +def mcp_server() -> Server: + return Server(name="test_server", on_list_resources=handle_list_resources) @pytest.mark.anyio diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ab117f1f0..1a35f8af5 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -7,6 +7,7 @@ from mcp import Client, types from mcp.client.session import ClientSession from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -23,27 +24,6 @@ async def test_bidirectional_progress_notifications(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) - # Run a server session so we can send progress updates in tool - async def run_server(): - # Create a server session - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="ProgressTestServer", - server_version="0.1.0", - capabilities=server.get_capabilities(NotificationOptions(), {}), - ), - ) as server_session: - global serv_sesh - - serv_sesh = server_session - async for message in server_session.incoming_messages: - try: - await server._handle_message(message, server_session, {}) - except Exception as e: # pragma: no cover - raise e - # Track progress updates server_progress_updates: list[dict[str, Any]] = [] client_progress_updates: list[dict[str, Any]] = [] @@ -52,42 +32,40 @@ async def run_server(): server_progress_token = "server_token_123" client_progress_token = "client_token_456" - # Create a server with progress capability - server = Server(name="ProgressTestServer") + serv_sesh: ServerSession | None = None - # Register progress handler - @server.progress_notification() async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + ctx: ServerRequestContext[Any], params: types.ProgressNotificationParams + ) -> None: server_progress_updates.append( { - "token": progress_token, - "progress": progress, - "total": total, - "message": message, + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, } ) - # Register list tool handler - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list[types.TextContent]: + async def handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams + ) -> types.CallToolResult: + assert serv_sesh is not None # Make sure we received a progress token - if name == "test_tool": + if params.name == "test_tool": + arguments = params.arguments if arguments and "_meta" in arguments: progressToken = arguments["_meta"]["progressToken"] @@ -122,9 +100,37 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ else: # pragma: no cover raise ValueError("Progress token not sent.") - return [types.TextContent(type="text", text="Tool executed successfully")] + return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) + + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + # Create a server with progress capability + server = Server( + name="ProgressTestServer", + on_progress=handle_progress, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + # Run a server session so we can send progress updates in tool + async def run_server(): + nonlocal serv_sesh + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + serv_sesh = server_session + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, {}) + except Exception as e: # pragma: no cover + raise e # Client message handler to store progress notifications async def handle_client_message( @@ -217,22 +223,22 @@ async def test_progress_context_manager(): # Track progress updates server_progress_updates: list[dict[str, Any]] = [] - server = Server(name="ProgressContextTestServer") - - progress_token = None - - # Register progress handler - @server.progress_notification() async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + ctx: ServerRequestContext[Any], params: types.ProgressNotificationParams + ) -> None: server_progress_updates.append( - {"token": progress_token, "progress": progress, "total": total, "message": message} + { + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, + } ) + server = Server(name="ProgressContextTestServer", on_progress=handle_progress) + + progress_token = None + # Run server session to receive progress updates async def run_server(): # Create a server session @@ -333,31 +339,39 @@ def mock_log_exception(msg: str, *args: Any, **kwargs: Any) -> None: async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: raise ValueError("Progress callback failed!") - # Create a server with a tool that sends progress notifications - server = Server(name="TestProgressServer") - - @server.call_tool() - async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: - if name == "progress_tool": + async def handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams + ) -> types.CallToolResult: + if params.name == "progress_tool": # Send a progress notification - await server.request_context.session.send_progress_notification( - progress_token=server.request_context.request_id, + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, progress=50.0, total=100.0, message="Halfway done", ) - return [types.TextContent(type="text", text="progress_result")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] + return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + input_schema={}, + ) + ] + ) + + # Create a server with a tool that sends progress notifications + server = Server( + name="TestProgressServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Test with mocked logging with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 182b4671d..02cc6fc78 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,10 +1,9 @@ -from typing import Any - import anyio import pytest from mcp import Client, types from mcp.client.session import ClientSession +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams @@ -17,7 +16,6 @@ JSONRPCError, JSONRPCRequest, JSONRPCResponse, - TextContent, ) @@ -41,30 +39,36 @@ async def test_request_cancellation(): ev_cancelled = anyio.Event() request_id = None - # Create a server with a slow tool - server = Server(name="TestSessionServer") - - # Register the tool handler - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + async def handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams + ) -> types.CallToolResult: nonlocal request_id, ev_tool_called - if name == "slow_tool": - request_id = server.request_context.request_id + if params.name == "slow_tool": + request_id = ctx.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel - return [] # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - # Register the tool so it shows up in list_tools - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow_tool", - description="A slow tool that takes 10 seconds to complete", - input_schema={}, - ) - ] + return types.CallToolResult(content=[]) # pragma: no cover + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + input_schema={}, + ) + ] + ) + + # Create a server with a slow tool + server = Server( + name="TestSessionServer", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) async def make_request(client: Client): nonlocal ev_cancelled diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index e8ed01b46..1383a4bb9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -23,6 +23,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError @@ -42,6 +43,56 @@ SERVER_NAME = "test_server_for_SSE" +async def _handle_read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: # pragma: no cover + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + return types.ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, text=f"Read {parsed.netloc}", mimeType="text/plain")] + ) + if parsed.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return types.ReadResourceResult( + contents=[ + TextResourceContents(uri=params.uri, text=f"Slow response from {parsed.netloc}", mimeType="text/plain") + ] + ) + + raise MCPError(code=404, message="OOPS! no resource with that URI was found") + + +async def _handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: # pragma: no cover + return types.ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + +async def _handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams +) -> types.CallToolResult: # pragma: no cover + return types.CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +def _create_server_test() -> Server: # pragma: no cover + return Server( + SERVER_NAME, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -54,38 +105,6 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] - - # Test fixtures def make_server_app() -> Starlette: # pragma: no cover """Create test Starlette app with SSE transport""" @@ -94,7 +113,7 @@ def make_server_app() -> Starlette: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() + server = _create_server_test() async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: @@ -336,47 +355,48 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser assert isinstance(ping_result, EmptyResult) -# Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): # pragma: no cover - def __init__(self): - super().__init__("request_context_server") - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - headers_info = {} - context = self.request_context - if context.request: - headers_info = dict(context.request.headers) - - if name == "echo_headers": - return [TextContent(type="text", text=json.dumps(headers_info))] - elif name == "echo_context": - context_data = { - "request_id": args.get("request_id"), - "headers": headers_info, - } - return [TextContent(type="text", text=json.dumps(context_data))] - - return [TextContent(type="text", text=f"Called {name}")] - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echoes request headers", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echoes request context", - input_schema={ - "type": "object", - "properties": {"request_id": {"type": "string"}}, - "required": ["request_id"], - }, - ), - ] +# Request context handler functions for the context server +async def _ctx_handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams +) -> types.CallToolResult: # pragma: no cover + headers_info: dict[str, str] = {} + if ctx.request: + headers_info = dict(ctx.request.headers) + + if params.name == "echo_headers": + return types.CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + elif params.name == "echo_context": + args = params.arguments or {} + context_data = { + "request_id": args.get("request_id"), + "headers": headers_info, + } + return types.CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return types.CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +async def _ctx_handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: # pragma: no cover + return types.ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echoes request headers", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + input_schema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), + ] + ) def run_context_server(server_port: int) -> None: # pragma: no cover @@ -386,7 +406,11 @@ def run_context_server(server_port: int) -> None: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() + context_server = Server( + "request_context_server", + on_call_tool=_ctx_handle_call_tool, + on_list_tools=_ctx_handle_list_tools, + ) async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b04b92026..4e78cff14 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -29,6 +29,7 @@ from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -124,263 +125,288 @@ async def replay_events_after( # pragma: no cover return target_stream_id -# Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - self._lock = None # Will be initialized in async context - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise ValueError(f"Unknown resource: {uri}") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_tool_with_standalone_notification", - description="A test tool that sends a notification", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_sampling_tool", - description="A tool that triggers server-side sampling", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="wait_for_lock_with_notification", - description="A tool that sends a notification and waits for lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="release_lock", - description="A tool that releases the lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_stream_close", - description="A tool that closes SSE stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_notifications_and_close", - description="Tool that sends notification1, closes stream, sends notification2, notification3", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, - }, - ), - Tool( - name="tool_with_standalone_stream_close", - description="Tool that closes standalone GET stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), +# Module-level state for test server (used across handler functions) +_server_state: dict[str, anyio.Event | None] = {"lock": None} + + +async def _test_handle_read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: # pragma: no cover + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + return types.ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, text=f"Read {parsed.netloc}", mimeType="text/plain")] + ) + if parsed.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return types.ReadResourceResult( + contents=[ + TextResourceContents(uri=params.uri, text=f"Slow response from {parsed.netloc}", mimeType="text/plain") ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context + raise ValueError(f"Unknown resource: {uri}") - # When the tool is called, send a notification to test GET stream - if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated(uri="http://test_resource") - return [TextContent(type="text", text=f"Called {name}")] - elif name == "long_running_with_checkpoints": - # Send notifications that are part of the response stream - # This simulates a long-running tool that sends logs +async def _test_handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: # pragma: no cover + return types.ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_stream_closes", + description="Tool that closes SSE stream multiple times during execution", + input_schema={ + "type": "object", + "properties": { + "checkpoints": {"type": "integer", "default": 3}, + "sleep_time": {"type": "number", "default": 0.2}, + }, + }, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + ] + ) - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, # need for stream association - ) - await anyio.sleep(0.1) +async def _test_handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams +) -> types.CallToolResult: # pragma: no cover + name = params.name + args = params.arguments or {} - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated(uri="http://test_resource") + return types.CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - return [TextContent(type="text", text="Completed!")] + elif name == "long_running_with_checkpoints": + # Send notifications that are part of the response stream + # This simulates a long-running tool that sends logs - elif name == "test_sampling_tool": - # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="Server needs client sampling"), - ) - ], - max_tokens=100, - related_request_id=ctx.request_id, + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, # need for stream association + ) + + await anyio.sleep(0.1) + + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) + + return types.CallToolResult(content=[TextContent(type="text", text="Completed!")]) + + elif name == "test_sampling_tool": + # Test sampling by requesting the client to sample a message + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="Server needs client sampling"), ) + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) - # Return the sampling result in the tool response - # Since we're not passing tools param, result.content is single content - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) - return [ - TextContent( - type="text", - text=f"Response from sampling: {response}", - ) - ] - - elif name == "wait_for_lock_with_notification": - # Initialize lock if not already done - if self._lock is None: - self._lock = anyio.Event() - - # First send a notification - await ctx.session.send_log_message( - level="info", - data="First notification before lock", - logger="lock_tool", - related_request_id=ctx.request_id, + # Return the sampling result in the tool response + # Since we're not passing tools param, result.content is single content + if sampling_result.content.type == "text": + response = sampling_result.content.text + else: + response = str(sampling_result.content) + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=f"Response from sampling: {response}", ) + ] + ) - # Now wait for the lock to be released - await self._lock.wait() + elif name == "wait_for_lock_with_notification": + # Initialize lock if not already done + if _server_state["lock"] is None: + _server_state["lock"] = anyio.Event() + + # First send a notification + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - # Send second notification after lock is released - await ctx.session.send_log_message( - level="info", - data="Second notification after lock", - logger="lock_tool", - related_request_id=ctx.request_id, - ) + # Now wait for the lock to be released + await _server_state["lock"].wait() - return [TextContent(type="text", text="Completed")] + # Send second notification after lock is released + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "release_lock": - assert self._lock is not None, "Lock must be initialized before releasing" + return types.CallToolResult(content=[TextContent(type="text", text="Completed")]) - # Release the lock - self._lock.set() - return [TextContent(type="text", text="Lock released")] + elif name == "release_lock": + assert _server_state["lock"] is not None, "Lock must be initialized before releasing" - elif name == "tool_with_stream_close": - # Send notification before closing - await ctx.session.send_log_message( - level="info", - data="Before close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream (triggers client reconnect) - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Continue processing (events stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="After close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="Done")] - - elif name == "tool_with_multiple_notifications_and_close": - # Send notification1 - await ctx.session.send_log_message( - level="info", - data="notification1", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Send notification2, notification3 (stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="notification2", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - await ctx.session.send_log_message( - level="info", - data="notification3", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="All notifications sent")] + # Release the lock + _server_state["lock"].set() + return types.CallToolResult(content=[TextContent(type="text", text="Lock released")]) + + elif name == "tool_with_stream_close": + # Send notification before closing + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + # Close SSE stream (triggers client reconnect) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Continue processing (events stored in event_store) + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return types.CallToolResult(content=[TextContent(type="text", text="Done")]) + + elif name == "tool_with_multiple_notifications_and_close": + # Send notification1 + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + # Close SSE stream + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Send notification2, notification3 (stored in event_store) + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return types.CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) + + elif name == "tool_with_multiple_stream_closes": + num_checkpoints = args.get("checkpoints", 3) + sleep_time = args.get("sleep_time", 0.2) + + for i in range(num_checkpoints): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) + if ctx.close_sse_stream: + await ctx.close_sse_stream() - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) + await anyio.sleep(sleep_time) - if ctx.close_sse_stream: - await ctx.close_sse_stream() + return types.CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - await anyio.sleep(sleep_time) + elif name == "tool_with_standalone_stream_close": + # Test for GET stream reconnection + # 1. Send unsolicited notification via GET stream (no related_request_id) + await ctx.session.send_resource_updated(uri="http://notification_1") - return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + # Small delay to ensure notification is flushed before closing + await anyio.sleep(0.1) - elif name == "tool_with_standalone_stream_close": - # Test for GET stream reconnection - # 1. Send unsolicited notification via GET stream (no related_request_id) - await ctx.session.send_resource_updated(uri="http://notification_1") + # 2. Close the standalone GET stream + if ctx.close_standalone_sse_stream: + await ctx.close_standalone_sse_stream() - # Small delay to ensure notification is flushed before closing - await anyio.sleep(0.1) + # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) + await anyio.sleep(1.5) - # 2. Close the standalone GET stream - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + # 4. Send another notification on the new GET stream connection + await ctx.session.send_resource_updated(uri="http://notification_2") - # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) - await anyio.sleep(1.5) + return types.CallToolResult(content=[TextContent(type="text", text="Standalone stream close test done")]) - # 4. Send another notification on the new GET stream connection - await ctx.session.send_resource_updated(uri="http://notification_2") + return types.CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - return [TextContent(type="text", text="Standalone stream close test done")] - return [TextContent(type="text", text=f"Called {name}")] +def _create_test_server() -> Server: # pragma: no cover + """Create a test Server instance with handlers registered via constructor kwargs.""" + return Server( + SERVER_NAME, + on_read_resource=_test_handle_read_resource, + on_list_tools=_test_handle_list_tools, + on_call_tool=_test_handle_call_tool, + ) def create_app( @@ -396,7 +422,7 @@ def create_app( retry_interval: Retry interval in milliseconds for SSE polling. """ # Create server instance - server = ServerTest() + server = _create_test_server() # Create the session manager security_settings = TransportSecuritySettings( @@ -1384,70 +1410,80 @@ async def sampling_callback( assert captured_message_params.messages[0].content.text == "Server needs client sampling" -# Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__("ContextAwareServer") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echo request headers from context", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echo request context with custom data", - input_schema={ - "type": "object", - "properties": { - "request_id": {"type": "string"}, - }, - "required": ["request_id"], +# Context-aware server handler functions for testing request context propagation + + +async def _ctx_list_tools( # pragma: no cover + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echo request headers from context", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + input_schema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, }, - ), + "required": ["request_id"], + }, + ), + ] + ) + + +async def _ctx_call_tool( # pragma: no cover + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams +) -> types.CallToolResult: + name = params.name + args = params.arguments or {} + + if name == "echo_headers": + # Access the request object from context + headers_info = {} + if ctx.request and isinstance(ctx.request, Request): + headers_info = dict(ctx.request.headers) + return types.CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + + elif name == "echo_context": + # Return full context information + context_data: dict[str, Any] = { + "request_id": args.get("request_id"), + "headers": {}, + "method": None, + "path": None, + } + if ctx.request and isinstance(ctx.request, Request): + request = ctx.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=json.dumps(context_data), + ) ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - - if name == "echo_headers": - # Access the request object from context - headers_info = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return [TextContent(type="text", text=json.dumps(headers_info))] - - elif name == "echo_context": - # Return full context information - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return [ - TextContent( - type="text", - text=json.dumps(context_data), - ) - ] - - return [TextContent(type="text", text=f"Unknown tool: {name}")] + return types.CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # Server runner for context-aware testing def run_context_aware_server(port: int): # pragma: no cover """Run the context-aware test server.""" - server = ContextAwareServerTest() + server = Server( + "ContextAwareServer", + on_list_tools=_ctx_list_tools, + on_call_tool=_ctx_call_tool, + ) session_manager = StreamableHTTPSessionManager( app=server, diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 07e19195d..85de6e49b 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,10 +12,11 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp import MCPError +from mcp import MCPError, types from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client from mcp.server import Server +from mcp.server.context import ServerRequestContext from mcp.server.websocket import websocket_server from mcp.types import EmptyResult, InitializeResult, ReadResourceResult, TextContent, TextResourceContents, Tool from tests.test_helpers import wait_for_server @@ -35,42 +36,54 @@ def server_url(server_port: int) -> str: return f"ws://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) +# Module-level handlers for the test server - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - raise MCPError(code=404, message="OOPS! no resource with that URI was found") +async def handle_read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: # pragma: no cover + parsed = urlparse(str(params.uri)) + if parsed.scheme == "foobar": + return types.ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=f"Read {parsed.netloc}")]) + elif parsed.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return types.ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, text=f"Slow response from {parsed.netloc}")] + ) - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] + raise MCPError(code=404, message="OOPS! no resource with that URI was found") - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] + +async def handle_list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: # pragma: no cover + return types.ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + +async def handle_call_tool( + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams +) -> types.CallToolResult: # pragma: no cover + return types.CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) # Test fixtures def make_server_app() -> Starlette: # pragma: no cover """Create test Starlette app with WebSocket transport""" - server = ServerTest() + server = Server( + SERVER_NAME, + on_read_resource=handle_read_resource, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) async def handle_ws(websocket: WebSocket): async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: