diff --git a/.changeset/session-hydration-csharp-parity.md b/.changeset/session-hydration-csharp-parity.md new file mode 100644 index 000000000..843f68b61 --- /dev/null +++ b/.changeset/session-hydration-csharp-parity.md @@ -0,0 +1,10 @@ +--- +'@modelcontextprotocol/server': minor +--- + +Add multi-node session hydration support, aligned with the C# SDK pattern: + +- `WebStandardStreamableHTTPServerTransport` now accepts a `sessionId` constructor option. When set, the transport validates incoming `mcp-session-id` headers against that value and rejects re-initialization, without requiring a fresh initialize handshake on this node. +- `Server.restoreInitializeState(params)` restores negotiated client capabilities and version from persisted `InitializeRequest` params, so capability-gated server-initiated features (sampling, elicitation, roots) work on hydrated instances. + +Internal refactor: the private `_initialized` flag is removed. Its checks are replaced by equivalent `sessionId === undefined` checks, so observable behavior (error codes and messages) is unchanged. diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index f1a1851f4..3b3e726f4 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -449,6 +449,21 @@ export class Server extends Protocol { return this._clientCapabilities; } + /** + * Restores the initialize-handshake state (client capabilities and version) from + * persisted `InitializeRequest` params. Use this in multi-node deployments when + * reconstructing a server for a session that was initialized on a different node, + * so that capability-gated server-initiated features (sampling, elicitation, roots) + * work correctly on the hydrated instance. + * + * Pair with the `sessionId` option on `WebStandardStreamableHTTPServerTransport` + * to restore transport-layer session validation alongside protocol-layer state. + */ + restoreInitializeState(params: InitializeRequest['params']): void { + this._clientCapabilities = params.capabilities; + this._clientVersion = params.clientInfo; + } + /** * After initialization has completed, this will be populated with information about the client's name and version. */ diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index 31053f35c..04f6103cc 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -75,10 +75,24 @@ export interface WebStandardStreamableHTTPServerTransportOptions { * Function that generates a session ID for the transport. * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) * - * If not provided, session management is disabled (stateless mode). + * If neither this nor {@link sessionId} is provided, session management is disabled (stateless mode). */ sessionIdGenerator?: () => string; + /** + * Pre-sets the session ID at construction time. Use this when reconstructing a + * transport for an existing session, e.g. in multi-node deployments where a request + * arrives at a node that did not handle the original initialize handshake. + * + * The transport will validate incoming `mcp-session-id` headers against this value + * and reject re-initialization attempts. + * + * **Note:** This restores transport-layer session validation only. To restore + * protocol-layer state (negotiated client capabilities), call + * `Server.restoreInitializeState()` with persisted `InitializeRequest` params. + */ + sessionId?: string; + /** * A callback for session initialization events * This is called when the server initializes a new session. @@ -228,7 +242,6 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { private _streamMapping: Map = new Map(); private _requestToStreamMapping: Map = new Map(); private _requestResponseMap: Map = new Map(); - private _initialized: boolean = false; private _enableJsonResponse: boolean = false; private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; @@ -247,6 +260,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { constructor(options: WebStandardStreamableHTTPServerTransportOptions = {}) { this.sessionIdGenerator = options.sessionIdGenerator; + this.sessionId = options.sessionId; this._enableJsonResponse = options.enableJsonResponse ?? false; this._eventStore = options.eventStore; this._onsessioninitialized = options.onsessioninitialized; @@ -667,9 +681,9 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ const isInitializationRequest = messages.some(element => isInitializeRequest(element)); if (isInitializationRequest) { - // If it's a server with session management and the session ID is already set we should reject the request - // to avoid re-initialization. - if (this._initialized && this.sessionId !== undefined) { + // If the session ID is already set (via prior initialize or constructor) + // reject the request to avoid re-initialization. + if (this.sessionId !== undefined) { this.onerror?.(new Error('Invalid Request: Server already initialized')); return this.createJsonErrorResponse(400, -32_600, 'Invalid Request: Server already initialized'); } @@ -678,7 +692,6 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { return this.createJsonErrorResponse(400, -32_600, 'Invalid Request: Only one initialization request is allowed'); } this.sessionId = this.sessionIdGenerator?.(); - this._initialized = true; // If we have a session ID and an onsessioninitialized handler, call it immediately // This is needed in cases where the server needs to keep track of multiple sessions @@ -847,13 +860,14 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { * Returns `Response` error if invalid, `undefined` otherwise */ private validateSession(req: Request): Response | undefined { - if (this.sessionIdGenerator === undefined) { - // If the sessionIdGenerator ID is not set, the session management is disabled - // and we don't need to validate the session ID + if (this.sessionIdGenerator === undefined && this.sessionId === undefined) { + // No generator and no pre-set session ID means session management is + // disabled (stateless mode); skip validation. return undefined; } - if (!this._initialized) { - // If the server has not been initialized yet, reject all requests + if (this.sessionId === undefined) { + // Stateful mode but no session ID yet: initialize has not run on this + // transport instance. this.onerror?.(new Error('Bad Request: Server not initialized')); return this.createJsonErrorResponse(400, -32_000, 'Bad Request: Server not initialized'); } diff --git a/test/integration/test/server/streamableHttp.sessionHydration.test.ts b/test/integration/test/server/streamableHttp.sessionHydration.test.ts new file mode 100644 index 000000000..87e9700ee --- /dev/null +++ b/test/integration/test/server/streamableHttp.sessionHydration.test.ts @@ -0,0 +1,218 @@ +import type { CallToolResult, JSONRPCErrorResponse, JSONRPCMessage } from '@modelcontextprotocol/core'; +import { McpServer, Server, WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import * as z from 'zod/v4'; + +const PROTOCOL_VERSION = '2025-11-25'; + +const TEST_MESSAGES = { + initialize: { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: PROTOCOL_VERSION, + capabilities: {} + }, + id: 'init-1' + } as JSONRPCMessage, + toolsList: { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'tools-1' + } as JSONRPCMessage +}; + +function createRequest( + method: string, + body?: JSONRPCMessage | JSONRPCMessage[], + options?: { sessionId?: string; extraHeaders?: Record } +): Request { + const headers: Record = {}; + + if (method === 'POST') { + headers.Accept = 'application/json, text/event-stream'; + } else if (method === 'GET') { + headers.Accept = 'text/event-stream'; + } + + if (body) { + headers['Content-Type'] = 'application/json'; + } + + if (options?.sessionId) { + headers['mcp-session-id'] = options.sessionId; + headers['mcp-protocol-version'] = PROTOCOL_VERSION; + } + + if (options?.extraHeaders) { + Object.assign(headers, options.extraHeaders); + } + + return new Request('http://localhost/mcp', { + method, + headers, + body: body ? JSON.stringify(body) : undefined + }); +} + +async function readSSEEvent(response: Response): Promise { + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + return new TextDecoder().decode(value); +} + +function parseSSEData(text: string): unknown { + const dataLine = text.split('\n').find(line => line.startsWith('data:')); + if (!dataLine) throw new Error('No data line found in SSE event'); + return JSON.parse(dataLine.slice(5).trim()); +} + +function expectErrorResponse(data: unknown, expectedCode: number, expectedMessagePattern: RegExp): void { + expect(data).toMatchObject({ + jsonrpc: '2.0', + error: expect.objectContaining({ + code: expectedCode, + message: expect.stringMatching(expectedMessagePattern) + }) + }); +} + +describe('WebStandardStreamableHTTPServerTransport session hydration', () => { + let transport: WebStandardStreamableHTTPServerTransport; + let mcpServer: McpServer; + + beforeEach(() => { + mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + mcpServer.registerTool( + 'greet', + { + description: 'A simple greeting tool', + inputSchema: z.object({ name: z.string().describe('Name to greet') }) + }, + async ({ name }): Promise => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + }); + + afterEach(async () => { + await transport?.close(); + }); + + async function connectTransport(options?: ConstructorParameters[0]) { + transport = new WebStandardStreamableHTTPServerTransport(options); + await mcpServer.connect(transport); + } + + describe('transport-layer hydration (sessionId option)', () => { + it('processes requests without initialize when constructed with sessionId', async () => { + const sessionId = 'persisted-session-id'; + await connectTransport({ sessionId }); + + const response = await transport.handleRequest(createRequest('POST', TEST_MESSAGES.toolsList, { sessionId })); + + expect(response.status).toBe(200); + expect(response.headers.get('mcp-session-id')).toBe(sessionId); + + const eventData = parseSSEData(await readSSEEvent(response)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'greet' })]) + }), + id: 'tools-1' + }); + }); + + it('rejects requests with a mismatched session ID', async () => { + await connectTransport({ sessionId: 'persisted-session-id' }); + + const response = await transport.handleRequest( + createRequest('POST', TEST_MESSAGES.toolsList, { sessionId: 'wrong-session-id' }) + ); + + expect(response.status).toBe(404); + expectErrorResponse(await response.json(), -32_001, /Session not found/); + }); + + it('rejects requests without a session ID header', async () => { + await connectTransport({ sessionId: 'persisted-session-id' }); + + const response = await transport.handleRequest(createRequest('POST', TEST_MESSAGES.toolsList)); + + expect(response.status).toBe(400); + const errorData = (await response.json()) as JSONRPCErrorResponse; + expectErrorResponse(errorData, -32_000, /Mcp-Session-Id header is required/); + expect(errorData.id).toBeNull(); + }); + + it('rejects re-initialize on a hydrated transport', async () => { + await connectTransport({ sessionId: 'persisted-session-id' }); + + const response = await transport.handleRequest(createRequest('POST', TEST_MESSAGES.initialize)); + + expect(response.status).toBe(400); + expectErrorResponse(await response.json(), -32_600, /Server already initialized/); + }); + + it('leaves default initialize flow unchanged when sessionId is not provided', async () => { + await connectTransport({ sessionIdGenerator: () => 'generated-session-id' }); + + const initResponse = await transport.handleRequest(createRequest('POST', TEST_MESSAGES.initialize)); + + expect(initResponse.status).toBe(200); + expect(initResponse.headers.get('mcp-session-id')).toBe('generated-session-id'); + + const toolsResponse = await transport.handleRequest( + createRequest('POST', TEST_MESSAGES.toolsList, { sessionId: 'generated-session-id' }) + ); + + expect(toolsResponse.status).toBe(200); + const eventData = parseSSEData(await readSSEEvent(toolsResponse)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'greet' })]) + }), + id: 'tools-1' + }); + }); + }); + + describe('Server.restoreInitializeState', () => { + it('restores client capabilities without an initialize round-trip', () => { + const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: {} }); + + expect(server.getClientCapabilities()).toBeUndefined(); + expect(server.getClientVersion()).toBeUndefined(); + + server.restoreInitializeState({ + protocolVersion: PROTOCOL_VERSION, + capabilities: { sampling: {}, elicitation: { form: {} } }, + clientInfo: { name: 'persisted-client', version: '2.0.0' } + }); + + expect(server.getClientCapabilities()).toEqual({ sampling: {}, elicitation: { form: {} } }); + expect(server.getClientVersion()).toEqual({ name: 'persisted-client', version: '2.0.0' }); + }); + + it('enables capability-gated methods after restoration', async () => { + const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: {} }); + + // Before restoration, server thinks client has no sampling capability + expect(server.getClientCapabilities()?.sampling).toBeUndefined(); + + server.restoreInitializeState({ + protocolVersion: PROTOCOL_VERSION, + capabilities: { sampling: {} }, + clientInfo: { name: 'c', version: '1' } + }); + + // After restoration, sampling capability is visible + expect(server.getClientCapabilities()?.sampling).toEqual({}); + }); + }); +});