From f0920476c84d9aab9cbb80fdd76fbbc3352153bf Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 15 May 2026 22:10:11 +0800 Subject: [PATCH] fix: isolate streamable http request errors --- src/mcp/client/streamable_http.py | 17 +++++++-- tests/shared/test_streamable_http.py | 57 ++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c6338..ac317de08c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -468,10 +468,19 @@ async def _handle_message(session_message: SessionMessage) -> None: ) async def handle_request_async(): - if is_resumption: - await self._handle_resumption_request(ctx) - else: - await self._handle_post_request(ctx) + try: + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + except Exception as exc: + if not isinstance(message, JSONRPCRequest): + raise + + logger.exception("Error handling streamable HTTP request") + error_data = ErrorData(code=INTERNAL_ERROR, message=f"Request failed: {exc}") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) + await ctx.read_stream_writer.send(error_msg) # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..0ff10fc4b6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -57,7 +57,10 @@ CallToolRequestParams, CallToolResult, InitializeResult, + JSONRPCError, + JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, ListToolsResult, PaginatedRequestParams, ReadResourceRequestParams, @@ -1105,6 +1108,60 @@ async def test_streamable_http_client_error_handling(initialized_client_session: assert "Unknown resource: unknown://test-error" in exc_info.value.error.message +@pytest.mark.anyio +async def test_streamable_http_request_error_does_not_close_writer(): + async def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + if body["method"] == "tools/list": + raise httpx.ConnectError("boom", request=request) + + return httpx.Response( + 200, + headers={"content-type": "application/json"}, + json={"jsonrpc": "2.0", "id": body["id"], "result": {}}, + request=request, + ) + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with streamable_http_client("http://testserver/mcp", http_client=client) as (read_stream, write_stream): + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id="bad", method="tools/list"))) + + with anyio.fail_after(1): + error_message = await read_stream.receive() + + assert isinstance(error_message, SessionMessage) + assert isinstance(error_message.message, JSONRPCError) + assert error_message.message.id == "bad" + assert error_message.message.error.code == types.INTERNAL_ERROR + + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id="ok", method="ping"))) + + with anyio.fail_after(1): + response_message = await read_stream.receive() + + assert isinstance(response_message, SessionMessage) + assert isinstance(response_message.message, JSONRPCResponse) + assert response_message.message.id == "ok" + + +@pytest.mark.anyio +async def test_streamable_http_notification_error_still_closes_writer(): + request_seen = anyio.Event() + + async def handler(request: httpx.Request) -> httpx.Response: + request_seen.set() + raise httpx.ConnectError("boom", request=request) + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + async with streamable_http_client("http://testserver/mcp", http_client=client) as (_, write_stream): + await write_stream.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled")) + ) + + with anyio.fail_after(1): # pragma: no branch + await request_seen.wait() + + @pytest.mark.anyio async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): """Test that session ID persists across requests."""