Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions src/scenarios/server/dns-rebinding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ function probeBody(specVersion: string): {
};
}

function getResponseHeader(
headers: Record<string, string | string[] | undefined>,
name: string
): string | undefined {
const value = headers[name] ?? headers[name.toLowerCase()];
if (Array.isArray(value)) return value[0];
return value;
}

/**
* Send an MCP request with custom Host and Origin headers.
* Both headers are set to the same value so that servers checking either
Expand All @@ -96,7 +105,7 @@ async function sendRequestWithHostAndOrigin(
serverUrl: string,
hostOrOrigin: string,
specVersion: string
): Promise<{ statusCode: number; body: unknown }> {
): Promise<{ statusCode: number; body: unknown; sessionId?: string }> {
// Build the SEP-2243 standard headers (Mcp-Method, Accept, ...) for the
// probe's JSON-RPC method so a strictly-conformant server only rejects the
// request for the Host/Origin values under test, then layer the
Expand All @@ -121,6 +130,43 @@ async function sendRequestWithHostAndOrigin(
body = null;
}

return {
statusCode: response.statusCode,
body,
sessionId: getResponseHeader(response.headers, 'mcp-session-id')
};
}

async function sendInitializedNotification(
serverUrl: string,
hostOrOrigin: string,
specVersion: string,
sessionId?: string
): Promise<{ statusCode: number; body: unknown }> {
const method = 'notifications/initialized';
const response = await request(serverUrl, {
method: 'POST',
headers: buildStandardHeaders(method, undefined, {
headers: {
Host: hostOrOrigin,
Origin: `http://${hostOrOrigin}`,
'MCP-Protocol-Version': specVersion,
...(sessionId ? { 'Mcp-Session-Id': sessionId } : {})
}
}),
body: JSON.stringify({
jsonrpc: '2.0',
method
})
});

let body: unknown;
try {
body = await response.body.json();
} catch {
body = null;
}

return {
statusCode: response.statusCode,
body
Expand Down Expand Up @@ -253,25 +299,50 @@ See: https://github.com/modelcontextprotocol/typescript-sdk/security/advisories/
);
const isAccepted =
response.statusCode >= 200 && response.statusCode < 300;
const initializedNotification =
isAccepted && specVersion !== DRAFT_PROTOCOL_VERSION
? await sendInitializedNotification(
serverUrl,
validHost,
specVersion,
response.sessionId
)
: undefined;
const initializedAccepted =
!initializedNotification ||
(initializedNotification.statusCode >= 200 &&
initializedNotification.statusCode < 300);

const details = {
hostHeader: validHost,
originHeader: `http://${validHost}`,
statusCode: response.statusCode,
body: response.body
body: response.body,
...(initializedNotification
? {
initializedNotification: {
statusCode: initializedNotification.statusCode,
body: initializedNotification.body,
sessionIdSent: response.sessionId ?? null
}
}
: {})
};

if (isAccepted) {
if (isAccepted && initializedAccepted) {
checks.push({
...acceptedCheckBase,
status: 'SUCCESS',
details
});
} else {
const errorMessage = !isAccepted
? `Expected HTTP 2xx for valid localhost Host/Origin headers, got ${response.statusCode}`
: `Expected HTTP 2xx for initialized notification after valid localhost initialize, got ${initializedNotification?.statusCode}`;
checks.push({
...acceptedCheckBase,
status: 'FAILURE',
errorMessage: `Expected HTTP 2xx for valid localhost Host/Origin headers, got ${response.statusCode}`,
errorMessage,
details
});
}
Expand Down
211 changes: 211 additions & 0 deletions src/scenarios/server/negative.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { testContext } from '../../connection/testing';
import { spawn, ChildProcess } from 'child_process';
import { createServer, type IncomingMessage, type Server } from 'http';
import type { AddressInfo } from 'net';
import path from 'path';
import { DNSRebindingProtectionScenario } from './dns-rebinding';
import { ResourcesNotFoundErrorScenario } from './resources';
Expand Down Expand Up @@ -54,6 +56,47 @@ function stopServer(proc: ChildProcess | null): Promise<void> {
});
}

async function readJsonBody(
req: IncomingMessage
): Promise<Record<string, unknown>> {
const chunks: Buffer[] = [];
for await (const chunk of req) {
chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
}
return JSON.parse(Buffer.concat(chunks).toString('utf8')) as Record<
string,
unknown
>;
}

function listenOnRandomPort(server: Server): Promise<string> {
return new Promise((resolve, reject) => {
const onError = (error: Error) => reject(error);
server.once('error', onError);
server.listen(0, () => {
server.off('error', onError);
const address = server.address();
if (!address || typeof address === 'string') {
reject(new Error('Server did not listen on a TCP port'));
return;
}
resolve(`http://localhost:${(address as AddressInfo).port}/mcp`);
});
});
}

function closeHttpServer(server: Server): Promise<void> {
return new Promise((resolve, reject) => {
server.close((error) => {
if (error) {
reject(error);
return;
}
resolve();
});
});
}

describe('Server scenario negative tests', () => {
describe('dns-rebinding-protection', () => {
let serverProcess: ChildProcess | null = null;
Expand Down Expand Up @@ -84,6 +127,174 @@ describe('Server scenario negative tests', () => {
);
expect(rebindingCheck?.status).toBe('FAILURE');
}, 10000);

it('sends initialized notification after a successful dated initialize', async () => {
const requests: Array<{
method?: string;
host?: string;
origin?: string;
sessionId?: string;
}> = [];
const sessionId = 'session-338';
const server = createServer(async (req, res) => {
if (req.method !== 'POST' || req.url !== '/mcp') {
res.writeHead(404).end();
return;
}

const body = await readJsonBody(req);
const method =
typeof body.method === 'string' ? body.method : undefined;
requests.push({
method,
host: req.headers.host,
origin:
typeof req.headers.origin === 'string'
? req.headers.origin
: undefined,
sessionId:
typeof req.headers['mcp-session-id'] === 'string'
? req.headers['mcp-session-id']
: undefined
});

if (
req.headers.host === 'evil.example.com' ||
req.headers.origin === 'http://evil.example.com'
) {
res
.writeHead(403, { 'Content-Type': 'application/json' })
.end(JSON.stringify({ error: 'forbidden' }));
return;
}

if (method === 'initialize') {
res
.writeHead(200, {
'Content-Type': 'application/json',
'Mcp-Session-Id': sessionId
})
.end(
JSON.stringify({
jsonrpc: '2.0',
id: body.id,
result: {
protocolVersion: LATEST_SPEC_VERSION,
capabilities: {},
serverInfo: {
name: 'dns-rebinding-lifecycle-test',
version: '1.0.0'
}
}
})
);
return;
}

if (method === 'notifications/initialized') {
res.writeHead(
req.headers['mcp-session-id'] === sessionId ? 202 : 400
);
res.end();
return;
}

res.writeHead(500).end();
});

const serverUrl = await listenOnRandomPort(server);
try {
const checks = await new DNSRebindingProtectionScenario().run(
testContext(serverUrl, LATEST_SPEC_VERSION)
);

expect(
checks.find((c) => c.id === 'localhost-host-rebinding-rejected')
?.status
).toBe('SUCCESS');
expect(
checks.find((c) => c.id === 'localhost-host-valid-accepted')?.status
).toBe('SUCCESS');

const validHost = new URL(serverUrl).host;
expect(requests).toEqual([
expect.objectContaining({
method: 'initialize',
host: 'evil.example.com',
origin: 'http://evil.example.com'
}),
expect.objectContaining({
method: 'initialize',
host: validHost,
origin: `http://${validHost}`
}),
expect.objectContaining({
method: 'notifications/initialized',
host: validHost,
origin: `http://${validHost}`,
sessionId
})
]);
} finally {
await closeHttpServer(server);
}
}, 10000);

it('does not send initialized notification for the draft discover probe', async () => {
const methods: Array<string | undefined> = [];
const server = createServer(async (req, res) => {
if (req.method !== 'POST' || req.url !== '/mcp') {
res.writeHead(404).end();
return;
}

const body = await readJsonBody(req);
const method =
typeof body.method === 'string' ? body.method : undefined;
methods.push(method);

if (
req.headers.host === 'evil.example.com' ||
req.headers.origin === 'http://evil.example.com'
) {
res
.writeHead(403, { 'Content-Type': 'application/json' })
.end(JSON.stringify({ error: 'forbidden' }));
return;
}

if (method === 'server/discover') {
res.writeHead(200, { 'Content-Type': 'application/json' }).end(
JSON.stringify({
jsonrpc: '2.0',
id: body.id,
result: {}
})
);
return;
}

res.writeHead(500).end();
});

const serverUrl = await listenOnRandomPort(server);
try {
const checks = await new DNSRebindingProtectionScenario().run(
testContext(serverUrl, DRAFT_PROTOCOL_VERSION)
);

expect(
checks.find((c) => c.id === 'localhost-host-rebinding-rejected')
?.status
).toBe('SUCCESS');
expect(
checks.find((c) => c.id === 'localhost-host-valid-accepted')?.status
).toBe('SUCCESS');
expect(methods).toEqual(['server/discover', 'server/discover']);
} finally {
await closeHttpServer(server);
}
}, 10000);
});

describe('sep-2164-resource-not-found', () => {
Expand Down