Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
Expand Down Expand Up @@ -92,6 +92,10 @@
}
)
_STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS"
_MCP_RECONNECT_ERROR_MESSAGES = (
"session terminated",
"session was terminated",
)

try:
import anyio
Expand All @@ -110,6 +114,22 @@
)


def _is_mcp_reconnect_error(exc: BaseException) -> bool:
try:
anyio_module = anyio
except NameError:
anyio_module = None

closed_resource_error = getattr(anyio_module, "ClosedResourceError", None)
if isinstance(closed_resource_error, type) and isinstance(
exc, closed_resource_error
):
return True

message = str(exc).lower()
return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES)
Comment on lines +117 to +130

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since anyio is an optional dependency (imported inside a try...except block on line 100), referencing anyio.ClosedResourceError directly will raise a NameError at runtime if anyio is not installed and _is_mcp_reconnect_error is executed (for example, when call_tool_with_reconnect raises a ValueError because self.session is None). Guarding the check with "anyio" in globals() prevents this runtime error.

Suggested change
def _is_mcp_reconnect_error(exc: BaseException) -> bool:
if isinstance(exc, anyio.ClosedResourceError):
return True
message = str(exc).lower()
return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES)
def _is_mcp_reconnect_error(exc: BaseException) -> bool:
if "anyio" in globals() and isinstance(exc, anyio.ClosedResourceError):
return True
message = str(exc).lower()
return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES)



def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format"""
if config.get("mcpServers"):
Expand Down Expand Up @@ -605,7 +625,7 @@ async def call_tool_with_reconnect(
"""

@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
retry=retry_if_exception(_is_mcp_reconnect_error),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(logger, logging.WARNING),
Expand All @@ -621,9 +641,15 @@ async def _call_with_retry():
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
except Exception as exc:
if not _is_mcp_reconnect_error(exc):
raise

logger.warning(
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
"MCP tool %s call failed (%s: %s), attempting to reconnect...",
tool_name,
type(exc).__name__,
exc,
)
# Attempt to reconnect
await self._reconnect()
Expand Down
103 changes: 103 additions & 0 deletions tests/unit/test_mcp_client_reconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from datetime import timedelta

import anyio
import pytest
from tenacity import wait_none

from astrbot.core.agent import mcp_client


class FlakyMcpSession:
def __init__(self, first_error: Exception | None = None) -> None:
self.calls = 0
self.first_error = first_error or RuntimeError("Session terminated")

async def call_tool(
self,
*,
name: str,
arguments: dict,
read_timeout_seconds: timedelta,
) -> dict[str, object]:
self.calls += 1
if self.calls == 1:
raise self.first_error
return {
"name": name,
"arguments": arguments,
"timeout": read_timeout_seconds.total_seconds(),
}


@pytest.mark.parametrize(
("error", "expected"),
[
(RuntimeError("Session terminated"), True),
(RuntimeError("SESSION TERMINATED"), True),
(RuntimeError("session was terminated"), True),
(anyio.ClosedResourceError(), True),
(RuntimeError("business flow terminated normally"), False),
(RuntimeError("terminated"), False),
],
)
def test_mcp_reconnect_error_detection_is_narrow(
error: BaseException, expected: bool
) -> None:
assert mcp_client._is_mcp_reconnect_error(error) is expected


@pytest.mark.asyncio
async def test_call_tool_reconnects_on_session_terminated(monkeypatch) -> None:
monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none())

client = mcp_client.MCPClient()
session = FlakyMcpSession()
reconnects = 0

async def reconnect() -> None:
nonlocal reconnects
Comment on lines +49 to +58

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a test ensuring non-reconnectable exceptions are not retried and propagate instead of triggering reconnect.

To exercise the new narrow error handling, please add a test where the first call raises a non-reconnectable exception (e.g., ValueError("business logic failed")) and assert that no reconnects occur (e.g., reconnects == 0) and the exception is re-raised (e.g., with pytest.raises(ValueError): call_tool_with_reconnect(...)). This will cover the if not _is_mcp_reconnect_error(exc): raise path and guard against masking unrelated failures.

reconnects += 1
client.session = session

client.session = session
client._reconnect = reconnect

result = await client.call_tool_with_reconnect(
tool_name="lookup",
arguments={"url": "https://example.com"},
read_timeout_seconds=timedelta(seconds=5),
)

assert result == {
"name": "lookup",
"arguments": {"url": "https://example.com"},
"timeout": 5.0,
}
assert session.calls == 2
assert reconnects == 1


@pytest.mark.asyncio
async def test_call_tool_does_not_reconnect_on_business_error(monkeypatch) -> None:
monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none())

client = mcp_client.MCPClient()
session = FlakyMcpSession(first_error=ValueError("business logic failed"))
reconnects = 0

async def reconnect() -> None:
nonlocal reconnects
reconnects += 1

client.session = session
client._reconnect = reconnect

with pytest.raises(ValueError, match="business logic failed"):
await client.call_tool_with_reconnect(
tool_name="lookup",
arguments={"url": "https://example.com"},
read_timeout_seconds=timedelta(seconds=5),
)

assert session.calls == 1
assert reconnects == 0
Loading