From 6aa731b0cef233970277e73323d84469b10d33b3 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Tue, 12 May 2026 11:29:49 +0000 Subject: [PATCH] refactor(core): add internal Dispatcher and StreamDriver classes Dispatcher is a stateless handler registry with a dispatch(req, env) async-generator that yields notifications then one terminal response. StreamDriver owns per-connection state (id correlation, timeouts, progress, cancellation) over a persistent transport. Additive: new files, not yet wired. Protocol composes them in R2. --- packages/core/src/index.ts | 3 + packages/core/src/shared/context.ts | 57 +++ packages/core/src/shared/dispatcher.ts | 396 ++++++++++++++++ packages/core/src/shared/protocol.ts | 6 + packages/core/src/shared/streamDriver.ts | 446 ++++++++++++++++++ packages/core/test/shared/dispatcher.test.ts | 322 +++++++++++++ .../core/test/shared/streamDriver.test.ts | 202 ++++++++ 7 files changed, 1432 insertions(+) create mode 100644 packages/core/src/shared/context.ts create mode 100644 packages/core/src/shared/dispatcher.ts create mode 100644 packages/core/src/shared/streamDriver.ts create mode 100644 packages/core/test/shared/dispatcher.test.ts create mode 100644 packages/core/test/shared/streamDriver.test.ts diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 8bcc9c959..7a96cddf4 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -2,10 +2,13 @@ export * from './auth/errors.js'; export * from './errors/sdkErrors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; +export * from './shared/context.js'; +export * from './shared/dispatcher.js'; export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; +export * from './shared/streamDriver.js'; export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from './shared/taskManager.js'; export { extractTaskManagerOptions, NullTaskManager, TaskManager } from './shared/taskManager.js'; export * from './shared/toolNameValidation.js'; diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts new file mode 100644 index 000000000..2b878ed5b --- /dev/null +++ b/packages/core/src/shared/context.ts @@ -0,0 +1,57 @@ +import type { AuthInfo, JSONRPCMessage, Notification, Request, RequestId, Result } from '../types/index.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import type { NotificationOptions, RequestOptions } from './protocol.js'; + +/** + * Per-request environment a transport adapter passes to {@linkcode Dispatcher.dispatch}. + * Everything is optional; a bare `dispatch()` call works with no transport at all. + * + * @internal + */ +export type RequestEnv = { + /** + * Sends a request back to the peer (server→client elicitation/sampling, or + * client→server nested calls). Supplied by {@linkcode StreamDriver} when running + * over a persistent pipe, or by an HTTP adapter that has a backchannel. When + * undefined, `ctx.mcpReq.send` throws {@linkcode SdkErrorCode.NotConnected}. + */ + send?: (request: Request, options?: RequestOptions) => Promise; + + /** Validated auth token info for HTTP transports. */ + authInfo?: AuthInfo; + + /** Original HTTP `Request` (Fetch API), if any. */ + httpReq?: globalThis.Request; + + /** Abort signal for the inbound request. If omitted, a fresh controller is created. */ + signal?: AbortSignal; + + /** Transport session identifier (legacy `Mcp-Session-Id`). */ + sessionId?: string; + + /** Extension slot. Adapters and middleware populate keys here; copied onto `BaseContext.ext`. */ + ext?: Record; +}; + +/** + * The minimal contract a {@linkcode Dispatcher} owner needs to send outbound + * requests/notifications to the connected peer. Implemented by + * {@linkcode StreamDriver} for persistent pipes; request-shaped paths can supply + * their own. + * + * @internal + */ +export interface Outbound { + /** Send a request to the peer and resolve with the parsed result. */ + request(req: Request, resultSchema: T, options?: RequestOptions): Promise>; + /** Send a notification to the peer. */ + notification(notification: Notification, options?: NotificationOptions): Promise; + /** Close the underlying connection. */ + close(): Promise; + /** Clear a registered progress callback by its message id. Optional; pipe-channels expose this. */ + removeProgressHandler?(messageId: number): void; + /** Inform the channel which protocol version was negotiated (for header echoing etc.). Optional. */ + setProtocolVersion?(version: string): void; + /** Write a raw JSON-RPC message on the same stream as a prior request. Optional; pipe-only. */ + sendRaw?(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; +} diff --git a/packages/core/src/shared/dispatcher.ts b/packages/core/src/shared/dispatcher.ts new file mode 100644 index 000000000..e1709ae8a --- /dev/null +++ b/packages/core/src/shared/dispatcher.ts @@ -0,0 +1,396 @@ +import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + Notification, + NotificationMethod, + NotificationTypeMap, + Request, + RequestId, + RequestMethod, + RequestTypeMap, + Result, + ResultTypeMap +} from '../types/index.js'; +import { getNotificationSchema, getRequestSchema, getResultSchema, ProtocolError, ProtocolErrorCode } from '../types/index.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import { validateStandardSchema } from '../util/standardSchema.js'; +import type { RequestEnv } from './context.js'; +import type { BaseContext, RequestOptions } from './protocol.js'; + +/** + * One yielded item from {@linkcode Dispatcher.dispatch}. A dispatch yields zero or more + * notifications followed by exactly one terminal response. + * + * @internal + */ +export type DispatchOutput = + | { kind: 'notification'; message: JSONRPCNotification } + | { kind: 'response'; message: JSONRPCResponse | JSONRPCErrorResponse }; + +/** @internal */ +export type RawHandler = (request: JSONRPCRequest, ctx: ContextT) => Promise; + +/** Signature of {@linkcode Dispatcher.dispatch}. Target type for {@linkcode DispatchMiddleware}. @internal */ +export type DispatchFn = (req: JSONRPCRequest, env?: RequestEnv) => AsyncGenerator; + +/** + * Onion-style middleware around {@linkcode Dispatcher.dispatch}. Registered via + * {@linkcode Dispatcher.use}; composed outermost-first (registration order). + * + * A middleware may transform `req`/`env` before delegating, transform or filter + * yielded outputs, or short-circuit by yielding a response without calling `next`. + * + * @internal + */ +export type DispatchMiddleware = (next: DispatchFn) => DispatchFn; + +/** + * Derives the handler return type for the 3-arg `setRequestHandler` form from its + * `result` schema, defaulting to {@linkcode Result} when no schema is supplied. + * + * @internal + */ +export type DispatcherInferResult = R extends StandardSchemaV1 + ? StandardSchemaV1.InferOutput + : Result; + +/** + * Options for constructing a {@linkcode Dispatcher}. + * + * @internal + */ +export type DispatcherOptions = { + /** + * Enriches the base context for the owner's specific role (server/client). + * Receives the base context the dispatcher built plus the {@linkcode RequestEnv} + * the adapter supplied. Default returns the base unchanged. + */ + buildContext?: (base: BaseContext, env: RequestEnv) => ContextT; + + /** + * Wraps every registered request handler with role-specific validation + * (e.g. capability checks). Runs for both the 2-arg and 3-arg registration + * paths. Default is identity. + */ + wrapHandler?: (method: string, handler: RawHandler) => RawHandler; +}; + +/** + * Stateless JSON-RPC handler registry with a request-in / messages-out + * {@linkcode Dispatcher.dispatch | dispatch()} entry point. + * + * Holds no transport, no correlation state, no timers. One instance can serve + * any number of concurrent requests from any driver. + * + * @internal + */ +export class Dispatcher { + protected _requestHandlers: Map> = new Map(); + protected _notificationHandlers: Map Promise> = new Map(); + private _dispatchMw: DispatchMiddleware[] = []; + + /** + * A handler to invoke for any request types that do not have their own handler installed. + */ + fallbackRequestHandler?: (request: JSONRPCRequest, ctx: ContextT) => Promise; + + /** + * A handler to invoke for any notification types that do not have their own handler installed. + */ + fallbackNotificationHandler?: (notification: Notification) => Promise; + + constructor(private _options: DispatcherOptions = {}) {} + + /** + * Registers a {@linkcode DispatchMiddleware}. Registration order is outer-to-inner: + * the first middleware registered sees the rawest request and the final yields. + */ + use(mw: DispatchMiddleware): this { + this._dispatchMw.push(mw); + return this; + } + + /** + * Dispatch one inbound request through the registered middleware chain, then the + * core handler lookup. Yields any notifications the handler emits via + * `ctx.mcpReq.notify()`, then yields exactly one terminal response. + * + * Never throws for handler errors; they are wrapped as JSON-RPC error responses. + * May throw if iteration itself is misused. + */ + dispatch(request: JSONRPCRequest, env: RequestEnv = {}): AsyncGenerator { + // eslint-disable-next-line unicorn/consistent-function-scoping -- closes over `this` + let chain: DispatchFn = (r, e) => this._dispatchCore(r, e); + // eslint-disable-next-line unicorn/no-array-reverse -- toReversed() requires ES2023 lib; consumers may target ES2022 + for (const mw of [...this._dispatchMw].reverse()) chain = mw(chain); + return chain(request, env); + } + + /** + * The handler lookup + invocation. Middleware composes around this; owners do + * not override `dispatch()` directly — use {@linkcode Dispatcher.use | use()} instead. + */ + private async *_dispatchCore(request: JSONRPCRequest, env: RequestEnv = {}): AsyncGenerator { + const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; + + if (handler === undefined) { + yield errorResponse(request.id, ProtocolErrorCode.MethodNotFound, 'Method not found'); + return; + } + + const queue: JSONRPCNotification[] = []; + let wake: (() => void) | undefined; + let done = false; + let final: JSONRPCResponse | JSONRPCErrorResponse | undefined; + + const localAbort = new AbortController(); + const onEnvAbort = () => localAbort.abort(env.signal!.reason); + if (env.signal) { + if (env.signal.aborted) localAbort.abort(env.signal.reason); + else env.signal.addEventListener('abort', onEnvAbort, { once: true }); + } + + const send = + env.send ?? + (async () => { + throw new SdkError(SdkErrorCode.NotConnected, 'No outbound channel: ctx.mcpReq.send requires a connected peer'); + }); + + const base: BaseContext = { + sessionId: env.sessionId, + mcpReq: { + id: request.id, + method: request.method, + _meta: request.params?._meta, + signal: localAbort.signal, + send: (async (r: Request, schemaOrOptions?: unknown, maybeOptions?: RequestOptions) => { + const isSchema = schemaOrOptions != null && typeof schemaOrOptions === 'object' && '~standard' in schemaOrOptions; + const options = isSchema ? maybeOptions : (schemaOrOptions as RequestOptions | undefined); + const resultSchema = isSchema ? (schemaOrOptions as StandardSchemaV1) : getResultSchema(r.method as RequestMethod); + if (!resultSchema) { + throw new TypeError( + `'${r.method}' is not a spec method; pass a result schema as the second argument to ctx.mcpReq.send().` + ); + } + // Thread the dispatch-local abort so env.send sinks (e.g. BackchannelCompat) see + // cancellation when the inbound request is aborted, instead of waiting for their own timeout. + const result = await send(r, { ...options, signal: options?.signal ?? localAbort.signal }); + const parsed = await validateStandardSchema(resultSchema, result); + if (!parsed.success) { + throw new SdkError(SdkErrorCode.InvalidResult, `Invalid result for ${r.method}: ${parsed.error}`); + } + return parsed.data; + }) as BaseContext['mcpReq']['send'], + notify: async (n: Notification) => { + if (done) return; + queue.push({ jsonrpc: '2.0', method: n.method, params: n.params } as JSONRPCNotification); + wake?.(); + } + }, + http: env.authInfo || env.httpReq ? { authInfo: env.authInfo } : undefined, + ext: env.ext + }; + const ctx = this._options.buildContext ? this._options.buildContext(base, env) : (base as ContextT); + + Promise.resolve() + .then(() => handler(request, ctx)) + .then( + result => { + final = localAbort.signal.aborted + ? errorResponse(request.id, ProtocolErrorCode.InternalError, 'Request cancelled').message + : { jsonrpc: '2.0', id: request.id, result }; + }, + error => { + final = localAbort.signal.aborted + ? errorResponse(request.id, ProtocolErrorCode.InternalError, 'Request cancelled').message + : toErrorResponse(request.id, error); + } + ) + .finally(() => { + done = true; + wake?.(); + }); + + try { + while (true) { + while (queue.length > 0) { + yield { kind: 'notification', message: queue.shift()! }; + } + if (done) break; + await new Promise(resolve => { + wake = resolve; + }); + wake = undefined; + } + // Drain anything pushed between done=true and the wake. + while (queue.length > 0) { + yield { kind: 'notification', message: queue.shift()! }; + } + yield { kind: 'response', message: final! }; + } finally { + // Consumer broke early (generator.return()): tell the still-running handler to stop + // and detach the env.signal listener so a long-lived caller signal doesn't leak. + if (!done) localAbort.abort(new SdkError(SdkErrorCode.ConnectionClosed, 'dispatch consumer closed')); + env.signal?.removeEventListener('abort', onEnvAbort); + } + } + + /** + * Dispatch one inbound notification to its handler. Errors are reported via the + * returned promise; unknown methods are silently ignored. + */ + async dispatchNotification(notification: JSONRPCNotification): Promise { + const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; + if (handler === undefined) return; + await Promise.resolve().then(() => handler(notification)); + } + + /** + * Registers a handler to invoke when this dispatcher receives a request with the given method. + * + * Note that this will replace any previous request handler for the same method. + * + * For spec methods, pass `(method, handler)`; the request is parsed with the spec + * schema and the handler receives the typed `Request`. For custom (non-spec) + * methods, pass `(method, schemas, handler)`; `params` are validated against + * `schemas.params` and the handler receives the parsed params object directly. + * Supplying `schemas.result` types the handler's return value. + */ + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise + ): void; + setRequestHandler

( + method: string, + schemas: { params: P; result?: R }, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ContextT) => DispatcherInferResult | Promise> + ): void; + setRequestHandler(method: string, schemasOrHandler: unknown, maybeHandler?: unknown): void { + let stored: RawHandler; + + if (maybeHandler === undefined) { + const handler = schemasOrHandler as (request: unknown, ctx: ContextT) => Result | Promise; + const schema = getRequestSchema(method as RequestMethod); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` + ); + } + stored = (request, ctx) => { + const parsed = schema.parse(request); + return Promise.resolve(handler(parsed, ctx)); + }; + } else { + const schemas = schemasOrHandler as { params: StandardSchemaV1 }; + const handler = maybeHandler as (params: unknown, ctx: ContextT) => Result | Promise; + stored = async (request, ctx) => { + const userParams = { ...((request.params ?? {}) as Record) }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemas.params, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); + } + return handler(parsed.data, ctx); + }; + } + + const wrap = this._options.wrapHandler ?? ((_m, h) => h); + this._requestHandlers.set(method, wrap(method, stored)); + } + + /** Registers a raw handler with no schema parsing. Used for compat shims. */ + setRawRequestHandler(method: string, handler: RawHandler): void { + this._requestHandlers.set(method, handler); + } + + removeRequestHandler(method: string): void { + this._requestHandlers.delete(method); + } + + assertCanSetRequestHandler(method: string): void { + if (this._requestHandlers.has(method)) { + throw new Error(`A request handler for ${method} already exists, which would be overridden`); + } + } + + /** + * Registers a handler to invoke when this dispatcher receives a notification with the given method. + * + * Note that this will replace any previous notification handler for the same method. + * + * For spec methods, pass `(method, handler)`; the notification is parsed with the + * spec schema. For custom (non-spec) methods, pass `(method, schemas, handler)`; + * `params` are validated against `schemas.params` and the handler receives the + * parsed params object directly. + */ + setNotificationHandler( + method: M, + handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void; + setNotificationHandler

( + method: string, + schemas: { params: P }, + handler: (params: StandardSchemaV1.InferOutput

, notification: Notification) => void | Promise + ): void; + setNotificationHandler(method: string, schemasOrHandler: unknown, maybeHandler?: unknown): void { + if (maybeHandler !== undefined) { + const schemas = schemasOrHandler as { params: StandardSchemaV1 }; + const handler = maybeHandler as (params: unknown, notification: Notification) => void | Promise; + this._notificationHandlers.set(method, async notification => { + const userParams = { ...((notification.params ?? {}) as Record) }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemas.params, userParams); + if (!parsed.success) { + throw new Error(`Invalid params for notification ${method}: ${parsed.error}`); + } + await handler(parsed.data, notification); + }); + return; + } + + const handler = schemasOrHandler as (notification: unknown) => void | Promise; + const schema = getNotificationSchema(method as NotificationMethod); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().` + ); + } + this._notificationHandlers.set(method, notification => { + const parsed = schema.parse(notification); + return Promise.resolve(handler(parsed)); + }); + } + + removeNotificationHandler(method: string): void { + this._notificationHandlers.delete(method); + } + + /** Convenience: collect a full dispatch into a single response, discarding notifications. */ + async dispatchToResponse(request: JSONRPCRequest, env?: RequestEnv): Promise { + let resp: JSONRPCResponse | JSONRPCErrorResponse | undefined; + for await (const out of this.dispatch(request, env)) { + if (out.kind === 'response') resp = out.message; + } + return resp!; + } +} + +function errorResponse(id: RequestId, code: number, message: string): { kind: 'response'; message: JSONRPCErrorResponse } { + return { kind: 'response', message: { jsonrpc: '2.0', id, error: { code, message } } }; +} + +function toErrorResponse(id: RequestId, error: unknown): JSONRPCErrorResponse { + const e = error as { code?: unknown; message?: unknown; data?: unknown }; + return { + jsonrpc: '2.0', + id, + error: { + code: Number.isSafeInteger(e?.code) ? (e.code as number) : ProtocolErrorCode.InternalError, + message: typeof e?.message === 'string' ? e.message : 'Internal error', + ...(e?.data !== undefined && { data: e.data }) + } + }; +} diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index d91c80412..cdebf43d9 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -203,6 +203,12 @@ export type BaseContext = { */ authInfo?: AuthInfo; }; + + /** + * Extension slot. Adapters and middleware populate keys here; handlers cast to the + * extension's declared type to read them. Core never reads or writes this field. + */ + ext?: Record; }; /** diff --git a/packages/core/src/shared/streamDriver.ts b/packages/core/src/shared/streamDriver.ts new file mode 100644 index 000000000..84c851e95 --- /dev/null +++ b/packages/core/src/shared/streamDriver.ts @@ -0,0 +1,446 @@ +import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + MessageExtraInfo, + Notification, + Progress, + ProgressNotification, + Request, + RequestId, + Result +} from '../types/index.js'; +import { + isJSONRPCErrorResponse, + isJSONRPCNotification, + isJSONRPCRequest, + isJSONRPCResultResponse, + ProtocolError, + SUPPORTED_PROTOCOL_VERSIONS +} from '../types/index.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import { validateStandardSchema } from '../util/standardSchema.js'; +import type { Outbound, RequestEnv } from './context.js'; +import type { Dispatcher } from './dispatcher.js'; +import type { NotificationOptions, ProgressCallback, RequestOptions } from './protocol.js'; +import { DEFAULT_REQUEST_TIMEOUT_MSEC } from './protocol.js'; +import type { Transport } from './transport.js'; + +/** + * Pass-through Standard Schema. Used where a schema-typed call needs to deliver the raw + * `Result` and let the caller (e.g. {@linkcode Dispatcher}'s `mcpReq.send`) apply the + * user-supplied or spec result schema. + * + * @internal + */ +export const RAW_RESULT_SCHEMA: StandardSchemaV1 = { + '~standard': { version: 1, vendor: 'mcp-passthrough', validate: value => ({ value: value as Result }) } +}; + +type TimeoutInfo = { + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + resetTimeoutOnProgress: boolean; + onTimeout: () => void; +}; + +/** @internal */ +export type StreamDriverOptions = { + supportedProtocolVersions?: string[]; + debouncedNotificationMethods?: string[]; + /** + * Hook to enrich the per-request {@linkcode RequestEnv} from transport-supplied + * {@linkcode MessageExtraInfo} (e.g. auth, http req). + */ + buildEnv?: (extra: MessageExtraInfo | undefined, base: RequestEnv) => RequestEnv; +}; + +/** + * Options for {@linkcode attachChannelTransport}. Mirrors {@linkcode StreamDriverOptions} + * plus the `onclose`/`onerror`/`onresponse` callbacks the caller would otherwise set + * on the driver instance. + * + * @internal + */ +export type AttachOptions = StreamDriverOptions & { + onclose?: () => void; + onerror?: (error: Error) => void; + onresponse?: ( + response: JSONRPCResultResponse | JSONRPCErrorResponse, + messageId: number + ) => { consumed: boolean; preserveProgress?: boolean }; +}; + +/** + * Runs a {@linkcode Dispatcher} over a persistent bidirectional {@linkcode Transport} + * (stdio, WebSocket, InMemory). Owns all per-connection state: outbound request + * id correlation, timeouts, progress callbacks, cancellation, debouncing. + * + * One driver per pipe. The dispatcher it wraps may be shared. + * + * @internal + */ +export class StreamDriver implements Outbound { + private _requestMessageId = 0; + private _responseHandlers: Map void> = new Map(); + private _progressHandlers: Map = new Map(); + private _timeoutInfo: Map = new Map(); + private _requestHandlerAbortControllers: Map = new Map(); + private _pendingDebouncedNotifications = new Set(); + private _closed = false; + private _supportedProtocolVersions: string[]; + + onclose?: () => void; + onerror?: (error: Error) => void; + /** + * Tap for every inbound response. Return `consumed: true` to claim it (suppresses the + * matched-handler dispatch / unknown-id error). Return `preserveProgress: true` to keep + * the progress handler registered after the matched handler runs. Set by the owner. + */ + onresponse?: (response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number) => { consumed: boolean; preserveProgress?: boolean }; + + constructor( + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- driver is context-agnostic; owner supplies ContextT via dispatcher + readonly dispatcher: Dispatcher, + readonly pipe: Transport, + private _options: StreamDriverOptions = {} + ) { + this._supportedProtocolVersions = _options.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + } + + /** Exposed so an owner can clear progress callbacks. See {@linkcode Outbound.removeProgressHandler}. */ + removeProgressHandler(token: number): void { + this._progressHandlers.delete(token); + } + + /** + * Wires the pipe's callbacks and starts it. After this resolves, inbound + * requests are dispatched and {@linkcode StreamDriver.request | request()} works. + */ + async start(): Promise { + const prevClose = this.pipe.onclose; + this.pipe.onclose = () => { + try { + prevClose?.(); + } finally { + this._onclose(); + } + }; + + const prevError = this.pipe.onerror; + this.pipe.onerror = (error: Error) => { + prevError?.(error); + this._onerror(error); + }; + + const prevMessage = this.pipe.onmessage; + this.pipe.onmessage = (message, extra) => { + prevMessage?.(message, extra); + if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { + this._onresponse(message); + } else if (isJSONRPCRequest(message)) { + this._onrequest(message, extra); + } else if (isJSONRPCNotification(message)) { + this._onnotification(message); + } else { + this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); + } + }; + + this.pipe.setSupportedProtocolVersions?.(this._supportedProtocolVersions); + await this.pipe.start(); + } + + async close(): Promise { + await this.pipe.close(); + } + + /** {@linkcode Outbound.setProtocolVersion} — delegates to the pipe. */ + setProtocolVersion(version: string): void { + this.pipe.setProtocolVersion?.(version); + } + + /** {@linkcode Outbound.sendRaw} — write a raw JSON-RPC message to the pipe. */ + async sendRaw(message: Parameters[0], options?: { relatedRequestId?: RequestId }): Promise { + await this.pipe.send(message, options); + } + + /** + * Sends a request over the pipe and resolves with the parsed result. + */ + request(req: Request, resultSchema: T, options?: RequestOptions): Promise> { + const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + let onAbort: (() => void) | undefined; + let cleanupId: number | undefined; + + let responseReceived = false; + + return new Promise>((resolve, reject) => { + if (options?.signal?.aborted) { + const reason = options.signal.reason; + throw reason instanceof Error ? reason : new DOMException('Request was aborted before send', 'AbortError'); + } + + const messageId = this._requestMessageId++; + cleanupId = messageId; + const jsonrpcRequest: JSONRPCRequest = { ...req, jsonrpc: '2.0', id: messageId }; + + if (options?.onprogress) { + this._progressHandlers.set(messageId, options.onprogress); + jsonrpcRequest.params = { + ...req.params, + _meta: { ...(req.params?._meta as Record | undefined), progressToken: messageId } + }; + } + + const cancel = (reason: unknown) => { + if (responseReceived) return; + this._progressHandlers.delete(messageId); + this.pipe + .send( + { + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { requestId: messageId, reason: String(reason) } + }, + { relatedRequestId, resumptionToken, onresumptiontoken } + ) + .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`))); + const error = reason instanceof SdkError ? reason : new SdkError(SdkErrorCode.RequestTimeout, String(reason)); + reject(error); + }; + + this._responseHandlers.set(messageId, response => { + responseReceived = true; + if (options?.signal?.aborted) return; + if (response instanceof Error) return reject(response); + validateStandardSchema(resultSchema, response.result).then( + parsed => + parsed.success + ? resolve(parsed.data) + : reject(new SdkError(SdkErrorCode.InvalidResult, `Invalid result for ${req.method}: ${parsed.error}`)), + reject + ); + }); + + onAbort = () => cancel(options?.signal?.reason); + options?.signal?.addEventListener('abort', onAbort, { once: true }); + + const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; + this._setupTimeout( + messageId, + timeout, + options?.maxTotalTimeout, + () => cancel(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })), + options?.resetTimeoutOnProgress ?? false + ); + + this.pipe.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + this._progressHandlers.delete(messageId); + reject(error); + }); + }).finally(() => { + if (onAbort) options?.signal?.removeEventListener('abort', onAbort); + if (cleanupId !== undefined) { + this._responseHandlers.delete(cleanupId); + this._cleanupTimeout(cleanupId); + } + }); + } + + /** + * Sends a notification over the pipe. Supports debouncing per the constructor option. + */ + async notification(notification: Notification, options?: NotificationOptions): Promise { + if (this._closed) return; + const jsonrpc: JSONRPCNotification = { jsonrpc: '2.0', method: notification.method, params: notification.params }; + + const debounced = this._options.debouncedNotificationMethods ?? []; + const canDebounce = debounced.includes(notification.method) && !notification.params && !options?.relatedRequestId; + if (canDebounce) { + if (this._pendingDebouncedNotifications.has(notification.method)) return; + this._pendingDebouncedNotifications.add(notification.method); + Promise.resolve().then(() => { + // If the entry was already removed (by _onclose), skip the send. + if (!this._pendingDebouncedNotifications.delete(notification.method)) return; + this.pipe.send(jsonrpc, options).catch(error => this._onerror(error)); + }); + return; + } + await this.pipe.send(jsonrpc, options); + } + + private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { + const abort = new AbortController(); + this._requestHandlerAbortControllers.set(request.id, abort); + + const baseEnv: RequestEnv = { + signal: abort.signal, + authInfo: extra?.authInfo, + httpReq: extra?.request, + sessionId: this.pipe.sessionId, + send: (r, opts) => this.request(r, RAW_RESULT_SCHEMA, { ...opts, relatedRequestId: request.id }) as Promise + }; + const env = this._options.buildEnv ? this._options.buildEnv(extra, baseEnv) : baseEnv; + + const drain = async () => { + for await (const out of this.dispatcher.dispatch(request, env)) { + if (out.kind === 'notification') { + await this.pipe.send(out.message, { relatedRequestId: request.id }); + } else if (!abort.signal.aborted) { + await this.pipe.send(out.message, { relatedRequestId: request.id }); + } + } + }; + drain() + .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) + .finally(() => { + if (this._requestHandlerAbortControllers.get(request.id) === abort) { + this._requestHandlerAbortControllers.delete(request.id); + } + }); + } + + private _onnotification(notification: JSONRPCNotification): void { + if (notification.method === 'notifications/cancelled') { + const requestId = (notification.params as { requestId?: RequestId } | undefined)?.requestId; + if (requestId !== undefined) + this._requestHandlerAbortControllers.get(requestId)?.abort((notification.params as { reason?: unknown })?.reason); + return; + } + if (notification.method === 'notifications/progress') { + this._onprogress(notification.params as ProgressNotification['params']); + return; + } + this.dispatcher + .dispatchNotification(notification) + .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); + } + + private _onprogress(progressParams: ProgressNotification['params']): void { + const { progressToken, ...params } = progressParams; + const messageId = Number(progressToken); + const handler = this._progressHandlers.get(messageId); + if (!handler) { + this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(progressParams)}`)); + return; + } + const responseHandler = this._responseHandlers.get(messageId); + const info = this._timeoutInfo.get(messageId); + if (info && responseHandler && info.resetTimeoutOnProgress) { + try { + this._resetTimeout(messageId); + } catch (error) { + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); + responseHandler(error as Error); + return; + } + } + handler(params as Progress); + } + + private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { + const messageId = Number(response.id); + const tap = this.onresponse?.(response, messageId); + if (tap?.consumed) return; + const handler = this._responseHandlers.get(messageId); + if (handler === undefined) { + this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); + return; + } + this._responseHandlers.delete(messageId); + this._cleanupTimeout(messageId); + if (!tap?.preserveProgress) this._progressHandlers.delete(messageId); + if (isJSONRPCResultResponse(response)) { + handler(response); + } else { + handler(ProtocolError.fromError(response.error.code, response.error.message, response.error.data)); + } + } + + private _onclose(): void { + this._closed = true; + const responseHandlers = this._responseHandlers; + this._responseHandlers = new Map(); + this._progressHandlers.clear(); + this._pendingDebouncedNotifications.clear(); + for (const info of this._timeoutInfo.values()) clearTimeout(info.timeoutId); + this._timeoutInfo.clear(); + const aborts = this._requestHandlerAbortControllers; + this._requestHandlerAbortControllers = new Map(); + const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed'); + try { + this.onclose?.(); + } finally { + for (const handler of responseHandlers.values()) handler(error); + for (const c of aborts.values()) c.abort(error); + } + } + + private _onerror(error: Error): void { + this.onerror?.(error); + } + + private _setupTimeout(id: number, timeout: number, maxTotal: number | undefined, onTimeout: () => void, reset: boolean): void { + this._timeoutInfo.set(id, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout: maxTotal, + resetTimeoutOnProgress: reset, + onTimeout + }); + } + + private _resetTimeout(id: number): boolean { + const info = this._timeoutInfo.get(id); + if (!info) return false; + const elapsed = Date.now() - info.startTime; + if (info.maxTotalTimeout && elapsed >= info.maxTotalTimeout) { + this._timeoutInfo.delete(id); + throw new SdkError(SdkErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { + maxTotalTimeout: info.maxTotalTimeout, + totalElapsed: elapsed + }); + } + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + return true; + } + + private _cleanupTimeout(id: number): void { + const info = this._timeoutInfo.get(id); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(id); + } + } +} + +/** + * Wraps a {@linkcode Transport} in a {@linkcode StreamDriver} and starts it. + * + * @internal + */ +export async function attachChannelTransport( + pipe: Transport, + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- adapter is context-agnostic + dispatcher: Dispatcher, + options?: AttachOptions +): Promise { + const driver = new StreamDriver(dispatcher, pipe, options); + if (options?.onclose || options?.onerror || options?.onresponse) { + driver.onclose = options.onclose; + driver.onerror = options.onerror; + driver.onresponse = options.onresponse; + } + await driver.start(); + return driver; +} diff --git a/packages/core/test/shared/dispatcher.test.ts b/packages/core/test/shared/dispatcher.test.ts new file mode 100644 index 000000000..4f36e16b7 --- /dev/null +++ b/packages/core/test/shared/dispatcher.test.ts @@ -0,0 +1,322 @@ +import { describe, expect, expectTypeOf, test } from 'vitest'; +import { z } from 'zod/v4'; + +import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; +import type { DispatchOutput } from '../../src/shared/dispatcher.js'; +import { Dispatcher } from '../../src/shared/dispatcher.js'; +import type { JSONRPCErrorResponse, JSONRPCRequest, JSONRPCResultResponse, Result } from '../../src/types/index.js'; +import { ProtocolError, ProtocolErrorCode } from '../../src/types/index.js'; + +const req = (method: string, params?: Record, id = 1): JSONRPCRequest => ({ jsonrpc: '2.0', id, method, params }); + +async function collect(gen: AsyncIterable): Promise { + const out: DispatchOutput[] = []; + for await (const o of gen) out.push(o); + return out; +} + +describe('Dispatcher', () => { + test('dispatch yields a single response for a registered handler', async () => { + const d = new Dispatcher(); + d.setRequestHandler('ping', async () => ({})); + const out = await collect(d.dispatch(req('ping'))); + expect(out).toHaveLength(1); + expect(out[0]!.kind).toBe('response'); + expect((out[0]!.message as JSONRPCResultResponse).result).toEqual({}); + }); + + test('dispatch yields MethodNotFound for an unregistered method', async () => { + const d = new Dispatcher(); + const out = await collect(d.dispatch(req('tools/list'))); + expect(out).toHaveLength(1); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.code).toBe(ProtocolErrorCode.MethodNotFound); + }); + + test('handler throw is wrapped as InternalError', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('boom', async () => { + throw new Error('kaboom'); + }); + const out = await collect(d.dispatch(req('boom'))); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.code).toBe(ProtocolErrorCode.InternalError); + expect(msg.error.message).toBe('kaboom'); + }); + + test('handler throwing ProtocolError preserves code and data', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('boom', async () => { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'bad', { hint: 'x' }); + }); + const out = await collect(d.dispatch(req('boom'))); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.code).toBe(ProtocolErrorCode.InvalidParams); + expect(msg.error.data).toEqual({ hint: 'x' }); + }); + + test('ctx.mcpReq.notify yields notifications before the final response', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('work', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 1, progress: 0.5 } }); + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: 'hi' } }); + return { ok: true } as Result; + }); + const out = await collect(d.dispatch(req('work'))); + expect(out.map(o => o.kind)).toEqual(['notification', 'notification', 'response']); + const first = out[0]!; + expect(first.kind === 'notification' && (first.message.params as { progress: number }).progress).toBe(0.5); + expect((out[2]!.message as JSONRPCResultResponse).result).toEqual({ ok: true }); + }); + + test('notifications interleave with async handler work', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('work', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: '1' } }); + await new Promise(r => setTimeout(r, 1)); + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: '2' } }); + return {} as Result; + }); + const seen: string[] = []; + for await (const o of d.dispatch(req('work'))) { + seen.push(o.kind === 'notification' ? `n:${(o.message.params as { data: string }).data}` : 'response'); + } + expect(seen).toEqual(['n:1', 'n:2', 'response']); + }); + + test('ctx.mcpReq.send with no env.send throws SdkError(NotConnected)', async () => { + const d = new Dispatcher(); + let caught: unknown; + d.setRawRequestHandler('elicit', async (_r, ctx) => { + try { + await ctx.mcpReq.send({ method: 'elicitation/create', params: {} }); + } catch (e) { + caught = e; + throw e; + } + return {} as Result; + }); + const out = await collect(d.dispatch(req('elicit'))); + expect(caught).toBeInstanceOf(SdkError); + expect((caught as SdkError).code).toBe(SdkErrorCode.NotConnected); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.message).toMatch(/No outbound channel/); + }); + + test('ctx.mcpReq.send delegates to env.send and validates with the supplied result schema', async () => { + const d = new Dispatcher(); + let sent: unknown; + d.setRawRequestHandler('ask', async (_r, ctx) => { + const r = await ctx.mcpReq.send({ method: 'acme/ping' }, z.object({ pong: z.boolean() })); + return { got: r } as Result; + }); + const out = await collect( + d.dispatch(req('ask'), { + send: async r => { + sent = r; + return { pong: true } as Result; + } + }) + ); + expect(sent).toEqual({ method: 'acme/ping' }); + expect((out[0]!.message as JSONRPCResultResponse).result).toEqual({ got: { pong: true } }); + }); + + test('env.signal abort yields a cancelled error response', async () => { + const d = new Dispatcher(); + const ac = new AbortController(); + d.setRawRequestHandler('slow', async (_r, ctx) => { + if (ctx.mcpReq.signal.aborted) return {} as Result; + await new Promise(resolve => ctx.mcpReq.signal.addEventListener('abort', () => resolve(), { once: true })); + return {} as Result; + }); + const gen = d.dispatch(req('slow'), { signal: ac.signal }); + const p = collect(gen); + await Promise.resolve(); + ac.abort('stop'); + const out = await p; + const msg = out[out.length - 1]!.message as JSONRPCErrorResponse; + expect(msg.error.message).toBe('Request cancelled'); + }); + + test('env values surface on context', async () => { + const d = new Dispatcher(); + let seen: { sid: unknown; auth: unknown; ext: unknown } | undefined; + d.setRawRequestHandler('echo', async (_r, ctx) => { + seen = { sid: ctx.sessionId, auth: ctx.http?.authInfo, ext: ctx.ext }; + return {} as Result; + }); + await collect( + d.dispatch(req('echo'), { sessionId: 's1', authInfo: { token: 't', clientId: 'c', scopes: [] }, ext: { mark: 'x' } }) + ); + expect(seen?.sid).toBe('s1'); + expect((seen?.auth as { token: string }).token).toBe('t'); + expect((seen?.ext as { mark: string }).mark).toBe('x'); + }); + + test('dispatchNotification routes to handler and ignores unknown', async () => { + const d = new Dispatcher(); + let got: unknown; + d.setNotificationHandler('notifications/initialized', n => { + got = n.method; + }); + await d.dispatchNotification({ jsonrpc: '2.0', method: 'notifications/initialized' }); + expect(got).toBe('notifications/initialized'); + await expect(d.dispatchNotification({ jsonrpc: '2.0', method: 'unknown/thing' })).resolves.toBeUndefined(); + }); + + test('fallbackRequestHandler is used when no specific handler matches', async () => { + const d = new Dispatcher(); + d.fallbackRequestHandler = async r => ({ echoed: r.method }) as Result; + const out = await collect(d.dispatch(req('whatever/method'))); + expect((out[0]!.message as JSONRPCResultResponse).result).toEqual({ echoed: 'whatever/method' }); + }); + + test('assertCanSetRequestHandler throws on collision', () => { + const d = new Dispatcher(); + d.setRequestHandler('ping', async () => ({})); + expect(() => d.assertCanSetRequestHandler('ping')).toThrow(/already exists/); + }); + + test('setRequestHandler parses request via schema', async () => { + const d = new Dispatcher(); + let parsed: unknown; + d.setRequestHandler('ping', r => { + parsed = r; + return {}; + }); + await collect(d.dispatch(req('ping'))); + expect(parsed).toMatchObject({ method: 'ping' }); + }); + + test('dispatchToResponse returns the terminal response', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('x', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: 'n' } }); + return { v: 1 } as Result; + }); + const r = (await d.dispatchToResponse(req('x'))) as JSONRPCResultResponse; + expect(r.result).toEqual({ v: 1 }); + }); + + test('options.buildContext enriches the context the handler receives', async () => { + type ExtCtx = import('../../src/shared/protocol.js').BaseContext & { role: string }; + const d = new Dispatcher({ buildContext: (base, _env) => ({ ...base, role: 'server' }) }); + let role: string | undefined; + d.setRawRequestHandler('echo', async (_r, ctx) => { + role = ctx.role; + return {} as Result; + }); + await collect(d.dispatch(req('echo'))); + expect(role).toBe('server'); + }); + + test('options.wrapHandler is applied to registered handlers', async () => { + const calls: string[] = []; + const d = new Dispatcher({ + wrapHandler: (method, h) => async (r, ctx) => { + calls.push(method); + return h(r, ctx); + } + }); + d.setRequestHandler('ping', async () => ({})); + await collect(d.dispatch(req('ping'))); + expect(calls).toEqual(['ping']); + }); +}); + +describe('Dispatcher.use middleware composition', () => { + test('middleware composes outermost-first and can transform output', async () => { + const d = new Dispatcher(); + const order: string[] = []; + d.use( + next => + async function* (r, e) { + order.push('a-in'); + yield* next(r, e); + order.push('a-out'); + } + ); + d.use( + next => + async function* (r, e) { + order.push('b-in'); + yield* next(r, e); + order.push('b-out'); + } + ); + d.setRawRequestHandler('x', async () => ({}) as Result); + await collect(d.dispatch(req('x'))); + expect(order).toEqual(['a-in', 'b-in', 'b-out', 'a-out']); + }); + + test('middleware can short-circuit by yielding a response without calling next', async () => { + const d = new Dispatcher(); + d.use( + () => + async function* (r) { + yield { kind: 'response', message: { jsonrpc: '2.0', id: r.id, result: { mw: true } } }; + } + ); + d.setRawRequestHandler('x', async () => ({ handler: true }) as Result); + const r = (await d.dispatchToResponse(req('x'))) as JSONRPCResultResponse; + expect(r.result).toEqual({ mw: true }); + }); +}); + +describe('Dispatcher.setRequestHandler 3-arg (custom method + {params, result})', () => { + test('parses params, strips _meta, types handler arg', async () => { + const d = new Dispatcher(); + const params = z.object({ q: z.string(), limit: z.number().optional() }); + d.setRequestHandler('acme/search', { params }, async p => { + return { hits: [p.q], limit: p.limit ?? 10 } as Result; + }); + const r = (await d.dispatchToResponse(req('acme/search', { q: 'foo', _meta: { progressToken: 1 } }))) as JSONRPCResultResponse; + expect(r.result).toEqual({ hits: ['foo'], limit: 10 }); + }); + + test('schema validation failure becomes InvalidParams error response', async () => { + const d = new Dispatcher(); + d.setRequestHandler('acme/search', { params: z.object({ q: z.string() }) }, async () => ({}) as Result); + const r = (await d.dispatchToResponse(req('acme/search', { q: 123 }))) as JSONRPCErrorResponse; + expect(r.error.code).toBe(ProtocolErrorCode.InvalidParams); + expect(r.error.message).toMatch(/Invalid params for acme\/search/); + }); + + test('result schema types the handler return value', () => { + const d = new Dispatcher(); + const params = z.object({ q: z.string() }); + const result = z.object({ hits: z.array(z.string()) }); + d.setRequestHandler('acme/search', { params, result }, async p => { + expectTypeOf(p.q).toBeString(); + return { hits: [p.q] }; + }); + // @ts-expect-error -- result schema enforces shape + d.setRequestHandler('acme/search', { params, result }, async () => ({ wrong: 1 })); + }); + + test('arbitrary string method accepted via 3-arg overload with loose Result', () => { + const d = new Dispatcher(); + d.setRequestHandler('x/custom', { params: z.object({}) }, async () => ({ anything: 1 }) as Result); + d.setRequestHandler('tools/list', { params: z.object({}) }, async () => ({ tools: [] })); + }); + + test('3-arg setNotificationHandler validates params and dispatches', async () => { + const d = new Dispatcher(); + let seen: number | undefined; + d.setNotificationHandler('ui/ping', { params: z.object({ ts: z.number() }) }, p => { + expectTypeOf(p.ts).toBeNumber(); + seen = p.ts; + }); + await d.dispatchNotification({ jsonrpc: '2.0', method: 'ui/ping', params: { ts: 42 } }); + expect(seen).toBe(42); + }); + + test('2-arg form rejects non-spec methods at runtime', () => { + const d = new Dispatcher(); + const setReq = d.setRequestHandler.bind(d) as (m: string, h: () => Promise) => void; + const setNotif = d.setNotificationHandler.bind(d) as (m: string, h: () => void) => void; + expect(() => setReq('acme/search', async () => ({}))).toThrow(/not a spec request method/); + expect(() => setNotif('acme/ping', () => {})).toThrow(/not a spec notification method/); + }); +}); diff --git a/packages/core/test/shared/streamDriver.test.ts b/packages/core/test/shared/streamDriver.test.ts new file mode 100644 index 000000000..b40d85226 --- /dev/null +++ b/packages/core/test/shared/streamDriver.test.ts @@ -0,0 +1,202 @@ +import { describe, expect, test, vi } from 'vitest'; +import { z } from 'zod/v4'; + +import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; +import { Dispatcher } from '../../src/shared/dispatcher.js'; +import { RAW_RESULT_SCHEMA, StreamDriver } from '../../src/shared/streamDriver.js'; +import type { JSONRPCMessage, JSONRPCRequest, Result } from '../../src/types/index.js'; +import { ProtocolError, ProtocolErrorCode } from '../../src/types/index.js'; +import { InMemoryTransport } from '../../src/util/inMemory.js'; + +async function tick(): Promise { + await new Promise(r => setTimeout(r, 0)); +} + +function pair(): { + near: StreamDriver; + farPipe: InMemoryTransport; + farSent: JSONRPCMessage[]; + dispatcher: Dispatcher; +} { + const [a, b] = InMemoryTransport.createLinkedPair(); + const dispatcher = new Dispatcher(); + const near = new StreamDriver(dispatcher, a); + const farSent: JSONRPCMessage[] = []; + b.onmessage = m => farSent.push(m); + return { near, farPipe: b, farSent, dispatcher }; +} + +describe('StreamDriver — outbound request/response correlation', () => { + test('request() resolves when matching response arrives', async () => { + const { near, farPipe, farSent } = pair(); + await near.start(); + await farPipe.start(); + + const p = near.request({ method: 'ping' }, z.object({})); + await tick(); + const sent = farSent[0] as JSONRPCRequest; + expect(sent.method).toBe('ping'); + await farPipe.send({ jsonrpc: '2.0', id: sent.id, result: {} }); + await expect(p).resolves.toEqual({}); + }); + + test('request() rejects with ProtocolError on JSON-RPC error response', async () => { + const { near, farPipe, farSent } = pair(); + await near.start(); + await farPipe.start(); + + const p = near.request({ method: 'ping' }, z.object({})); + await tick(); + const sent = farSent[0] as JSONRPCRequest; + await farPipe.send({ jsonrpc: '2.0', id: sent.id, error: { code: ProtocolErrorCode.InvalidParams, message: 'bad' } }); + await expect(p).rejects.toBeInstanceOf(ProtocolError); + }); + + test('request() times out and sends notifications/cancelled', async () => { + const { near, farPipe, farSent } = pair(); + await near.start(); + await farPipe.start(); + + const p = near.request({ method: 'ping' }, z.object({}), { timeout: 10 }); + await expect(p).rejects.toMatchObject({ code: SdkErrorCode.RequestTimeout }); + await tick(); + expect(farSent.some(m => 'method' in m && m.method === 'notifications/cancelled')).toBe(true); + }); + + test('caller signal abort rejects and sends notifications/cancelled', async () => { + const { near, farPipe, farSent } = pair(); + await near.start(); + await farPipe.start(); + + const ac = new AbortController(); + const p = near.request({ method: 'ping' }, z.object({}), { signal: ac.signal }); + await tick(); + ac.abort('stop'); + await expect(p).rejects.toBeInstanceOf(SdkError); + await tick(); + expect(farSent.some(m => 'method' in m && m.method === 'notifications/cancelled')).toBe(true); + }); + + test('progress notification invokes onprogress and resetTimeoutOnProgress extends the timeout', async () => { + const { near, farPipe, farSent } = pair(); + await near.start(); + await farPipe.start(); + + const onprogress = vi.fn(); + const p = near.request({ method: 'ping' }, z.object({}), { onprogress, timeout: 50, resetTimeoutOnProgress: true }); + await tick(); + const sent = farSent[0] as JSONRPCRequest; + const token = (sent.params as { _meta: { progressToken: number } })._meta.progressToken; + await farPipe.send({ jsonrpc: '2.0', method: 'notifications/progress', params: { progressToken: token, progress: 0.5 } }); + await tick(); + expect(onprogress).toHaveBeenCalledWith({ progress: 0.5 }); + await farPipe.send({ jsonrpc: '2.0', id: sent.id, result: {} }); + await expect(p).resolves.toEqual({}); + }); + + test('close rejects in-flight requests with ConnectionClosed', async () => { + const { near, farPipe } = pair(); + await near.start(); + await farPipe.start(); + + const p = near.request({ method: 'ping' }, z.object({})); + await tick(); + await near.close(); + await expect(p).rejects.toMatchObject({ code: SdkErrorCode.ConnectionClosed }); + }); +}); + +describe('StreamDriver — inbound dispatch', () => { + test('inbound request is dispatched and response sent on the pipe', async () => { + const { near, farPipe, farSent, dispatcher } = pair(); + dispatcher.setRequestHandler('ping', async () => ({})); + await near.start(); + await farPipe.start(); + + await farPipe.send({ jsonrpc: '2.0', id: 7, method: 'ping' }); + await tick(); + const resp = farSent.find(m => 'id' in m && m.id === 7 && 'result' in m); + expect(resp).toBeDefined(); + }); + + test('inbound notification is routed to dispatchNotification', async () => { + const { near, farPipe, dispatcher } = pair(); + let seen = false; + dispatcher.setNotificationHandler('notifications/initialized', () => { + seen = true; + }); + await near.start(); + await farPipe.start(); + + await farPipe.send({ jsonrpc: '2.0', method: 'notifications/initialized' }); + await tick(); + expect(seen).toBe(true); + }); + + test('notifications/cancelled aborts the matching in-flight handler', async () => { + const { near, farPipe, dispatcher } = pair(); + let aborted = false; + dispatcher.setRawRequestHandler('slow', async (_r, ctx) => { + await new Promise(resolve => + ctx.mcpReq.signal.addEventListener('abort', () => { + aborted = true; + resolve(); + }) + ); + return {} as Result; + }); + await near.start(); + await farPipe.start(); + + await farPipe.send({ jsonrpc: '2.0', id: 9, method: 'slow' }); + await tick(); + await farPipe.send({ jsonrpc: '2.0', method: 'notifications/cancelled', params: { requestId: 9, reason: 'x' } }); + await tick(); + expect(aborted).toBe(true); + }); + + test('handler ctx.mcpReq.send routes through driver.request with relatedRequestId', async () => { + const { near, farPipe, farSent, dispatcher } = pair(); + dispatcher.setRawRequestHandler('outer', async (_r, ctx) => { + const r = await ctx.mcpReq.send({ method: 'ping' }); + return { inner: r } as Result; + }); + await near.start(); + await farPipe.start(); + + await farPipe.send({ jsonrpc: '2.0', id: 1, method: 'outer' }); + await tick(); + const innerReq = farSent.find(m => 'method' in m && m.method === 'ping') as JSONRPCRequest; + expect(innerReq).toBeDefined(); + await farPipe.send({ jsonrpc: '2.0', id: innerReq.id, result: {} }); + await tick(); + const outerResp = farSent.find(m => 'id' in m && m.id === 1 && 'result' in m); + expect(outerResp).toMatchObject({ result: { inner: {} } }); + }); +}); + +describe('StreamDriver — debounced notifications', () => { + test('notifications in debouncedNotificationMethods coalesce within a tick', async () => { + const [a, b] = InMemoryTransport.createLinkedPair(); + const driver = new StreamDriver(new Dispatcher(), a, { debouncedNotificationMethods: ['notifications/tools/list_changed'] }); + const farSent: JSONRPCMessage[] = []; + b.onmessage = m => farSent.push(m); + await driver.start(); + await b.start(); + + void driver.notification({ method: 'notifications/tools/list_changed' }); + void driver.notification({ method: 'notifications/tools/list_changed' }); + void driver.notification({ method: 'notifications/tools/list_changed' }); + await tick(); + await tick(); + const sent = farSent.filter(m => 'method' in m && m.method === 'notifications/tools/list_changed'); + expect(sent).toHaveLength(1); + }); +}); + +describe('RAW_RESULT_SCHEMA', () => { + test('passes any value through unchanged', async () => { + const v = await RAW_RESULT_SCHEMA['~standard'].validate({ anything: 1 }); + expect('value' in v && v.value).toEqual({ anything: 1 }); + }); +});