diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index 5c1689ca6..6affc1c28 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -38,6 +38,9 @@ export { checkResourceAllowed, resourceUrlFromServerUrl } from '../../shared/aut // Metadata utilities export { getDisplayName } from '../../shared/metadataUtils.js'; +// Dispatcher types referenced by public `use()` method (NOT the Dispatcher class itself) +export type { DispatchMiddleware } from '../../shared/dispatcher.js'; + // Protocol types (NOT the Protocol class itself or mergeCapabilities) export type { BaseContext, diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 7a96cddf4..1abbd160b 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -2,7 +2,6 @@ export * from './auth/errors.js'; export * from './errors/sdkErrors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; -export * from './shared/context.js'; export * from './shared/dispatcher.js'; export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts index 2b878ed5b..8608224fc 100644 --- a/packages/core/src/shared/context.ts +++ b/packages/core/src/shared/context.ts @@ -1,6 +1,238 @@ -import type { AuthInfo, JSONRPCMessage, Notification, Request, RequestId, Result } from '../types/index.js'; +import type { + AuthInfo, + ClientCapabilities, + CreateMessageRequest, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + LoggingLevel, + Notification, + Progress, + Request, + RequestId, + RequestMeta, + RequestMethod, + Result, + ResultTypeMap, + ServerCapabilities +} from '../types/index.js'; import type { StandardSchemaV1 } from '../util/standardSchema.js'; -import type { NotificationOptions, RequestOptions } from './protocol.js'; +import type { TransportSendOptions } from './transport.js'; + +/** + * Callback for progress notifications. + */ +export type ProgressCallback = (progress: Progress) => void; + +/** + * Additional initialization options. + */ +export type ProtocolOptions = { + /** + * Protocol versions supported. First version is preferred (sent by client, + * used as fallback by server). Passed to transport during `connect()`. + * + * @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS} + */ + supportedProtocolVersions?: string[]; + + /** + * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. + * + * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. + * + * Currently this defaults to `false`, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to `true`. + */ + enforceStrictCapabilities?: boolean; + /** + * An array of notification method names that should be automatically debounced. + * Any notifications with a method in this list will be coalesced if they + * occur in the same tick of the event loop. + * e.g., `['notifications/tools/list_changed']` + */ + debouncedNotificationMethods?: string[]; +}; + +/** + * The default request timeout, in milliseconds. + */ +export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; + +/** + * Options that can be given per request. + */ +export type RequestOptions = { + /** + * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + */ + onprogress?: ProgressCallback; + + /** + * Can be used to cancel an in-flight request. This will cause an `AbortError` to be raised from `request()`. + */ + signal?: AbortSignal; + + /** + * A timeout (in milliseconds) for this request. If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised from `request()`. + * + * If not specified, {@linkcode DEFAULT_REQUEST_TIMEOUT_MSEC} will be used as the timeout. + */ + timeout?: number; + + /** + * If `true`, receiving a progress notification will reset the request timeout. + * This is useful for long-running operations that send periodic progress updates. + * Default: `false` + */ + resetTimeoutOnProgress?: boolean; + + /** + * Maximum total time (in milliseconds) to wait for a response. + * If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised, regardless of progress notifications. + * If not specified, there is no maximum total timeout. + */ + maxTotalTimeout?: number; +} & TransportSendOptions; + +/** + * Options that can be given per notification. + */ +export type NotificationOptions = { + /** + * May be used to indicate to the transport which incoming request to associate this outgoing notification with. + */ + relatedRequestId?: RequestId; +}; + +/** + * Base context provided to all request handlers. + */ +export type BaseContext = { + /** + * The session ID from the transport, if available. + */ + sessionId?: string; + + /** + * Information about the MCP request being handled. + */ + mcpReq: { + /** + * The JSON-RPC ID of the request being handled. + */ + id: RequestId; + + /** + * The method name of the request (e.g., 'tools/call', 'ping'). + */ + method: string; + + /** + * Metadata from the original request. + */ + _meta?: RequestMeta; + + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + signal: AbortSignal; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + * + * For spec methods the result type is inferred from the method name. + * For custom (non-spec) methods, pass a result schema as the second argument. + */ + send: { + ( + request: { method: M; params?: Record }, + options?: RequestOptions + ): Promise; + ( + request: Request, + resultSchema: T, + options?: RequestOptions + ): Promise>; + }; + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + notify: (notification: Notification) => Promise; + }; + + /** + * HTTP transport information, only available when using an HTTP-based transport. + */ + http?: { + /** + * Information about a validated access token, provided to request handlers. + */ + authInfo?: AuthInfo; + }; + + /** + * Extension slot. Adapters and middleware populate keys here; handlers cast to the + * extension's declared type to read them. Core never reads or writes this field. + */ + ext?: Record; +}; + +/** + * Context provided to server-side request handlers, extending {@linkcode BaseContext} with server-specific fields. + */ +export type ServerContext = BaseContext & { + mcpReq: { + /** + * Send a log message notification to the client. + * Respects the client's log level filter set via logging/setLevel. + */ + log: (level: LoggingLevel, data: unknown, logger?: string) => Promise; + + /** + * Send an elicitation request to the client, requesting user input. + */ + elicitInput: (params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions) => Promise; + + /** + * Request LLM sampling from the client. + */ + requestSampling: ( + params: CreateMessageRequest['params'], + options?: RequestOptions + ) => Promise; + }; + + http?: { + /** + * The original HTTP request. + */ + req?: globalThis.Request; + + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using a StreamableHTTPServerTransport with eventStore configured. + */ + closeSSE?: () => void; + + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using a StreamableHTTPServerTransport with eventStore configured. + */ + closeStandaloneSSE?: () => void; + }; +}; + +/** + * Context provided to client-side request handlers. + */ +export type ClientContext = BaseContext; /** * Per-request environment a transport adapter passes to {@linkcode Dispatcher.dispatch}. @@ -17,6 +249,13 @@ export type RequestEnv = { */ send?: (request: Request, options?: RequestOptions) => Promise; + /** + * Sends a notification back to the peer, related to the request being dispatched. + * When supplied, `ctx.mcpReq.notify` calls this; when undefined, the dispatcher + * yields the notification inline. + */ + notify?: (notification: Notification) => Promise; + /** Validated auth token info for HTTP transports. */ authInfo?: AuthInfo; @@ -29,6 +268,12 @@ export type RequestEnv = { /** Transport session identifier (legacy `Mcp-Session-Id`). */ sessionId?: string; + /** + * The originating request id, when the dispatch is on behalf of an inbound request. + * Adapters propagate this so wrapped `send`/`notify` carry `relatedRequestId`. + */ + relatedRequestId?: RequestId; + /** Extension slot. Adapters and middleware populate keys here; copied onto `BaseContext.ext`. */ ext?: Record; }; @@ -48,10 +293,42 @@ export interface Outbound { notification(notification: Notification, options?: NotificationOptions): Promise; /** Close the underlying connection. */ close(): Promise; - /** Clear a registered progress callback by its message id. Optional; pipe-channels expose this. */ - removeProgressHandler?(messageId: number): void; /** Inform the channel which protocol version was negotiated (for header echoing etc.). Optional. */ setProtocolVersion?(version: string): void; - /** Write a raw JSON-RPC message on the same stream as a prior request. Optional; pipe-only. */ - sendRaw?(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; +} + +/** + * Schema bundle accepted by `setRequestHandler`'s 3-arg form. + * + * `params` is required and validates the inbound `request.params`. `result` is optional; + * when supplied it types the handler's return value (no runtime validation is performed + * on the result). + */ +export interface RequestHandlerSchemas< + P extends StandardSchemaV1 = StandardSchemaV1, + R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined +> { + params: P; + result?: R; +} + +function isPlainObject(value: unknown): value is Record { + return value !== null && typeof value === 'object' && !Array.isArray(value); +} + +export function mergeCapabilities(base: ServerCapabilities, additional: Partial): ServerCapabilities; +export function mergeCapabilities(base: ClientCapabilities, additional: Partial): ClientCapabilities; +export function mergeCapabilities(base: T, additional: Partial): T { + const result: T = { ...base }; + for (const key in additional) { + const k = key as keyof T; + const addValue = additional[k]; + if (addValue === undefined) continue; + const baseValue = result[k]; + result[k] = + isPlainObject(baseValue) && isPlainObject(addValue) + ? ({ ...(baseValue as Record), ...(addValue as Record) } as T[typeof k]) + : (addValue as T[typeof k]); + } + return result; } diff --git a/packages/core/src/shared/dispatcher.ts b/packages/core/src/shared/dispatcher.ts index e1709ae8a..959011728 100644 --- a/packages/core/src/shared/dispatcher.ts +++ b/packages/core/src/shared/dispatcher.ts @@ -17,8 +17,7 @@ import type { import { getNotificationSchema, getRequestSchema, getResultSchema, ProtocolError, ProtocolErrorCode } from '../types/index.js'; import type { StandardSchemaV1 } from '../util/standardSchema.js'; import { validateStandardSchema } from '../util/standardSchema.js'; -import type { RequestEnv } from './context.js'; -import type { BaseContext, RequestOptions } from './protocol.js'; +import type { BaseContext, RequestEnv, RequestOptions } from './context.js'; /** * One yielded item from {@linkcode Dispatcher.dispatch}. A dispatch yields zero or more @@ -184,11 +183,13 @@ export class Dispatcher { } return parsed.data; }) as BaseContext['mcpReq']['send'], - notify: async (n: Notification) => { - if (done) return; - queue.push({ jsonrpc: '2.0', method: n.method, params: n.params } as JSONRPCNotification); - wake?.(); - } + notify: + env.notify ?? + (async (n: Notification) => { + if (done) return; + queue.push({ jsonrpc: '2.0', method: n.method, params: n.params } as JSONRPCNotification); + wake?.(); + }) }, http: env.authInfo || env.httpReq ? { authInfo: env.authInfo } : undefined, ext: env.ext diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index cdebf43d9..fc60d8eb6 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,276 +1,55 @@ +/** + * The {@linkcode Protocol} class composes a stateless {@linkcode Dispatcher} + * (handler registry + dispatch) and a per-connection {@linkcode StreamDriver} + * (id correlation, timeouts, progress) under the same public surface as before. + * Handler-context types live in `./context.ts` and are re-exported here for + * backward-compatible import paths. + */ + import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; import type { - AuthInfo, - CancelledNotification, - ClientCapabilities, - CreateMessageRequest, - CreateMessageResult, - CreateMessageResultWithTools, - ElicitRequestFormParams, - ElicitRequestURLParams, - ElicitResult, - JSONRPCErrorResponse, - JSONRPCNotification, JSONRPCRequest, - JSONRPCResponse, - JSONRPCResultResponse, - LoggingLevel, MessageExtraInfo, Notification, NotificationMethod, NotificationTypeMap, - Progress, - ProgressNotification, Request, - RequestId, - RequestMeta, RequestMethod, RequestTypeMap, Result, - ResultTypeMap, - ServerCapabilities -} from '../types/index.js'; -import { - getNotificationSchema, - getRequestSchema, - getResultSchema, - isJSONRPCErrorResponse, - isJSONRPCNotification, - isJSONRPCRequest, - isJSONRPCResultResponse, - ProtocolError, - ProtocolErrorCode, - SUPPORTED_PROTOCOL_VERSIONS + ResultTypeMap } from '../types/index.js'; +import { getResultSchema, SUPPORTED_PROTOCOL_VERSIONS } from '../types/index.js'; import type { StandardSchemaV1 } from '../util/standardSchema.js'; -import { isStandardSchema, validateStandardSchema } from '../util/standardSchema.js'; -import type { Transport, TransportSendOptions } from './transport.js'; - -/** - * Callback for progress notifications. - */ -export type ProgressCallback = (progress: Progress) => void; - -/** - * Additional initialization options. - */ -export type ProtocolOptions = { - /** - * Protocol versions supported. First version is preferred (sent by client, - * used as fallback by server). Passed to transport during {@linkcode Protocol.connect | connect()}. - * - * @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS} - */ - supportedProtocolVersions?: string[]; - - /** - * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. - * - * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. - * - * Currently this defaults to `false`, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to `true`. - */ - enforceStrictCapabilities?: boolean; - /** - * An array of notification method names that should be automatically debounced. - * Any notifications with a method in this list will be coalesced if they - * occur in the same tick of the event loop. - * e.g., `['notifications/tools/list_changed']` - */ - debouncedNotificationMethods?: string[]; -}; - -/** - * The default request timeout, in milliseconds. - */ -export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; - -/** - * Options that can be given per request. - */ -export type RequestOptions = { - /** - * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. - * - * For task-augmented requests: progress notifications continue after {@linkcode CreateTaskResult} is returned and stop automatically when the task reaches a terminal status. - */ - onprogress?: ProgressCallback; - - /** - * Can be used to cancel an in-flight request. This will cause an `AbortError` to be raised from {@linkcode Protocol.request | request()}. - */ - signal?: AbortSignal; - - /** - * A timeout (in milliseconds) for this request. If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised from {@linkcode Protocol.request | request()}. - * - * If not specified, {@linkcode DEFAULT_REQUEST_TIMEOUT_MSEC} will be used as the timeout. - */ - timeout?: number; - - /** - * If `true`, receiving a progress notification will reset the request timeout. - * This is useful for long-running operations that send periodic progress updates. - * Default: `false` - */ - resetTimeoutOnProgress?: boolean; - - /** - * Maximum total time (in milliseconds) to wait for a response. - * If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised, regardless of progress notifications. - * If not specified, there is no maximum total timeout. - */ - maxTotalTimeout?: number; -} & TransportSendOptions; - -/** - * Options that can be given per notification. - */ -export type NotificationOptions = { - /** - * May be used to indicate to the transport which incoming request to associate this outgoing notification with. - */ - relatedRequestId?: RequestId; -}; - -/** - * Base context provided to all request handlers. - */ -export type BaseContext = { - /** - * The session ID from the transport, if available. - */ - sessionId?: string; - - /** - * Information about the MCP request being handled. - */ - mcpReq: { - /** - * The JSON-RPC ID of the request being handled. - */ - id: RequestId; - - /** - * The method name of the request (e.g., 'tools/call', 'ping'). - */ - method: string; - - /** - * Metadata from the original request. - */ - _meta?: RequestMeta; - - /** - * An abort signal used to communicate if the request was cancelled from the sender's side. - */ - signal: AbortSignal; - - /** - * Sends a request that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - * - * For spec methods the result type is inferred from the method name. - * For custom (non-spec) methods, pass a result schema as the second argument. - */ - send: { - ( - request: { method: M; params?: Record }, - options?: RequestOptions - ): Promise; - ( - request: Request, - resultSchema: T, - options?: RequestOptions - ): Promise>; - }; - - /** - * Sends a notification that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - notify: (notification: Notification) => Promise; - }; - - /** - * HTTP transport information, only available when using an HTTP-based transport. - */ - http?: { - /** - * Information about a validated access token, provided to request handlers. - */ - authInfo?: AuthInfo; - }; - - /** - * Extension slot. Adapters and middleware populate keys here; handlers cast to the - * extension's declared type to read them. Core never reads or writes this field. - */ - ext?: Record; -}; - -/** - * Context provided to server-side request handlers, extending {@linkcode BaseContext} with server-specific fields. - */ -export type ServerContext = BaseContext & { - mcpReq: { - /** - * Send a log message notification to the client. - * Respects the client's log level filter set via logging/setLevel. - */ - log: (level: LoggingLevel, data: unknown, logger?: string) => Promise; - - /** - * Send an elicitation request to the client, requesting user input. - */ - elicitInput: (params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions) => Promise; - - /** - * Request LLM sampling from the client. - */ - requestSampling: ( - params: CreateMessageRequest['params'], - options?: RequestOptions - ) => Promise; - }; +import { isStandardSchema } from '../util/standardSchema.js'; +import type { BaseContext, NotificationOptions, ProtocolOptions, RequestEnv, RequestHandlerSchemas, RequestOptions } from './context.js'; +import type { DispatchMiddleware, DispatchOutput, RawHandler } from './dispatcher.js'; +import { Dispatcher } from './dispatcher.js'; +import { RAW_RESULT_SCHEMA, StreamDriver } from './streamDriver.js'; +import type { Transport } from './transport.js'; - http?: { - /** - * The original HTTP request. - */ - req?: globalThis.Request; +export * from './context.js'; - /** - * Closes the SSE stream for this request, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - */ - closeSSE?: () => void; +type InferHandlerResult = R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result; - /** - * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - */ - closeStandaloneSSE?: () => void; - }; -}; - -/** - * Context provided to client-side request handlers. - */ -export type ClientContext = BaseContext; - -/** - * Information about a request's timeout state - */ -type TimeoutInfo = { - timeoutId: ReturnType; - startTime: number; - timeout: number; - maxTotalTimeout?: number; - resetTimeoutOnProgress: boolean; - onTimeout: () => void; +/** Type-erased forwarding signature for the {@linkcode Dispatcher.setRequestHandler} overload set. */ +type SetRequestHandlerImpl = ( + method: string, + schemaOrHandler: unknown, + maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise +) => void; + +/** Type-erased forwarding signature for the {@linkcode Dispatcher.setNotificationHandler} overload set. */ +type SetNotificationHandlerImpl = ( + method: string, + schemaOrHandler: unknown, + maybeHandler?: (params: unknown, notification: Notification) => void | Promise +) => void; + +/** RequestEnv augmented with the per-request fields {@linkcode Protocol} threads through. */ +type ProtocolEnv = RequestEnv & { + /** {@linkcode MessageExtraInfo} captured by the transport, forwarded to {@linkcode Protocol.buildContext}. */ + _transportExtra?: MessageExtraInfo; }; /** @@ -281,15 +60,8 @@ type TimeoutInfo = { * implementations most code should use. */ export abstract class Protocol { - private _transport?: Transport; - private _requestMessageId = 0; - private _requestHandlers: Map Promise> = new Map(); - private _requestHandlerAbortControllers: Map = new Map(); - private _notificationHandlers: Map Promise> = new Map(); - private _responseHandlers: Map void> = new Map(); - private _progressHandlers: Map = new Map(); - private _timeoutInfo: Map = new Map(); - private _pendingDebouncedNotifications = new Set(); + private readonly _dispatcher: Dispatcher; + private _driver?: StreamDriver; protected _supportedProtocolVersions: string[]; @@ -307,343 +79,43 @@ export abstract class Protocol { */ onerror?: (error: Error) => void; - /** - * A handler to invoke for any request types that do not have their own handler installed. - */ - fallbackRequestHandler?: (request: JSONRPCRequest, ctx: ContextT) => Promise; - - /** - * A handler to invoke for any notification types that do not have their own handler installed. - */ - fallbackNotificationHandler?: (notification: Notification) => Promise; - constructor(private _options?: ProtocolOptions) { - this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; - - this.setNotificationHandler('notifications/cancelled', notification => { - this._oncancel(notification); - }); - - this.setNotificationHandler('notifications/progress', notification => { - this._onprogress(notification); + this._dispatcher = new Dispatcher({ + buildContext: (base, env) => { + // StreamDriver populates `_transportExtra`; direct `dispatch(req, env)` callers + // pass `httpReq`/`authInfo` on RequestEnv. Synthesize MessageExtraInfo from the + // latter when the former is absent so `ctx.http.req`/`ctx.http.authInfo` are + // populated on both paths. + const pe = env as ProtocolEnv; + const extra = + pe._transportExtra ?? + (env.httpReq !== undefined || env.authInfo !== undefined + ? { request: env.httpReq, authInfo: env.authInfo } + : undefined); + return this.buildContext(base, extra); + }, + wrapHandler: (method, handler) => this._wrapHandler(method, handler) }); + this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; - this.setRequestHandler( - 'ping', - // Automatic pong by default. - _request => ({}) as Result - ); + // notifications/{cancelled,progress} are handled by StreamDriver internally + // (abort + progress callback) and then also dispatched here so user-registered + // handlers still fire (parity with the pre-decomposition Protocol). + this.setNotificationHandler('notifications/cancelled', () => {}); + this.setNotificationHandler('notifications/progress', () => {}); + this.setRequestHandler('ping', _request => ({}) as Result); } + // ─────────────────────────────────────────────────────────────────────── + // Subclass hooks (abstract) + // ─────────────────────────────────────────────────────────────────────── + /** * Builds the context object for request handlers. Subclasses must override * to return the appropriate context type (e.g., ServerContext adds HTTP request info). */ protected abstract buildContext(ctx: BaseContext, transportInfo?: MessageExtraInfo): ContextT; - private async _oncancel(notification: CancelledNotification): Promise { - if (!notification.params.requestId) { - return; - } - // Handle request cancellation - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); - controller?.abort(notification.params.reason); - } - - private _setupTimeout( - messageId: number, - timeout: number, - maxTotalTimeout: number | undefined, - onTimeout: () => void, - resetTimeoutOnProgress: boolean = false - ) { - this._timeoutInfo.set(messageId, { - timeoutId: setTimeout(onTimeout, timeout), - startTime: Date.now(), - timeout, - maxTotalTimeout, - resetTimeoutOnProgress, - onTimeout - }); - } - - private _resetTimeout(messageId: number): boolean { - const info = this._timeoutInfo.get(messageId); - if (!info) return false; - - const totalElapsed = Date.now() - info.startTime; - if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { - this._timeoutInfo.delete(messageId); - throw new SdkError(SdkErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { - maxTotalTimeout: info.maxTotalTimeout, - totalElapsed - }); - } - - clearTimeout(info.timeoutId); - info.timeoutId = setTimeout(info.onTimeout, info.timeout); - return true; - } - - private _cleanupTimeout(messageId: number) { - const info = this._timeoutInfo.get(messageId); - if (info) { - clearTimeout(info.timeoutId); - this._timeoutInfo.delete(messageId); - } - } - - /** - * Attaches to the given transport, starts it, and starts listening for messages. - * - * The caller assumes ownership of the {@linkcode Transport}, replacing any callbacks that have already been set, and expects that it is the only user of the {@linkcode Transport} instance going forward. - */ - async connect(transport: Transport): Promise { - this._transport = transport; - const _onclose = this.transport?.onclose; - this._transport.onclose = () => { - try { - _onclose?.(); - } finally { - this._onclose(); - } - }; - - const _onerror = this.transport?.onerror; - this._transport.onerror = (error: Error) => { - _onerror?.(error); - this._onerror(error); - }; - - const _onmessage = this._transport?.onmessage; - this._transport.onmessage = (message, extra) => { - _onmessage?.(message, extra); - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - this._onresponse(message); - } else if (isJSONRPCRequest(message)) { - this._onrequest(message, extra); - } else if (isJSONRPCNotification(message)) { - this._onnotification(message); - } else { - this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); - } - }; - - // Pass supported protocol versions to transport for header validation - transport.setSupportedProtocolVersions?.(this._supportedProtocolVersions); - - await this._transport.start(); - } - - private _onclose(): void { - const responseHandlers = this._responseHandlers; - this._responseHandlers = new Map(); - this._progressHandlers.clear(); - this._pendingDebouncedNotifications.clear(); - - for (const info of this._timeoutInfo.values()) { - clearTimeout(info.timeoutId); - } - this._timeoutInfo.clear(); - - const requestHandlerAbortControllers = this._requestHandlerAbortControllers; - this._requestHandlerAbortControllers = new Map(); - - const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed'); - - this._transport = undefined; - - try { - this.onclose?.(); - } finally { - for (const handler of responseHandlers.values()) { - handler(error); - } - - for (const controller of requestHandlerAbortControllers.values()) { - controller.abort(error); - } - } - } - - private _onerror(error: Error): void { - this.onerror?.(error); - } - - private _onnotification(notification: JSONRPCNotification): void { - const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; - - // Ignore notifications not being subscribed to. - if (handler === undefined) { - return; - } - - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => handler(notification)) - .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); - } - - private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; - - // Capture the current transport at request time to ensure responses go to the correct client - const capturedTransport = this._transport; - - const sendNotification = (notification: Notification, options?: NotificationOptions) => - this.notification(notification, { ...options, relatedRequestId: request.id }); - const sendRequest = (r: Request, resultSchema: U, options?: RequestOptions) => - this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }); - - if (handler === undefined) { - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: ProtocolErrorCode.MethodNotFound, - message: 'Method not found' - } - }; - capturedTransport?.send(errorResponse).catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - return; - } - - const abortController = new AbortController(); - this._requestHandlerAbortControllers.set(request.id, abortController); - - const baseCtx: BaseContext = { - sessionId: capturedTransport?.sessionId, - mcpReq: { - id: request.id, - method: request.method, - _meta: request.params?._meta, - signal: abortController.signal, - // BaseContext.mcpReq.send is declared with two overloads (spec-method-keyed and explicit-schema). Arrow - // literals can't carry overload signatures, so the inferred single-signature type isn't assignable to - // that overloaded property type. The cast is sound: this impl dispatches both overload paths via the - // isStandardSchema guard, and sendRequest validates the result against the resolved schema either way. - send: ((r: Request, schemaOrOptions?: StandardSchemaV1 | RequestOptions, maybeOptions?: RequestOptions) => { - if (isStandardSchema(schemaOrOptions)) { - return sendRequest(r, schemaOrOptions, maybeOptions); - } - const resultSchema = getResultSchema(r.method); - if (!resultSchema) { - throw new TypeError( - `'${r.method}' is not a spec method; pass a result schema as the second argument to ctx.mcpReq.send().` - ); - } - return sendRequest(r, resultSchema, schemaOrOptions); - }) as BaseContext['mcpReq']['send'], - notify: sendNotification - }, - http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined - }; - const ctx = this.buildContext(baseCtx, extra); - - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => handler(request, ctx)) - .then( - async result => { - if (abortController.signal.aborted) { - // Request was cancelled - return; - } - - const response: JSONRPCResponse = { - result, - jsonrpc: '2.0', - id: request.id - }; - await capturedTransport?.send(response); - }, - async error => { - if (abortController.signal.aborted) { - // Request was cancelled - return; - } - - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: Number.isSafeInteger(error['code']) ? error['code'] : ProtocolErrorCode.InternalError, - message: error.message ?? 'Internal error', - ...(error['data'] !== undefined && { data: error['data'] }) - } - }; - await capturedTransport?.send(errorResponse); - } - ) - .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) - .finally(() => { - if (this._requestHandlerAbortControllers.get(request.id) === abortController) { - this._requestHandlerAbortControllers.delete(request.id); - } - }); - } - - private _onprogress(notification: ProgressNotification): void { - const { progressToken, ...params } = notification.params; - const messageId = Number(progressToken); - - const handler = this._progressHandlers.get(messageId); - if (!handler) { - this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); - return; - } - - const responseHandler = this._responseHandlers.get(messageId); - const timeoutInfo = this._timeoutInfo.get(messageId); - - if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { - try { - this._resetTimeout(messageId); - } catch (error) { - // Clean up if maxTotalTimeout was exceeded - this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); - this._cleanupTimeout(messageId); - responseHandler(error as Error); - return; - } - } - - handler(params); - } - - private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { - const messageId = Number(response.id); - - const handler = this._responseHandlers.get(messageId); - if (handler === undefined) { - this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); - return; - } - - this._responseHandlers.delete(messageId); - this._cleanupTimeout(messageId); - this._progressHandlers.delete(messageId); - - if (isJSONRPCResultResponse(response)) { - handler(response); - } else { - const error = ProtocolError.fromError(response.error.code, response.error.message, response.error.data); - handler(error); - } - } - - get transport(): Transport | undefined { - return this._transport; - } - - /** - * Closes the connection. - */ - async close(): Promise { - await this._transport?.close(); - } - /** * A method to check if a capability is supported by the remote side, for the given method to be called. * @@ -665,213 +137,46 @@ export abstract class Protocol { */ protected abstract assertRequestHandlerCapability(method: string): void; + // ─────────────────────────────────────────────────────────────────────── + // Handler registration (delegates to Dispatcher) + // ─────────────────────────────────────────────────────────────────────── + /** - * Sends a request and waits for a response. - * - * For spec methods the result schema is resolved automatically from the method name - * and the return type is method-keyed. For custom (non-spec) methods, pass a - * `resultSchema` as the second argument; the response is validated against it and - * the return type is inferred from the schema. - * - * Do not use this method to emit notifications! Use {@linkcode Protocol.notification | notification()} instead. + * A handler to invoke for any request types that do not have their own handler installed. */ - request( - request: { method: M; params?: Record }, - options?: RequestOptions - ): Promise; - request( - request: Request, - resultSchema: T, - options?: RequestOptions - ): Promise>; - request(request: Request, schemaOrOptions?: StandardSchemaV1 | RequestOptions, maybeOptions?: RequestOptions): Promise { - if (isStandardSchema(schemaOrOptions)) { - return this._requestWithSchema(request, schemaOrOptions, maybeOptions); - } - const resultSchema = getResultSchema(request.method); - if (!resultSchema) { - throw new TypeError(`'${request.method}' is not a spec method; pass a result schema as the second argument to request().`); - } - return this._requestWithSchema(request, resultSchema, schemaOrOptions); + get fallbackRequestHandler(): ((request: JSONRPCRequest, ctx: ContextT) => Promise) | undefined { + return this._dispatcher.fallbackRequestHandler; + } + set fallbackRequestHandler(h) { + this._dispatcher.fallbackRequestHandler = h; } /** - * Sends a request and waits for a response, using the provided schema for validation. - * - * This is the internal implementation used by SDK methods that need to specify - * a particular result schema (e.g., for compatibility or task-specific schemas). + * A handler to invoke for any notification types that do not have their own handler installed. */ - protected _requestWithSchema( - request: Request, - resultSchema: T, - options?: RequestOptions - ): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; - - let onAbort: (() => void) | undefined; - let cleanupMessageId: number | undefined; - - // Send the request - return new Promise>((resolve, reject) => { - const earlyReject = (error: unknown) => { - reject(error); - }; - - if (!this._transport) { - earlyReject(new Error('Not connected')); - return; - } - - if (this._options?.enforceStrictCapabilities === true) { - try { - this.assertCapabilityForMethod(request.method); - } catch (error) { - earlyReject(error); - return; - } - } - - options?.signal?.throwIfAborted(); - - const messageId = this._requestMessageId++; - cleanupMessageId = messageId; - const jsonrpcRequest: JSONRPCRequest = { - ...request, - jsonrpc: '2.0', - id: messageId - }; - - if (options?.onprogress) { - this._progressHandlers.set(messageId, options.onprogress); - jsonrpcRequest.params = { - ...request.params, - _meta: { - ...request.params?._meta, - progressToken: messageId - } - }; - } - - let responseReceived = false; - - const cancel = (reason: unknown) => { - if (responseReceived) { - return; - } - this._progressHandlers.delete(messageId); - - this._transport - ?.send( - { - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { - requestId: messageId, - reason: String(reason) - } - }, - { relatedRequestId, resumptionToken, onresumptiontoken } - ) - .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`))); - - // Wrap the reason in an SdkError if it isn't already - const error = reason instanceof SdkError ? reason : new SdkError(SdkErrorCode.RequestTimeout, String(reason)); - reject(error); - }; - - this._responseHandlers.set(messageId, response => { - if (options?.signal?.aborted) { - return; - } - responseReceived = true; - - if (response instanceof Error) { - return reject(response); - } - - validateStandardSchema(resultSchema, response.result).then(parseResult => { - if (parseResult.success) { - resolve(parseResult.data); - } else { - reject(new SdkError(SdkErrorCode.InvalidResult, `Invalid result for ${request.method}: ${parseResult.error}`)); - } - }, reject); - }); - - onAbort = () => cancel(options?.signal?.reason); - options?.signal?.addEventListener('abort', onAbort, { once: true }); - - const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; - const timeoutHandler = () => cancel(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })); - - this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - - this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { - this._progressHandlers.delete(messageId); - reject(error); - }); - }).finally(() => { - // Per-request cleanup that must run on every exit path. Consolidated - // here so new exit paths added to the promise body can't forget it. - // _progressHandlers is NOT cleaned up here: _onresponse deletes it - // on resolution, and error paths above delete it inline. - if (onAbort) { - options?.signal?.removeEventListener('abort', onAbort); - } - if (cleanupMessageId !== undefined) { - this._responseHandlers.delete(cleanupMessageId); - this._cleanupTimeout(cleanupMessageId); - } - }); + get fallbackNotificationHandler(): ((notification: Notification) => Promise) | undefined { + return this._dispatcher.fallbackNotificationHandler; + } + set fallbackNotificationHandler(h) { + this._dispatcher.fallbackNotificationHandler = h; } /** - * Emits a notification, which is a one-way message that does not expect a response. + * Registers a {@linkcode DispatchMiddleware} around the inbound request path. + * Registration order is outer-to-inner: the first `use()` call wraps all later ones. */ - async notification(notification: Notification, options?: NotificationOptions): Promise { - if (!this._transport) { - throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); - } - - this.assertNotificationCapability(notification.method); - - const jsonrpcNotification: JSONRPCNotification = { jsonrpc: '2.0', ...notification }; - - const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; - // A notification can only be debounced if it's in the list AND it's "simple" - // (i.e., has no parameters and no related request ID that could be lost). - const canDebounce = debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId; - - if (canDebounce) { - // If a notification of this type is already scheduled, do nothing. - if (this._pendingDebouncedNotifications.has(notification.method)) { - return; - } - - // Mark this notification type as pending. - this._pendingDebouncedNotifications.add(notification.method); - - // Schedule the actual send to happen in the next microtask. - // This allows all synchronous calls in the current event loop tick to be coalesced. - Promise.resolve().then(() => { - // Un-mark the notification so the next one can be scheduled. - this._pendingDebouncedNotifications.delete(notification.method); - - // SAFETY CHECK: If the connection was closed while this was pending, abort. - if (!this._transport) { - return; - } - - // Send the notification, but don't await it here to avoid blocking. - // Handle potential errors with a .catch(). - this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); - }); - - // Return immediately. - return; - } + use(mw: DispatchMiddleware): this { + this._dispatcher.use(mw); + return this; + } - await this._transport.send(jsonrpcNotification, options); + /** + * Dispatch one inbound request through the middleware chain and registered handler, + * yielding any handler-emitted notifications then exactly one terminal response. + * Transport-free entry point; `env` carries per-request input from the caller. + */ + dispatch(request: JSONRPCRequest, env?: RequestEnv): AsyncGenerator { + return this._dispatcher.dispatch(request, env); } /** @@ -910,32 +215,7 @@ export abstract class Protocol { maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise ): void { this.assertRequestHandlerCapability(method); - - let stored: (request: JSONRPCRequest, ctx: ContextT) => Promise; - - if (typeof schemasOrHandler === 'function') { - const schema = getRequestSchema(method); - if (!schema) { - throw new TypeError( - `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` - ); - } - stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx)); - } else if (maybeHandler) { - stored = async (request, ctx) => { - const userParams = { ...request.params }; - delete userParams._meta; - const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); - if (!parsed.success) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); - } - return maybeHandler(parsed.data, ctx); - }; - } else { - throw new TypeError('setRequestHandler: handler is required'); - } - - this._requestHandlers.set(method, this._wrapHandler(method, stored)); + (this._dispatcher.setRequestHandler as SetRequestHandlerImpl)(method, schemasOrHandler, maybeHandler); } /** @@ -946,10 +226,7 @@ export abstract class Protocol { * * Subclasses overriding this hook avoid redeclaring `setRequestHandler`'s overload set. */ - protected _wrapHandler( - _method: string, - handler: (request: JSONRPCRequest, ctx: ContextT) => Promise - ): (request: JSONRPCRequest, ctx: ContextT) => Promise { + protected _wrapHandler(_method: string, handler: RawHandler): RawHandler { return handler; } @@ -957,16 +234,14 @@ export abstract class Protocol { * Removes the request handler for the given method. */ removeRequestHandler(method: RequestMethod | string): void { - this._requestHandlers.delete(method); + this._dispatcher.removeRequestHandler(method); } /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ assertCanSetRequestHandler(method: RequestMethod | string): void { - if (this._requestHandlers.has(method)) { - throw new Error(`A request handler for ${method} already exists, which would be overridden`); - } + this._dispatcher.assertCanSetRequestHandler(method); } /** @@ -994,73 +269,126 @@ export abstract class Protocol { schemasOrHandler: { params: StandardSchemaV1 } | ((notification: unknown) => void | Promise), maybeHandler?: (params: unknown, notification: Notification) => void | Promise ): void { - if (typeof schemasOrHandler === 'function') { - const schema = getNotificationSchema(method); - if (!schema) { - throw new TypeError( - `'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().` - ); - } - this._notificationHandlers.set(method, notification => Promise.resolve(schemasOrHandler(schema.parse(notification)))); - return; - } - - if (!maybeHandler) { - throw new TypeError('setNotificationHandler: handler is required'); - } - this._notificationHandlers.set(method, async notification => { - const userParams = { ...notification.params }; - delete userParams._meta; - const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); - if (!parsed.success) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for notification ${method}: ${parsed.error}`); - } - await maybeHandler(parsed.data, notification); - }); + (this._dispatcher.setNotificationHandler as SetNotificationHandlerImpl)(method, schemasOrHandler, maybeHandler); } /** * Removes the notification handler for the given method. */ removeNotificationHandler(method: NotificationMethod | string): void { - this._notificationHandlers.delete(method); + this._dispatcher.removeNotificationHandler(method); } -} -/** - * Schema bundle accepted by {@linkcode Protocol.setRequestHandler | setRequestHandler}'s 3-arg form. - * - * `params` is required and validates the inbound `request.params`. `result` is optional; - * when supplied it types the handler's return value (no runtime validation is performed - * on the result). - */ -export interface RequestHandlerSchemas< - P extends StandardSchemaV1 = StandardSchemaV1, - R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined -> { - params: P; - result?: R; -} + // ─────────────────────────────────────────────────────────────────────── + // Connection (delegates to StreamDriver) + // ─────────────────────────────────────────────────────────────────────── -type InferHandlerResult = R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result; + get transport(): Transport | undefined { + return this._driver?.pipe; + } -function isPlainObject(value: unknown): value is Record { - return value !== null && typeof value === 'object' && !Array.isArray(value); -} + /** + * Attaches to the given transport, starts it, and starts listening for messages. + * + * The caller assumes ownership of the {@linkcode Transport}, replacing any callbacks that have already been set, and expects that it is the only user of the {@linkcode Transport} instance going forward. + */ + async connect(transport: Transport): Promise { + const driver = new StreamDriver(this._dispatcher, transport, { + supportedProtocolVersions: this._supportedProtocolVersions, + debouncedNotificationMethods: this._options?.debouncedNotificationMethods, + buildEnv: (extra, base) => { + const env: ProtocolEnv = { + ...base, + _transportExtra: extra, + notify: (n: Notification) => this.notification(n, { relatedRequestId: base.relatedRequestId }), + send: (r, opts) => this._requestWithSchema(r, RAW_RESULT_SCHEMA, { ...opts, relatedRequestId: base.relatedRequestId }) + }; + return env; + } + }); + this._driver = driver; + + driver.onclose = () => { + if (this._driver === driver) { + this._driver = undefined; + this.onclose?.(); + } + }; + driver.onerror = error => this.onerror?.(error); -export function mergeCapabilities(base: ServerCapabilities, additional: Partial): ServerCapabilities; -export function mergeCapabilities(base: ClientCapabilities, additional: Partial): ClientCapabilities; -export function mergeCapabilities(base: T, additional: Partial): T { - const result: T = { ...base }; - for (const key in additional) { - const k = key as keyof T; - const addValue = additional[k]; - if (addValue === undefined) continue; - const baseValue = result[k]; - result[k] = - isPlainObject(baseValue) && isPlainObject(addValue) - ? ({ ...(baseValue as Record), ...(addValue as Record) } as T[typeof k]) - : (addValue as T[typeof k]); + await driver.start(); + } + + /** + * Closes the connection. + */ + async close(): Promise { + await this._driver?.close(); + } + + /** + * Sends a request and waits for a response. + * + * For spec methods the result schema is resolved automatically from the method name + * and the return type is method-keyed. For custom (non-spec) methods, pass a + * `resultSchema` as the second argument; the response is validated against it and + * the return type is inferred from the schema. + * + * Do not use this method to emit notifications! Use {@linkcode Protocol.notification | notification()} instead. + */ + request( + request: { method: M; params?: Record }, + options?: RequestOptions + ): Promise; + request( + request: Request, + resultSchema: T, + options?: RequestOptions + ): Promise>; + request(request: Request, schemaOrOptions?: StandardSchemaV1 | RequestOptions, maybeOptions?: RequestOptions): Promise { + if (isStandardSchema(schemaOrOptions)) { + return this._requestWithSchema(request, schemaOrOptions, maybeOptions); + } + const resultSchema = getResultSchema(request.method); + if (!resultSchema) { + throw new TypeError(`'${request.method}' is not a spec method; pass a result schema as the second argument to request().`); + } + return this._requestWithSchema(request, resultSchema, schemaOrOptions); + } + + /** + * Sends a request and waits for a response, using the provided schema for validation. + * + * This is the internal implementation used by SDK methods that need to specify + * a particular result schema (e.g., for compatibility or task-specific schemas). + */ + protected _requestWithSchema( + request: Request, + resultSchema: T, + options?: RequestOptions + ): Promise> { + if (!this._driver) { + return Promise.reject(new Error('Not connected')); + } + if (this._options?.enforceStrictCapabilities === true) { + try { + this.assertCapabilityForMethod(request.method as RequestMethod); + } catch (error) { + return Promise.reject(error); + } + } + return this._driver.request(request, resultSchema, options); + } + + /** + * Emits a notification, which is a one-way message that does not expect a response. + */ + async notification(notification: Notification, options?: NotificationOptions): Promise { + const driver = this._driver; + if (!driver) { + throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); + } + this.assertNotificationCapability(notification.method as NotificationMethod); + await driver.notification(notification, options); } - return result; } diff --git a/packages/core/src/shared/streamDriver.ts b/packages/core/src/shared/streamDriver.ts index 84c851e95..bb9f61bfa 100644 --- a/packages/core/src/shared/streamDriver.ts +++ b/packages/core/src/shared/streamDriver.ts @@ -23,10 +23,9 @@ import { } from '../types/index.js'; import type { StandardSchemaV1 } from '../util/standardSchema.js'; import { validateStandardSchema } from '../util/standardSchema.js'; -import type { Outbound, RequestEnv } from './context.js'; +import type { NotificationOptions, Outbound, ProgressCallback, RequestEnv, RequestOptions } from './context.js'; +import { DEFAULT_REQUEST_TIMEOUT_MSEC } from './context.js'; import type { Dispatcher } from './dispatcher.js'; -import type { NotificationOptions, ProgressCallback, RequestOptions } from './protocol.js'; -import { DEFAULT_REQUEST_TIMEOUT_MSEC } from './protocol.js'; import type { Transport } from './transport.js'; /** @@ -62,18 +61,14 @@ export type StreamDriverOptions = { /** * Options for {@linkcode attachChannelTransport}. Mirrors {@linkcode StreamDriverOptions} - * plus the `onclose`/`onerror`/`onresponse` callbacks the caller would otherwise set - * on the driver instance. + * plus the `onclose`/`onerror` callbacks the caller would otherwise set on the driver + * instance. * * @internal */ export type AttachOptions = StreamDriverOptions & { onclose?: () => void; onerror?: (error: Error) => void; - onresponse?: ( - response: JSONRPCResultResponse | JSONRPCErrorResponse, - messageId: number - ) => { consumed: boolean; preserveProgress?: boolean }; }; /** @@ -97,12 +92,6 @@ export class StreamDriver implements Outbound { onclose?: () => void; onerror?: (error: Error) => void; - /** - * Tap for every inbound response. Return `consumed: true` to claim it (suppresses the - * matched-handler dispatch / unknown-id error). Return `preserveProgress: true` to keep - * the progress handler registered after the matched handler runs. Set by the owner. - */ - onresponse?: (response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number) => { consumed: boolean; preserveProgress?: boolean }; constructor( // eslint-disable-next-line @typescript-eslint/no-explicit-any -- driver is context-agnostic; owner supplies ContextT via dispatcher @@ -113,11 +102,6 @@ export class StreamDriver implements Outbound { this._supportedProtocolVersions = _options.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; } - /** Exposed so an owner can clear progress callbacks. See {@linkcode Outbound.removeProgressHandler}. */ - removeProgressHandler(token: number): void { - this._progressHandlers.delete(token); - } - /** * Wires the pipe's callbacks and starts it. After this resolves, inbound * requests are dispatched and {@linkcode StreamDriver.request | request()} works. @@ -165,11 +149,6 @@ export class StreamDriver implements Outbound { this.pipe.setProtocolVersion?.(version); } - /** {@linkcode Outbound.sendRaw} — write a raw JSON-RPC message to the pipe. */ - async sendRaw(message: Parameters[0], options?: { relatedRequestId?: RequestId }): Promise { - await this.pipe.send(message, options); - } - /** * Sends a request over the pipe and resolves with the parsed result. */ @@ -284,6 +263,7 @@ export class StreamDriver implements Outbound { authInfo: extra?.authInfo, httpReq: extra?.request, sessionId: this.pipe.sessionId, + relatedRequestId: request.id, send: (r, opts) => this.request(r, RAW_RESULT_SCHEMA, { ...opts, relatedRequestId: request.id }) as Promise }; const env = this._options.buildEnv ? this._options.buildEnv(extra, baseEnv) : baseEnv; @@ -293,7 +273,7 @@ export class StreamDriver implements Outbound { if (out.kind === 'notification') { await this.pipe.send(out.message, { relatedRequestId: request.id }); } else if (!abort.signal.aborted) { - await this.pipe.send(out.message, { relatedRequestId: request.id }); + await this.pipe.send(out.message); } } }; @@ -307,16 +287,23 @@ export class StreamDriver implements Outbound { } private _onnotification(notification: JSONRPCNotification): void { + // Dispatch on a microtask so callers see the same timing as the + // notification-handler path (which dispatches via Promise.resolve()). if (notification.method === 'notifications/cancelled') { - const requestId = (notification.params as { requestId?: RequestId } | undefined)?.requestId; - if (requestId !== undefined) - this._requestHandlerAbortControllers.get(requestId)?.abort((notification.params as { reason?: unknown })?.reason); - return; - } - if (notification.method === 'notifications/progress') { - this._onprogress(notification.params as ProgressNotification['params']); - return; + Promise.resolve() + .then(() => { + const requestId = (notification.params as { requestId?: RequestId } | undefined)?.requestId; + if (requestId !== undefined) + this._requestHandlerAbortControllers.get(requestId)?.abort((notification.params as { reason?: unknown })?.reason); + }) + .catch(error => this._onerror(error instanceof Error ? error : new Error(String(error)))); + } else if (notification.method === 'notifications/progress') { + Promise.resolve() + .then(() => this._onprogress(notification.params as ProgressNotification['params'])) + .catch(error => this._onerror(error instanceof Error ? error : new Error(String(error)))); } + // Always dispatch (including cancelled/progress) so user-registered handlers fire. + // The driver-internal handling above is in addition to, not instead of, user handlers. this.dispatcher .dispatchNotification(notification) .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); @@ -348,8 +335,6 @@ export class StreamDriver implements Outbound { private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { const messageId = Number(response.id); - const tap = this.onresponse?.(response, messageId); - if (tap?.consumed) return; const handler = this._responseHandlers.get(messageId); if (handler === undefined) { this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); @@ -357,7 +342,7 @@ export class StreamDriver implements Outbound { } this._responseHandlers.delete(messageId); this._cleanupTimeout(messageId); - if (!tap?.preserveProgress) this._progressHandlers.delete(messageId); + this._progressHandlers.delete(messageId); if (isJSONRPCResultResponse(response)) { handler(response); } else { @@ -436,11 +421,8 @@ export async function attachChannelTransport( options?: AttachOptions ): Promise { const driver = new StreamDriver(dispatcher, pipe, options); - if (options?.onclose || options?.onerror || options?.onresponse) { - driver.onclose = options.onclose; - driver.onerror = options.onerror; - driver.onresponse = options.onresponse; - } + driver.onclose = options?.onclose; + driver.onerror = options?.onerror; await driver.start(); return driver; } diff --git a/typedoc.config.mjs b/typedoc.config.mjs index f2a4e50f5..7cfedbff0 100644 --- a/typedoc.config.mjs +++ b/typedoc.config.mjs @@ -50,7 +50,15 @@ export default { externalSymbolLinkMappings: { '@modelcontextprotocol/core': { StandardSchemaV1: 'https://standardschema.dev/', - StandardJSONSchemaV1: 'https://standardschema.dev/' + StandardJSONSchemaV1: 'https://standardschema.dev/', + // Internal types referenced from public-facing JSDoc. Dispatcher/StreamDriver/ + // RequestEnv stay private; mapping to '#' renders the link as plain text without + // typedoc treating the unresolved reference as an error. + Dispatcher: '#', + 'Dispatcher.dispatch': '#', + 'Dispatcher.use': '#', + RequestEnv: '#', + StreamDriver: '#' } } };