Skip to content

Commit 3179143

Browse files
committed
Add graceful SSE drain on session manager shutdown
Terminate all active transports (stateful and stateless) before cancelling the task group during shutdown. This closes in-memory streams so EventSourceResponse can send a final more_body=False chunk — a clean HTTP close instead of a connection reset. Without this, the request-scoped task groups introduced for the memory leak fix isolate stateless transports from the manager's cancel scope, so nothing ever closes their streams on shutdown.
1 parent 02547cb commit 3179143

File tree

2 files changed

+197
-9
lines changed

2 files changed

+197
-9
lines changed

src/mcp/server/streamable_http_manager.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
9292

93+
# Track in-flight stateless transports for graceful shutdown
94+
self._stateless_transports: set[StreamableHTTPServerTransport] = set()
95+
9396
# The task group will be set during lifespan
9497
self._task_group = None
9598
# Thread-safe tracking of run() calls
@@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
130133
yield # Let the application run
131134
finally:
132135
logger.info("StreamableHTTP session manager shutting down")
136+
137+
# Terminate all active transports before cancelling the task
138+
# group. This closes their in-memory streams, which lets
139+
# EventSourceResponse send a final ``more_body=False`` chunk
140+
# — a clean HTTP close instead of a connection reset.
141+
for transport in list(self._server_instances.values()):
142+
try:
143+
await transport.terminate()
144+
except Exception: # pragma: no cover
145+
logger.debug("Error terminating transport during shutdown", exc_info=True)
146+
for transport in list(self._stateless_transports):
147+
try:
148+
await transport.terminate()
149+
except Exception: # pragma: no cover
150+
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)
151+
133152
# Cancel task group to stop all spawned tasks
134153
tg.cancel_scope.cancel()
135154
self._task_group = None
136155
# Clear any remaining server instances
137156
self._server_instances.clear()
157+
self._stateless_transports.clear()
138158

