|
1 | 1 | import io |
2 | 2 | import sys |
3 | 3 | from io import TextIOWrapper |
| 4 | +from typing import cast |
4 | 5 |
|
5 | 6 | import anyio |
6 | 7 | import pytest |
7 | 8 |
|
8 | 9 | from mcp.server.stdio import stdio_server |
9 | 10 | from mcp.shared.message import SessionMessage |
10 | | -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter |
| 11 | +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter |
| 12 | + |
| 13 | + |
| 14 | +class BlockingStdout: |
| 15 | + def __init__(self) -> None: |
| 16 | + self.lines: list[str] = [] |
| 17 | + self.write_started = anyio.Event() |
| 18 | + self.release_write = anyio.Event() |
| 19 | + |
| 20 | + async def write(self, text: str) -> None: |
| 21 | + self.lines.append(text) |
| 22 | + self.write_started.set() |
| 23 | + await self.release_write.wait() |
| 24 | + |
| 25 | + async def flush(self) -> None: |
| 26 | + pass |
11 | 27 |
|
12 | 28 |
|
13 | 29 | @pytest.mark.anyio |
@@ -92,3 +108,26 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): |
92 | 108 | second = await read_stream.receive() |
93 | 109 | assert isinstance(second, SessionMessage) |
94 | 110 | assert second.message == valid |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.anyio |
| 114 | +async def test_stdio_server_write_stream_allows_response_after_slow_notification(): |
| 115 | + """A slow stdout write for a notification must not block the next response.""" |
| 116 | + stdout = BlockingStdout() |
| 117 | + typed_stdout = cast(anyio.AsyncFile[str], stdout) |
| 118 | + |
| 119 | + async with stdio_server(stdin=anyio.AsyncFile(io.StringIO()), stdout=typed_stdout) as ( |
| 120 | + read_stream, |
| 121 | + write_stream, |
| 122 | + ): |
| 123 | + notification = JSONRPCNotification(jsonrpc="2.0", method="notifications/progress") |
| 124 | + await write_stream.send(SessionMessage(notification)) |
| 125 | + await stdout.write_started.wait() |
| 126 | + |
| 127 | + with anyio.move_on_after(0.1) as scope: |
| 128 | + await write_stream.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=2, result={}))) |
| 129 | + |
| 130 | + assert not scope.cancel_called |
| 131 | + stdout.release_write.set() |
| 132 | + await write_stream.aclose() |
| 133 | + await read_stream.aclose() |
0 commit comments