Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions google/genai/_mcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@

"""Utils for working with MCP tools."""

import asyncio
import contextlib
import httpx

from importlib.metadata import PackageNotFoundError, version
import typing
from typing import Any

import google.auth
from google.auth.transport.requests import Request

from . import _common
from . import types

Expand All @@ -28,12 +35,19 @@
else:
McpClientSession: typing.Type = Any
McpTool: typing.Type = Any
streamable_http_client: Any = None
create_mcp_http_client: Any = None

try:
from mcp.types import Tool as McpTool
from mcp import ClientSession as McpClientSession
from mcp.client.streamable_http import streamable_http_client
from mcp.shared._httpx_utils import create_mcp_http_client
except ImportError:
McpTool = None
McpClientSession = None
streamable_http_client = None
create_mcp_http_client = None


def mcp_to_gemini_tool(tool: McpTool) -> types.Tool:
Expand Down Expand Up @@ -146,3 +160,84 @@ def _filter_to_supported_schema(

return filtered_schema


def _fetch_agent_platform_token_sync() -> str:
"""Synchronously fetches and refreshes the Google Cloud credentials."""
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
credentials, _ = google.auth.default(scopes=scopes) # type: ignore
credentials.refresh(Request()) # type: ignore
return credentials.token # type: ignore


@contextlib.asynccontextmanager
async def _connect_agent_platform_mcp(api_client: Any, toolset_name: str):
"""Internal helper to manage the Vertex MCP lifecycle per request."""
if streamable_http_client is None:
raise ImportError(
"The 'mcp' package is required to use Vertex MCP servers."
)

base_url = None
if hasattr(api_client, '_http_options') and hasattr(api_client._http_options, 'base_url'):
base_url = api_client._http_options.base_url

if base_url:
if base_url.endswith('/'):
base_url = base_url[:-1]
mcp_url = f"{base_url}/mcp/{toolset_name}"
else:
location = getattr(api_client, 'location', 'us-central1')
if location == 'global':
mcp_url = f"https://aiplatform.googleapis.com/mcp/{toolset_name}"
else:
mcp_url = f"https://{location}-aiplatform.googleapis.com/mcp/{toolset_name}"


token = await asyncio.to_thread(_fetch_agent_platform_token_sync)
project = getattr(api_client, "project", None)

headers = {
"Authorization": f"Bearer {token}",
}
if project:
headers["X-Goog-User-Project"] = project

set_mcp_usage_header(headers)

timeout = httpx.Timeout(30.0, read=300.0)
http_client = httpx.AsyncClient(headers=headers, timeout=timeout)

try:
async with http_client:
async with streamable_http_client(url=mcp_url, http_client=http_client) as streams:
read_stream, write_stream, _ = streams
async with McpClientSession(read_stream, write_stream) as session:
await session.initialize()
try:
yield session
except GeneratorExit:
return

except BaseExceptionGroup as eg:

error_messages = []

def _extract_errors(exc):
# Handle potentially nested ExceptionGroups
if hasattr(exc, 'exceptions'):
for e in exc.exceptions:
_extract_errors(e)
else:
msg = f"{type(exc).__name__}: {str(exc)}"
if hasattr(exc, 'response') and exc.response is not None:
msg += f" (HTTP {exc.response.status_code}: {exc.response.text})"
error_messages.append(msg)

_extract_errors(eg)

raise ValueError(
f"Failed to connect to Vertex MCP Server at {mcp_url}.\n"
f"Underlying errors: {error_messages}"
) from eg


Loading