diff --git a/AGENTS.md b/AGENTS.md index 3615e713a..95377efc4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -68,12 +68,6 @@ strands-agents/ │ │ │ ├── _executor.py # Base executor │ │ │ ├── concurrent.py # Thread/process pool │ │ │ └── sequential.py # Sequential execution -│ │ ├── mcp/ # Model Context Protocol -│ │ │ ├── mcp_client.py # MCP client implementation -│ │ │ ├── mcp_agent_tool.py # MCP tool wrapper -│ │ │ ├── mcp_types.py # MCP type definitions -│ │ │ ├── mcp_tasks.py # Task-augmented execution config -│ │ │ └── mcp_instrumentation.py # MCP telemetry │ │ └── structured_output/ # Structured output handling │ │ ├── structured_output_tool.py │ │ ├── structured_output_utils.py @@ -110,6 +104,13 @@ strands-agents/ │ │ ├── a2a.py # A2A protocol types │ │ └── models/ # Model-specific types │ │ +│ ├── mcp/ # Model Context Protocol +│ │ ├── mcp_client.py # MCP client implementation +│ │ ├── mcp_agent_tool.py # MCP tool wrapper +│ │ ├── mcp_types.py # MCP type definitions +│ │ ├── mcp_tasks.py # Task-augmented execution config +│ │ └── mcp_instrumentation.py # MCP telemetry +│ │ │ ├── session/ # Session management │ │ ├── session_manager.py # Base interface │ │ ├── file_session_manager.py # File-based storage @@ -443,7 +444,7 @@ Enable tasks by passing a `TasksConfig` to `MCPClient`: ```python from datetime import timedelta -from strands.tools.mcp import MCPClient, TasksConfig +from strands.mcp import MCPClient, TasksConfig # Enable with defaults (ttl=1min, poll_timeout=5min) client = MCPClient(transport, tasks_config={}) @@ -474,9 +475,9 @@ Task-augmented execution is used when ALL conditions are met: ### Key Files -- `src/strands/tools/mcp/mcp_tasks.py` - `TasksConfig` and defaults -- `src/strands/tools/mcp/mcp_client.py` - Task execution logic (`_call_tool_as_task_and_poll_async`) -- `tests/strands/tools/mcp/test_mcp_client_tasks.py` - Unit tests +- `src/strands/mcp/mcp_tasks.py` - `TasksConfig` and defaults +- `src/strands/mcp/mcp_client.py` - Task execution logic (`_call_tool_as_task_and_poll_async`) +- `tests/strands/mcp/test_mcp_client_tasks.py` - Unit tests - `tests_integ/mcp/test_mcp_client_tasks.py` - Integration tests - `tests_integ/mcp/task_echo_server.py` - Test server with task support diff --git a/README.md b/README.md index 173adc006..3e8ea99e6 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ Seamlessly integrate Model Context Protocol (MCP) servers: ```python from strands import Agent -from strands.tools.mcp import MCPClient +from strands.mcp import MCPClient from mcp import stdio_client, StdioServerParameters aws_docs_client = MCPClient( diff --git a/src/strands/mcp/__init__.py b/src/strands/mcp/__init__.py new file mode 100644 index 000000000..8d2c1daa2 --- /dev/null +++ b/src/strands/mcp/__init__.py @@ -0,0 +1,14 @@ +"""Model Context Protocol (MCP) integration. + +This package provides integration with the Model Context Protocol (MCP), allowing agents to use tools provided by MCP +servers. + +- Docs: https://www.anthropic.com/news/model-context-protocol +""" + +from .mcp_agent_tool import MCPAgentTool +from .mcp_client import MCPClient, ToolFilters +from .mcp_tasks import TasksConfig +from .mcp_types import MCPTransport + +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] diff --git a/src/strands/mcp/mcp_agent_tool.py b/src/strands/mcp/mcp_agent_tool.py new file mode 100644 index 000000000..5ae4d2c9e --- /dev/null +++ b/src/strands/mcp/mcp_agent_tool.py @@ -0,0 +1,119 @@ +"""MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework. + +This module provides the MCPAgentTool class which serves as an adapter between +MCP (Model Context Protocol) tools and the agent framework's tool interface. +It allows MCP tools to be seamlessly integrated and used within the agent ecosystem. +""" + +import logging +from datetime import timedelta +from typing import TYPE_CHECKING, Any + +from mcp.types import Tool as MCPTool +from typing_extensions import override + +from ..types._events import ToolResultEvent +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse + +if TYPE_CHECKING: + from .mcp_client import MCPClient + +logger = logging.getLogger(__name__) + + +class MCPAgentTool(AgentTool): + """Adapter class that wraps an MCP tool and exposes it as an AgentTool. + + This class bridges the gap between the MCP protocol's tool representation + and the agent framework's tool interface, allowing MCP tools to be used + seamlessly within the agent framework. + """ + + def __init__( + self, + mcp_tool: MCPTool, + mcp_client: "MCPClient", + name_override: str | None = None, + timeout: timedelta | None = None, + ) -> None: + """Initialize a new MCPAgentTool instance. + + Args: + mcp_tool: The MCP tool to adapt + mcp_client: The MCP server connection to use for tool invocation + name_override: Optional name to use for the agent tool (for disambiguation) + If None, uses the original MCP tool name + timeout: Optional timeout duration for tool execution + """ + super().__init__() + logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) + self.mcp_tool = mcp_tool + self.mcp_client = mcp_client + self._agent_tool_name = name_override or mcp_tool.name + self.timeout = timeout + + @property + def tool_name(self) -> str: + """Get the name of the tool. + + Returns: + str: The agent-facing name of the tool (may be disambiguated) + """ + return self._agent_tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the specification of the tool. + + This method converts the MCP tool specification to the agent framework's + ToolSpec format, including the input schema, description, and optional output schema. + + Returns: + ToolSpec: The tool specification in the agent framework format + """ + description: str = self.mcp_tool.description or f"Tool which performs {self.mcp_tool.name}" + + spec: ToolSpec = { + "inputSchema": {"json": self.mcp_tool.inputSchema}, + "name": self.tool_name, # Use agent-facing name in spec + "description": description, + } + + if self.mcp_tool.outputSchema: + spec["outputSchema"] = {"json": self.mcp_tool.outputSchema} + + return spec + + @property + def tool_type(self) -> str: + """Get the type of the tool. + + Returns: + str: The type of the tool, always "python" for MCP tools + """ + return "python" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Stream the MCP tool. + + This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and + input arguments. + + Args: + tool_use: The tool use request containing tool ID and parameters. + invocation_state: Context for the tool invocation, including agent state. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Tool events with the last being the tool result. + """ + logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"]) + + result = await self.mcp_client.call_tool_async( + tool_use_id=tool_use["toolUseId"], + name=self.mcp_tool.name, # Use original MCP name for server communication + arguments=tool_use["input"], + read_timeout_seconds=self.timeout, + ) + yield ToolResultEvent(result) diff --git a/src/strands/mcp/mcp_client.py b/src/strands/mcp/mcp_client.py new file mode 100644 index 000000000..ecfab3f6e --- /dev/null +++ b/src/strands/mcp/mcp_client.py @@ -0,0 +1,1212 @@ +"""Model Context Protocol (MCP) server connection management module. + +This module provides the MCPClient class which handles connections to MCP servers. +It manages the lifecycle of MCP connections, including initialization, tool discovery, +tool invocation, and proper cleanup of resources. The connection runs in a background +thread to avoid blocking the main application thread while maintaining communication +with the MCP service. +""" + +import asyncio +import base64 +import contextvars +import json +import logging +import threading +import uuid +from asyncio import AbstractEventLoop +from collections.abc import Callable, Coroutine, Sequence +from concurrent import futures +from datetime import timedelta +from re import Pattern +from types import TracebackType +from typing import Any, TypeVar, cast + +import anyio +from mcp import ClientSession, ListToolsResult +from mcp.client.session import ElicitationFnT +from mcp.shared.exceptions import McpError +from mcp.types import ( + BlobResourceContents, + ElicitationRequiredErrorData, + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + TextResourceContents, +) +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import EmbeddedResource as MCPEmbeddedResource +from mcp.types import ImageContent as MCPImageContent +from mcp.types import TextContent as MCPTextContent +from pydantic import AnyUrl +from typing_extensions import Protocol, TypedDict + +from ..tools.tool_provider import ToolProvider +from ..types import PaginatedList +from ..types.exceptions import MCPClientInitializationError, ToolProviderException +from ..types.media import ImageFormat +from ..types.tools import AgentTool, ToolResultContent, ToolResultStatus +from .mcp_agent_tool import MCPAgentTool +from .mcp_instrumentation import mcp_instrumentation +from .mcp_tasks import DEFAULT_TASK_CONFIG, DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL, TasksConfig +from .mcp_types import MCPToolResult, MCPTransport + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class _ToolFilterCallback(Protocol): + def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... + + +_ToolMatcher = str | Pattern[str] | _ToolFilterCallback + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + """ + + allowed: list[_ToolMatcher] + rejected: list[_ToolMatcher] + + +MIME_TO_FORMAT: dict[str, ImageFormat] = { + "image/jpeg": "jpeg", + "image/jpg": "jpeg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", +} + +CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = ( + "the client session is not running. Ensure the agent is used within " + "the MCP client context manager. For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror" +) + +# Non-fatal error patterns that should not cause connection collapse +_NON_FATAL_ERROR_PATTERNS = [ + # Occurs when client receives response with unrecognized ID + # Can occur after a client-side timeout + # See: https://github.com/modelcontextprotocol/python-sdk/blob/c51936f61f35a15f0b1f8fb6887963e5baee1506/src/mcp/shared/session.py#L421 + "unknown request id", +] + + +class MCPClient(ToolProvider): + """Represents a connection to a Model Context Protocol (MCP) server. + + This class implements a context manager pattern for efficient connection management, + allowing reuse of the same connection for multiple tool calls to reduce latency. + It handles the creation, initialization, and cleanup of MCP connections. + + The connection runs in a background thread to avoid blocking the main application thread + while maintaining communication with the MCP service. When structured content is available + from MCP tools, it will be returned as the last item in the content array of the ToolResult. + """ + + def __init__( + self, + transport_callable: Callable[[], MCPTransport], + *, + startup_timeout: int = 30, + tool_filters: ToolFilters | None = None, + prefix: str | None = None, + elicitation_callback: ElicitationFnT | None = None, + tasks_config: TasksConfig | None = None, + ) -> None: + """Initialize a new MCP Server connection. + + Args: + transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple. + startup_timeout: Timeout after which MCP server initialization should be cancelled. + Defaults to 30. + tool_filters: Optional filters to apply to tools. + prefix: Optional prefix for tool names. + elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + tasks_config: Configuration for MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools that support it. + See TasksConfig for details. This feature is experimental and subject to change. + """ + self._startup_timeout = startup_timeout + self._tool_filters = tool_filters + self._prefix = prefix + self._elicitation_callback = elicitation_callback + + mcp_instrumentation() + self._session_id = uuid.uuid4() + self._log_debug_with_thread("initializing MCPClient connection") + # Main thread blocks until future completes + self._init_future: futures.Future[None] = futures.Future() + # Set within the inner loop as it needs the asyncio loop + self._close_future: asyncio.futures.Future[None] | None = None + self._close_exception: None | Exception = None + # Do not want to block other threads while close event is false + self._transport_callable = transport_callable + + self._background_thread: threading.Thread | None = None + self._background_thread_session: ClientSession | None = None + self._background_thread_event_loop: AbstractEventLoop | None = None + self._loaded_tools: list[MCPAgentTool] | None = None + self._tool_provider_started = False + self.server_instructions: str | None = None + self._consumers: set[Any] = set() + + # Task support configuration and caching + self._tasks_config = tasks_config + self._server_task_capable: bool | None = None + + # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) + if self._is_tasks_enabled(): + from mcp.types import TaskExecutionMode + + self._tool_task_support_cache: dict[str, TaskExecutionMode] = {} + + def __enter__(self) -> "MCPClient": + """Context manager entry point which initializes the MCP server connection. + + TODO: Refactor to lazy initialization pattern following idiomatic Python. + Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead. + """ + return self.start() + + def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None: + """Context manager exit point that cleans up resources.""" + self.stop(exc_type, exc_val, exc_tb) + + def start(self) -> "MCPClient": + """Starts the background thread and waits for initialization. + + This method starts the background thread that manages the MCP connection + and blocks until the connection is ready or times out. + + Returns: + self: The MCPClient instance + + Raises: + Exception: If the MCP connection fails to initialize within the timeout period + """ + if self._is_session_active(): + raise MCPClientInitializationError("the client session is currently running") + + self._log_debug_with_thread("entering MCPClient context") + # Copy context vars to propagate to the background thread + # This ensures that context set in the main thread is accessible in the background thread + # See: https://github.com/strands-agents/sdk-python/issues/1440 + ctx = contextvars.copy_context() + self._background_thread = threading.Thread(target=ctx.run, args=(self._background_task,), daemon=True) + self._background_thread.start() + self._log_debug_with_thread("background thread started, waiting for ready event") + try: + # Blocking main thread until session is initialized in other thread or if the thread stops + self._init_future.result(timeout=self._startup_timeout) + self._log_debug_with_thread("the client initialization was successful") + except futures.TimeoutError as e: + logger.exception("client initialization timed out") + # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit + self.stop(None, None, None) + raise MCPClientInitializationError( + f"background thread did not start in {self._startup_timeout} seconds" + ) from e + except Exception as e: + logger.exception("client failed to initialize") + # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit + self.stop(None, None, None) + raise MCPClientInitializationError("the client initialization failed") from e + return self + + # ToolProvider interface methods + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: + """Load and return tools from the MCP server. + + This method implements the ToolProvider interface by loading tools + from the MCP server and caching them for reuse. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of AgentTool instances from the MCP server. + """ + logger.debug( + "started=<%s>, cached_tools=<%s> | loading tools", + self._tool_provider_started, + self._loaded_tools is not None, + ) + + if not self._tool_provider_started: + try: + logger.debug("starting MCP client") + self.start() + self._tool_provider_started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._loaded_tools is None: + logger.debug("loading tools from MCP server") + self._loaded_tools = [] + pagination_token = None + page_count = 0 + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + # Use constructor defaults for prefix and filters in load_tools + paginated_tools = self.list_tools_sync( + pagination_token, prefix=self._prefix, tool_filters=self._tool_filters + ) + + # Tools are already filtered by list_tools_sync, so add them all + for tool in paginated_tools: + self._loaded_tools.append(tool) + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._loaded_tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) + + return self._loaded_tools + + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + """ + self._consumers.add(consumer_id) + logger.debug("added provider consumer, count=%d", len(self._consumers)) + + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + This method is idempotent - calling it multiple times with the same ID + has no additional effect after the first call. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + Uses existing synchronous stop() method for safe cleanup. + """ + self._consumers.discard(consumer_id) + logger.debug("removed provider consumer, count=%d", len(self._consumers)) + + if not self._consumers and self._tool_provider_started: + logger.debug("no consumers remaining, cleaning up") + try: + self.stop(None, None, None) # Existing sync method - safe for finalizers + self._tool_provider_started = False + self._loaded_tools = None + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # MCP-specific methods + + def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: + """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. + + This method is defensive and can handle partial initialization states that may occur + if start() fails partway through initialization. + + Resources to cleanup: + - _background_thread: Thread running the async event loop + - _background_thread_session: MCP ClientSession (auto-closed by context manager) + - _background_thread_event_loop: AsyncIO event loop in background thread + - _close_future: AsyncIO future to signal thread shutdown + - _close_exception: Exception that caused the background thread shutdown; None if a normal shutdown occurred. + - _init_future: Future for initialization synchronization + + Cleanup order: + 1. Signal close future to background thread (if session initialized) + 2. Wait for background thread to complete + 3. Reset all state for reuse + + Args: + exc_type: Exception type if an exception was raised in the context + exc_val: Exception value if an exception was raised in the context + exc_tb: Exception traceback if an exception was raised in the context + """ + self._log_debug_with_thread("exiting MCPClient context") + + # Only try to signal close future if we have a background thread + if self._background_thread is not None: + # Signal close future if event loop exists + if self._background_thread_event_loop is not None: + + async def _set_close_event() -> None: + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) + + # Not calling _invoke_on_background_thread since the session does not need to exist + # we only need the thread and event loop to exist. + asyncio.run_coroutine_threadsafe(coro=_set_close_event(), loop=self._background_thread_event_loop) + + self._log_debug_with_thread("waiting for background thread to join") + self._background_thread.join() + + if self._background_thread_event_loop is not None: + self._background_thread_event_loop.close() + + self._log_debug_with_thread("background thread is closed, MCPClient context exited") + + # Reset fields to allow instance reuse + self._init_future = futures.Future() + self._background_thread = None + self._background_thread_session = None + self._background_thread_event_loop = None + self._session_id = uuid.uuid4() + self._loaded_tools = None + self._tool_provider_started = False + self._consumers = set() + self._server_task_capable = None + self._tool_task_support_cache = {} + + if self._close_exception: + exception = self._close_exception + self._close_exception = None + raise RuntimeError("Connection to the MCP server was closed") from exception + + def list_tools_sync( + self, + pagination_token: str | None = None, + prefix: str | None = None, + tool_filters: ToolFilters | None = None, + ) -> PaginatedList[MCPAgentTool]: + """Synchronously retrieves the list of available tools from the MCP server. + + This method calls the asynchronous list_tools method on the MCP session + and adapts the returned tools to the AgentTool interface. + + Args: + pagination_token: Optional token for pagination + prefix: Optional prefix to apply to tool names. If None, uses constructor default. + If explicitly provided (including empty string), overrides constructor default. + tool_filters: Optional filters to apply to tools. If None, uses constructor default. + If explicitly provided (including empty dict), overrides constructor default. + + Returns: + List[AgentTool]: A list of available tools adapted to the AgentTool interface + """ + self._log_debug_with_thread("listing MCP tools synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + effective_prefix = self._prefix if prefix is None else prefix + effective_filters = self._tool_filters if tool_filters is None else tool_filters + + async def _list_tools_async() -> ListToolsResult: + return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) + + list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() + self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) + + mcp_tools = [] + for tool in list_tools_response.tools: + if self._is_tasks_enabled(): + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support or "forbidden" + + # Apply prefix if specified + if effective_prefix: + prefixed_name = f"{effective_prefix}_{tool.name}" + mcp_tool = MCPAgentTool(tool, self, name_override=prefixed_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", tool.name, prefixed_name) + else: + mcp_tool = MCPAgentTool(tool, self) + + # Apply filters if specified + if self._should_include_tool_with_filters(mcp_tool, effective_filters): + mcp_tools.append(mcp_tool) + + self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) + return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) + + def list_prompts_sync(self, pagination_token: str | None = None) -> ListPromptsResult: + """Synchronously retrieves the list of available prompts from the MCP server. + + This method calls the asynchronous list_prompts method on the MCP session + and returns the raw ListPromptsResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListPromptsResult: The raw MCP response containing prompts and pagination info + """ + self._log_debug_with_thread("listing MCP prompts synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_prompts_async() -> ListPromptsResult: + return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token) + + list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() + self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts)) + for prompt in list_prompts_result.prompts: + self._log_debug_with_thread(prompt.name) + + return list_prompts_result + + def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult: + """Synchronously retrieves a prompt from the MCP server. + + Args: + prompt_id: The ID of the prompt to retrieve + args: Optional arguments to pass to the prompt + + Returns: + GetPromptResult: The prompt response from the MCP server + """ + self._log_debug_with_thread("getting MCP prompt synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _get_prompt_async() -> GetPromptResult: + return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args) + + get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() + self._log_debug_with_thread("received prompt from MCP server") + + return get_prompt_result + + def list_resources_sync(self, pagination_token: str | None = None) -> ListResourcesResult: + """Synchronously retrieves the list of available resources from the MCP server. + + This method calls the asynchronous list_resources method on the MCP session + and returns the raw ListResourcesResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourcesResult: The raw MCP response containing resources and pagination info + """ + self._log_debug_with_thread("listing MCP resources synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resources_async() -> ListResourcesResult: + return await cast(ClientSession, self._background_thread_session).list_resources(cursor=pagination_token) + + list_resources_result: ListResourcesResult = self._invoke_on_background_thread(_list_resources_async()).result() + self._log_debug_with_thread("received %d resources from MCP server", len(list_resources_result.resources)) + + return list_resources_result + + def read_resource_sync(self, uri: AnyUrl | str) -> ReadResourceResult: + """Synchronously reads a resource from the MCP server. + + Args: + uri: The URI of the resource to read + + Returns: + ReadResourceResult: The resource content from the MCP server + """ + self._log_debug_with_thread("reading MCP resource synchronously: %s", uri) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _read_resource_async() -> ReadResourceResult: + # Convert string to AnyUrl if needed + resource_uri = AnyUrl(uri) if isinstance(uri, str) else uri + return await cast(ClientSession, self._background_thread_session).read_resource(resource_uri) + + read_resource_result: ReadResourceResult = self._invoke_on_background_thread(_read_resource_async()).result() + self._log_debug_with_thread("received resource content from MCP server") + + return read_resource_result + + def list_resource_templates_sync(self, pagination_token: str | None = None) -> ListResourceTemplatesResult: + """Synchronously retrieves the list of available resource templates from the MCP server. + + Resource templates define URI patterns that can be used to access resources dynamically. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourceTemplatesResult: The raw MCP response containing resource templates and pagination info + """ + self._log_debug_with_thread("listing MCP resource templates synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resource_templates_async() -> ListResourceTemplatesResult: + return await cast(ClientSession, self._background_thread_session).list_resource_templates( + cursor=pagination_token + ) + + list_resource_templates_result: ListResourceTemplatesResult = self._invoke_on_background_thread( + _list_resource_templates_async() + ).result() + self._log_debug_with_thread( + "received %d resource templates from MCP server", len(list_resource_templates_result.resourceTemplates) + ) + + return list_resource_templates_result + + def _create_call_tool_coroutine( + self, + name: str, + arguments: dict[str, Any] | None, + read_timeout_seconds: timedelta | None, + meta: dict[str, Any] | None = None, + ) -> Coroutine[Any, Any, MCPCallToolResult]: + """Create the appropriate coroutine for calling a tool. + + This method encapsulates the decision logic for whether to use task-augmented + execution or direct call_tool, returning the appropriate coroutine. + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + read_timeout_seconds: Optional timeout for the tool call. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). + + Returns: + A coroutine that will execute the tool call. + """ + use_task = self._should_use_task(name) + + if use_task: + self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + + async def _call_as_task() -> MCPCallToolResult: + # When task-augmented execution is used, use the read_timeout_seconds parameter + # (which is a timedelta) for the polling timeout. + return await self._call_tool_as_task_and_poll_async( + name, arguments, poll_timeout=read_timeout_seconds, meta=meta + ) + + return _call_as_task() + else: + self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) + + async def _call_tool_direct() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds, meta=meta + ) + + return _call_tool_direct() + + def call_tool_sync( + self, + tool_use_id: str, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, + ) -> MCPToolResult: + """Synchronously calls a tool on the MCP server. + + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. + + Args: + tool_use_id: Unique identifier for this tool use + name: Name of the tool to call + arguments: Optional arguments to pass to the tool + read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) + + Returns: + MCPToolResult: The result of the tool call + """ + self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + try: + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() + return self._handle_tool_result(tool_use_id, call_tool_result) + except Exception as e: + logger.exception("tool execution failed") + return self._handle_tool_execution_error(tool_use_id, e) + + async def call_tool_async( + self, + tool_use_id: str, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, + ) -> MCPToolResult: + """Asynchronously calls a tool on the MCP server. + + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. + + Args: + tool_use_id: Unique identifier for this tool use + name: Name of the tool to call + arguments: Optional arguments to pass to the tool + read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) + + Returns: + MCPToolResult: The result of the tool call + """ + self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + try: + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + future = self._invoke_on_background_thread(coro) + call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) + return self._handle_tool_result(tool_use_id, call_tool_result) + except Exception as e: + logger.exception("tool execution failed") + return self._handle_tool_execution_error(tool_use_id, e) + + def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: + """Create error ToolResult with consistent logging and elicitation callback support. + + Args: + tool_use_id: Unique identifier for this tool use. + exception: The exception that occurred during tool execution. + + Returns: + MCPToolResult: Error result containing either the elicitation data or the + original exception message. + """ + if isinstance(exception, McpError) and exception.error.code == -32042: + try: + error_data = ElicitationRequiredErrorData.model_validate(exception.error.data) + elicitations = [e.model_dump(exclude_none=True) for e in error_data.elicitations] + + return MCPToolResult( + status="error", + toolUseId=tool_use_id, + content=[ + {"text": (f"MCP Elicitation required: [{str(exception)}] with data {json.dumps(elicitations)}")} + ], + ) + except Exception: + logger.debug("Failed to parse ElicitationRequiredErrorData from -32042 error", exc_info=True) + + return MCPToolResult( + status="error", + toolUseId=tool_use_id, + content=[{"text": f"Tool execution failed: {str(exception)}"}], + ) + + def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult: + """Maps MCP tool result to the agent's MCPToolResult format. + + This method processes the content from the MCP tool call result and converts it to the format + expected by the framework. + + Args: + tool_use_id: Unique identifier for this tool use + call_tool_result: The result from the MCP tool call + + Returns: + MCPToolResult: The converted tool result + """ + self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) + + # Build a typed list of ToolResultContent. + mapped_contents: list[ToolResultContent] = [ + mc + for content in call_tool_result.content + if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None + ] + + status: ToolResultStatus = "error" if call_tool_result.isError else "success" + self._log_debug_with_thread("tool execution completed with status: %s", status) + result = MCPToolResult( + status=status, + toolUseId=tool_use_id, + content=mapped_contents, + ) + + if call_tool_result.structuredContent: + result["structuredContent"] = call_tool_result.structuredContent + if call_tool_result.meta: + result["metadata"] = call_tool_result.meta + + return result + + async def _async_background_thread(self) -> None: + """Asynchronous method that runs in the background thread to manage the MCP connection. + + This method establishes the transport connection, creates and initializes the MCP session, + signals readiness to the main thread, and waits for a close signal. + """ + self._log_debug_with_thread("starting async background thread for MCP connection") + + # Initialized here so that it has the asyncio loop + self._close_future = asyncio.Future() + + try: + async with self._transport_callable() as (read_stream, write_stream, *_): + self._log_debug_with_thread("transport connection established") + async with ClientSession( + read_stream, + write_stream, + message_handler=self._handle_error_message, + elicitation_callback=self._elicitation_callback, + ) as session: + self._log_debug_with_thread("initializing MCP session") + init_result = await session.initialize() + + self._log_debug_with_thread("session initialized successfully") + # Store server instructions from InitializeResult for Host applications + self.server_instructions = init_result.instructions + # Store the session for use while we await the close event + self._background_thread_session = session + + # Cache server task capability immediately after initialization + # Capabilities are exchanged during session.initialize(), so this is available now + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + self._log_debug_with_thread( + "server_task_capable=<%s> | cached server task capability", self._server_task_capable + ) + + # Signal that the session has been created and is ready for use + self._init_future.set_result(None) + + self._log_debug_with_thread("waiting for close signal") + # Keep background thread running until signaled to close. + # Thread is not blocked as this a future + await self._close_future + + self._log_debug_with_thread("close signal received") + except Exception as e: + # If we encounter an exception and the future is still running, + # it means it was encountered during the initialization phase. + if not self._init_future.done(): + self._init_future.set_exception(e) + else: + # _close_future is automatically cancelled by the framework which doesn't provide us with the useful + # exception, so instead we store the exception in a different field where stop() can read it + self._close_exception = e + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) + + self._log_debug_with_thread( + "encountered exception on background thread after initialization %s", str(e) + ) + + # Raise an exception if the underlying client raises an exception in a message + # This happens when the underlying client has an http timeout error + async def _handle_error_message(self, message: Exception | Any) -> None: + if isinstance(message, Exception): + error_msg = str(message).lower() + if any(pattern in error_msg for pattern in _NON_FATAL_ERROR_PATTERNS): + self._log_debug_with_thread("ignoring non-fatal MCP session error: %s", message) + else: + raise message + await anyio.lowlevel.checkpoint() + + def _background_task(self) -> None: + """Sets up and runs the event loop in the background thread. + + This method creates a new event loop for the background thread, + sets it as the current event loop, and runs the async_background_thread + coroutine until completion. In this case "until completion" means until the _close_future is resolved. + This allows for a long-running event loop. + """ + self._log_debug_with_thread("setting up background task event loop") + # Clear any running-loop state leaked by OpenTelemetry's ThreadingInstrumentor, which wraps Thread.run() + # and can propagate the parent thread's event loop reference, causing run_until_complete() to fail. + asyncio._set_running_loop(None) + self._background_thread_event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._background_thread_event_loop) + self._background_thread_event_loop.run_until_complete(self._async_background_thread()) + + def _map_mcp_content_to_tool_result_content( + self, + content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, + ) -> ToolResultContent | None: + """Maps MCP content types to tool result content types. + + This method converts MCP-specific content types to the generic + ToolResultContent format used by the agent framework. + + Args: + content: The MCP content to convert + + Returns: + ToolResultContent or None: The converted content, or None if the content type is not supported + """ + if isinstance(content, MCPTextContent): + self._log_debug_with_thread("mapping MCP text content") + return {"text": content.text} + elif isinstance(content, MCPImageContent): + self._log_debug_with_thread("mapping MCP image content with mime type: %s", content.mimeType) + return { + "image": { + "format": MIME_TO_FORMAT[content.mimeType], + "source": {"bytes": base64.b64decode(content.data)}, + } + } + elif isinstance(content, MCPEmbeddedResource): + """ + TODO: Include URI information in results. + Models may find it useful to be aware not only of the information, + but the location of the information too. + + This may be difficult without taking an opinionated position. For example, + a content block may need to indicate that the following Image content block + is of particular URI. + """ + + self._log_debug_with_thread("mapping MCP embedded resource content") + + resource = content.resource + if isinstance(resource, TextResourceContents): + return {"text": resource.text} + elif isinstance(resource, BlobResourceContents): + try: + raw_bytes = base64.b64decode(resource.blob) + except Exception: + self._log_debug_with_thread("embedded resource blob could not be decoded - dropping") + return None + + if resource.mimeType and ( + resource.mimeType.startswith("text/") + or resource.mimeType + in ( + "application/json", + "application/xml", + "application/javascript", + "application/yaml", + "application/x-yaml", + ) + or resource.mimeType.endswith(("+json", "+xml")) + ): + try: + return {"text": raw_bytes.decode("utf-8", errors="replace")} + except Exception: + pass + + if resource.mimeType in MIME_TO_FORMAT: + return { + "image": { + "format": MIME_TO_FORMAT[resource.mimeType], + "source": {"bytes": raw_bytes}, + } + } + + self._log_debug_with_thread("embedded resource blob with non-textual/unknown mimeType - dropping") + return None + + return None # type: ignore[unreachable] # Defensive: future MCP resource types + else: + self._log_debug_with_thread("unhandled content type: %s - dropping content", content.__class__.__name__) + return None + + def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Logger helper to help differentiate logs coming from MCPClient background thread.""" + formatted_msg = msg % args if args else msg + logger.debug( + "[Thread: %s, Session: %s] %s", threading.current_thread().name, self._session_id, formatted_msg, **kwargs + ) + + def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]: + # save a reference to this so that even if it's reset we have the original + close_future = self._close_future + + if ( + self._background_thread_session is None + or self._background_thread_event_loop is None + or close_future is None + ): + raise MCPClientInitializationError("the client session was not initialized") + + async def run_async() -> T: + # Fix for strands-agents/sdk-python/issues/995 - cancel all pending invocations if/when the session closes + invoke_event = asyncio.create_task(coro) + tasks: list[asyncio.Task | asyncio.Future] = [ + invoke_event, + close_future, + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if done.pop() == close_future: + self._log_debug_with_thread("event loop for the server closed before the invoke completed") + raise RuntimeError("Connection to the MCP server was closed") + else: + return await invoke_event + + invoke_future = asyncio.run_coroutine_threadsafe(coro=run_async(), loop=self._background_thread_event_loop) + return invoke_future + + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on constructor filters.""" + return self._should_include_tool_with_filters(tool, self._tool_filters) + + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: ToolFilters | None) -> bool: + """Check if a tool should be included based on provided filters.""" + if not filters: + return True + + # Apply allowed filter + if "allowed" in filters: + if not self._matches_patterns(tool, filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in filters: + if self._matches_patterns(tool, filters["rejected"]): + return False + + return True + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif isinstance(pattern, Pattern): + if pattern.match(tool.mcp_tool.name): + return True + elif isinstance(pattern, str): + if pattern == tool.mcp_tool.name: + return True + return False + + def _is_session_active(self) -> bool: + if self._background_thread is None or not self._background_thread.is_alive(): + return False + + if self._close_future is not None and self._close_future.done(): + return False + + return True + + def _is_tasks_enabled(self) -> bool: + """Check if tasks feature is enabled. + + Tasks are enabled if tasks config is defined and not None. + + Returns: + True if task-augmented execution is enabled, False otherwise. + """ + return self._tasks_config is not None + + def _get_task_config(self) -> TasksConfig: + """Returns the task execution configuration, configured with defaults if not specified.""" + task_config = self._tasks_config or DEFAULT_TASK_CONFIG + return TasksConfig( + ttl=task_config.get("ttl", DEFAULT_TASK_TTL), + poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), + ) + + def _has_server_task_support(self) -> bool: + """Check if the MCP server supports task-augmented tool calls. + + Returns the capability value that was cached immediately after session initialization. + Server capabilities are exchanged during the MCP handshake, so this is available + as soon as start() completes. + + Returns: + True if server supports task-augmented tool calls, False otherwise. + """ + return self._server_task_capable or False + + def _should_use_task(self, tool_name: str) -> bool: + """Determine if task-augmented execution should be used for a tool. + + Task-augmented execution requires: + 1. tasks config is enabled (opt-in check) + 2. Server supports tasks (capability check) + 3. Tool taskSupport is 'required' or 'optional' + + Args: + tool_name: Name of the tool to check. + + Returns: + True if task-augmented execution should be used, False otherwise. + """ + # Opt-in check: tasks must be explicitly enabled via tasks config + if not self._is_tasks_enabled(): + return False + + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_OPTIONAL, TASK_REQUIRED + + # Server capability check (per MCP spec) + if not self._has_server_task_support(): + return False + + # Tool-level capability check (cached during list_tools_sync) + task_support = self._tool_task_support_cache.get(tool_name) + + # Use tasks for TASK_REQUIRED or TASK_OPTIONAL when server supports + if task_support == TASK_REQUIRED or task_support == TASK_OPTIONAL: + return True + + # Default: 'forbidden', None, or unknown -> don't use tasks + return False + + def _create_task_error_result(self, message: str) -> MCPCallToolResult: + """Create an error MCPCallToolResult with consistent formatting. + + This helper reduces duplication in task error handling paths. + + Args: + message: The error message to include in the result. + + Returns: + MCPCallToolResult with isError=True and the message as text content. + """ + return MCPCallToolResult( + isError=True, + content=[MCPTextContent(type="text", text=message)], + ) + + # ================================================================================== + # Task-Augmented Tool Execution + # ================================================================================== + # + # The MCP spec defines task-augmented execution for long-running tools. The flow is: + # + # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) + # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() + # 3. If not using tasks: call_tool() directly + # + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + # ================================================================================== + + async def _call_tool_as_task_and_poll_async( + self, + name: str, + arguments: dict[str, Any] | None = None, + ttl: timedelta | None = None, + poll_timeout: timedelta | None = None, + meta: dict[str, Any] | None = None, + ) -> MCPCallToolResult: + """Call a tool using task-augmented execution and poll until completion. + + This method implements the MCP task workflow: + 1. Creates a task via call_tool_as_task + 2. Polls using poll_task until terminal status (with timeout protection) + 3. Gets the final result using get_task_result + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + ttl: Task time-to-live. Uses configured value if not specified. + poll_timeout: Timeout for polling. Uses configured value if not specified. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). + + Returns: + MCPCallToolResult: The final tool result after task completion. + """ + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_STATUS_CANCELLED, TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, GetTaskResult + + session = cast(ClientSession, self._background_thread_session) + + # Precedence: arg > config > default + timeout = poll_timeout or self._get_task_config().get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT) + ttl = ttl or self._get_task_config().get("ttl", DEFAULT_TASK_TTL) + ttl_ms = int(ttl.total_seconds() * 1000) + + # Step 1: Create the task + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl_ms) + create_result = await session.experimental.call_tool_as_task( + name=name, + arguments=arguments, + ttl=ttl_ms, + meta=meta, + ) + task_id = create_result.task.taskId + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) + + # Step 2: Poll until terminal status (with timeout protection) + # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility + async def _poll_until_terminal() -> GetTaskResult | None: + """Inner function to poll task status until terminal state.""" + final = None + async for task in session.experimental.poll_task(task_id): + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | task status update", + name, + task_id, + task.status, + ) + final = task + return final + + try: + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout.total_seconds()) + except asyncio.TimeoutError: + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, timeout_seconds=<%s> | task polling timed out", + name, + task_id, + timeout.total_seconds(), + ) + return self._create_task_error_result( + f"Task {task_id} polling timed out after {timeout.total_seconds()} seconds" + ) + + # Step 3: Handle terminal status + if final_status is None: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) + return self._create_task_error_result(f"Task {task_id} polling completed without status") + + if final_status.status == TASK_STATUS_FAILED: + error_msg = final_status.statusMessage or "Task failed" + self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) + return self._create_task_error_result(error_msg) + + if final_status.status == TASK_STATUS_CANCELLED: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) + return self._create_task_error_result("Task was cancelled") + + # Step 4: Get the actual result for completed tasks (with error handling for race conditions) + if final_status.status == TASK_STATUS_COMPLETED: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) + try: + result = await session.experimental.get_task_result(task_id, MCPCallToolResult) + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) + return result + except Exception as e: + # Handle race condition: task completed but result retrieval failed + # (e.g., result expired, network error, server restarted) + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) + ) + return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") + + # Unexpected status - return as error + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", + name, + task_id, + final_status.status, + ) + return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/src/strands/mcp/mcp_instrumentation.py b/src/strands/mcp/mcp_instrumentation.py new file mode 100644 index 000000000..5e64cc3d5 --- /dev/null +++ b/src/strands/mcp/mcp_instrumentation.py @@ -0,0 +1,337 @@ +"""OpenTelemetry instrumentation for Model Context Protocol (MCP) tracing. + +Enables distributed tracing across MCP client-server boundaries by injecting +OpenTelemetry context into MCP request metadata (_meta field) and extracting +it on the server side, creating unified traces that span from agent calls +through MCP tool executions. + +Based on: https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-mcp +Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 +""" + +from collections.abc import AsyncGenerator, Callable +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from dataclasses import dataclass +from typing import Any + +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCMessage, JSONRPCRequest +from opentelemetry import context, propagate +from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper + +# Module-level flag to ensure instrumentation is applied only once +_instrumentation_applied = False + + +@dataclass(slots=True, frozen=True) +class ItemWithContext: + """Wrapper for items that need to carry OpenTelemetry context. + + Used to preserve tracing context across async boundaries in MCP sessions, + ensuring that distributed traces remain connected even when messages are + processed asynchronously. + + Attributes: + item: The original item being wrapped + ctx: The OpenTelemetry context associated with the item + """ + + item: Any + ctx: context.Context + + +def mcp_instrumentation() -> None: + """Apply OpenTelemetry instrumentation patches to MCP components. + + This function instruments three key areas of MCP communication: + 1. Client-side: Injects tracing context into tool call requests + 2. Transport-level: Extracts context from incoming messages + 3. Session-level: Manages bidirectional context flow + + The patches enable distributed tracing by: + - Adding OpenTelemetry context to the _meta field of MCP requests + - Extracting and activating context on the server side + - Preserving context across async message processing boundaries + + This function is idempotent - multiple calls will not accumulate wrappers. + """ + global _instrumentation_applied + + # Return early if instrumentation has already been applied + if _instrumentation_applied: + return + + def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any) -> Any: + """Patch MCP client to inject OpenTelemetry context into tool calls. + + Intercepts outgoing MCP requests and injects the current OpenTelemetry + context into the request's _meta field for tools/call methods. This + enables server-side context extraction and trace continuation. + + Args: + wrapped: The original function being wrapped + instance: The instance the method is being called on + args: Positional arguments to the wrapped function + kwargs: Keyword arguments to the wrapped function + + Returns: + Result of the wrapped function call + """ + if len(args) < 1: + return wrapped(*args, **kwargs) + + request = args[0] + method = getattr(request.root, "method", None) + + if method != "tools/call": + return wrapped(*args, **kwargs) + + try: + if hasattr(request.root, "params") and request.root.params: + # Handle Pydantic models + if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"): + params_dict = request.root.params.model_dump(by_alias=True) + # Add _meta with tracing context + meta = params_dict.get("_meta") if params_dict.get("_meta") is not None else {} + params_dict["_meta"] = meta + propagate.get_global_textmap().inject(meta) + + # Recreate the Pydantic model with the updated data + # This preserves the original model type and avoids serialization warnings + params_class = type(request.root.params) + try: + request.root.params = params_class.model_validate(params_dict) + except Exception: + # Fallback to dict if model recreation fails + request.root.params = params_dict + + elif isinstance(request.root.params, dict): + # Handle dict params directly + meta = request.root.params.setdefault("_meta", {}) + propagate.get_global_textmap().inject(meta) + + return wrapped(*args, **kwargs) + + except Exception: + return wrapped(*args, **kwargs) + + def transport_wrapper() -> Callable[ + [Callable[..., Any], Any, Any, Any], _AsyncGeneratorContextManager[tuple[Any, Any]] + ]: + """Create a wrapper for MCP transport connections. + + Returns a context manager that wraps transport read/write streams + with context extraction capabilities. The wrapped reader will + automatically extract OpenTelemetry context from incoming messages. + + Returns: + An async context manager that yields wrapped transport streams + """ + + @asynccontextmanager + async def traced_method( + wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any + ) -> AsyncGenerator[tuple[Any, Any], None]: + async with wrapped(*args, **kwargs) as result: + try: + read_stream, write_stream = result + except ValueError: + read_stream, write_stream, _ = result + yield TransportContextExtractingReader(read_stream), write_stream + + return traced_method + + def session_init_wrapper() -> Callable[[Any, Any, tuple[Any, ...], dict[str, Any]], None]: + """Create a wrapper for MCP session initialization. + + Wraps session message streams to enable bidirectional context flow. + The reader extracts and activates context, while the writer preserves + context for async processing. + + Returns: + A function that wraps session initialization + """ + + def traced_method( + wrapped: Callable[..., Any], instance: Any, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + wrapped(*args, **kwargs) + reader = getattr(instance, "_incoming_message_stream_reader", None) + writer = getattr(instance, "_incoming_message_stream_writer", None) + if reader and writer: + instance._incoming_message_stream_reader = SessionContextAttachingReader(reader) + instance._incoming_message_stream_writer = SessionContextSavingWriter(writer) + + return traced_method + + # Apply patches + wrap_function_wrapper("mcp.shared.session", "BaseSession.send_request", patch_mcp_client) + + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.streamable_http", "StreamableHTTPServerTransport.connect", transport_wrapper() + ), + "mcp.server.streamable_http", + ) + + register_post_import_hook( + lambda _: wrap_function_wrapper("mcp.server.session", "ServerSession.__init__", session_init_wrapper()), + "mcp.server.session", + ) + + # Mark instrumentation as applied + _instrumentation_applied = True + + +class TransportContextExtractingReader(ObjectProxy): + """A proxy reader that extracts OpenTelemetry context from MCP messages. + + Wraps an async message stream reader to automatically extract and activate + OpenTelemetry context from the _meta field of incoming MCP requests. This + enables server-side trace continuation from client-injected context. + + The reader handles both SessionMessage and JSONRPCMessage formats, and + supports both dict and Pydantic model parameter structures. + """ + + def __init__(self, wrapped: Any) -> None: + """Initialize the context-extracting reader. + + Args: + wrapped: The original async stream reader to wrap + """ + super().__init__(wrapped) + + async def __aenter__(self) -> Any: + """Enter the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + """Exit the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def __aiter__(self) -> AsyncGenerator[Any, None]: + """Iterate over messages, extracting and activating context as needed. + + For each incoming message, checks if it contains tracing context in + the _meta field. If found, extracts and activates the context for + the duration of message processing, then properly detaches it. + + Yields: + Messages from the wrapped stream, processed under the appropriate + OpenTelemetry context + """ + async for item in self.__wrapped__: + if isinstance(item, SessionMessage): + request = item.message.root + elif type(item) is JSONRPCMessage: + request = item.root + else: + yield item + continue + + if isinstance(request, JSONRPCRequest) and request.params: + # Handle both dict and Pydantic model params + if hasattr(request.params, "get"): + # Dict-like access + meta = request.params.get("_meta") + elif hasattr(request.params, "_meta"): + # Direct attribute access for Pydantic models + meta = getattr(request.params, "_meta", None) + else: + meta = None + + if meta: + extracted_context = propagate.extract(meta) + restore = context.attach(extracted_context) + try: + yield item + continue + finally: + context.detach(restore) + yield item + + +class SessionContextSavingWriter(ObjectProxy): + """A proxy writer that preserves OpenTelemetry context with outgoing items. + + Wraps an async message stream writer to capture the current OpenTelemetry + context and associate it with outgoing items. This enables context + preservation across async boundaries in MCP session processing. + """ + + def __init__(self, wrapped: Any) -> None: + """Initialize the context-saving writer. + + Args: + wrapped: The original async stream writer to wrap + """ + super().__init__(wrapped) + + async def __aenter__(self) -> Any: + """Enter the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + """Exit the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def send(self, item: Any) -> Any: + """Send an item while preserving the current OpenTelemetry context. + + Captures the current context and wraps the item with it, enabling + the receiving side to restore the appropriate tracing context. + + Args: + item: The item to send through the stream + + Returns: + Result of sending the wrapped item + """ + ctx = context.get_current() + return await self.__wrapped__.send(ItemWithContext(item, ctx)) + + +class SessionContextAttachingReader(ObjectProxy): + """A proxy reader that restores OpenTelemetry context from wrapped items. + + Wraps an async message stream reader to detect ItemWithContext instances + and restore their associated OpenTelemetry context during processing. + This completes the context preservation cycle started by SessionContextSavingWriter. + """ + + def __init__(self, wrapped: Any) -> None: + """Initialize the context-attaching reader. + + Args: + wrapped: The original async stream reader to wrap + """ + super().__init__(wrapped) + + async def __aenter__(self) -> Any: + """Enter the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + """Exit the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def __aiter__(self) -> AsyncGenerator[Any, None]: + """Iterate over items, restoring context for ItemWithContext instances. + + For items wrapped with context, temporarily activates the associated + OpenTelemetry context during processing, then properly detaches it. + Regular items are yielded without context modification. + + Yields: + Unwrapped items processed under their associated OpenTelemetry context + """ + async for item in self.__wrapped__: + if isinstance(item, ItemWithContext): + restore = context.attach(item.ctx) + try: + yield item.item + finally: + context.detach(restore) + else: + yield item diff --git a/src/strands/mcp/mcp_tasks.py b/src/strands/mcp/mcp_tasks.py new file mode 100644 index 000000000..36537f7df --- /dev/null +++ b/src/strands/mcp/mcp_tasks.py @@ -0,0 +1,33 @@ +"""Task-augmented tool execution configuration for MCP. + +This module provides configuration types and defaults for the experimental MCP Tasks feature. +""" + +from datetime import timedelta + +from typing_extensions import TypedDict + + +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + When enabled, supported tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Warning: + This is an experimental feature in the 2025-11-25 MCP specification and + both the specification and the Strands Agents implementation of this + feature are subject to change. + + Attributes: + ttl: Task time-to-live. Defaults to 1 minute. + poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. + """ + + ttl: timedelta + poll_timeout: timedelta + + +DEFAULT_TASK_TTL = timedelta(minutes=1) +DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) +DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) diff --git a/src/strands/mcp/mcp_types.py b/src/strands/mcp/mcp_types.py new file mode 100644 index 000000000..f9ee5ac40 --- /dev/null +++ b/src/strands/mcp/mcp_types.py @@ -0,0 +1,67 @@ +"""Type definitions for MCP integration.""" + +from contextlib import AbstractAsyncContextManager +from typing import Any + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client.streamable_http import GetSessionIdCallback +from mcp.shared.memory import MessageStream +from mcp.shared.message import SessionMessage +from typing_extensions import NotRequired + +from ..types.tools import ToolResult + +""" +MCPTransport defines the interface for MCP transport implementations. This abstracts +communication with an MCP server, hiding details of the underlying transport mechanism (WebSocket, stdio, etc.). + +It represents an async context manager that yields a tuple of read and write streams for MCP communication. +When used with `async with`, it should establish the connection and yield the streams, then clean up +when the context is exited. + +The read stream receives messages from the client (or exceptions if parsing fails), while the write +stream sends messages to the client. + +Example implementation (simplified): +```python +@contextlib.asynccontextmanager +async def my_transport_implementation(): + # Set up connection + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + # Start background tasks to handle actual I/O + async with anyio.create_task_group() as tg: + tg.start_soon(reader_task, read_stream_writer) + tg.start_soon(writer_task, write_stream_reader) + + # Yield the streams to the caller + yield (read_stream, write_stream) +``` +""" +# GetSessionIdCallback was added for HTTP Streaming but was not applied to the MessageStream type +# https://github.com/modelcontextprotocol/python-sdk/blob/ed25167fa5d715733437996682e20c24470e8177/src/mcp/client/streamable_http.py#L418 +_MessageStreamWithGetSessionIdCallback = tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback +] +MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] + + +class MCPToolResult(ToolResult): + """Result of an MCP tool execution. + + Extends the base ToolResult with MCP-specific structured content support. + The structuredContent field contains optional JSON data returned by MCP tools + that provides structured results beyond the standard text/image/document content. + + Attributes: + structuredContent: Optional JSON object containing structured data returned + by the MCP tool. This allows MCP tools to return complex data structures + that can be processed programmatically by agents or other tools. + metadata: Optional arbitrary metadata returned by the MCP tool. This field allows + MCP servers to attach custom metadata to tool results (e.g., token usage, + performance metrics, or business-specific tracking information). + """ + + structuredContent: NotRequired[dict[str, Any]] + metadata: NotRequired[dict[str, Any]] diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index 8d2c1daa2..f3b8adf2b 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -1,14 +1,18 @@ -"""Model Context Protocol (MCP) integration. +"""Deprecated: MCP integration has moved to ``strands.mcp``. -This package provides integration with the Model Context Protocol (MCP), allowing agents to use tools provided by MCP -servers. - -- Docs: https://www.anthropic.com/news/model-context-protocol +This module re-exports the public API from its new location and emits a +``DeprecationWarning`` at import time. Update imports to ``strands.mcp``. """ -from .mcp_agent_tool import MCPAgentTool -from .mcp_client import MCPClient, ToolFilters -from .mcp_tasks import TasksConfig -from .mcp_types import MCPTransport +import warnings + +from ...mcp import MCPAgentTool, MCPClient, MCPTransport, TasksConfig, ToolFilters + +warnings.warn( + "strands.tools.mcp has moved to strands.mcp. " + "Import from strands.mcp instead; strands.tools.mcp will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) __all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index bedd93f24..00906df9c 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -1,119 +1,13 @@ -"""MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework. +"""Deprecated: moved to ``strands.mcp.mcp_agent_tool``.""" -This module provides the MCPAgentTool class which serves as an adapter between -MCP (Model Context Protocol) tools and the agent framework's tool interface. -It allows MCP tools to be seamlessly integrated and used within the agent ecosystem. -""" +import warnings -import logging -from datetime import timedelta -from typing import TYPE_CHECKING, Any +from ...mcp.mcp_agent_tool import * # noqa: F401, F403 +from ...mcp.mcp_agent_tool import MCPAgentTool # noqa: F401 -from mcp.types import Tool as MCPTool -from typing_extensions import override - -from ...types._events import ToolResultEvent -from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse - -if TYPE_CHECKING: - from .mcp_client import MCPClient - -logger = logging.getLogger(__name__) - - -class MCPAgentTool(AgentTool): - """Adapter class that wraps an MCP tool and exposes it as an AgentTool. - - This class bridges the gap between the MCP protocol's tool representation - and the agent framework's tool interface, allowing MCP tools to be used - seamlessly within the agent framework. - """ - - def __init__( - self, - mcp_tool: MCPTool, - mcp_client: "MCPClient", - name_override: str | None = None, - timeout: timedelta | None = None, - ) -> None: - """Initialize a new MCPAgentTool instance. - - Args: - mcp_tool: The MCP tool to adapt - mcp_client: The MCP server connection to use for tool invocation - name_override: Optional name to use for the agent tool (for disambiguation) - If None, uses the original MCP tool name - timeout: Optional timeout duration for tool execution - """ - super().__init__() - logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) - self.mcp_tool = mcp_tool - self.mcp_client = mcp_client - self._agent_tool_name = name_override or mcp_tool.name - self.timeout = timeout - - @property - def tool_name(self) -> str: - """Get the name of the tool. - - Returns: - str: The agent-facing name of the tool (may be disambiguated) - """ - return self._agent_tool_name - - @property - def tool_spec(self) -> ToolSpec: - """Get the specification of the tool. - - This method converts the MCP tool specification to the agent framework's - ToolSpec format, including the input schema, description, and optional output schema. - - Returns: - ToolSpec: The tool specification in the agent framework format - """ - description: str = self.mcp_tool.description or f"Tool which performs {self.mcp_tool.name}" - - spec: ToolSpec = { - "inputSchema": {"json": self.mcp_tool.inputSchema}, - "name": self.tool_name, # Use agent-facing name in spec - "description": description, - } - - if self.mcp_tool.outputSchema: - spec["outputSchema"] = {"json": self.mcp_tool.outputSchema} - - return spec - - @property - def tool_type(self) -> str: - """Get the type of the tool. - - Returns: - str: The type of the tool, always "python" for MCP tools - """ - return "python" - - @override - async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: - """Stream the MCP tool. - - This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and - input arguments. - - Args: - tool_use: The tool use request containing tool ID and parameters. - invocation_state: Context for the tool invocation, including agent state. - **kwargs: Additional keyword arguments for future extensibility. - - Yields: - Tool events with the last being the tool result. - """ - logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"]) - - result = await self.mcp_client.call_tool_async( - tool_use_id=tool_use["toolUseId"], - name=self.mcp_tool.name, # Use original MCP name for server communication - arguments=tool_use["input"], - read_timeout_seconds=self.timeout, - ) - yield ToolResultEvent(result) +warnings.warn( + "strands.tools.mcp.mcp_agent_tool has moved to strands.mcp.mcp_agent_tool. " + "Import from strands.mcp.mcp_agent_tool instead; strands.tools.mcp will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index e81dc7130..a73e8345b 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -1,1212 +1,13 @@ -"""Model Context Protocol (MCP) server connection management module. +"""Deprecated: moved to ``strands.mcp.mcp_client``.""" -This module provides the MCPClient class which handles connections to MCP servers. -It manages the lifecycle of MCP connections, including initialization, tool discovery, -tool invocation, and proper cleanup of resources. The connection runs in a background -thread to avoid blocking the main application thread while maintaining communication -with the MCP service. -""" +import warnings -import asyncio -import base64 -import contextvars -import json -import logging -import threading -import uuid -from asyncio import AbstractEventLoop -from collections.abc import Callable, Coroutine, Sequence -from concurrent import futures -from datetime import timedelta -from re import Pattern -from types import TracebackType -from typing import Any, TypeVar, cast +from ...mcp.mcp_client import * # noqa: F401, F403 +from ...mcp.mcp_client import MCPClient, ToolFilters # noqa: F401 -import anyio -from mcp import ClientSession, ListToolsResult -from mcp.client.session import ElicitationFnT -from mcp.shared.exceptions import McpError -from mcp.types import ( - BlobResourceContents, - ElicitationRequiredErrorData, - GetPromptResult, - ListPromptsResult, - ListResourcesResult, - ListResourceTemplatesResult, - ReadResourceResult, - TextResourceContents, +warnings.warn( + "strands.tools.mcp.mcp_client has moved to strands.mcp.mcp_client. " + "Import from strands.mcp.mcp_client instead; strands.tools.mcp will be removed in a future release.", + DeprecationWarning, + stacklevel=2, ) -from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import EmbeddedResource as MCPEmbeddedResource -from mcp.types import ImageContent as MCPImageContent -from mcp.types import TextContent as MCPTextContent -from pydantic import AnyUrl -from typing_extensions import Protocol, TypedDict - -from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError, ToolProviderException -from ...types.media import ImageFormat -from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus -from ..tool_provider import ToolProvider -from .mcp_agent_tool import MCPAgentTool -from .mcp_instrumentation import mcp_instrumentation -from .mcp_tasks import DEFAULT_TASK_CONFIG, DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL, TasksConfig -from .mcp_types import MCPToolResult, MCPTransport - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -class _ToolFilterCallback(Protocol): - def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... - - -_ToolMatcher = str | Pattern[str] | _ToolFilterCallback - - -class ToolFilters(TypedDict, total=False): - """Filters for controlling which MCP tools are loaded and available. - - Tools are filtered in this order: - 1. If 'allowed' is specified, only tools matching these patterns are included - 2. Tools matching 'rejected' patterns are then excluded - """ - - allowed: list[_ToolMatcher] - rejected: list[_ToolMatcher] - - -MIME_TO_FORMAT: dict[str, ImageFormat] = { - "image/jpeg": "jpeg", - "image/jpg": "jpeg", - "image/png": "png", - "image/gif": "gif", - "image/webp": "webp", -} - -CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = ( - "the client session is not running. Ensure the agent is used within " - "the MCP client context manager. For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror" -) - -# Non-fatal error patterns that should not cause connection collapse -_NON_FATAL_ERROR_PATTERNS = [ - # Occurs when client receives response with unrecognized ID - # Can occur after a client-side timeout - # See: https://github.com/modelcontextprotocol/python-sdk/blob/c51936f61f35a15f0b1f8fb6887963e5baee1506/src/mcp/shared/session.py#L421 - "unknown request id", -] - - -class MCPClient(ToolProvider): - """Represents a connection to a Model Context Protocol (MCP) server. - - This class implements a context manager pattern for efficient connection management, - allowing reuse of the same connection for multiple tool calls to reduce latency. - It handles the creation, initialization, and cleanup of MCP connections. - - The connection runs in a background thread to avoid blocking the main application thread - while maintaining communication with the MCP service. When structured content is available - from MCP tools, it will be returned as the last item in the content array of the ToolResult. - """ - - def __init__( - self, - transport_callable: Callable[[], MCPTransport], - *, - startup_timeout: int = 30, - tool_filters: ToolFilters | None = None, - prefix: str | None = None, - elicitation_callback: ElicitationFnT | None = None, - tasks_config: TasksConfig | None = None, - ) -> None: - """Initialize a new MCP Server connection. - - Args: - transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple. - startup_timeout: Timeout after which MCP server initialization should be cancelled. - Defaults to 30. - tool_filters: Optional filters to apply to tools. - prefix: Optional prefix for tool names. - elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. - tasks_config: Configuration for MCP task-augmented execution for long-running tools. - If provided (not None), enables task-augmented execution for tools that support it. - See TasksConfig for details. This feature is experimental and subject to change. - """ - self._startup_timeout = startup_timeout - self._tool_filters = tool_filters - self._prefix = prefix - self._elicitation_callback = elicitation_callback - - mcp_instrumentation() - self._session_id = uuid.uuid4() - self._log_debug_with_thread("initializing MCPClient connection") - # Main thread blocks until future completes - self._init_future: futures.Future[None] = futures.Future() - # Set within the inner loop as it needs the asyncio loop - self._close_future: asyncio.futures.Future[None] | None = None - self._close_exception: None | Exception = None - # Do not want to block other threads while close event is false - self._transport_callable = transport_callable - - self._background_thread: threading.Thread | None = None - self._background_thread_session: ClientSession | None = None - self._background_thread_event_loop: AbstractEventLoop | None = None - self._loaded_tools: list[MCPAgentTool] | None = None - self._tool_provider_started = False - self.server_instructions: str | None = None - self._consumers: set[Any] = set() - - # Task support configuration and caching - self._tasks_config = tasks_config - self._server_task_capable: bool | None = None - - # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) - if self._is_tasks_enabled(): - from mcp.types import TaskExecutionMode - - self._tool_task_support_cache: dict[str, TaskExecutionMode] = {} - - def __enter__(self) -> "MCPClient": - """Context manager entry point which initializes the MCP server connection. - - TODO: Refactor to lazy initialization pattern following idiomatic Python. - Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead. - """ - return self.start() - - def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None: - """Context manager exit point that cleans up resources.""" - self.stop(exc_type, exc_val, exc_tb) - - def start(self) -> "MCPClient": - """Starts the background thread and waits for initialization. - - This method starts the background thread that manages the MCP connection - and blocks until the connection is ready or times out. - - Returns: - self: The MCPClient instance - - Raises: - Exception: If the MCP connection fails to initialize within the timeout period - """ - if self._is_session_active(): - raise MCPClientInitializationError("the client session is currently running") - - self._log_debug_with_thread("entering MCPClient context") - # Copy context vars to propagate to the background thread - # This ensures that context set in the main thread is accessible in the background thread - # See: https://github.com/strands-agents/sdk-python/issues/1440 - ctx = contextvars.copy_context() - self._background_thread = threading.Thread(target=ctx.run, args=(self._background_task,), daemon=True) - self._background_thread.start() - self._log_debug_with_thread("background thread started, waiting for ready event") - try: - # Blocking main thread until session is initialized in other thread or if the thread stops - self._init_future.result(timeout=self._startup_timeout) - self._log_debug_with_thread("the client initialization was successful") - except futures.TimeoutError as e: - logger.exception("client initialization timed out") - # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit - self.stop(None, None, None) - raise MCPClientInitializationError( - f"background thread did not start in {self._startup_timeout} seconds" - ) from e - except Exception as e: - logger.exception("client failed to initialize") - # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit - self.stop(None, None, None) - raise MCPClientInitializationError("the client initialization failed") from e - return self - - # ToolProvider interface methods - async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: - """Load and return tools from the MCP server. - - This method implements the ToolProvider interface by loading tools - from the MCP server and caching them for reuse. - - Args: - **kwargs: Additional arguments for future compatibility. - - Returns: - List of AgentTool instances from the MCP server. - """ - logger.debug( - "started=<%s>, cached_tools=<%s> | loading tools", - self._tool_provider_started, - self._loaded_tools is not None, - ) - - if not self._tool_provider_started: - try: - logger.debug("starting MCP client") - self.start() - self._tool_provider_started = True - logger.debug("MCP client started successfully") - except Exception as e: - logger.error("error=<%s> | failed to start MCP client", e) - raise ToolProviderException(f"Failed to start MCP client: {e}") from e - - if self._loaded_tools is None: - logger.debug("loading tools from MCP server") - self._loaded_tools = [] - pagination_token = None - page_count = 0 - - while True: - logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) - # Use constructor defaults for prefix and filters in load_tools - paginated_tools = self.list_tools_sync( - pagination_token, prefix=self._prefix, tool_filters=self._tool_filters - ) - - # Tools are already filtered by list_tools_sync, so add them all - for tool in paginated_tools: - self._loaded_tools.append(tool) - - logger.debug( - "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", - page_count, - len(paginated_tools), - len(self._loaded_tools), - ) - - pagination_token = paginated_tools.pagination_token - page_count += 1 - - if pagination_token is None: - break - - logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) - - return self._loaded_tools - - def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: - """Add a consumer to this tool provider. - - Synchronous to prevent GC deadlocks when called from Agent finalizers. - """ - self._consumers.add(consumer_id) - logger.debug("added provider consumer, count=%d", len(self._consumers)) - - def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: - """Remove a consumer from this tool provider. - - This method is idempotent - calling it multiple times with the same ID - has no additional effect after the first call. - - Synchronous to prevent GC deadlocks when called from Agent finalizers. - Uses existing synchronous stop() method for safe cleanup. - """ - self._consumers.discard(consumer_id) - logger.debug("removed provider consumer, count=%d", len(self._consumers)) - - if not self._consumers and self._tool_provider_started: - logger.debug("no consumers remaining, cleaning up") - try: - self.stop(None, None, None) # Existing sync method - safe for finalizers - self._tool_provider_started = False - self._loaded_tools = None - except Exception as e: - logger.error("error=<%s> | failed to cleanup MCP client", e) - raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e - - # MCP-specific methods - - def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: - """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. - - This method is defensive and can handle partial initialization states that may occur - if start() fails partway through initialization. - - Resources to cleanup: - - _background_thread: Thread running the async event loop - - _background_thread_session: MCP ClientSession (auto-closed by context manager) - - _background_thread_event_loop: AsyncIO event loop in background thread - - _close_future: AsyncIO future to signal thread shutdown - - _close_exception: Exception that caused the background thread shutdown; None if a normal shutdown occurred. - - _init_future: Future for initialization synchronization - - Cleanup order: - 1. Signal close future to background thread (if session initialized) - 2. Wait for background thread to complete - 3. Reset all state for reuse - - Args: - exc_type: Exception type if an exception was raised in the context - exc_val: Exception value if an exception was raised in the context - exc_tb: Exception traceback if an exception was raised in the context - """ - self._log_debug_with_thread("exiting MCPClient context") - - # Only try to signal close future if we have a background thread - if self._background_thread is not None: - # Signal close future if event loop exists - if self._background_thread_event_loop is not None: - - async def _set_close_event() -> None: - if self._close_future and not self._close_future.done(): - self._close_future.set_result(None) - - # Not calling _invoke_on_background_thread since the session does not need to exist - # we only need the thread and event loop to exist. - asyncio.run_coroutine_threadsafe(coro=_set_close_event(), loop=self._background_thread_event_loop) - - self._log_debug_with_thread("waiting for background thread to join") - self._background_thread.join() - - if self._background_thread_event_loop is not None: - self._background_thread_event_loop.close() - - self._log_debug_with_thread("background thread is closed, MCPClient context exited") - - # Reset fields to allow instance reuse - self._init_future = futures.Future() - self._background_thread = None - self._background_thread_session = None - self._background_thread_event_loop = None - self._session_id = uuid.uuid4() - self._loaded_tools = None - self._tool_provider_started = False - self._consumers = set() - self._server_task_capable = None - self._tool_task_support_cache = {} - - if self._close_exception: - exception = self._close_exception - self._close_exception = None - raise RuntimeError("Connection to the MCP server was closed") from exception - - def list_tools_sync( - self, - pagination_token: str | None = None, - prefix: str | None = None, - tool_filters: ToolFilters | None = None, - ) -> PaginatedList[MCPAgentTool]: - """Synchronously retrieves the list of available tools from the MCP server. - - This method calls the asynchronous list_tools method on the MCP session - and adapts the returned tools to the AgentTool interface. - - Args: - pagination_token: Optional token for pagination - prefix: Optional prefix to apply to tool names. If None, uses constructor default. - If explicitly provided (including empty string), overrides constructor default. - tool_filters: Optional filters to apply to tools. If None, uses constructor default. - If explicitly provided (including empty dict), overrides constructor default. - - Returns: - List[AgentTool]: A list of available tools adapted to the AgentTool interface - """ - self._log_debug_with_thread("listing MCP tools synchronously") - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - effective_prefix = self._prefix if prefix is None else prefix - effective_filters = self._tool_filters if tool_filters is None else tool_filters - - async def _list_tools_async() -> ListToolsResult: - return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) - - list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() - self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) - - mcp_tools = [] - for tool in list_tools_response.tools: - if self._is_tasks_enabled(): - # Cache taskSupport for task-augmented execution decisions - task_support = None - if tool.execution is not None and tool.execution.taskSupport is not None: - task_support = tool.execution.taskSupport - self._tool_task_support_cache[tool.name] = task_support or "forbidden" - - # Apply prefix if specified - if effective_prefix: - prefixed_name = f"{effective_prefix}_{tool.name}" - mcp_tool = MCPAgentTool(tool, self, name_override=prefixed_name) - logger.debug("tool_rename=<%s->%s> | renamed tool", tool.name, prefixed_name) - else: - mcp_tool = MCPAgentTool(tool, self) - - # Apply filters if specified - if self._should_include_tool_with_filters(mcp_tool, effective_filters): - mcp_tools.append(mcp_tool) - - self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) - return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) - - def list_prompts_sync(self, pagination_token: str | None = None) -> ListPromptsResult: - """Synchronously retrieves the list of available prompts from the MCP server. - - This method calls the asynchronous list_prompts method on the MCP session - and returns the raw ListPromptsResult with pagination support. - - Args: - pagination_token: Optional token for pagination - - Returns: - ListPromptsResult: The raw MCP response containing prompts and pagination info - """ - self._log_debug_with_thread("listing MCP prompts synchronously") - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - async def _list_prompts_async() -> ListPromptsResult: - return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token) - - list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() - self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts)) - for prompt in list_prompts_result.prompts: - self._log_debug_with_thread(prompt.name) - - return list_prompts_result - - def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult: - """Synchronously retrieves a prompt from the MCP server. - - Args: - prompt_id: The ID of the prompt to retrieve - args: Optional arguments to pass to the prompt - - Returns: - GetPromptResult: The prompt response from the MCP server - """ - self._log_debug_with_thread("getting MCP prompt synchronously") - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - async def _get_prompt_async() -> GetPromptResult: - return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args) - - get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() - self._log_debug_with_thread("received prompt from MCP server") - - return get_prompt_result - - def list_resources_sync(self, pagination_token: str | None = None) -> ListResourcesResult: - """Synchronously retrieves the list of available resources from the MCP server. - - This method calls the asynchronous list_resources method on the MCP session - and returns the raw ListResourcesResult with pagination support. - - Args: - pagination_token: Optional token for pagination - - Returns: - ListResourcesResult: The raw MCP response containing resources and pagination info - """ - self._log_debug_with_thread("listing MCP resources synchronously") - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - async def _list_resources_async() -> ListResourcesResult: - return await cast(ClientSession, self._background_thread_session).list_resources(cursor=pagination_token) - - list_resources_result: ListResourcesResult = self._invoke_on_background_thread(_list_resources_async()).result() - self._log_debug_with_thread("received %d resources from MCP server", len(list_resources_result.resources)) - - return list_resources_result - - def read_resource_sync(self, uri: AnyUrl | str) -> ReadResourceResult: - """Synchronously reads a resource from the MCP server. - - Args: - uri: The URI of the resource to read - - Returns: - ReadResourceResult: The resource content from the MCP server - """ - self._log_debug_with_thread("reading MCP resource synchronously: %s", uri) - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - async def _read_resource_async() -> ReadResourceResult: - # Convert string to AnyUrl if needed - resource_uri = AnyUrl(uri) if isinstance(uri, str) else uri - return await cast(ClientSession, self._background_thread_session).read_resource(resource_uri) - - read_resource_result: ReadResourceResult = self._invoke_on_background_thread(_read_resource_async()).result() - self._log_debug_with_thread("received resource content from MCP server") - - return read_resource_result - - def list_resource_templates_sync(self, pagination_token: str | None = None) -> ListResourceTemplatesResult: - """Synchronously retrieves the list of available resource templates from the MCP server. - - Resource templates define URI patterns that can be used to access resources dynamically. - - Args: - pagination_token: Optional token for pagination - - Returns: - ListResourceTemplatesResult: The raw MCP response containing resource templates and pagination info - """ - self._log_debug_with_thread("listing MCP resource templates synchronously") - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - async def _list_resource_templates_async() -> ListResourceTemplatesResult: - return await cast(ClientSession, self._background_thread_session).list_resource_templates( - cursor=pagination_token - ) - - list_resource_templates_result: ListResourceTemplatesResult = self._invoke_on_background_thread( - _list_resource_templates_async() - ).result() - self._log_debug_with_thread( - "received %d resource templates from MCP server", len(list_resource_templates_result.resourceTemplates) - ) - - return list_resource_templates_result - - def _create_call_tool_coroutine( - self, - name: str, - arguments: dict[str, Any] | None, - read_timeout_seconds: timedelta | None, - meta: dict[str, Any] | None = None, - ) -> Coroutine[Any, Any, MCPCallToolResult]: - """Create the appropriate coroutine for calling a tool. - - This method encapsulates the decision logic for whether to use task-augmented - execution or direct call_tool, returning the appropriate coroutine. - - Args: - name: Name of the tool to call. - arguments: Optional arguments to pass to the tool. - read_timeout_seconds: Optional timeout for the tool call. - meta: Optional metadata to pass to the tool call per MCP spec (_meta). - - Returns: - A coroutine that will execute the tool call. - """ - use_task = self._should_use_task(name) - - if use_task: - self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) - - async def _call_as_task() -> MCPCallToolResult: - # When task-augmented execution is used, use the read_timeout_seconds parameter - # (which is a timedelta) for the polling timeout. - return await self._call_tool_as_task_and_poll_async( - name, arguments, poll_timeout=read_timeout_seconds, meta=meta - ) - - return _call_as_task() - else: - self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) - - async def _call_tool_direct() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds, meta=meta - ) - - return _call_tool_direct() - - def call_tool_sync( - self, - tool_use_id: str, - name: str, - arguments: dict[str, Any] | None = None, - read_timeout_seconds: timedelta | None = None, - meta: dict[str, Any] | None = None, - ) -> MCPToolResult: - """Synchronously calls a tool on the MCP server. - - This method automatically uses task-augmented execution when appropriate, - based on server capabilities and tool-level taskSupport settings. - - Args: - tool_use_id: Unique identifier for this tool use - name: Name of the tool to call - arguments: Optional arguments to pass to the tool - read_timeout_seconds: Optional timeout for the tool call - meta: Optional metadata to pass to the tool call per MCP spec (_meta) - - Returns: - MCPToolResult: The result of the tool call - """ - self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) - call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() - return self._handle_tool_result(tool_use_id, call_tool_result) - except Exception as e: - logger.exception("tool execution failed") - return self._handle_tool_execution_error(tool_use_id, e) - - async def call_tool_async( - self, - tool_use_id: str, - name: str, - arguments: dict[str, Any] | None = None, - read_timeout_seconds: timedelta | None = None, - meta: dict[str, Any] | None = None, - ) -> MCPToolResult: - """Asynchronously calls a tool on the MCP server. - - This method automatically uses task-augmented execution when appropriate, - based on server capabilities and tool-level taskSupport settings. - - Args: - tool_use_id: Unique identifier for this tool use - name: Name of the tool to call - arguments: Optional arguments to pass to the tool - read_timeout_seconds: Optional timeout for the tool call - meta: Optional metadata to pass to the tool call per MCP spec (_meta) - - Returns: - MCPToolResult: The result of the tool call - """ - self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id) - if not self._is_session_active(): - raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - - try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) - future = self._invoke_on_background_thread(coro) - call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) - return self._handle_tool_result(tool_use_id, call_tool_result) - except Exception as e: - logger.exception("tool execution failed") - return self._handle_tool_execution_error(tool_use_id, e) - - def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: - """Create error ToolResult with consistent logging and elicitation callback support. - - Args: - tool_use_id: Unique identifier for this tool use. - exception: The exception that occurred during tool execution. - - Returns: - MCPToolResult: Error result containing either the elicitation data or the - original exception message. - """ - if isinstance(exception, McpError) and exception.error.code == -32042: - try: - error_data = ElicitationRequiredErrorData.model_validate(exception.error.data) - elicitations = [e.model_dump(exclude_none=True) for e in error_data.elicitations] - - return MCPToolResult( - status="error", - toolUseId=tool_use_id, - content=[ - {"text": (f"MCP Elicitation required: [{str(exception)}] with data {json.dumps(elicitations)}")} - ], - ) - except Exception: - logger.debug("Failed to parse ElicitationRequiredErrorData from -32042 error", exc_info=True) - - return MCPToolResult( - status="error", - toolUseId=tool_use_id, - content=[{"text": f"Tool execution failed: {str(exception)}"}], - ) - - def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult: - """Maps MCP tool result to the agent's MCPToolResult format. - - This method processes the content from the MCP tool call result and converts it to the format - expected by the framework. - - Args: - tool_use_id: Unique identifier for this tool use - call_tool_result: The result from the MCP tool call - - Returns: - MCPToolResult: The converted tool result - """ - self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) - - # Build a typed list of ToolResultContent. - mapped_contents: list[ToolResultContent] = [ - mc - for content in call_tool_result.content - if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None - ] - - status: ToolResultStatus = "error" if call_tool_result.isError else "success" - self._log_debug_with_thread("tool execution completed with status: %s", status) - result = MCPToolResult( - status=status, - toolUseId=tool_use_id, - content=mapped_contents, - ) - - if call_tool_result.structuredContent: - result["structuredContent"] = call_tool_result.structuredContent - if call_tool_result.meta: - result["metadata"] = call_tool_result.meta - - return result - - async def _async_background_thread(self) -> None: - """Asynchronous method that runs in the background thread to manage the MCP connection. - - This method establishes the transport connection, creates and initializes the MCP session, - signals readiness to the main thread, and waits for a close signal. - """ - self._log_debug_with_thread("starting async background thread for MCP connection") - - # Initialized here so that it has the asyncio loop - self._close_future = asyncio.Future() - - try: - async with self._transport_callable() as (read_stream, write_stream, *_): - self._log_debug_with_thread("transport connection established") - async with ClientSession( - read_stream, - write_stream, - message_handler=self._handle_error_message, - elicitation_callback=self._elicitation_callback, - ) as session: - self._log_debug_with_thread("initializing MCP session") - init_result = await session.initialize() - - self._log_debug_with_thread("session initialized successfully") - # Store server instructions from InitializeResult for Host applications - self.server_instructions = init_result.instructions - # Store the session for use while we await the close event - self._background_thread_session = session - - # Cache server task capability immediately after initialization - # Capabilities are exchanged during session.initialize(), so this is available now - caps = session.get_server_capabilities() - self._server_task_capable = ( - caps is not None - and caps.tasks is not None - and caps.tasks.requests is not None - and caps.tasks.requests.tools is not None - and caps.tasks.requests.tools.call is not None - ) - self._log_debug_with_thread( - "server_task_capable=<%s> | cached server task capability", self._server_task_capable - ) - - # Signal that the session has been created and is ready for use - self._init_future.set_result(None) - - self._log_debug_with_thread("waiting for close signal") - # Keep background thread running until signaled to close. - # Thread is not blocked as this a future - await self._close_future - - self._log_debug_with_thread("close signal received") - except Exception as e: - # If we encounter an exception and the future is still running, - # it means it was encountered during the initialization phase. - if not self._init_future.done(): - self._init_future.set_exception(e) - else: - # _close_future is automatically cancelled by the framework which doesn't provide us with the useful - # exception, so instead we store the exception in a different field where stop() can read it - self._close_exception = e - if self._close_future and not self._close_future.done(): - self._close_future.set_result(None) - - self._log_debug_with_thread( - "encountered exception on background thread after initialization %s", str(e) - ) - - # Raise an exception if the underlying client raises an exception in a message - # This happens when the underlying client has an http timeout error - async def _handle_error_message(self, message: Exception | Any) -> None: - if isinstance(message, Exception): - error_msg = str(message).lower() - if any(pattern in error_msg for pattern in _NON_FATAL_ERROR_PATTERNS): - self._log_debug_with_thread("ignoring non-fatal MCP session error: %s", message) - else: - raise message - await anyio.lowlevel.checkpoint() - - def _background_task(self) -> None: - """Sets up and runs the event loop in the background thread. - - This method creates a new event loop for the background thread, - sets it as the current event loop, and runs the async_background_thread - coroutine until completion. In this case "until completion" means until the _close_future is resolved. - This allows for a long-running event loop. - """ - self._log_debug_with_thread("setting up background task event loop") - # Clear any running-loop state leaked by OpenTelemetry's ThreadingInstrumentor, which wraps Thread.run() - # and can propagate the parent thread's event loop reference, causing run_until_complete() to fail. - asyncio._set_running_loop(None) - self._background_thread_event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._background_thread_event_loop) - self._background_thread_event_loop.run_until_complete(self._async_background_thread()) - - def _map_mcp_content_to_tool_result_content( - self, - content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, - ) -> ToolResultContent | None: - """Maps MCP content types to tool result content types. - - This method converts MCP-specific content types to the generic - ToolResultContent format used by the agent framework. - - Args: - content: The MCP content to convert - - Returns: - ToolResultContent or None: The converted content, or None if the content type is not supported - """ - if isinstance(content, MCPTextContent): - self._log_debug_with_thread("mapping MCP text content") - return {"text": content.text} - elif isinstance(content, MCPImageContent): - self._log_debug_with_thread("mapping MCP image content with mime type: %s", content.mimeType) - return { - "image": { - "format": MIME_TO_FORMAT[content.mimeType], - "source": {"bytes": base64.b64decode(content.data)}, - } - } - elif isinstance(content, MCPEmbeddedResource): - """ - TODO: Include URI information in results. - Models may find it useful to be aware not only of the information, - but the location of the information too. - - This may be difficult without taking an opinionated position. For example, - a content block may need to indicate that the following Image content block - is of particular URI. - """ - - self._log_debug_with_thread("mapping MCP embedded resource content") - - resource = content.resource - if isinstance(resource, TextResourceContents): - return {"text": resource.text} - elif isinstance(resource, BlobResourceContents): - try: - raw_bytes = base64.b64decode(resource.blob) - except Exception: - self._log_debug_with_thread("embedded resource blob could not be decoded - dropping") - return None - - if resource.mimeType and ( - resource.mimeType.startswith("text/") - or resource.mimeType - in ( - "application/json", - "application/xml", - "application/javascript", - "application/yaml", - "application/x-yaml", - ) - or resource.mimeType.endswith(("+json", "+xml")) - ): - try: - return {"text": raw_bytes.decode("utf-8", errors="replace")} - except Exception: - pass - - if resource.mimeType in MIME_TO_FORMAT: - return { - "image": { - "format": MIME_TO_FORMAT[resource.mimeType], - "source": {"bytes": raw_bytes}, - } - } - - self._log_debug_with_thread("embedded resource blob with non-textual/unknown mimeType - dropping") - return None - - return None # type: ignore[unreachable] # Defensive: future MCP resource types - else: - self._log_debug_with_thread("unhandled content type: %s - dropping content", content.__class__.__name__) - return None - - def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: - """Logger helper to help differentiate logs coming from MCPClient background thread.""" - formatted_msg = msg % args if args else msg - logger.debug( - "[Thread: %s, Session: %s] %s", threading.current_thread().name, self._session_id, formatted_msg, **kwargs - ) - - def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]: - # save a reference to this so that even if it's reset we have the original - close_future = self._close_future - - if ( - self._background_thread_session is None - or self._background_thread_event_loop is None - or close_future is None - ): - raise MCPClientInitializationError("the client session was not initialized") - - async def run_async() -> T: - # Fix for strands-agents/sdk-python/issues/995 - cancel all pending invocations if/when the session closes - invoke_event = asyncio.create_task(coro) - tasks: list[asyncio.Task | asyncio.Future] = [ - invoke_event, - close_future, - ] - - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - - if done.pop() == close_future: - self._log_debug_with_thread("event loop for the server closed before the invoke completed") - raise RuntimeError("Connection to the MCP server was closed") - else: - return await invoke_event - - invoke_future = asyncio.run_coroutine_threadsafe(coro=run_async(), loop=self._background_thread_event_loop) - return invoke_future - - def _should_include_tool(self, tool: MCPAgentTool) -> bool: - """Check if a tool should be included based on constructor filters.""" - return self._should_include_tool_with_filters(tool, self._tool_filters) - - def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: ToolFilters | None) -> bool: - """Check if a tool should be included based on provided filters.""" - if not filters: - return True - - # Apply allowed filter - if "allowed" in filters: - if not self._matches_patterns(tool, filters["allowed"]): - return False - - # Apply rejected filter - if "rejected" in filters: - if self._matches_patterns(tool, filters["rejected"]): - return False - - return True - - def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> bool: - """Check if tool matches any of the given patterns.""" - for pattern in patterns: - if callable(pattern): - if pattern(tool): - return True - elif isinstance(pattern, Pattern): - if pattern.match(tool.mcp_tool.name): - return True - elif isinstance(pattern, str): - if pattern == tool.mcp_tool.name: - return True - return False - - def _is_session_active(self) -> bool: - if self._background_thread is None or not self._background_thread.is_alive(): - return False - - if self._close_future is not None and self._close_future.done(): - return False - - return True - - def _is_tasks_enabled(self) -> bool: - """Check if tasks feature is enabled. - - Tasks are enabled if tasks config is defined and not None. - - Returns: - True if task-augmented execution is enabled, False otherwise. - """ - return self._tasks_config is not None - - def _get_task_config(self) -> TasksConfig: - """Returns the task execution configuration, configured with defaults if not specified.""" - task_config = self._tasks_config or DEFAULT_TASK_CONFIG - return TasksConfig( - ttl=task_config.get("ttl", DEFAULT_TASK_TTL), - poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), - ) - - def _has_server_task_support(self) -> bool: - """Check if the MCP server supports task-augmented tool calls. - - Returns the capability value that was cached immediately after session initialization. - Server capabilities are exchanged during the MCP handshake, so this is available - as soon as start() completes. - - Returns: - True if server supports task-augmented tool calls, False otherwise. - """ - return self._server_task_capable or False - - def _should_use_task(self, tool_name: str) -> bool: - """Determine if task-augmented execution should be used for a tool. - - Task-augmented execution requires: - 1. tasks config is enabled (opt-in check) - 2. Server supports tasks (capability check) - 3. Tool taskSupport is 'required' or 'optional' - - Args: - tool_name: Name of the tool to check. - - Returns: - True if task-augmented execution should be used, False otherwise. - """ - # Opt-in check: tasks must be explicitly enabled via tasks config - if not self._is_tasks_enabled(): - return False - - # Local import to avoid errors on old SDK versions that don't support Tasks - from mcp.types import TASK_OPTIONAL, TASK_REQUIRED - - # Server capability check (per MCP spec) - if not self._has_server_task_support(): - return False - - # Tool-level capability check (cached during list_tools_sync) - task_support = self._tool_task_support_cache.get(tool_name) - - # Use tasks for TASK_REQUIRED or TASK_OPTIONAL when server supports - if task_support == TASK_REQUIRED or task_support == TASK_OPTIONAL: - return True - - # Default: 'forbidden', None, or unknown -> don't use tasks - return False - - def _create_task_error_result(self, message: str) -> MCPCallToolResult: - """Create an error MCPCallToolResult with consistent formatting. - - This helper reduces duplication in task error handling paths. - - Args: - message: The error message to include in the result. - - Returns: - MCPCallToolResult with isError=True and the message as text content. - """ - return MCPCallToolResult( - isError=True, - content=[MCPTextContent(type="text", text=message)], - ) - - # ================================================================================== - # Task-Augmented Tool Execution - # ================================================================================== - # - # The MCP spec defines task-augmented execution for long-running tools. The flow is: - # - # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) - # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() - # 3. If not using tasks: call_tool() directly - # - # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks - # ================================================================================== - - async def _call_tool_as_task_and_poll_async( - self, - name: str, - arguments: dict[str, Any] | None = None, - ttl: timedelta | None = None, - poll_timeout: timedelta | None = None, - meta: dict[str, Any] | None = None, - ) -> MCPCallToolResult: - """Call a tool using task-augmented execution and poll until completion. - - This method implements the MCP task workflow: - 1. Creates a task via call_tool_as_task - 2. Polls using poll_task until terminal status (with timeout protection) - 3. Gets the final result using get_task_result - - Args: - name: Name of the tool to call. - arguments: Optional arguments to pass to the tool. - ttl: Task time-to-live. Uses configured value if not specified. - poll_timeout: Timeout for polling. Uses configured value if not specified. - meta: Optional metadata to pass to the tool call per MCP spec (_meta). - - Returns: - MCPCallToolResult: The final tool result after task completion. - """ - # Local import to avoid errors on old SDK versions that don't support Tasks - from mcp.types import TASK_STATUS_CANCELLED, TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, GetTaskResult - - session = cast(ClientSession, self._background_thread_session) - - # Precedence: arg > config > default - timeout = poll_timeout or self._get_task_config().get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT) - ttl = ttl or self._get_task_config().get("ttl", DEFAULT_TASK_TTL) - ttl_ms = int(ttl.total_seconds() * 1000) - - # Step 1: Create the task - self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl_ms) - create_result = await session.experimental.call_tool_as_task( - name=name, - arguments=arguments, - ttl=ttl_ms, - meta=meta, - ) - task_id = create_result.task.taskId - self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) - - # Step 2: Poll until terminal status (with timeout protection) - # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility - async def _poll_until_terminal() -> GetTaskResult | None: - """Inner function to poll task status until terminal state.""" - final = None - async for task in session.experimental.poll_task(task_id): - self._log_debug_with_thread( - "tool=<%s>, task_id=<%s>, status=<%s> | task status update", - name, - task_id, - task.status, - ) - final = task - return final - - try: - final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout.total_seconds()) - except asyncio.TimeoutError: - self._log_debug_with_thread( - "tool=<%s>, task_id=<%s>, timeout_seconds=<%s> | task polling timed out", - name, - task_id, - timeout.total_seconds(), - ) - return self._create_task_error_result( - f"Task {task_id} polling timed out after {timeout.total_seconds()} seconds" - ) - - # Step 3: Handle terminal status - if final_status is None: - self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) - return self._create_task_error_result(f"Task {task_id} polling completed without status") - - if final_status.status == TASK_STATUS_FAILED: - error_msg = final_status.statusMessage or "Task failed" - self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) - return self._create_task_error_result(error_msg) - - if final_status.status == TASK_STATUS_CANCELLED: - self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) - return self._create_task_error_result("Task was cancelled") - - # Step 4: Get the actual result for completed tasks (with error handling for race conditions) - if final_status.status == TASK_STATUS_COMPLETED: - self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) - try: - result = await session.experimental.get_task_result(task_id, MCPCallToolResult) - self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) - return result - except Exception as e: - # Handle race condition: task completed but result retrieval failed - # (e.g., result expired, network error, server restarted) - self._log_debug_with_thread( - "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) - ) - return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") - - # Unexpected status - return as error - self._log_debug_with_thread( - "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", - name, - task_id, - final_status.status, - ) - return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index 5e64cc3d5..9ec796c6e 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -1,337 +1,13 @@ -"""OpenTelemetry instrumentation for Model Context Protocol (MCP) tracing. +"""Deprecated: moved to ``strands.mcp.mcp_instrumentation``.""" -Enables distributed tracing across MCP client-server boundaries by injecting -OpenTelemetry context into MCP request metadata (_meta field) and extracting -it on the server side, creating unified traces that span from agent calls -through MCP tool executions. +import warnings -Based on: https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-mcp -Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 -""" +from ...mcp.mcp_instrumentation import * # noqa: F401, F403 +from ...mcp.mcp_instrumentation import mcp_instrumentation # noqa: F401 -from collections.abc import AsyncGenerator, Callable -from contextlib import _AsyncGeneratorContextManager, asynccontextmanager -from dataclasses import dataclass -from typing import Any - -from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest -from opentelemetry import context, propagate -from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper - -# Module-level flag to ensure instrumentation is applied only once -_instrumentation_applied = False - - -@dataclass(slots=True, frozen=True) -class ItemWithContext: - """Wrapper for items that need to carry OpenTelemetry context. - - Used to preserve tracing context across async boundaries in MCP sessions, - ensuring that distributed traces remain connected even when messages are - processed asynchronously. - - Attributes: - item: The original item being wrapped - ctx: The OpenTelemetry context associated with the item - """ - - item: Any - ctx: context.Context - - -def mcp_instrumentation() -> None: - """Apply OpenTelemetry instrumentation patches to MCP components. - - This function instruments three key areas of MCP communication: - 1. Client-side: Injects tracing context into tool call requests - 2. Transport-level: Extracts context from incoming messages - 3. Session-level: Manages bidirectional context flow - - The patches enable distributed tracing by: - - Adding OpenTelemetry context to the _meta field of MCP requests - - Extracting and activating context on the server side - - Preserving context across async message processing boundaries - - This function is idempotent - multiple calls will not accumulate wrappers. - """ - global _instrumentation_applied - - # Return early if instrumentation has already been applied - if _instrumentation_applied: - return - - def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any) -> Any: - """Patch MCP client to inject OpenTelemetry context into tool calls. - - Intercepts outgoing MCP requests and injects the current OpenTelemetry - context into the request's _meta field for tools/call methods. This - enables server-side context extraction and trace continuation. - - Args: - wrapped: The original function being wrapped - instance: The instance the method is being called on - args: Positional arguments to the wrapped function - kwargs: Keyword arguments to the wrapped function - - Returns: - Result of the wrapped function call - """ - if len(args) < 1: - return wrapped(*args, **kwargs) - - request = args[0] - method = getattr(request.root, "method", None) - - if method != "tools/call": - return wrapped(*args, **kwargs) - - try: - if hasattr(request.root, "params") and request.root.params: - # Handle Pydantic models - if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"): - params_dict = request.root.params.model_dump(by_alias=True) - # Add _meta with tracing context - meta = params_dict.get("_meta") if params_dict.get("_meta") is not None else {} - params_dict["_meta"] = meta - propagate.get_global_textmap().inject(meta) - - # Recreate the Pydantic model with the updated data - # This preserves the original model type and avoids serialization warnings - params_class = type(request.root.params) - try: - request.root.params = params_class.model_validate(params_dict) - except Exception: - # Fallback to dict if model recreation fails - request.root.params = params_dict - - elif isinstance(request.root.params, dict): - # Handle dict params directly - meta = request.root.params.setdefault("_meta", {}) - propagate.get_global_textmap().inject(meta) - - return wrapped(*args, **kwargs) - - except Exception: - return wrapped(*args, **kwargs) - - def transport_wrapper() -> Callable[ - [Callable[..., Any], Any, Any, Any], _AsyncGeneratorContextManager[tuple[Any, Any]] - ]: - """Create a wrapper for MCP transport connections. - - Returns a context manager that wraps transport read/write streams - with context extraction capabilities. The wrapped reader will - automatically extract OpenTelemetry context from incoming messages. - - Returns: - An async context manager that yields wrapped transport streams - """ - - @asynccontextmanager - async def traced_method( - wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any - ) -> AsyncGenerator[tuple[Any, Any], None]: - async with wrapped(*args, **kwargs) as result: - try: - read_stream, write_stream = result - except ValueError: - read_stream, write_stream, _ = result - yield TransportContextExtractingReader(read_stream), write_stream - - return traced_method - - def session_init_wrapper() -> Callable[[Any, Any, tuple[Any, ...], dict[str, Any]], None]: - """Create a wrapper for MCP session initialization. - - Wraps session message streams to enable bidirectional context flow. - The reader extracts and activates context, while the writer preserves - context for async processing. - - Returns: - A function that wraps session initialization - """ - - def traced_method( - wrapped: Callable[..., Any], instance: Any, args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> None: - wrapped(*args, **kwargs) - reader = getattr(instance, "_incoming_message_stream_reader", None) - writer = getattr(instance, "_incoming_message_stream_writer", None) - if reader and writer: - instance._incoming_message_stream_reader = SessionContextAttachingReader(reader) - instance._incoming_message_stream_writer = SessionContextSavingWriter(writer) - - return traced_method - - # Apply patches - wrap_function_wrapper("mcp.shared.session", "BaseSession.send_request", patch_mcp_client) - - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.server.streamable_http", "StreamableHTTPServerTransport.connect", transport_wrapper() - ), - "mcp.server.streamable_http", - ) - - register_post_import_hook( - lambda _: wrap_function_wrapper("mcp.server.session", "ServerSession.__init__", session_init_wrapper()), - "mcp.server.session", - ) - - # Mark instrumentation as applied - _instrumentation_applied = True - - -class TransportContextExtractingReader(ObjectProxy): - """A proxy reader that extracts OpenTelemetry context from MCP messages. - - Wraps an async message stream reader to automatically extract and activate - OpenTelemetry context from the _meta field of incoming MCP requests. This - enables server-side trace continuation from client-injected context. - - The reader handles both SessionMessage and JSONRPCMessage formats, and - supports both dict and Pydantic model parameter structures. - """ - - def __init__(self, wrapped: Any) -> None: - """Initialize the context-extracting reader. - - Args: - wrapped: The original async stream reader to wrap - """ - super().__init__(wrapped) - - async def __aenter__(self) -> Any: - """Enter the async context manager by delegating to the wrapped object.""" - return await self.__wrapped__.__aenter__() - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: - """Exit the async context manager by delegating to the wrapped object.""" - return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) - - async def __aiter__(self) -> AsyncGenerator[Any, None]: - """Iterate over messages, extracting and activating context as needed. - - For each incoming message, checks if it contains tracing context in - the _meta field. If found, extracts and activates the context for - the duration of message processing, then properly detaches it. - - Yields: - Messages from the wrapped stream, processed under the appropriate - OpenTelemetry context - """ - async for item in self.__wrapped__: - if isinstance(item, SessionMessage): - request = item.message.root - elif type(item) is JSONRPCMessage: - request = item.root - else: - yield item - continue - - if isinstance(request, JSONRPCRequest) and request.params: - # Handle both dict and Pydantic model params - if hasattr(request.params, "get"): - # Dict-like access - meta = request.params.get("_meta") - elif hasattr(request.params, "_meta"): - # Direct attribute access for Pydantic models - meta = getattr(request.params, "_meta", None) - else: - meta = None - - if meta: - extracted_context = propagate.extract(meta) - restore = context.attach(extracted_context) - try: - yield item - continue - finally: - context.detach(restore) - yield item - - -class SessionContextSavingWriter(ObjectProxy): - """A proxy writer that preserves OpenTelemetry context with outgoing items. - - Wraps an async message stream writer to capture the current OpenTelemetry - context and associate it with outgoing items. This enables context - preservation across async boundaries in MCP session processing. - """ - - def __init__(self, wrapped: Any) -> None: - """Initialize the context-saving writer. - - Args: - wrapped: The original async stream writer to wrap - """ - super().__init__(wrapped) - - async def __aenter__(self) -> Any: - """Enter the async context manager by delegating to the wrapped object.""" - return await self.__wrapped__.__aenter__() - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: - """Exit the async context manager by delegating to the wrapped object.""" - return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) - - async def send(self, item: Any) -> Any: - """Send an item while preserving the current OpenTelemetry context. - - Captures the current context and wraps the item with it, enabling - the receiving side to restore the appropriate tracing context. - - Args: - item: The item to send through the stream - - Returns: - Result of sending the wrapped item - """ - ctx = context.get_current() - return await self.__wrapped__.send(ItemWithContext(item, ctx)) - - -class SessionContextAttachingReader(ObjectProxy): - """A proxy reader that restores OpenTelemetry context from wrapped items. - - Wraps an async message stream reader to detect ItemWithContext instances - and restore their associated OpenTelemetry context during processing. - This completes the context preservation cycle started by SessionContextSavingWriter. - """ - - def __init__(self, wrapped: Any) -> None: - """Initialize the context-attaching reader. - - Args: - wrapped: The original async stream reader to wrap - """ - super().__init__(wrapped) - - async def __aenter__(self) -> Any: - """Enter the async context manager by delegating to the wrapped object.""" - return await self.__wrapped__.__aenter__() - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: - """Exit the async context manager by delegating to the wrapped object.""" - return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) - - async def __aiter__(self) -> AsyncGenerator[Any, None]: - """Iterate over items, restoring context for ItemWithContext instances. - - For items wrapped with context, temporarily activates the associated - OpenTelemetry context during processing, then properly detaches it. - Regular items are yielded without context modification. - - Yields: - Unwrapped items processed under their associated OpenTelemetry context - """ - async for item in self.__wrapped__: - if isinstance(item, ItemWithContext): - restore = context.attach(item.ctx) - try: - yield item.item - finally: - context.detach(restore) - else: - yield item +warnings.warn( + "strands.tools.mcp.mcp_instrumentation has moved to strands.mcp.mcp_instrumentation. " + "Import from strands.mcp.mcp_instrumentation instead; strands.tools.mcp will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) diff --git a/src/strands/tools/mcp/mcp_tasks.py b/src/strands/tools/mcp/mcp_tasks.py index 36537f7df..d7761141b 100644 --- a/src/strands/tools/mcp/mcp_tasks.py +++ b/src/strands/tools/mcp/mcp_tasks.py @@ -1,33 +1,18 @@ -"""Task-augmented tool execution configuration for MCP. - -This module provides configuration types and defaults for the experimental MCP Tasks feature. -""" - -from datetime import timedelta - -from typing_extensions import TypedDict - - -class TasksConfig(TypedDict, total=False): - """Configuration for MCP Tasks (task-augmented tool execution). - - When enabled, supported tool calls use the MCP task workflow: - create task -> poll for completion -> get result. - - Warning: - This is an experimental feature in the 2025-11-25 MCP specification and - both the specification and the Strands Agents implementation of this - feature are subject to change. - - Attributes: - ttl: Task time-to-live. Defaults to 1 minute. - poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. - """ - - ttl: timedelta - poll_timeout: timedelta - - -DEFAULT_TASK_TTL = timedelta(minutes=1) -DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) -DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) +"""Deprecated: moved to ``strands.mcp.mcp_tasks``.""" + +import warnings + +from ...mcp.mcp_tasks import * # noqa: F401, F403 +from ...mcp.mcp_tasks import ( # noqa: F401 + DEFAULT_TASK_CONFIG, + DEFAULT_TASK_POLL_TIMEOUT, + DEFAULT_TASK_TTL, + TasksConfig, +) + +warnings.warn( + "strands.tools.mcp.mcp_tasks has moved to strands.mcp.mcp_tasks. " + "Import from strands.mcp.mcp_tasks instead; strands.tools.mcp will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 8fbf573be..187b8f2eb 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,67 +1,13 @@ -"""Type definitions for MCP integration.""" +"""Deprecated: moved to ``strands.mcp.mcp_types``.""" -from contextlib import AbstractAsyncContextManager -from typing import Any +import warnings -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.streamable_http import GetSessionIdCallback -from mcp.shared.memory import MessageStream -from mcp.shared.message import SessionMessage -from typing_extensions import NotRequired +from ...mcp.mcp_types import * # noqa: F401, F403 +from ...mcp.mcp_types import MCPToolResult, MCPTransport # noqa: F401 -from ...types.tools import ToolResult - -""" -MCPTransport defines the interface for MCP transport implementations. This abstracts -communication with an MCP server, hiding details of the underlying transport mechanism (WebSocket, stdio, etc.). - -It represents an async context manager that yields a tuple of read and write streams for MCP communication. -When used with `async with`, it should establish the connection and yield the streams, then clean up -when the context is exited. - -The read stream receives messages from the client (or exceptions if parsing fails), while the write -stream sends messages to the client. - -Example implementation (simplified): -```python -@contextlib.asynccontextmanager -async def my_transport_implementation(): - # Set up connection - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - # Start background tasks to handle actual I/O - async with anyio.create_task_group() as tg: - tg.start_soon(reader_task, read_stream_writer) - tg.start_soon(writer_task, write_stream_reader) - - # Yield the streams to the caller - yield (read_stream, write_stream) -``` -""" -# GetSessionIdCallback was added for HTTP Streaming but was not applied to the MessageStream type -# https://github.com/modelcontextprotocol/python-sdk/blob/ed25167fa5d715733437996682e20c24470e8177/src/mcp/client/streamable_http.py#L418 -_MessageStreamWithGetSessionIdCallback = tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback -] -MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] - - -class MCPToolResult(ToolResult): - """Result of an MCP tool execution. - - Extends the base ToolResult with MCP-specific structured content support. - The structuredContent field contains optional JSON data returned by MCP tools - that provides structured results beyond the standard text/image/document content. - - Attributes: - structuredContent: Optional JSON object containing structured data returned - by the MCP tool. This allows MCP tools to return complex data structures - that can be processed programmatically by agents or other tools. - metadata: Optional arbitrary metadata returned by the MCP tool. This field allows - MCP servers to attach custom metadata to tool results (e.g., token usage, - performance metrics, or business-specific tracking information). - """ - - structuredContent: NotRequired[dict[str, Any]] - metadata: NotRequired[dict[str, Any]] +warnings.warn( + "strands.tools.mcp.mcp_types has moved to strands.mcp.mcp_types. " + "Import from strands.mcp.mcp_types instead; strands.tools.mcp will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) diff --git a/tests/strands/mcp/__init__.py b/tests/strands/mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/mcp/conftest.py similarity index 95% rename from tests/strands/tools/mcp/conftest.py rename to tests/strands/mcp/conftest.py index d0ac46bdc..30aad8941 100644 --- a/tests/strands/tools/mcp/conftest.py +++ b/tests/strands/mcp/conftest.py @@ -37,7 +37,7 @@ def mock_session(): mock_session_cm.__aenter__.return_value = mock_session # Patch ClientSession to return our mock session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + with patch("strands.mcp.mcp_client.ClientSession", return_value=mock_session_cm): yield mock_session diff --git a/tests/strands/mcp/test_canonical_import_path.py b/tests/strands/mcp/test_canonical_import_path.py new file mode 100644 index 000000000..0194d3bee --- /dev/null +++ b/tests/strands/mcp/test_canonical_import_path.py @@ -0,0 +1,25 @@ +"""Tests for the canonical ``strands.mcp`` import path. + +The implementation currently lives in ``strands.tools.mcp``. This test +locks in the contract that ``strands.mcp`` re-exports the same objects so +that users can migrate imports ahead of the follow-up refactor that +moves the implementation. +""" + + +def test_strands_mcp_reexports_public_api() -> None: + import strands.mcp as new + import strands.tools.mcp as old + + assert new.MCPClient is old.MCPClient + assert new.MCPAgentTool is old.MCPAgentTool + assert new.MCPTransport is old.MCPTransport + assert new.TasksConfig is old.TasksConfig + assert new.ToolFilters is old.ToolFilters + + +def test_strands_mcp_all_matches_tools_mcp_all() -> None: + import strands.mcp as new + import strands.tools.mcp as old + + assert sorted(new.__all__) == sorted(old.__all__) diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/mcp/test_mcp_agent_tool.py similarity index 98% rename from tests/strands/tools/mcp/test_mcp_agent_tool.py rename to tests/strands/mcp/test_mcp_agent_tool.py index 81a2d9afb..fe11bd04a 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/mcp/test_mcp_agent_tool.py @@ -4,7 +4,7 @@ import pytest from mcp.types import Tool as MCPTool -from strands.tools.mcp import MCPAgentTool, MCPClient +from strands.mcp import MCPAgentTool, MCPClient from strands.types._events import ToolResultEvent diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/mcp/test_mcp_client.py similarity index 99% rename from tests/strands/tools/mcp/test_mcp_client.py rename to tests/strands/mcp/test_mcp_client.py index bf0e7ce8e..1de113d3c 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/mcp/test_mcp_client.py @@ -21,8 +21,8 @@ from mcp.types import Tool as MCPTool from pydantic import AnyUrl -from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_types import MCPToolResult +from strands.mcp import MCPClient +from strands.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError # Fixtures mock_transport and mock_session are imported from conftest.py diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/mcp/test_mcp_client_contextvar.py similarity index 97% rename from tests/strands/tools/mcp/test_mcp_client_contextvar.py rename to tests/strands/mcp/test_mcp_client_contextvar.py index 1770a050a..4d0f65a14 100644 --- a/tests/strands/tools/mcp/test_mcp_client_contextvar.py +++ b/tests/strands/mcp/test_mcp_client_contextvar.py @@ -12,7 +12,7 @@ import pytest -from strands.tools.mcp import MCPClient +from strands.mcp import MCPClient @pytest.fixture @@ -43,7 +43,7 @@ def mock_session(): mock_session_cm = AsyncMock() mock_session_cm.__aenter__.return_value = mock_session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + with patch("strands.mcp.mcp_client.ClientSession", return_value=mock_session_cm): yield mock_session diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/mcp/test_mcp_client_tasks.py similarity index 98% rename from tests/strands/tools/mcp/test_mcp_client_tasks.py rename to tests/strands/mcp/test_mcp_client_tasks.py index d566ac6f5..36d1a705b 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/mcp/test_mcp_client_tasks.py @@ -11,8 +11,8 @@ from mcp.types import Tool as MCPTool from mcp.types import ToolExecution -from strands.tools.mcp import MCPClient, TasksConfig -from strands.tools.mcp.mcp_tasks import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL +from strands.mcp import MCPClient, TasksConfig +from strands.mcp.mcp_tasks import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL from .conftest import create_server_capabilities diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/mcp/test_mcp_client_tool_provider.py similarity index 96% rename from tests/strands/tools/mcp/test_mcp_client_tool_provider.py rename to tests/strands/mcp/test_mcp_client_tool_provider.py index 9cb90167d..006b41cb3 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/mcp/test_mcp_client_tool_provider.py @@ -6,9 +6,9 @@ import pytest from mcp.types import Tool as MCPTool -from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_agent_tool import MCPAgentTool -from strands.tools.mcp.mcp_client import ToolFilters +from strands.mcp import MCPClient +from strands.mcp.mcp_agent_tool import MCPAgentTool +from strands.mcp.mcp_client import ToolFilters from strands.types import PaginatedList from strands.types.exceptions import ToolProviderException @@ -257,7 +257,7 @@ async def test_prefix_renames_tools(mock_transport): with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -391,7 +391,7 @@ def test_list_tools_sync_prefix_override_constructor_default(mock_transport): with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -434,7 +434,7 @@ def test_list_tools_sync_prefix_override_with_empty_string(mock_transport): with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -477,7 +477,7 @@ def test_list_tools_sync_prefix_uses_constructor_default_when_none(mock_transpor with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -521,7 +521,7 @@ def test_list_tools_sync_tool_filters_override_constructor_default(mock_transpor with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -562,7 +562,7 @@ def test_list_tools_sync_tool_filters_override_with_empty_dict(mock_transport): with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -603,7 +603,7 @@ def test_list_tools_sync_tool_filters_uses_constructor_default_when_none(mock_tr with ( patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): # Mock the MCP server response mock_list_tools_result = MagicMock() @@ -645,7 +645,7 @@ def test_list_tools_sync_combined_prefix_and_filter_overrides(mock_transport): with ( patch.object(client, "_is_session_active", return_value=True), patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): mock_future = MagicMock() mock_future.result.return_value = mock_result @@ -700,7 +700,7 @@ def test_list_tools_sync_direct_usage_without_constructor_defaults(mock_transpor with ( patch.object(client, "_is_session_active", return_value=True), patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): mock_future = MagicMock() mock_future.result.return_value = mock_result @@ -754,7 +754,7 @@ def test_list_tools_sync_regex_filter_override(mock_transport): with ( patch.object(client, "_is_session_active", return_value=True), patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): mock_future = MagicMock() mock_future.result.return_value = mock_result @@ -798,7 +798,7 @@ def test_list_tools_sync_callable_filter_override(mock_transport): with ( patch.object(client, "_is_session_active", return_value=True), patch.object(client, "_invoke_on_background_thread") as mock_invoke, - patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + patch("strands.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): mock_future = MagicMock() mock_future.result.return_value = mock_result diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/mcp/test_mcp_instrumentation.py similarity index 94% rename from tests/strands/tools/mcp/test_mcp_instrumentation.py rename to tests/strands/mcp/test_mcp_instrumentation.py index 9d44bba0c..9ef5face1 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/mcp/test_mcp_instrumentation.py @@ -5,8 +5,8 @@ from mcp.types import JSONRPCMessage, JSONRPCRequest from opentelemetry import context, propagate -from strands.tools.mcp.mcp_client import MCPClient -from strands.tools.mcp.mcp_instrumentation import ( +from strands.mcp.mcp_client import MCPClient +from strands.mcp.mcp_instrumentation import ( ItemWithContext, SessionContextAttachingReader, SessionContextSavingWriter, @@ -18,7 +18,7 @@ @pytest.fixture(autouse=True) def reset_mcp_instrumentation(): """Reset MCP instrumentation state before each test.""" - import strands.tools.mcp.mcp_instrumentation as mcp_inst + import strands.mcp.mcp_instrumentation as mcp_inst mcp_inst._instrumentation_applied = False yield @@ -342,7 +342,7 @@ def __getattr__(self, name): class TestMCPInstrumentation: def test_mcp_instrumentation_called_on_client_init(self): """Test that mcp_instrumentation is called when MCPClient is initialized.""" - with patch("strands.tools.mcp.mcp_client.mcp_instrumentation") as mock_instrumentation: + with patch("strands.mcp.mcp_client.mcp_instrumentation") as mock_instrumentation: # Mock transport def mock_transport(): read_stream = AsyncMock() @@ -359,7 +359,7 @@ def test_mcp_instrumentation_idempotent_with_multiple_clients(self): """Test that mcp_instrumentation is only called once even with multiple MCPClient instances.""" # Mock the wrap_function_wrapper to count calls - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: # Mock transport def mock_transport(): read_stream = AsyncMock() @@ -379,8 +379,8 @@ def mock_transport(): def test_mcp_instrumentation_calls_wrap_function_wrapper(self): """Test that mcp_instrumentation calls the expected wrapper functions.""" with ( - patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap, - patch("strands.tools.mcp.mcp_instrumentation.register_post_import_hook") as mock_register, + patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap, + patch("strands.mcp.mcp_instrumentation.register_post_import_hook") as mock_register, ): mcp_instrumentation() @@ -410,7 +410,7 @@ def test_patch_mcp_client_injects_context_pydantic_model(self): mock_request.root.params = mock_params # Create the patch function - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: mcp_instrumentation() patch_function = mock_wrap.call_args_list[0][0][2] @@ -440,7 +440,7 @@ def test_patch_mcp_client_preserves_existing_meta_pydantic(self): mock_params = MockPydanticParams(_meta={"com.example/request_id": "abc-123"}, name="echo") mock_request.root.params = mock_params - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: mcp_instrumentation() patch_function = mock_wrap.call_args_list[0][0][2] @@ -465,7 +465,7 @@ def test_patch_mcp_client_injects_context_dict_params(self): mock_request.root.params = {"existing": "param"} # Create the patch function - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: mcp_instrumentation() patch_function = mock_wrap.call_args_list[0][0][2] @@ -491,7 +491,7 @@ def test_patch_mcp_client_skips_non_tools_call(self): mock_request = MagicMock() mock_request.root.method = "other/method" - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: mcp_instrumentation() patch_function = mock_wrap.call_args_list[0][0][2] @@ -515,7 +515,7 @@ def test_patch_mcp_client_handles_exception_gracefully(self): mock_request.root.params = MagicMock() mock_request.root.params.model_dump.side_effect = Exception("Test exception") - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: mcp_instrumentation() patch_function = mock_wrap.call_args_list[0][0][2] @@ -546,7 +546,7 @@ def model_validate(self, data): failing_params = FailingMockPydanticParams(existing="param") mock_request.root.params = failing_params - with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + with patch("strands.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: mcp_instrumentation() patch_function = mock_wrap.call_args_list[0][0][2] diff --git a/tests/strands/tools/mcp/test_deprecated_aliases.py b/tests/strands/tools/mcp/test_deprecated_aliases.py new file mode 100644 index 000000000..b09f46b9e --- /dev/null +++ b/tests/strands/tools/mcp/test_deprecated_aliases.py @@ -0,0 +1,86 @@ +"""Tests for the backwards-compatible aliases at ``strands.tools.mcp``. + +The MCP integration moved to ``strands.mcp`` but the old import paths must +continue to work and emit ``DeprecationWarning`` until a future release +removes them. +""" + +import importlib +import sys +import warnings + + +def _reimport(module_name: str) -> None: + """Force a fresh import so the module-level warnings fire again.""" + sys.modules.pop(module_name, None) + importlib.import_module(module_name) + + +def test_package_import_emits_deprecation_warning() -> None: + _reimport("strands.tools.mcp") + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + _reimport("strands.tools.mcp") + + deprecations = [w for w in caught if issubclass(w.category, DeprecationWarning)] + assert len(deprecations) == 1 + assert "strands.mcp" in str(deprecations[0].message) + + +def test_package_reexports_public_names() -> None: + from strands import mcp as new_mcp + from strands.tools import mcp as old_mcp + + for name in ("MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"): + assert getattr(old_mcp, name) is getattr(new_mcp, name) + + +def test_submodule_imports_still_work() -> None: + """Legacy submodule paths like ``strands.tools.mcp.mcp_client`` must resolve.""" + from strands.mcp import mcp_agent_tool as new_agent_tool + from strands.mcp import mcp_client as new_client + from strands.mcp import mcp_instrumentation as new_instrumentation + from strands.mcp import mcp_tasks as new_tasks + from strands.mcp import mcp_types as new_types + from strands.tools.mcp import mcp_agent_tool as old_agent_tool + from strands.tools.mcp import mcp_client as old_client + from strands.tools.mcp import mcp_instrumentation as old_instrumentation + from strands.tools.mcp import mcp_tasks as old_tasks + from strands.tools.mcp import mcp_types as old_types + + assert old_client.MCPClient is new_client.MCPClient + assert old_client.ToolFilters is new_client.ToolFilters + assert old_agent_tool.MCPAgentTool is new_agent_tool.MCPAgentTool + assert old_types.MCPTransport is new_types.MCPTransport + assert old_types.MCPToolResult is new_types.MCPToolResult + assert old_tasks.TasksConfig is new_tasks.TasksConfig + assert old_tasks.DEFAULT_TASK_POLL_TIMEOUT is new_tasks.DEFAULT_TASK_POLL_TIMEOUT + assert old_tasks.DEFAULT_TASK_TTL is new_tasks.DEFAULT_TASK_TTL + assert old_instrumentation.mcp_instrumentation is new_instrumentation.mcp_instrumentation + + +def test_new_path_does_not_emit_deprecation_warning() -> None: + for name in ( + "strands.mcp", + "strands.mcp.mcp_client", + "strands.mcp.mcp_agent_tool", + "strands.mcp.mcp_types", + "strands.mcp.mcp_tasks", + "strands.mcp.mcp_instrumentation", + ): + sys.modules.pop(name, None) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + _reimport("strands.mcp") + _reimport("strands.mcp.mcp_client") + _reimport("strands.mcp.mcp_agent_tool") + _reimport("strands.mcp.mcp_types") + _reimport("strands.mcp.mcp_tasks") + _reimport("strands.mcp.mcp_instrumentation") + + from_strands_mcp = [ + w for w in caught if "strands.mcp" in str(w.message) and "strands.tools.mcp" not in str(w.message) + ] + deprecations = [w for w in from_strands_mcp if issubclass(w.category, DeprecationWarning)] + assert deprecations == [] diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 3723f381b..87ee6181e 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -7,9 +7,9 @@ import pytest import strands +from strands.mcp import MCPClient from strands.tools import PythonAgentTool, ToolProvider from strands.tools.decorator import DecoratedFunctionTool, tool -from strands.tools.mcp import MCPClient from strands.tools.registry import ToolRegistry diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index fe2b10df3..ef7ef3955 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -12,8 +12,8 @@ from mcp.types import ImageContent as MCPImageContent from strands import Agent -from strands.tools.mcp.mcp_client import MCPClient -from strands.tools.mcp.mcp_types import MCPTransport +from strands.mcp.mcp_client import MCPClient +from strands.mcp.mcp_types import MCPTransport from strands.types.content import Message from strands.types.exceptions import MCPClientInitializationError from strands.types.tools import ToolUse diff --git a/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py b/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py index 3e6132b38..e95a2d0c9 100644 --- a/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py +++ b/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py @@ -10,7 +10,7 @@ from strands import Agent from strands.hooks import AfterToolCallEvent, HookProvider, HookRegistry -from strands.tools.mcp.mcp_client import MCPClient +from strands.mcp.mcp_client import MCPClient class ToolResultCapture(HookProvider): diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index 751fb655f..3ed64e825 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -9,7 +9,7 @@ import pytest from mcp.client.streamable_http import streamablehttp_client -from strands.tools.mcp import MCPClient, MCPTransport, TasksConfig +from strands.mcp import MCPClient, MCPTransport, TasksConfig def _find_available_port() -> int: diff --git a/tests_integ/mcp/test_mcp_elicitation.py b/tests_integ/mcp/test_mcp_elicitation.py index 794ecbb98..2e4cd44a2 100644 --- a/tests_integ/mcp/test_mcp_elicitation.py +++ b/tests_integ/mcp/test_mcp_elicitation.py @@ -5,7 +5,7 @@ from mcp.types import ElicitResult from strands import Agent -from strands.tools.mcp import MCPClient +from strands.mcp import MCPClient @pytest.fixture diff --git a/tests_integ/mcp/test_mcp_output_schema.py b/tests_integ/mcp/test_mcp_output_schema.py index 69ef3cd3c..b522296cb 100644 --- a/tests_integ/mcp/test_mcp_output_schema.py +++ b/tests_integ/mcp/test_mcp_output_schema.py @@ -2,7 +2,7 @@ from mcp import StdioServerParameters, stdio_client -from strands.tools.mcp.mcp_client import MCPClient +from strands.mcp.mcp_client import MCPClient from .echo_server import EchoResponse diff --git a/tests_integ/mcp/test_mcp_resources.py b/tests_integ/mcp/test_mcp_resources.py index dccf3b808..cd734133a 100644 --- a/tests_integ/mcp/test_mcp_resources.py +++ b/tests_integ/mcp/test_mcp_resources.py @@ -18,7 +18,7 @@ from mcp.types import BlobResourceContents, TextResourceContents from pydantic import AnyUrl -from strands.tools.mcp.mcp_client import MCPClient +from strands.mcp.mcp_client import MCPClient def test_mcp_resources_list_and_read(): diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py index 7914bb326..812e1de4f 100644 --- a/tests_integ/mcp/test_mcp_tool_provider.py +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -7,8 +7,8 @@ from mcp import StdioServerParameters, stdio_client from strands import Agent -from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_client import ToolFilters +from strands.mcp import MCPClient +from strands.mcp.mcp_client import ToolFilters logging.basicConfig(level=logging.DEBUG)