diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c633..2aa8ee977 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -421,9 +421,15 @@ async def _handle_reconnection( await event_source.response.aclose() return - # Stream ended again without response - reconnect again (reset attempt counter) + # Stream ended again without response - reconnect again logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) + # Reset attempt counter only if the stream delivered new events + # (i.e. made forward progress). If no new events arrived, the + # server is connecting then dropping immediately — count that + # towards the retry budget to avoid infinite loops (#2393). + made_progress = reconnect_last_event_id != last_event_id + next_attempt = 0 if made_progress else attempt + 1 + await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, next_attempt) except Exception as e: # pragma: no cover logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb6..f6e3afe9a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -14,7 +14,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from urllib.parse import urlparse import anyio @@ -29,7 +29,14 @@ from mcp import MCPError, types from mcp.client.session import ClientSession -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.streamable_http import ( + MAX_RECONNECTION_ATTEMPTS, + StreamableHTTPTransport, + streamable_http_client, +) +from mcp.client.streamable_http import ( + RequestContext as ClientRequestContext, +) from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -2318,3 +2325,76 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_handle_reconnection_does_not_retry_infinitely(): + """Reconnection must give up when no forward progress is made. + + Regression test for #2393: when a stream connects successfully but drops + before delivering a response, the attempt counter was reset to 0 on the + recursive call, allowing an infinite retry loop. + + This test simulates a stream that connects, yields one non-completing SSE + event with the SAME event ID each time (no new data), then ends — + repeatedly. Without forward progress the loop must terminate within + MAX_RECONNECTION_ATTEMPTS. + """ + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + transport.session_id = "test-session" + + # Track how many times aconnect_sse is called + connect_count = 0 + + @asynccontextmanager + async def fake_aconnect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[Any]: + """Simulate a stream that connects OK, yields one event, then ends.""" + nonlocal connect_count + connect_count += 1 + + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + + # Yield a single non-completing notification SSE event with the SAME + # event ID every time, then end the stream. Because the ID never + # changes, the transport sees no forward progress and should count + # each reconnection towards the retry budget. + async def aiter_sse() -> AsyncIterator[ServerSentEvent]: + yield ServerSentEvent( + event="message", + data='{"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"tok","progress":1,"total":10}}', + id="evt-static", + retry=None, + ) + + event_source = MagicMock() + event_source.response = mock_response + event_source.aiter_sse = aiter_sse + yield event_source + + # Build a minimal RequestContext for _handle_reconnection + write_stream, read_stream = create_context_streams[SessionMessage | Exception](32) + + async with write_stream, read_stream: + request_message = JSONRPCRequest(jsonrpc="2.0", id="req-1", method="tools/call", params={}) + session_message = SessionMessage(request_message) + ctx = ClientRequestContext( + client=MagicMock(), + session_id="test-session", + session_message=session_message, + metadata=None, + read_stream_writer=write_stream, + ) + + with patch("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse): + # Use a short sleep override so the test doesn't wait on reconnection delays + with patch("mcp.client.streamable_http.DEFAULT_RECONNECTION_DELAY_MS", 0): + await transport._handle_reconnection(ctx, last_event_id="evt-0", retry_interval_ms=0) + + # The method should have connected at most MAX_RECONNECTION_ATTEMPTS times + # (one for the initial call at attempt=0, then up to MAX-1 more) + assert connect_count <= MAX_RECONNECTION_ATTEMPTS, ( + f"Expected at most {MAX_RECONNECTION_ATTEMPTS} reconnection attempts, " + f"but aconnect_sse was called {connect_count} times — " + f"the attempt counter is not being incremented across reconnections" + )