From ca4105ff2430154284649e21f7fa67499f6f6a3e Mon Sep 17 00:00:00 2001 From: stevessr Date: Wed, 22 Apr 2026 12:31:33 +0800 Subject: [PATCH 1/2] feat: add OAuth 2.0 support for MCP servers in WebUI - Implemented OAuth 2.0 authorization flow for MCP servers, allowing users to log in directly through the WebUI. - Added new API endpoints for starting OAuth authorization, checking flow status, and handling callbacks. - Updated the MCP server configuration to include OAuth 2.0 settings, supporting both authorization code and client credentials grant types. - Enhanced the dashboard UI to display OAuth status and provide a QR code for mobile login. - Updated localization files to include new strings related to OAuth functionality. - Added documentation for configuring OAuth 2.0 for MCP servers. --- astrbot/core/agent/mcp_client.py | 12 +- astrbot/core/agent/mcp_oauth.py | 709 ++++++++++++++++++ astrbot/core/provider/func_tool_manager.py | 45 +- astrbot/dashboard/routes/tools.py | 145 +++- .../extension/McpServersSection.vue | 344 ++++++++- .../i18n/locales/en-US/features/tool-use.json | 22 +- .../i18n/locales/ru-RU/features/tool-use.json | 22 +- .../i18n/locales/zh-CN/features/tool-use.json | 22 +- docs/en/use/mcp.md | 42 ++ docs/zh/use/mcp.md | 42 ++ 10 files changed, 1355 insertions(+), 50 deletions(-) create mode 100644 astrbot/core/agent/mcp_oauth.py diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index b75999ea65..296cbcc392 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -21,6 +21,7 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe +from .mcp_oauth import create_mcp_http_auth, has_mcp_oauth_config from .run_context import TContext from .tool import FunctionTool @@ -428,9 +429,12 @@ def logging_callback( self.server_errlogs.append(log_msg) if "url" in cfg: - success, error_msg = await _quick_test_mcp_connection(cfg) - if not success: - raise Exception(error_msg) + auth = await create_mcp_http_auth(cfg) + + if not has_mcp_oauth_config(cfg): + success, error_msg = await _quick_test_mcp_connection(cfg) + if not success: + raise Exception(error_msg) if "transport" in cfg: transport_type = cfg["transport"] @@ -446,6 +450,7 @@ def logging_callback( headers=cfg.get("headers", {}), timeout=cfg.get("timeout", 5), sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), + auth=auth, ) streams = await self.exit_stack.enter_async_context( self._streams_context, @@ -471,6 +476,7 @@ def logging_callback( timeout=timeout, sse_read_timeout=sse_read_timeout, terminate_on_close=cfg.get("terminate_on_close", True), + auth=auth, ) read_s, write_s, _ = await self.exit_stack.enter_async_context( self._streams_context, diff --git a/astrbot/core/agent/mcp_oauth.py b/astrbot/core/agent/mcp_oauth.py new file mode 100644 index 0000000000..3682cfd206 --- /dev/null +++ b/astrbot/core/agent/mcp_oauth.py @@ -0,0 +1,709 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +import time +import uuid +from collections.abc import Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal +from urllib.parse import parse_qs, urlparse + +import httpx +from pydantic import BaseModel, ConfigDict + +from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +try: + from mcp.client.auth import OAuthClientProvider, TokenStorage + from mcp.client.auth.extensions.client_credentials import ( + ClientCredentialsOAuthProvider, + ) + from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthToken, + ) +except (ModuleNotFoundError, ImportError): + OAuthClientProvider = None # type: ignore[assignment] + ClientCredentialsOAuthProvider = None # type: ignore[assignment] + TokenStorage = object # type: ignore[assignment] + OAuthClientInformationFull = None # type: ignore[assignment] + OAuthClientMetadata = None # type: ignore[assignment] + OAuthToken = None # type: ignore[assignment] + + +class MCPOAuthError(Exception): + """Base exception for MCP OAuth flows.""" + + +class MCPOAuthAuthorizationRequiredError(MCPOAuthError): + """Raised when interactive OAuth authorization is required.""" + + +class MCPOAuthConfig(BaseModel): + model_config = ConfigDict(extra="ignore") + + grant_type: Literal["authorization_code", "client_credentials"] = ( + "authorization_code" + ) + client_id: str | None = None + client_secret: str | None = None + token_endpoint_auth_method: ( + Literal["none", "client_secret_post", "client_secret_basic"] | None + ) = None + scope: str | None = None + redirect_uri: str | None = None + timeout: float = 300.0 + client_name: str | None = "AstrBot MCP Client" + client_uri: str | None = None + logo_uri: str | None = None + contacts: list[str] | None = None + tos_uri: str | None = None + policy_uri: str | None = None + software_id: str | None = None + software_version: str | None = None + client_metadata_url: str | None = None + + +def _prepare_config(config: Mapping[str, Any]) -> dict[str, Any]: + prepared = dict(config) + if prepared.get("mcpServers"): + first_key = next(iter(prepared["mcpServers"])) + prepared = dict(prepared["mcpServers"][first_key]) + prepared.pop("active", None) + return prepared + + +def get_mcp_oauth_config(config: Mapping[str, Any]) -> MCPOAuthConfig | None: + prepared = _prepare_config(config) + oauth_config = prepared.get("oauth2") or prepared.get("oauth") + if not isinstance(oauth_config, dict): + return None + return MCPOAuthConfig.model_validate(oauth_config) + + +def has_mcp_oauth_config(config: Mapping[str, Any]) -> bool: + return get_mcp_oauth_config(config) is not None + + +def _get_storage_fingerprint(config: Mapping[str, Any]) -> str: + prepared = _prepare_config(config) + oauth_config = get_mcp_oauth_config(prepared) + if oauth_config is None: + raise MCPOAuthError("OAuth 2.0 is not configured for this MCP server.") + + fingerprint_payload = { + "url": prepared.get("url"), + "transport": prepared.get("transport") or prepared.get("type"), + "grant_type": oauth_config.grant_type, + "client_id": oauth_config.client_id, + "client_secret": oauth_config.client_secret, + "token_endpoint_auth_method": oauth_config.token_endpoint_auth_method, + "scope": oauth_config.scope, + "redirect_uri": oauth_config.redirect_uri, + "client_metadata_url": oauth_config.client_metadata_url, + } + canonical = json.dumps( + fingerprint_payload, + sort_keys=True, + ensure_ascii=False, + separators=(",", ":"), + ) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + +def get_mcp_oauth_storage_path(config: Mapping[str, Any]) -> Path: + data_dir = Path(get_astrbot_data_path()) / "mcp_oauth" + return data_dir / f"{_get_storage_fingerprint(config)}.json" + + +class MCPFileTokenStorage(TokenStorage): + def __init__(self, storage_path: Path) -> None: + self.storage_path = storage_path + self._lock = asyncio.Lock() + + @classmethod + def from_mcp_config(cls, config: Mapping[str, Any]) -> MCPFileTokenStorage: + return cls(get_mcp_oauth_storage_path(config)) + + def _load_unlocked(self) -> dict[str, Any]: + if not self.storage_path.exists(): + return {} + try: + return json.loads(self.storage_path.read_text(encoding="utf-8")) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to load MCP OAuth storage %s: %s", + self.storage_path, + exc, + ) + return {} + + def _save_unlocked(self, payload: dict[str, Any]) -> None: + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + self.storage_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + try: + os.chmod(self.storage_path, 0o600) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to set permissions on MCP OAuth storage %s: %s", + self.storage_path, + exc, + ) + + async def get_tokens(self) -> OAuthToken | None: + async with self._lock: + payload = self._load_unlocked() + token_payload = payload.get("tokens") + if not token_payload: + return None + return OAuthToken.model_validate(token_payload) + + async def set_tokens(self, tokens: OAuthToken) -> None: + async with self._lock: + payload = self._load_unlocked() + payload["tokens"] = tokens.model_dump(mode="json", exclude_none=True) + if tokens.expires_in is not None: + payload["token_expires_at"] = time.time() + float(tokens.expires_in) + else: + payload.pop("token_expires_at", None) + self._save_unlocked(payload) + + async def clear_tokens(self) -> None: + async with self._lock: + payload = self._load_unlocked() + payload.pop("tokens", None) + payload.pop("token_expires_at", None) + self._save_unlocked(payload) + + async def get_client_info(self) -> OAuthClientInformationFull | None: + async with self._lock: + payload = self._load_unlocked() + client_info_payload = payload.get("client_info") + if not client_info_payload: + return None + return OAuthClientInformationFull.model_validate(client_info_payload) + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + async with self._lock: + payload = self._load_unlocked() + payload["client_info"] = client_info.model_dump( + mode="json", + exclude_none=True, + ) + self._save_unlocked(payload) + + async def get_redirect_uri(self) -> str | None: + async with self._lock: + payload = self._load_unlocked() + redirect_uri = payload.get("redirect_uri") + return str(redirect_uri) if isinstance(redirect_uri, str) else None + + async def set_redirect_uri(self, redirect_uri: str) -> None: + async with self._lock: + payload = self._load_unlocked() + payload["redirect_uri"] = redirect_uri + self._save_unlocked(payload) + + async def get_token_expires_at(self) -> float | None: + async with self._lock: + payload = self._load_unlocked() + expires_at = payload.get("token_expires_at") + if isinstance(expires_at, (int, float)): + return float(expires_at) + return None + + +def _get_token_endpoint_auth_method(oauth_config: MCPOAuthConfig) -> str: + if oauth_config.token_endpoint_auth_method: + return oauth_config.token_endpoint_auth_method + if oauth_config.client_secret: + return "client_secret_basic" + return "none" + + +async def _raise_interactive_redirect_required(_: str) -> None: + raise MCPOAuthAuthorizationRequiredError( + "OAuth 2.0 authorization is required. Complete authorization in the MCP server dialog first.", + ) + + +async def _raise_interactive_callback_required() -> tuple[str, str | None]: + raise MCPOAuthAuthorizationRequiredError( + "OAuth 2.0 authorization is required. Complete authorization in the MCP server dialog first.", + ) + + +if OAuthClientProvider is not None: + + class AstrBotOAuthClientProvider(OAuthClientProvider): + async def _initialize(self) -> None: + await super()._initialize() + + storage = self.context.storage + if not isinstance(storage, MCPFileTokenStorage): + return + + expires_at = await storage.get_token_expires_at() + if expires_at is not None: + self.context.token_expiry_time = expires_at + + if ( + expires_at is not None + and time.time() > expires_at + and not self.context.can_refresh_token() + ): + raise MCPOAuthAuthorizationRequiredError( + "The stored OAuth 2.0 token has expired. Complete authorization in the MCP server dialog again.", + ) + +else: + AstrBotOAuthClientProvider = None # type: ignore[assignment] + + +if ClientCredentialsOAuthProvider is not None: + + class AstrBotClientCredentialsOAuthProvider(ClientCredentialsOAuthProvider): + async def _initialize(self) -> None: + await super()._initialize() + + storage = self.context.storage + if not isinstance(storage, MCPFileTokenStorage): + return + + expires_at = await storage.get_token_expires_at() + if expires_at is not None: + self.context.token_expiry_time = expires_at + +else: + AstrBotClientCredentialsOAuthProvider = None # type: ignore[assignment] + + +def _build_client_metadata( + oauth_config: MCPOAuthConfig, + *, + redirect_uri: str, +) -> OAuthClientMetadata: + return OAuthClientMetadata( + redirect_uris=[redirect_uri], + token_endpoint_auth_method=_get_token_endpoint_auth_method(oauth_config), + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope=oauth_config.scope, + client_name=oauth_config.client_name, + client_uri=oauth_config.client_uri, + logo_uri=oauth_config.logo_uri, + contacts=oauth_config.contacts, + tos_uri=oauth_config.tos_uri, + policy_uri=oauth_config.policy_uri, + software_id=oauth_config.software_id, + software_version=oauth_config.software_version, + ) + + +async def _seed_client_info_if_needed( + storage: MCPFileTokenStorage, + oauth_config: MCPOAuthConfig, + *, + redirect_uri: str, +) -> None: + if not oauth_config.client_id: + return + + client_info = OAuthClientInformationFull( + redirect_uris=[redirect_uri], + client_id=oauth_config.client_id, + client_secret=oauth_config.client_secret, + grant_types=["authorization_code", "refresh_token"], + token_endpoint_auth_method=_get_token_endpoint_auth_method(oauth_config), + response_types=["code"], + scope=oauth_config.scope, + client_name=oauth_config.client_name, + client_uri=oauth_config.client_uri, + logo_uri=oauth_config.logo_uri, + contacts=oauth_config.contacts, + tos_uri=oauth_config.tos_uri, + policy_uri=oauth_config.policy_uri, + software_id=oauth_config.software_id, + software_version=oauth_config.software_version, + ) + await storage.set_client_info(client_info) + + +@dataclass(slots=True) +class MCPOAuthPendingFlow: + flow_id: str + config: dict[str, Any] + redirect_uri: str + created_at: float = field(default_factory=time.time) + status: Literal[ + "initializing", + "awaiting_user", + "authorizing", + "completed", + "failed", + ] = "initializing" + authorization_url: str | None = None + error: str | None = None + callback_code: str | None = None + callback_state: str | None = None + callback_error: str | None = None + oauth_state: str | None = None + url_ready_event: asyncio.Event = field(default_factory=asyncio.Event) + callback_ready_event: asyncio.Event = field(default_factory=asyncio.Event) + done_event: asyncio.Event = field(default_factory=asyncio.Event) + task: asyncio.Task[None] | None = None + + async def handle_redirect(self, authorization_url: str) -> None: + self.authorization_url = authorization_url + parsed_url = urlparse(authorization_url) + self.oauth_state = parse_qs(parsed_url.query).get("state", [None])[0] + self.status = "awaiting_user" + self.url_ready_event.set() + + async def wait_for_callback(self) -> tuple[str, str | None]: + await self.callback_ready_event.wait() + if self.callback_error: + raise MCPOAuthError(self.callback_error) + self.status = "authorizing" + return self.callback_code or "", self.callback_state + + def submit_callback( + self, + *, + code: str | None, + state: str | None, + error: str | None, + ) -> None: + self.callback_code = code + self.callback_state = state + self.callback_error = error + self.callback_ready_event.set() + + +async def create_mcp_http_auth( + config: Mapping[str, Any], + *, + interactive_flow: MCPOAuthPendingFlow | None = None, +) -> httpx.Auth | None: + prepared = _prepare_config(config) + if "url" not in prepared: + return None + + oauth_config = get_mcp_oauth_config(prepared) + if oauth_config is None: + return None + + if OAuthClientProvider is None or OAuthClientMetadata is None: + raise MCPOAuthError("The installed MCP dependency does not support OAuth 2.0.") + + storage = MCPFileTokenStorage.from_mcp_config(prepared) + + if oauth_config.grant_type == "client_credentials": + if not oauth_config.client_id or not oauth_config.client_secret: + raise MCPOAuthError( + "OAuth client_credentials requires both client_id and client_secret.", + ) + if AstrBotClientCredentialsOAuthProvider is None: + raise MCPOAuthError( + "The installed MCP dependency does not support OAuth 2.0 client_credentials.", + ) + return AstrBotClientCredentialsOAuthProvider( + server_url=str(prepared["url"]), + storage=storage, + client_id=oauth_config.client_id, + client_secret=oauth_config.client_secret, + token_endpoint_auth_method=_get_token_endpoint_auth_method(oauth_config), + scopes=oauth_config.scope, + ) + + if oauth_config.grant_type != "authorization_code": + raise MCPOAuthError( + f"Unsupported MCP OAuth grant_type: {oauth_config.grant_type}", + ) + + if interactive_flow is None: + stored_tokens = await storage.get_tokens() + if stored_tokens is None: + raise MCPOAuthAuthorizationRequiredError( + "OAuth 2.0 authorization is required. Complete authorization in the MCP server dialog first.", + ) + + expires_at = await storage.get_token_expires_at() + if ( + expires_at is not None + and time.time() > expires_at + and not stored_tokens.refresh_token + ): + raise MCPOAuthAuthorizationRequiredError( + "The stored OAuth 2.0 token has expired and no refresh token is available. Complete authorization in the MCP server dialog again.", + ) + + redirect_uri = ( + interactive_flow.redirect_uri + if interactive_flow is not None + else oauth_config.redirect_uri + or await storage.get_redirect_uri() + or "http://127.0.0.1/astrbot/mcp/oauth/callback/pending" + ) + + await storage.set_redirect_uri(redirect_uri) + await _seed_client_info_if_needed(storage, oauth_config, redirect_uri=redirect_uri) + + redirect_handler = ( + interactive_flow.handle_redirect + if interactive_flow is not None + else _raise_interactive_redirect_required + ) + callback_handler = ( + interactive_flow.wait_for_callback + if interactive_flow is not None + else _raise_interactive_callback_required + ) + + if AstrBotOAuthClientProvider is None: + raise MCPOAuthError("The installed MCP dependency does not support OAuth 2.0.") + + return AstrBotOAuthClientProvider( + server_url=str(prepared["url"]), + client_metadata=_build_client_metadata( + oauth_config, + redirect_uri=redirect_uri, + ), + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=oauth_config.timeout, + client_metadata_url=oauth_config.client_metadata_url, + ) + + +async def get_mcp_oauth_state(config: Mapping[str, Any]) -> dict[str, Any]: + oauth_config = get_mcp_oauth_config(config) + if oauth_config is None: + return { + "oauth2_enabled": False, + "oauth2_authorized": False, + "oauth2_grant_type": None, + } + + if oauth_config.grant_type == "client_credentials": + return { + "oauth2_enabled": True, + "oauth2_authorized": True, + "oauth2_grant_type": oauth_config.grant_type, + } + + storage = MCPFileTokenStorage.from_mcp_config(config) + tokens = await storage.get_tokens() + return { + "oauth2_enabled": True, + "oauth2_authorized": tokens is not None, + "oauth2_grant_type": oauth_config.grant_type, + } + + +async def _probe_http_oauth_connection( + config: Mapping[str, Any], + auth: httpx.Auth, +) -> None: + prepared = _prepare_config(config) + url = str(prepared["url"]) + headers = { + str(key): str(value) for key, value in dict(prepared.get("headers", {})).items() + } + timeout_value = float(prepared.get("timeout", 30)) + transport_type = prepared.get("transport") or prepared.get("type") or "sse" + + async with httpx.AsyncClient( + follow_redirects=True, + timeout=timeout_value, + ) as client: + if transport_type == "streamable_http": + response = await client.post( + url, + headers={ + **headers, + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "id": 0, + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "astrbot-oauth-probe", + "version": "1.0.0", + }, + }, + }, + auth=auth, + ) + else: + response = await client.get( + url, + headers={ + **headers, + "Accept": "application/json, text/event-stream", + }, + auth=auth, + ) + + if response.status_code != 200: + raise MCPOAuthError( + f"OAuth authorization probe failed: HTTP {response.status_code} {response.reason_phrase}", + ) + + +class MCPOAuthManager: + _FLOW_TTL_SECONDS = 900 + + def __init__(self) -> None: + self._flows: dict[str, MCPOAuthPendingFlow] = {} + self._state_to_flow_id: dict[str, str] = {} + self._lock = asyncio.Lock() + + async def _prune_flows(self) -> None: + threshold = time.time() - self._FLOW_TTL_SECONDS + async with self._lock: + expired_ids = [ + flow_id + for flow_id, flow in self._flows.items() + if flow.created_at < threshold + ] + for flow_id in expired_ids: + expired_states = [ + state + for state, state_flow_id in self._state_to_flow_id.items() + if state_flow_id == flow_id + ] + for state in expired_states: + self._state_to_flow_id.pop(state, None) + self._flows.pop(flow_id, None) + + async def _run_flow(self, flow: MCPOAuthPendingFlow) -> None: + try: + auth = await create_mcp_http_auth(flow.config, interactive_flow=flow) + if auth is None: + raise MCPOAuthError("OAuth 2.0 is not configured for this MCP server.") + await _probe_http_oauth_connection(flow.config, auth) + flow.status = "completed" + except Exception as exc: # noqa: BLE001 + flow.error = str(exc) + flow.status = "failed" + flow.url_ready_event.set() + finally: + flow.done_event.set() + + async def start_authorization( + self, + config: Mapping[str, Any], + *, + callback_base_url: str, + server_name: str | None = None, + force: bool = False, + ) -> MCPOAuthPendingFlow: + prepared = _prepare_config(config) + oauth_config = get_mcp_oauth_config(prepared) + if oauth_config is None: + raise MCPOAuthError("OAuth 2.0 is not configured for this MCP server.") + if oauth_config.grant_type != "authorization_code": + raise MCPOAuthError( + "Interactive login is only available for authorization_code flows.", + ) + if "url" not in prepared: + raise MCPOAuthError("OAuth 2.0 is only supported for HTTP MCP transports.") + + await self._prune_flows() + + storage = MCPFileTokenStorage.from_mcp_config(prepared) + if force: + await storage.clear_tokens() + + flow_id = uuid.uuid4().hex + redirect_uri = f"{callback_base_url.rstrip('/')}/mcp/oauth/callback" + + flow = MCPOAuthPendingFlow( + flow_id=flow_id, + config=prepared, + redirect_uri=redirect_uri, + ) + flow.task = asyncio.create_task( + self._run_flow(flow), + name=f"mcp-oauth:{flow_id}", + ) + + async with self._lock: + self._flows[flow_id] = flow + + wait_url_task = asyncio.create_task(flow.url_ready_event.wait()) + wait_done_task = asyncio.create_task(flow.done_event.wait()) + try: + done, pending = await asyncio.wait( + {wait_url_task, wait_done_task}, + timeout=15, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + if not done: + raise MCPOAuthError( + "Timed out while preparing the OAuth 2.0 authorization flow.", + ) + finally: + if not wait_url_task.done(): + wait_url_task.cancel() + if not wait_done_task.done(): + wait_done_task.cancel() + + if flow.status == "failed": + raise MCPOAuthError(flow.error or "Failed to start OAuth 2.0 flow.") + + if flow.oauth_state: + async with self._lock: + self._state_to_flow_id[flow.oauth_state] = flow.flow_id + + return flow + + async def submit_callback( + self, + flow_id: str | None = None, + *, + code: str | None, + state: str | None, + error: str | None, + ) -> None: + resolved_flow_id = flow_id + if resolved_flow_id is None and state: + resolved_flow_id = self._state_to_flow_id.get(state) + + async with self._lock: + flow = self._flows.get(resolved_flow_id or "") + if flow is None: + raise KeyError(flow_id or state or "") + flow.submit_callback(code=code, state=state, error=error) + + def get_flow_status(self, flow_id: str) -> dict[str, Any]: + flow = self._flows.get(flow_id) + if flow is None: + raise KeyError(flow_id) + return { + "flow_id": flow.flow_id, + "status": flow.status, + "authorization_url": flow.authorization_url, + "redirect_uri": flow.redirect_uri, + "error": flow.error, + } diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index ab6dd037f4..02da2ae85b 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -16,6 +16,11 @@ from astrbot import logger from astrbot.core import sp from astrbot.core.agent.mcp_client import MCPClient, MCPTool +from astrbot.core.agent.mcp_oauth import ( + MCPOAuthManager, + get_mcp_oauth_state, + has_mcp_oauth_config, +) from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.tools.registry import ( ensure_builtin_tools_loaded, @@ -225,6 +230,7 @@ def __init__(self) -> None: self._timeout_warn_lock = threading.Lock() self._runtime_lock = asyncio.Lock() self._mcp_starting: set[str] = set() + self._mcp_oauth_manager = MCPOAuthManager() self._init_timeout_default = _resolve_timeout( timeout=None, env_name=MCP_INIT_TIMEOUT_ENV, @@ -711,7 +717,7 @@ async def _terminate_mcp_client(self, name: str) -> None: @staticmethod async def test_mcp_server_connection(config: dict) -> list[str]: - if "url" in config: + if "url" in config and not has_mcp_oauth_config(config): success, error_msg = await _quick_test_mcp_connection(config) if not success: raise Exception(error_msg) @@ -727,6 +733,43 @@ async def test_mcp_server_connection(config: dict) -> list[str]: await mcp_client.cleanup() return tool_names + async def get_mcp_oauth_state(self, config: dict) -> dict[str, Any]: + return await get_mcp_oauth_state(config) + + async def start_mcp_oauth_authorization( + self, + config: dict, + *, + callback_base_url: str, + server_name: str | None = None, + force: bool = False, + ) -> dict[str, Any]: + flow = await self._mcp_oauth_manager.start_authorization( + config, + callback_base_url=callback_base_url, + server_name=server_name, + force=force, + ) + return self._mcp_oauth_manager.get_flow_status(flow.flow_id) + + def get_mcp_oauth_flow_status(self, flow_id: str) -> dict[str, Any]: + return self._mcp_oauth_manager.get_flow_status(flow_id) + + async def submit_mcp_oauth_callback( + self, + flow_id: str | None, + *, + code: str | None, + state: str | None, + error: str | None, + ) -> None: + await self._mcp_oauth_manager.submit_callback( + flow_id, + code=code, + state=state, + error=error, + ) + async def enable_mcp_server( self, name: str, diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 157b4d75bf..841c92b510 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -4,6 +4,7 @@ from astrbot.core import logger from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config +from astrbot.core.agent.mcp_oauth import MCPOAuthAuthorizationRequiredError from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star import star_map from astrbot.core.tools.registry import get_builtin_tool_config_statuses @@ -53,12 +54,19 @@ def __init__( "/tools/mcp/update": ("POST", self.update_mcp_server), "/tools/mcp/delete": ("POST", self.delete_mcp_server), "/tools/mcp/test": ("POST", self.test_mcp_connection), + "/tools/mcp/oauth/start": ("POST", self.start_mcp_oauth_authorization), + "/tools/mcp/oauth/status": ("GET", self.get_mcp_oauth_status), "/tools/list": ("GET", self.get_tool_list), "/tools/toggle-tool": ("POST", self.toggle_tool), "/tools/mcp/sync-provider": ("POST", self.sync_provider), } self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools + self.app.add_url_rule( + "/mcp/oauth/callback", + view_func=self.handle_mcp_oauth_callback, + methods=["GET"], + ) def _rollback_mcp_server(self, name: str) -> bool: try: @@ -101,7 +109,11 @@ async def get_mcp_servers(self): if key != "active": # active 已经处理 server_info[key] = value - # 如果MCP客户端已初始化,从客户端获取工具名称 + server_info.update( + await self.tool_mgr.get_mcp_oauth_state(server_config) + ) + + # 如果 MCP 客户端已初始化,从客户端获取工具名称 for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items(): if name_key == name: mcp_client = runtime.client @@ -134,7 +146,15 @@ async def add_mcp_server(self): # 复制所有配置字段 for key, value in server_data.items(): - if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段 + if key not in [ + "name", + "active", + "tools", + "errlogs", + "oauth2_enabled", + "oauth2_authorized", + "oauth2_grant_type", + ]: # 排除特殊字段 if key == "mcpServers": try: server_config = _extract_mcp_server_config( @@ -165,6 +185,8 @@ async def add_mcp_server(self): try: await self.tool_mgr.test_mcp_server_connection(server_config) + except MCPOAuthAuthorizationRequiredError as e: + return Response().error(f"{e!s}").__dict__ except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"MCP connection test failed: {e!s}").__dict__ @@ -178,6 +200,12 @@ async def add_mcp_server(self): server_config, timeout=30, ) + except MCPOAuthAuthorizationRequiredError as e: + rollback_ok = self._rollback_mcp_server(name) + err_msg = f"{e!s}" + if not rollback_ok: + err_msg += " Configuration rollback failed. Please check the config manually." + return Response().error(err_msg).__dict__ except TimeoutError: rollback_ok = self._rollback_mcp_server(name) err_msg = f"Timed out while enabling MCP server {name}." @@ -243,6 +271,9 @@ async def update_mcp_server(self): "tools", "errlogs", "oldName", + "oauth2_enabled", + "oauth2_authorized", + "oauth2_grant_type", ]: # 排除特殊字段 if key == "mcpServers": try: @@ -258,7 +289,7 @@ async def update_mcp_server(self): # 如果只更新活动状态,保留原始配置 if only_update_active and isinstance(old_config, dict): for key, value in old_config.items(): - if key != "active": # 除了active之外的所有字段都保留 + if key != "active": # 除了 active 之外的所有字段都保留 server_config[key] = value try: @@ -274,7 +305,7 @@ async def update_mcp_server(self): config["mcpServers"][name] = server_config if self.tool_mgr.save_mcp_config(config): - # 处理MCP客户端状态变化 + # 处理 MCP 客户端状态变化 if active: if ( old_name in self.tool_mgr.mcp_server_runtime_view @@ -306,6 +337,8 @@ async def update_mcp_server(self): config["mcpServers"][name], timeout=30, ) + except MCPOAuthAuthorizationRequiredError as e: + return Response().error(f"{e!s}").__dict__ except TimeoutError: return ( Response() @@ -436,11 +469,115 @@ async def test_mcp_connection(self): .ok(data=tools_name, message="🎉 MCP server is available!") .__dict__ ) + except MCPOAuthAuthorizationRequiredError as e: + return Response().error(f"{e!s}").__dict__ except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"Failed to test MCP connection: {e!s}").__dict__ + async def start_mcp_oauth_authorization(self): + try: + name = request.args.get("name") + payload = await request.json + if not isinstance(payload, dict): + return Response().error("Invalid JSON body: expected object").__dict__ + + config = payload.get("mcp_server_config") + if not isinstance(config, dict) or not config: + return Response().error("Invalid MCP server configuration").__dict__ + + if "mcpServers" in config: + try: + config = _extract_mcp_server_config(config["mcpServers"]) + except ValueError as e: + return Response().error(f"{e!s}").__dict__ + + # 优先使用配置中的对外可达的回调接口地址 + callback_api_base = self.config.get("callback_api_base") + callback_base_url = ( + callback_api_base + or payload.get("callback_base_url") + or request.url_root.rstrip("/") + ) + + flow_status = await self.tool_mgr.start_mcp_oauth_authorization( + config, + callback_base_url=callback_base_url, + server_name=name, + force=bool(payload.get("force", False)), + ) + return ( + Response() + .ok( + data=flow_status, + message="OAuth 2.0 authorization flow is ready.", + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"Failed to start MCP OAuth authorization: {e!s}") + .__dict__ + ) + + async def get_mcp_oauth_status(self): + try: + flow_id = request.args.get("flow_id", "").strip() + if not flow_id: + return Response().error("Missing required parameter: flow_id").__dict__ + + flow_status = self.tool_mgr.get_mcp_oauth_flow_status(flow_id) + return Response().ok(data=flow_status).__dict__ + except KeyError: + return Response().error("OAuth flow not found or expired").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"Failed to get MCP OAuth status: {e!s}").__dict__ + + async def handle_mcp_oauth_callback(self): + error = request.args.get("error") + error_description = request.args.get("error_description") + if error_description: + error = f"{error or 'oauth_error'}: {error_description}" + + state = request.args.get("state") + try: + await self.tool_mgr.submit_mcp_oauth_callback( + None, + code=request.args.get("code"), + state=state, + error=error, + ) + except KeyError: + return ( + "

OAuth flow not found or expired.

", + 404, + {"Content-Type": "text/html; charset=utf-8"}, + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + f"

OAuth callback failed: {e!s}

", + 500, + {"Content-Type": "text/html; charset=utf-8"}, + ) + + html = """ + + +

OAuth authorization completed.

+

You can return to AstrBot and wait for the status to update.

+ + + +""" + return html, 200, {"Content-Type": "text/html; charset=utf-8"} + async def get_tool_list(self): """Get all registered tools.""" try: diff --git a/dashboard/src/components/extension/McpServersSection.vue b/dashboard/src/components/extension/McpServersSection.vue index ae697b8015..f97e47119b 100644 --- a/dashboard/src/components/extension/McpServersSection.vue +++ b/dashboard/src/components/extension/McpServersSection.vue @@ -28,31 +28,38 @@
-