diff --git a/google/genai/_mcp_utils.py b/google/genai/_mcp_utils.py index 424727f06..7bd2c5fc9 100644 --- a/google/genai/_mcp_utils.py +++ b/google/genai/_mcp_utils.py @@ -15,10 +15,17 @@ """Utils for working with MCP tools.""" +import asyncio +import contextlib +import httpx + from importlib.metadata import PackageNotFoundError, version import typing from typing import Any +import google.auth +from google.auth.transport.requests import Request + from . import _common from . import types @@ -28,12 +35,19 @@ else: McpClientSession: typing.Type = Any McpTool: typing.Type = Any + streamable_http_client: Any = None + create_mcp_http_client: Any = None + try: from mcp.types import Tool as McpTool from mcp import ClientSession as McpClientSession + from mcp.client.streamable_http import streamable_http_client + from mcp.shared._httpx_utils import create_mcp_http_client except ImportError: McpTool = None McpClientSession = None + streamable_http_client = None + create_mcp_http_client = None def mcp_to_gemini_tool(tool: McpTool) -> types.Tool: @@ -146,3 +160,84 @@ def _filter_to_supported_schema( return filtered_schema + +def _fetch_agent_platform_token_sync() -> str: + """Synchronously fetches and refreshes the Google Cloud credentials.""" + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + credentials, _ = google.auth.default(scopes=scopes) # type: ignore + credentials.refresh(Request()) # type: ignore + return credentials.token # type: ignore + + +@contextlib.asynccontextmanager +async def _connect_agent_platform_mcp(api_client: Any, toolset_name: str): + """Internal helper to manage the Vertex MCP lifecycle per request.""" + if streamable_http_client is None: + raise ImportError( + "The 'mcp' package is required to use Vertex MCP servers." + ) + + base_url = None + if hasattr(api_client, '_http_options') and hasattr(api_client._http_options, 'base_url'): + base_url = api_client._http_options.base_url + + if base_url: + if base_url.endswith('/'): + base_url = base_url[:-1] + mcp_url = f"{base_url}/mcp/{toolset_name}" + else: + location = getattr(api_client, 'location', 'us-central1') + if location == 'global': + mcp_url = f"https://aiplatform.googleapis.com/mcp/{toolset_name}" + else: + mcp_url = f"https://{location}-aiplatform.googleapis.com/mcp/{toolset_name}" + + + token = await asyncio.to_thread(_fetch_agent_platform_token_sync) + project = getattr(api_client, "project", None) + + headers = { + "Authorization": f"Bearer {token}", + } + if project: + headers["X-Goog-User-Project"] = project + + set_mcp_usage_header(headers) + + timeout = httpx.Timeout(30.0, read=300.0) + http_client = httpx.AsyncClient(headers=headers, timeout=timeout) + + try: + async with http_client: + async with streamable_http_client(url=mcp_url, http_client=http_client) as streams: + read_stream, write_stream, _ = streams + async with McpClientSession(read_stream, write_stream) as session: + await session.initialize() + try: + yield session + except GeneratorExit: + return + + except BaseExceptionGroup as eg: + + error_messages = [] + + def _extract_errors(exc): + # Handle potentially nested ExceptionGroups + if hasattr(exc, 'exceptions'): + for e in exc.exceptions: + _extract_errors(e) + else: + msg = f"{type(exc).__name__}: {str(exc)}" + if hasattr(exc, 'response') and exc.response is not None: + msg += f" (HTTP {exc.response.status_code}: {exc.response.text})" + error_messages.append(msg) + + _extract_errors(eg) + + raise ValueError( + f"Failed to connect to Vertex MCP Server at {mcp_url}.\n" + f"Underlying errors: {error_messages}" + ) from eg + +