diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..fdd595892 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -205,6 +205,7 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S if request_mcp_session_id is None: # New session case logger.debug("Creating new transport") + is_session_migration = False async with self._session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( @@ -214,6 +215,32 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S security_settings=self.security_settings, retry_interval=self.retry_interval, ) + else: + # Unknown or expired session ID - server likely restarted. + # Create a transport with the client's session ID and mark as a + # session migration so the transport starts already initialized. + # This allows clients to reconnect transparently without sending + # a new initialize request. + if request_mcp_session_id in self._server_instances: + # Should not happen: already handled above, but check under lock + transport = self._server_instances[request_mcp_session_id] + await transport.handle_request(scope, receive, send) + return + logger.info(f"Unknown session {request_mcp_session_id}, reusing client session ID") + is_session_migration = True + async with self._session_creation_lock: + if request_mcp_session_id in self._server_instances: + transport = self._server_instances[request_mcp_session_id] + await transport.handle_request(scope, receive, send) + return + new_session_id = request_mcp_session_id + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, + security_settings=self.security_settings, + retry_interval=self.retry_interval, + ) assert http_transport.mcp_session_id is not None self._server_instances[http_transport.mcp_session_id] = http_transport @@ -235,11 +262,15 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE http_transport.idle_scope = idle_scope with idle_scope: + # For session migration (server restart), use stateless mode + # so the session starts already initialized. Otherwise the + # client receives "Received request before initialization" + # for any non-initialize request sent with the old session ID. await self.app.run( read_stream, write_stream, self.app.create_initialization_options(), - stateless=False, + stateless=is_session_migration, ) if idle_scope.cancelled_caught: @@ -268,22 +299,6 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) - else: - # Unknown or expired session ID - return 404 per MCP spec - # TODO: Align error code once spec clarifies - # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 - logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=None, - error=ErrorData(code=INVALID_REQUEST, message="Session not found"), - ) - response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_unset=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", - ) - await response(scope, receive, send) class StreamableHTTPASGIApp: