diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274..d893b31f5 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -274,6 +274,10 @@ async def send_request( class_name = request.__class__.__name__ message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." raise MCPError(code=REQUEST_TIMEOUT, message=message) + except (anyio.EndOfStream, anyio.ClosedResourceError) as e: + class_name = request.__class__.__name__ + message = f"Connection closed while waiting for response to {class_name}: {e}" + raise MCPError(code=CONNECTION_CLOSED, message=message) if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 182b4671d..53c24d51f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -304,3 +304,59 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): # pragma: no branch await ev_response.wait() + + +@pytest.mark.anyio +async def test_response_stream_closed_raises_mcp_error(): + """Test that EndOfStream on the per-request response stream raises MCPError. + + Reproduces the race from #1717: if the per-request response stream is closed + (e.g. receive loop calls aclose() during shutdown) before send_request reads + from it, receive() raises EndOfStream. Without the fix, this propagates as an + unhandled EndOfStream (or causes UnboundLocalError). + + Simulates this by closing the response stream's send side directly while the + server connection stays open (so the receive loop never enters its finally block). + """ + + ev_result = anyio.Event() + caught_error: list[MCPError] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, _server_write = server_streams + + async def make_request(client_session: ClientSession): + nonlocal caught_error + try: + await client_session.send_ping() + pytest.fail("Expected MCPError") # pragma: no cover + except MCPError as e: + caught_error.append(e) + ev_result.set() + + async def close_response_stream(client_session: ClientSession): + # Consume the request so the client's send completes + await server_read.receive() + + # Wait for send_request to register its response stream + while not client_session._response_streams: # pragma: no branch + await anyio.sleep(0.01) # pragma: no cover + + # Close the send side directly, bypassing the receive loop's + # graceful error injection. This triggers EndOfStream on receive(). + for stream in client_session._response_streams.values(): + await stream.aclose() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(make_request, client_session) + tg.start_soon(close_response_stream, client_session) + + with anyio.fail_after(2): # pragma: no branch + await ev_result.wait() + + assert len(caught_error) == 1 + assert "Connection closed" in str(caught_error[0])