139159
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140160
"""Process ASGI request with proper session handling and transport setup.
@@ -166,6 +186,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
166186
security_settings=self.security_settings,
167187
)
168188

189+
# Track for graceful shutdown
190+
self._stateless_transports.add(http_transport)
191+
169192
# Start server in a new task
170193
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
171194
async with http_transport.connect() as streams:
@@ -185,13 +208,16 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
185208
# This ensures the server task is cancelled when the request
186209
# finishes, preventing zombie tasks from accumulating.
187210
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1764
188-
async with anyio.create_task_group() as request_tg:
189-
await request_tg.start(run_stateless_server)
190-
# Handle the HTTP request directly in the caller's context
191-
# (not as a child task) so execution flows back naturally.
192-
await http_transport.handle_request(scope, receive, send)
193-
# Cancel the request-scoped task group to stop the server task.
194-
request_tg.cancel_scope.cancel()
211+
try:
212+
async with anyio.create_task_group() as request_tg:
213+
await request_tg.start(run_stateless_server)
214+
# Handle the HTTP request directly in the caller's context
215+
# (not as a child task) so execution flows back naturally.
216+
await http_transport.handle_request(scope, receive, send)
217+
# Cancel the request-scoped task group to stop the server task.
218+
request_tg.cancel_scope.cancel()
219+
finally:
220+
self._stateless_transports.discard(http_transport)
195221

196222
# Terminate after the task group exits — the server task is already
197223
# cancelled at this point, so this is just cleanup (sets _terminated

tests/server/test_streamable_http_manager.py

Lines changed: 164 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import pytest
1111
from starlette.types import Message
1212

13-
from mcp import Client
13+
from mcp import Client, types
1414
from mcp.client.streamable_http import streamable_http_client
1515
from mcp.server import Server, ServerRequestContext, streamable_http_manager
1616
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport
17-
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
17+
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
1818
from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams
1919

2020

@@ -490,3 +490,165 @@ def test_session_idle_timeout_rejects_non_positive():
490490
def test_session_idle_timeout_rejects_stateless():
491491
with pytest.raises(RuntimeError, match="not supported in stateless"):
492492
StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True)
493+
494+
495+
MCP_HEADERS = {
496+
"Accept": "application/json, text/event-stream",
497+
"Content-Type": "application/json",
498+
}
499+
500+
_INITIALIZE_REQUEST = {
501+
"jsonrpc": "2.0",
502+
"id": 1,
503+
"method": "initialize",
504+
"params": {
505+
"protocolVersion": "2025-03-26",
506+
"capabilities": {},
507+
"clientInfo": {"name": "test", "version": "0.1"},
508+
},
509+
}
510+
511+
_INITIALIZED_NOTIFICATION = {
512+
"jsonrpc": "2.0",
513+
"method": "notifications/initialized",
514+
}
515+
516+
_TOOL_CALL_REQUEST = {
517+
"jsonrpc": "2.0",
518+
"id": 2,
519+
"method": "tools/call",
520+
"params": {"name": "slow_tool", "arguments": {"message": "hello"}},
521+
}
522+
523+
524+
def _make_slow_tool_server() -> tuple[Server, anyio.Event]:
525+
"""Create an MCP server with a tool that blocks forever, returning
526+
the server and an event that fires when the tool starts executing."""
527+
tool_started = anyio.Event()
528+
529+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
530+
tool_started.set()
531+
await anyio.sleep_forever()
532+
return types.CallToolResult( # pragma: no cover
533+
content=[types.TextContent(type="text", text="never reached")]
534+
)
535+
536+
async def handle_list_tools(
537+
ctx: ServerRequestContext, params: PaginatedRequestParams | None
538+
) -> ListToolsResult: # pragma: no cover
539+
return ListToolsResult(
540+
tools=[
541+
types.Tool(
542+
name="slow_tool",
543+
description="A tool that blocks forever",
544+
input_schema={"type": "object", "properties": {"message": {"type": "string"}}},
545+
)
546+
]
547+
)
548+
549+
app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
550+
return app, tool_started
551+
552+
553+
class SSECloseTracker:
554+
"""ASGI middleware that tracks whether SSE responses close cleanly.
555+
556+
In HTTP, a clean close means sending a final empty chunk (``0\\r\\n\\r\\n``).
557+
At the ASGI protocol level this corresponds to a
558+
``{"type": "http.response.body", "more_body": False}`` message.
559+
560+
Without graceful drain, the server task is cancelled but nothing closes
561+
the stateless transport's streams — the SSE response hangs indefinitely
562+
and never sends the final body. A reverse proxy (e.g. nginx) would log
563+
"upstream prematurely closed connection while reading upstream".
564+
"""
565+
566+
def __init__(self, app: StreamableHTTPASGIApp) -> None:
567+
self.app = app
568+
self.sse_streams_opened = 0
569+
self.sse_streams_closed_cleanly = 0
570+
571+
async def __call__(self, scope: dict[str, Any], receive: Any, send: Any) -> None:
572+
is_sse = False
573+
574+
async def tracking_send(message: dict[str, Any]) -> None:
575+
nonlocal is_sse
576+
if message["type"] == "http.response.start":
577+
for name, value in message.get("headers", []):
578+
if name == b"content-type" and b"text/event-stream" in value:
579+
is_sse = True
580+
self.sse_streams_opened += 1
581+
break
582+
elif message["type"] == "http.response.body" and is_sse:
583+
if not message.get("more_body", False):
584+
self.sse_streams_closed_cleanly += 1
585+
await send(message)
586+
587+
await self.app(scope, receive, tracking_send)
588+
589+
590+
@pytest.mark.anyio
591+
async def test_graceful_shutdown_closes_sse_streams_cleanly():
592+
"""Verify that shutting down the session manager closes in-flight SSE
593+
streams with a proper ``more_body=False`` ASGI message.
594+
595+
This is the ASGI equivalent of sending the final HTTP chunk — the signal
596+
that reverse proxies like nginx use to distinguish a clean close from a
597+
connection reset ("upstream prematurely closed connection").
598+
599+
Without the graceful-drain fix, stateless transports are not tracked by
600+
the session manager. On shutdown nothing calls ``terminate()`` on them,
601+
so SSE responses hang indefinitely and never send the final body. With
602+
the fix, ``run()``'s finally block iterates ``_stateless_transports`` and
603+
terminates each one, closing the underlying memory streams and letting
604+
``EventSourceResponse`` complete normally.
605+
"""
606+
app, tool_started = _make_slow_tool_server()
607+
manager = StreamableHTTPSessionManager(app=app, stateless=True)
608+
609+
tracker = SSECloseTracker(StreamableHTTPASGIApp(manager))
610+
611+
manager_ready = anyio.Event()
612+
613+
with anyio.fail_after(10):
614+
async with anyio.create_task_group() as tg:
615+
616+
async def run_lifespan_and_shutdown() -> None:
617+
async with manager.run():
618+
manager_ready.set()
619+
with anyio.fail_after(5):
620+
await tool_started.wait()
621+
# manager.run() exits — graceful shutdown runs here
622+
623+
async def make_requests() -> None:
624+
with anyio.fail_after(5):
625+
await manager_ready.wait()
626+
async with (
627+
httpx.ASGITransport(tracker, raise_app_exceptions=False) as transport,
628+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
629+
):
630+
# Initialize
631+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
632+
resp.raise_for_status()
633+
634+
# Send initialized notification
635+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS)
636+
assert resp.status_code == 202
637+
638+
# Send slow tool call — returns an SSE stream that blocks
639+
# until shutdown terminates it
640+
await client.post(
641+
"/mcp/",
642+
json=_TOOL_CALL_REQUEST,
643+
headers=MCP_HEADERS,
644+
timeout=httpx.Timeout(10, connect=5),
645+
)
646+
647+
tg.start_soon(run_lifespan_and_shutdown)
648+
tg.start_soon(make_requests)
649+
650+
assert tracker.sse_streams_opened > 0, "Test should have opened at least one SSE stream"
651+
assert tracker.sse_streams_closed_cleanly == tracker.sse_streams_opened, (
652+
f"All {tracker.sse_streams_opened} SSE stream(s) should have closed with "
653+
f"more_body=False, but only {tracker.sse_streams_closed_cleanly} did"
654+
)

0 commit comments

Comments
 (0)