diff --git a/src/scenarios/server/dns-rebinding.ts b/src/scenarios/server/dns-rebinding.ts index e93f810..ab4ac8b 100644 --- a/src/scenarios/server/dns-rebinding.ts +++ b/src/scenarios/server/dns-rebinding.ts @@ -87,6 +87,15 @@ function probeBody(specVersion: string): { }; } +function getResponseHeader( + headers: Record, + name: string +): string | undefined { + const value = headers[name] ?? headers[name.toLowerCase()]; + if (Array.isArray(value)) return value[0]; + return value; +} + /** * Send an MCP request with custom Host and Origin headers. * Both headers are set to the same value so that servers checking either @@ -96,7 +105,7 @@ async function sendRequestWithHostAndOrigin( serverUrl: string, hostOrOrigin: string, specVersion: string -): Promise<{ statusCode: number; body: unknown }> { +): Promise<{ statusCode: number; body: unknown; sessionId?: string }> { // Build the SEP-2243 standard headers (Mcp-Method, Accept, ...) for the // probe's JSON-RPC method so a strictly-conformant server only rejects the // request for the Host/Origin values under test, then layer the @@ -121,6 +130,43 @@ async function sendRequestWithHostAndOrigin( body = null; } + return { + statusCode: response.statusCode, + body, + sessionId: getResponseHeader(response.headers, 'mcp-session-id') + }; +} + +async function sendInitializedNotification( + serverUrl: string, + hostOrOrigin: string, + specVersion: string, + sessionId?: string +): Promise<{ statusCode: number; body: unknown }> { + const method = 'notifications/initialized'; + const response = await request(serverUrl, { + method: 'POST', + headers: buildStandardHeaders(method, undefined, { + headers: { + Host: hostOrOrigin, + Origin: `http://${hostOrOrigin}`, + 'MCP-Protocol-Version': specVersion, + ...(sessionId ? { 'Mcp-Session-Id': sessionId } : {}) + } + }), + body: JSON.stringify({ + jsonrpc: '2.0', + method + }) + }); + + let body: unknown; + try { + body = await response.body.json(); + } catch { + body = null; + } + return { statusCode: response.statusCode, body @@ -253,25 +299,50 @@ See: https://github.com/modelcontextprotocol/typescript-sdk/security/advisories/ ); const isAccepted = response.statusCode >= 200 && response.statusCode < 300; + const initializedNotification = + isAccepted && specVersion !== DRAFT_PROTOCOL_VERSION + ? await sendInitializedNotification( + serverUrl, + validHost, + specVersion, + response.sessionId + ) + : undefined; + const initializedAccepted = + !initializedNotification || + (initializedNotification.statusCode >= 200 && + initializedNotification.statusCode < 300); const details = { hostHeader: validHost, originHeader: `http://${validHost}`, statusCode: response.statusCode, - body: response.body + body: response.body, + ...(initializedNotification + ? { + initializedNotification: { + statusCode: initializedNotification.statusCode, + body: initializedNotification.body, + sessionIdSent: response.sessionId ?? null + } + } + : {}) }; - if (isAccepted) { + if (isAccepted && initializedAccepted) { checks.push({ ...acceptedCheckBase, status: 'SUCCESS', details }); } else { + const errorMessage = !isAccepted + ? `Expected HTTP 2xx for valid localhost Host/Origin headers, got ${response.statusCode}` + : `Expected HTTP 2xx for initialized notification after valid localhost initialize, got ${initializedNotification?.statusCode}`; checks.push({ ...acceptedCheckBase, status: 'FAILURE', - errorMessage: `Expected HTTP 2xx for valid localhost Host/Origin headers, got ${response.statusCode}`, + errorMessage, details }); } diff --git a/src/scenarios/server/negative.test.ts b/src/scenarios/server/negative.test.ts index 583c5eb..86ca54f 100644 --- a/src/scenarios/server/negative.test.ts +++ b/src/scenarios/server/negative.test.ts @@ -1,5 +1,7 @@ import { testContext } from '../../connection/testing'; import { spawn, ChildProcess } from 'child_process'; +import { createServer, type IncomingMessage, type Server } from 'http'; +import type { AddressInfo } from 'net'; import path from 'path'; import { DNSRebindingProtectionScenario } from './dns-rebinding'; import { ResourcesNotFoundErrorScenario } from './resources'; @@ -54,6 +56,47 @@ function stopServer(proc: ChildProcess | null): Promise { }); } +async function readJsonBody( + req: IncomingMessage +): Promise> { + const chunks: Buffer[] = []; + for await (const chunk of req) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); + } + return JSON.parse(Buffer.concat(chunks).toString('utf8')) as Record< + string, + unknown + >; +} + +function listenOnRandomPort(server: Server): Promise { + return new Promise((resolve, reject) => { + const onError = (error: Error) => reject(error); + server.once('error', onError); + server.listen(0, () => { + server.off('error', onError); + const address = server.address(); + if (!address || typeof address === 'string') { + reject(new Error('Server did not listen on a TCP port')); + return; + } + resolve(`http://localhost:${(address as AddressInfo).port}/mcp`); + }); + }); +} + +function closeHttpServer(server: Server): Promise { + return new Promise((resolve, reject) => { + server.close((error) => { + if (error) { + reject(error); + return; + } + resolve(); + }); + }); +} + describe('Server scenario negative tests', () => { describe('dns-rebinding-protection', () => { let serverProcess: ChildProcess | null = null; @@ -84,6 +127,174 @@ describe('Server scenario negative tests', () => { ); expect(rebindingCheck?.status).toBe('FAILURE'); }, 10000); + + it('sends initialized notification after a successful dated initialize', async () => { + const requests: Array<{ + method?: string; + host?: string; + origin?: string; + sessionId?: string; + }> = []; + const sessionId = 'session-338'; + const server = createServer(async (req, res) => { + if (req.method !== 'POST' || req.url !== '/mcp') { + res.writeHead(404).end(); + return; + } + + const body = await readJsonBody(req); + const method = + typeof body.method === 'string' ? body.method : undefined; + requests.push({ + method, + host: req.headers.host, + origin: + typeof req.headers.origin === 'string' + ? req.headers.origin + : undefined, + sessionId: + typeof req.headers['mcp-session-id'] === 'string' + ? req.headers['mcp-session-id'] + : undefined + }); + + if ( + req.headers.host === 'evil.example.com' || + req.headers.origin === 'http://evil.example.com' + ) { + res + .writeHead(403, { 'Content-Type': 'application/json' }) + .end(JSON.stringify({ error: 'forbidden' })); + return; + } + + if (method === 'initialize') { + res + .writeHead(200, { + 'Content-Type': 'application/json', + 'Mcp-Session-Id': sessionId + }) + .end( + JSON.stringify({ + jsonrpc: '2.0', + id: body.id, + result: { + protocolVersion: LATEST_SPEC_VERSION, + capabilities: {}, + serverInfo: { + name: 'dns-rebinding-lifecycle-test', + version: '1.0.0' + } + } + }) + ); + return; + } + + if (method === 'notifications/initialized') { + res.writeHead( + req.headers['mcp-session-id'] === sessionId ? 202 : 400 + ); + res.end(); + return; + } + + res.writeHead(500).end(); + }); + + const serverUrl = await listenOnRandomPort(server); + try { + const checks = await new DNSRebindingProtectionScenario().run( + testContext(serverUrl, LATEST_SPEC_VERSION) + ); + + expect( + checks.find((c) => c.id === 'localhost-host-rebinding-rejected') + ?.status + ).toBe('SUCCESS'); + expect( + checks.find((c) => c.id === 'localhost-host-valid-accepted')?.status + ).toBe('SUCCESS'); + + const validHost = new URL(serverUrl).host; + expect(requests).toEqual([ + expect.objectContaining({ + method: 'initialize', + host: 'evil.example.com', + origin: 'http://evil.example.com' + }), + expect.objectContaining({ + method: 'initialize', + host: validHost, + origin: `http://${validHost}` + }), + expect.objectContaining({ + method: 'notifications/initialized', + host: validHost, + origin: `http://${validHost}`, + sessionId + }) + ]); + } finally { + await closeHttpServer(server); + } + }, 10000); + + it('does not send initialized notification for the draft discover probe', async () => { + const methods: Array = []; + const server = createServer(async (req, res) => { + if (req.method !== 'POST' || req.url !== '/mcp') { + res.writeHead(404).end(); + return; + } + + const body = await readJsonBody(req); + const method = + typeof body.method === 'string' ? body.method : undefined; + methods.push(method); + + if ( + req.headers.host === 'evil.example.com' || + req.headers.origin === 'http://evil.example.com' + ) { + res + .writeHead(403, { 'Content-Type': 'application/json' }) + .end(JSON.stringify({ error: 'forbidden' })); + return; + } + + if (method === 'server/discover') { + res.writeHead(200, { 'Content-Type': 'application/json' }).end( + JSON.stringify({ + jsonrpc: '2.0', + id: body.id, + result: {} + }) + ); + return; + } + + res.writeHead(500).end(); + }); + + const serverUrl = await listenOnRandomPort(server); + try { + const checks = await new DNSRebindingProtectionScenario().run( + testContext(serverUrl, DRAFT_PROTOCOL_VERSION) + ); + + expect( + checks.find((c) => c.id === 'localhost-host-rebinding-rejected') + ?.status + ).toBe('SUCCESS'); + expect( + checks.find((c) => c.id === 'localhost-host-valid-accepted')?.status + ).toBe('SUCCESS'); + expect(methods).toEqual(['server/discover', 'server/discover']); + } finally { + await closeHttpServer(server); + } + }, 10000); }); describe('sep-2164-resource-not-found', () => {