diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..eb8559f9d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -57,6 +57,8 @@ async def __call__( async def _default_message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: + if isinstance(message, Exception): + logger.warning("Unhandled exception in message handler", exc_info=message) await anyio.lowlevel.checkpoint() diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f0..b2a1227a3 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -705,3 +705,142 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_default_message_handler_logs_exceptions(caplog: pytest.LogCaptureFixture): + """Test that the default message handler logs exceptions instead of silently swallowing them. + + When an exception (e.g. a transport error) is delivered through the read stream, + the default handler should log it at WARNING level so the error is observable. + Previously, exceptions were silently discarded, making transport failures + impossible to diagnose. + """ + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + async def mock_server(): + # Receive the initialization request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request, JSONRPCRequest) + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) + + # Send init response + await server_to_client_send.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Inject an exception into the read stream (simulating a transport error) + await server_to_client_send.send(RuntimeError("SSE stream read timeout")) + + # Close the stream so the session can exit cleanly + await server_to_client_send.aclose() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + # Use the default message_handler (no override) + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Wait for the receive loop to process the exception + await anyio.sleep(0.1) + + # Verify the exception was logged instead of silently swallowed + warning_records = [r for r in caplog.records if "Unhandled exception in message handler" in r.message] + assert len(warning_records) >= 1 + # The exception details are attached via exc_info, visible in the formatted output + assert warning_records[0].exc_info is not None + assert warning_records[0].exc_info[1] is not None + assert "SSE stream read timeout" in str(warning_records[0].exc_info[1]) + + +@pytest.mark.anyio +async def test_custom_message_handler_can_suppress_exceptions(): + """Test that a custom message handler can suppress exceptions if desired.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + suppressed_exceptions: list[Exception] = [] + + async def suppressing_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): # pragma: no branch + suppressed_exceptions.append(message) + # Intentionally NOT re-raising — old silent behavior + + async def mock_server(): + # Receive the initialization request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request, JSONRPCRequest) + + result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) + + # Send init response + await server_to_client_send.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Inject an exception, then close the stream + await server_to_client_send.send(RuntimeError("transport error")) + await server_to_client_send.aclose() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=suppressing_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Give the receive loop time to process the exception + await anyio.sleep(0.1) + + # The custom handler captured the exception instead of crashing + assert len(suppressed_exceptions) == 1 + assert str(suppressed_exceptions[0]) == "transport error"