diff --git a/ts/docs/architecture/agent-patterns.md b/ts/docs/architecture/agent-patterns.md index b8a7d34d6a..e813988442 100644 --- a/ts/docs/architecture/agent-patterns.md +++ b/ts/docs/architecture/agent-patterns.md @@ -54,11 +54,24 @@ export function instantiate(): AppAgent { } ``` +Agents that allocate expensive resources (browser instances, network +connections, child processes) on `initializeAgentContext` should also implement +`closeAgentContext` to release them when the agent is disabled or the session +ends. The dispatcher calls `closeAgentContext` automatically; omitting it leaks +the resource for the lifetime of the process. + +```typescript +// Example: agent that holds a long-lived resource +export function instantiate(): AppAgent { + return { initializeAgentContext, closeAgentContext, executeAction }; +} +``` + +**Examples:** `weather`, `photo`, `list`, `image`, `video`, `utility` + **When to choose:** any integration with a well-defined, enumerable set of actions — REST APIs, CLI tools, file operations, data queries. -**Examples:** `weather`, `photo`, `list`, `image`, `video` - --- ### 2. `external-api` — REST / OAuth Bridge diff --git a/ts/docs/architecture/dispatcher.md b/ts/docs/architecture/dispatcher.md index d172bee7d4..18eefb9c9e 100644 --- a/ts/docs/architecture/dispatcher.md +++ b/ts/docs/architecture/dispatcher.md @@ -681,9 +681,15 @@ The dispatcher uses structured error handling at several levels: - **Command lock** — A `Limiter` ensures only one command executes at a time. Concurrent requests queue behind the lock. -- **Cancellation** — Each request gets an `AbortController`. Calling - `cancelCommand(requestId)` signals the abort, which propagates through - the translation and execution pipeline. +- **Cancellation** — Each request gets an `AbortController` whose signal + propagates through the translation and execution pipeline (LLM fetch, + streaming chunks, cache validation). Two cancellation paths exist: + - `cancelCommand(requestId)` — cancel by the server-assigned UUID, available + after `setUserRequest()` fires. Used by Escape/Ctrl+C once a request is running. + - `cancelCommandByClientId(clientRequestId)` — cancel by the client-assigned + id passed as the second argument to `processCommand()`. This AbortController + is created before the command lock is acquired, so it can abort a command + that is queued behind another in-flight command before `setUserRequest()` fires. - **Unknown actions** — When no agent matches, the dispatcher displays an error and uses semantic search to suggest the closest matching agents/schemas. diff --git a/ts/packages/agentRpc/src/client.ts b/ts/packages/agentRpc/src/client.ts index b0ae732189..a83af1eab2 100644 --- a/ts/packages/agentRpc/src/client.ts +++ b/ts/packages/agentRpc/src/client.ts @@ -32,6 +32,41 @@ import { ChannelProvider } from "./common.js"; import { getObjectProperty, uint8ArrayToBase64 } from "@typeagent/common-utils"; import { AgentInterfaceFunctionName } from "./server.js"; +/** + * Race a promise against an AbortSignal. If the signal fires before the + * promise settles, throw an AbortError immediately (the underlying work + * continues in the background — the caller does not wait for it). + */ +function raceWithSignal( + promise: Promise, + signal: AbortSignal | undefined, +): Promise { + if (!signal) { + return promise; + } + return new Promise((resolve, reject) => { + const onAbort = () => + reject( + signal.reason ?? + new DOMException( + "The operation was aborted.", + "AbortError", + ), + ); + signal.addEventListener("abort", onAbort, { once: true }); + promise.then( + (v) => { + signal.removeEventListener("abort", onAbort); + resolve(v); + }, + (e) => { + signal.removeEventListener("abort", onAbort); + reject(e); + }, + ); + }); +} + type ShimContext = | { contextId: number; @@ -472,12 +507,32 @@ export async function createAgentRpcClient( action: TypeAgentAction, context: ActionContext, ) { - return withActionContextAsync(context, (contextParams) => - rpc.invoke("executeAction", { - ...contextParams, - action, - }), - ); + return withActionContextAsync(context, (contextParams) => { + const signal = context.abortSignal; + if (signal) { + const onAbort = () => + rpc.send("cancelAction", { + actionContextId: contextParams.actionContextId, + }); + signal.addEventListener("abort", onAbort, { once: true }); + return raceWithSignal( + rpc.invoke("executeAction", { + ...contextParams, + action, + }), + signal, + ).finally(() => { + signal.removeEventListener("abort", onAbort); + }); + } + return raceWithSignal( + rpc.invoke("executeAction", { + ...contextParams, + action, + }), + signal, + ); + }); }, validateWildcardMatch( action: AppAction, diff --git a/ts/packages/agentRpc/src/server.ts b/ts/packages/agentRpc/src/server.ts index efb91a6465..8b47d8fd6f 100644 --- a/ts/packages/agentRpc/src/server.ts +++ b/ts/packages/agentRpc/src/server.ts @@ -132,10 +132,14 @@ export function createAgentRpcServer( if (agent.executeAction === undefined) { throw new Error("Invalid invocation of executeAction"); } - return agent.executeAction( - param.action, - getActionContextShim(param), - ); + const shim = getActionContextShim(param); + try { + return await agent.executeAction(param.action, shim); + } finally { + if (param.actionContextId !== undefined) { + actionAbortControllers.delete(param.actionContextId); + } + } }, async validateWildcardMatch(param): Promise { if (agent.validateWildcardMatch === undefined) { @@ -264,6 +268,8 @@ export function createAgentRpcServer( }, }; + const actionAbortControllers = new Map(); + const agentCallHandlers: AgentCallFunctions = { async streamPartialAction( param: Partial & { @@ -284,6 +290,9 @@ export function createAgentRpcServer( getActionContextShim(param), ); }, + cancelAction(param: { actionContextId: number }): void { + actionAbortControllers.get(param.actionContextId)?.abort(); + }, }; const rpc = createRpc< @@ -600,6 +609,18 @@ export function createAgentRpcServer( "Invalid action context param: missing actionContextId", ); } + // Reuse the controller already registered for this actionContextId if + // one exists. Multiple RPC entry points (executeAction, + // streamPartialAction, executeCommand, handleChoice) can build a shim + // for the same id, and overwriting the controller would orphan the + // signal handed to the in-flight executeAction — cancelAction would + // then abort the wrong one. The owning call (executeAction) is + // responsible for deleting the entry when it completes. + let abortController = actionAbortControllers.get(actionContextId); + if (abortController === undefined) { + abortController = new AbortController(); + actionAbortControllers.set(actionContextId, abortController); + } const sessionContext = getSessionContextShim(param); const actionIO: ActionIO = { setDisplay(content: DisplayContent): void { @@ -634,6 +655,9 @@ export function createAgentRpcServer( streamingContext: undefined, activityContext: param.activityContext, isFromReasoningLoop: param.isFromReasoningLoop ?? false, + get abortSignal() { + return abortController.signal; + }, get sessionContext() { return sessionContext; }, @@ -672,6 +696,6 @@ export function createAgentRpcServer( export type AgentInterfaceFunctionName = | keyof AgentInvokeFunctions - | keyof AgentCallFunctions; + | Exclude; export type AgentControlMessage = "exit"; diff --git a/ts/packages/agentRpc/src/types.ts b/ts/packages/agentRpc/src/types.ts index 0c580c6e2e..83e386b508 100644 --- a/ts/packages/agentRpc/src/types.ts +++ b/ts/packages/agentRpc/src/types.ts @@ -150,6 +150,7 @@ export type AgentCallFunctions = { delta: string | undefined; }, ) => void; + cancelAction: (param: { actionContextId: number }) => void; }; export type AgentInvokeFunctions = { diff --git a/ts/packages/agents/utility/jest.config.cjs b/ts/packages/agents/utility/jest.config.cjs new file mode 100644 index 0000000000..f475768a12 --- /dev/null +++ b/ts/packages/agents/utility/jest.config.cjs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +module.exports = require("../../../jest.config.js"); diff --git a/ts/packages/agents/utility/package.json b/ts/packages/agents/utility/package.json index 012ae46572..462c8fedfe 100644 --- a/ts/packages/agents/utility/package.json +++ b/ts/packages/agents/utility/package.json @@ -21,9 +21,12 @@ "asc": "asc -i ./src/utilitySchema.mts -o ./dist/utilitySchema.pas.json -t UtilityAction", "build": "concurrently npm:tsc npm:asc npm:agc", "clean": "rimraf --glob dist *.tsbuildinfo *.done.build.log", + "jest-esm": "node --no-warnings --experimental-vm-modules ./node_modules/jest/bin/jest.js", "prettier": "prettier --check . --ignore-path ../../../.prettierignore", "prettier:fix": "prettier --write . --ignore-path ../../../.prettierignore", - "tsc": "tsc -b" + "test": "npm run test:local", + "test:local": "pnpm run jest-esm --testPathPattern=\".*[.]spec[.]js\"", + "tsc": "tsc -b src test" }, "dependencies": { "@anthropic-ai/claude-agent-sdk": "^0.2.101", @@ -36,9 +39,12 @@ "puppeteer-extra-plugin-stealth": "^2.11.2" }, "devDependencies": { + "@jest/globals": "^29.7.0", "@typeagent/action-schema-compiler": "workspace:*", + "@types/jest": "^29.5.7", "action-grammar-compiler": "workspace:*", "concurrently": "^9.1.2", + "jest": "^29.7.0", "prettier": "^3.5.3", "rimraf": "^6.0.1", "typescript": "~5.4.5" diff --git a/ts/packages/agents/utility/src/actionHandler.mts b/ts/packages/agents/utility/src/actionHandler.mts index fc93d85259..3b0eb6c5e5 100644 --- a/ts/packages/agents/utility/src/actionHandler.mts +++ b/ts/packages/agents/utility/src/actionHandler.mts @@ -4,6 +4,7 @@ import type { ActionContext, AppAgent, + SessionContext, TypeAgentAction, } from "@typeagent/agent-sdk"; import { @@ -27,24 +28,57 @@ import type { UtilityAction } from "./utilitySchema.mjs"; (puppeteer as any).use(StealthPlugin()); -export type UtilityAgentContext = {}; +export type UtilityAgentContext = { + // Lazily-launched browser instance. Held here so closeAgentContext() can + // shut it down when the agent is disabled, preventing orphaned Chrome processes. + browserPromise: Promise | null; +}; async function initializeUtilityContext(): Promise { - getBrowser().catch(() => {}); // pre-warm: launch Chrome in background, ignore startup errors - return {}; + const agentContext: UtilityAgentContext = { browserPromise: null }; + // Pre-warm: launch Chrome in background, ignore startup errors. + getBrowser(agentContext).catch(() => {}); + return agentContext; +} + +async function closeUtilityContext( + context: SessionContext, +): Promise { + const agentContext = context.agentContext; + if (agentContext.browserPromise !== null) { + const browserPromise = agentContext.browserPromise; + // Clear immediately so any concurrent getBrowser() call starts fresh. + agentContext.browserPromise = null; + try { + const browser = await browserPromise; + await browser.close(); + } catch { + // Browser may have already crashed or been closed — ignore. + } + } } async function executeUtilityAction( action: TypeAgentAction, - _context: ActionContext, + context: ActionContext, ) { const a = action as UtilityAction; + const signal = context.abortSignal; + const agentContext = context.sessionContext.agentContext; try { switch (a.actionName) { case "webSearch": - return await handleWebSearch(a.parameters.query); + return await handleWebSearch( + a.parameters.query, + agentContext, + signal, + ); case "webFetch": - return await handleWebFetch(a.parameters.url); + return await handleWebFetch( + a.parameters.url, + agentContext, + signal, + ); case "readFile": return await handleReadFile(a.parameters.path); case "writeFile": @@ -59,6 +93,7 @@ async function executeUtilityAction( a.parameters.parseJson, a.parameters.htmlOutput, a.parameters.model, + signal, ); case "claudeTask": return await handleClaudeTask( @@ -66,6 +101,7 @@ async function executeUtilityAction( a.parameters.parseJson, a.parameters.model, a.parameters.maxTurns, + signal, ); default: return createActionResultFromError( @@ -73,41 +109,55 @@ async function executeUtilityAction( ); } } catch (error) { + if ( + signal?.aborted || + (error instanceof Error && error.name === "AbortError") + ) { + throw error; // let the dispatcher handle it as a cancellation + } return createActionResultFromError( `Utility action failed: ${error instanceof Error ? error.message : String(error)}`, ); } } -// Singleton browser — launched once, reused across calls -let _browserPromise: Promise | null = null; +// Singleton browser — launched once per agent context, reused across calls. +// Owned by UtilityAgentContext so closeAgentContext() can shut it down cleanly. -function getBrowser(): Promise { - if (!_browserPromise) { +function getBrowser(agentContext: UtilityAgentContext): Promise { + if (!agentContext.browserPromise) { const p = (puppeteer as any).launch({ headless: true, args: ["--no-sandbox", "--disable-setuid-sandbox"], }) as Promise; - _browserPromise = p.then((browser) => { + agentContext.browserPromise = p.then((browser) => { browser.on("disconnected", () => { - _browserPromise = null; + agentContext.browserPromise = null; }); return browser; }); - _browserPromise.catch(() => { - _browserPromise = null; + agentContext.browserPromise.catch(() => { + agentContext.browserPromise = null; }); } - return _browserPromise; + return agentContext.browserPromise; } const MAX_PAGE_CHARS = 50_000; -async function getPageContent(url: string): Promise { - const browser = await getBrowser(); +async function getPageContent( + url: string, + agentContext: UtilityAgentContext, + signal?: AbortSignal, +): Promise { + const browser = await getBrowser(agentContext); const page = await browser.newPage(); + // Cancel the navigation if the abort signal fires — Puppeteer doesn't + // accept AbortSignal directly, so close the page to interrupt goto(). + const onAbort = () => page.close().catch(() => {}); + signal?.addEventListener("abort", onAbort, { once: true }); try { await page.goto(url, { waitUntil: "networkidle2", @@ -130,20 +180,29 @@ async function getPageContent(url: string): Promise { `\n\n[Content truncated at ${MAX_PAGE_CHARS} characters]` : text; } finally { + signal?.removeEventListener("abort", onAbort); await page.close(); } } -async function handleWebSearch(searchQuery: string) { +async function handleWebSearch( + searchQuery: string, + agentContext: UtilityAgentContext, + signal?: AbortSignal, +) { const url = `https://html.duckduckgo.com/html/?q=${encodeURIComponent(searchQuery)}`; - const html = await getPageContent(url); + const html = await getPageContent(url, agentContext, signal); return createActionResultFromTextDisplay( `Search results for "${searchQuery}":\n\n${html}`, `Search results for "${searchQuery}":\n\n${html}`, ); } -async function handleWebFetch(url: string) { +async function handleWebFetch( + url: string, + agentContext: UtilityAgentContext, + signal?: AbortSignal, +) { // Normalize bare domain/path (no scheme) to https:// // Also replace spaces with hyphens in the path (genre slugs etc.) const withScheme = @@ -151,7 +210,7 @@ async function handleWebFetch(url: string) { ? url : `https://${url}`; const normalized = withScheme.replace(/ /g, "-"); - const html = await getPageContent(normalized); + const html = await getPageContent(normalized, agentContext, signal); return createActionResultFromTextDisplay( `Content from ${normalized}:\n\n${html}`, `Content from ${normalized}:\n\n${html}`, @@ -174,15 +233,29 @@ async function handleLlmTransform( parseJson?: boolean, htmlOutput?: boolean, model: string = "claude-haiku-4-5-20251001", + signal?: AbortSignal, ) { + const abortController = new AbortController(); const fullPrompt = `${prompt}\n\n${input}`; - const queryInstance = query({ prompt: fullPrompt, options: { model } }); + const queryInstance = query({ + prompt: fullPrompt, + options: { model, abortController }, + }); + const onAbort = () => { + abortController.abort(signal?.reason); + queryInstance.return(); + }; + signal?.addEventListener("abort", onAbort, { once: true }); let responseText = ""; - for await (const message of queryInstance) { - if (message.type === "result" && message.subtype === "success") { - responseText = (message as any).result || ""; - break; + try { + for await (const message of queryInstance) { + if (message.type === "result" && message.subtype === "success") { + responseText = (message as any).result || ""; + break; + } } + } finally { + signal?.removeEventListener("abort", onAbort); } if (parseJson) { @@ -225,12 +298,15 @@ async function handleClaudeTask( parseJson?: boolean, model: string = "claude-haiku-4-5-20251001", maxTurns: number = 10, + signal?: AbortSignal, ) { + const abortController = new AbortController(); const queryInstance = query({ prompt: goal, options: { model, maxTurns, + abortController, permissionMode: "acceptEdits", canUseTool: async () => ({ behavior: "allow" as const }), allowedTools: ["WebSearch", "WebFetch"], @@ -243,12 +319,21 @@ async function handleClaudeTask( }, }, }); + const onAbort = () => { + abortController.abort(signal?.reason); + queryInstance.return(); + }; + signal?.addEventListener("abort", onAbort, { once: true }); let responseText = ""; - for await (const message of queryInstance) { - if (message.type === "result" && message.subtype === "success") { - responseText = (message as any).result || ""; - break; + try { + for await (const message of queryInstance) { + if (message.type === "result" && message.subtype === "success") { + responseText = (message as any).result || ""; + break; + } } + } finally { + signal?.removeEventListener("abort", onAbort); } if (parseJson) { @@ -276,6 +361,7 @@ async function handleClaudeTask( export function instantiate(): AppAgent { return { initializeAgentContext: initializeUtilityContext, + closeAgentContext: closeUtilityContext, executeAction: executeUtilityAction, }; } diff --git a/ts/packages/agents/utility/test/browserLifecycle.spec.ts b/ts/packages/agents/utility/test/browserLifecycle.spec.ts new file mode 100644 index 0000000000..1e5b98925d --- /dev/null +++ b/ts/packages/agents/utility/test/browserLifecycle.spec.ts @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { jest, describe, test, expect, beforeEach } from "@jest/globals"; + +// --- Fake browser & page --------------------------------------------------- + +let disconnectedHandler: (() => void) | undefined; + +function makeFakeBrowser() { + disconnectedHandler = undefined; + return { + on: jest.fn((event: string, handler: () => void) => { + if (event === "disconnected") { + disconnectedHandler = handler; + } + }), + newPage: jest.fn(() => + Promise.resolve({ + goto: jest.fn(), + content: jest.fn(), + close: jest.fn(), + }), + ), + close: jest.fn(() => Promise.resolve()), + }; +} + +let fakeBrowser = makeFakeBrowser(); + +// Mutable reference updated each beforeEach so the mock factory always sees +// the latest fake browser without needing to re-import puppeteer-extra. +const launchMock = jest.fn(() => Promise.resolve(fakeBrowser)); + +// --- ESM mocks (must precede dynamic import) -------------------------------- + +jest.unstable_mockModule("puppeteer-extra", () => ({ + default: { + use: jest.fn(), + launch: launchMock, + }, +})); + +jest.unstable_mockModule("puppeteer-extra-plugin-stealth", () => ({ + default: jest.fn(), +})); + +jest.unstable_mockModule("browser-typeagent/htmlReducer", () => ({ + createNodeHtmlReducer: jest.fn(), +})); + +jest.unstable_mockModule("html-to-text", () => ({ + convert: jest.fn(), +})); + +jest.unstable_mockModule("@anthropic-ai/claude-agent-sdk", () => ({ + query: jest.fn(), +})); + +// --- Import module under test ----------------------------------------------- + +const { instantiate } = await import("../src/actionHandler.mjs"); + +describe("browser lifecycle", () => { + beforeEach(() => { + fakeBrowser = makeFakeBrowser(); + launchMock.mockImplementation(() => Promise.resolve(fakeBrowser)); + }); + + test("instantiate() exports closeAgentContext", () => { + const agent = instantiate(); + expect(agent.closeAgentContext).toBeDefined(); + expect(typeof agent.closeAgentContext).toBe("function"); + }); + + test("initializeAgentContext pre-warms browserPromise after a tick", async () => { + const agent = instantiate(); + const ctx = await agent.initializeAgentContext!(); + // Pre-warm is fire-and-forget; let microtasks settle + await new Promise((r) => setTimeout(r, 0)); + expect((ctx as any).browserPromise).not.toBeNull(); + }); + + test("closeAgentContext closes browser and nulls browserPromise", async () => { + const agent = instantiate(); + const agentContext = await agent.initializeAgentContext!(); + await new Promise((r) => setTimeout(r, 0)); + expect((agentContext as any).browserPromise).not.toBeNull(); + + const sessionContext = { agentContext } as any; + await agent.closeAgentContext!(sessionContext); + + expect(fakeBrowser.close).toHaveBeenCalledTimes(1); + expect((agentContext as any).browserPromise).toBeNull(); + }); + + test("double close is a no-op", async () => { + const agent = instantiate(); + const agentContext = await agent.initializeAgentContext!(); + await new Promise((r) => setTimeout(r, 0)); + + const sessionContext = { agentContext } as any; + await agent.closeAgentContext!(sessionContext); + await agent.closeAgentContext!(sessionContext); + + expect(fakeBrowser.close).toHaveBeenCalledTimes(1); + }); + + test("browser disconnected event clears browserPromise", async () => { + const agent = instantiate(); + const agentContext = await agent.initializeAgentContext!(); + await new Promise((r) => setTimeout(r, 0)); + + expect((agentContext as any).browserPromise).not.toBeNull(); + expect(disconnectedHandler).toBeDefined(); + + // Simulate browser crash/disconnect + disconnectedHandler!(); + + expect((agentContext as any).browserPromise).toBeNull(); + }); +}); diff --git a/ts/packages/agents/utility/test/tsconfig.json b/ts/packages/agents/utility/test/tsconfig.json new file mode 100644 index 0000000000..a9e01ce1d7 --- /dev/null +++ b/ts/packages/agents/utility/test/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../../../../tsconfig.base.json", + "compilerOptions": { + "composite": true, + "rootDir": ".", + "outDir": "../dist/test" + }, + "references": [{ "path": "../src" }], + "include": ["./**/*"] +} diff --git a/ts/packages/agents/utility/tsconfig.json b/ts/packages/agents/utility/tsconfig.json index cd81e5c355..76b9db86e7 100644 --- a/ts/packages/agents/utility/tsconfig.json +++ b/ts/packages/agents/utility/tsconfig.json @@ -2,5 +2,5 @@ "extends": "../../../tsconfig.base.json", "compilerOptions": { "composite": true }, "include": [], - "references": [{ "path": "./src" }] + "references": [{ "path": "./src" }, { "path": "./test" }] } diff --git a/ts/packages/aiclient/src/models.ts b/ts/packages/aiclient/src/models.ts index dbfa047166..c463a721c6 100644 --- a/ts/packages/aiclient/src/models.ts +++ b/ts/packages/aiclient/src/models.ts @@ -71,6 +71,7 @@ export interface ChatModel extends TypeChatLanguageModel { usageCallback?: CompleteUsageStatsCallback, jsonSchema?: CompletionJsonSchema, promptLogFn?: (msg: any) => void, + signal?: AbortSignal, ): Promise>; } @@ -85,6 +86,7 @@ export interface ChatModelWithStreaming extends ChatModel { usageCallback?: CompleteUsageStatsCallback, jsonSchema?: CompletionJsonSchema, promptLogFn?: (msg: any) => void, + signal?: AbortSignal, ): Promise>>; } diff --git a/ts/packages/aiclient/src/ollamaModels.ts b/ts/packages/aiclient/src/ollamaModels.ts index 1a95c071bf..4e6048dc01 100644 --- a/ts/packages/aiclient/src/ollamaModels.ts +++ b/ts/packages/aiclient/src/ollamaModels.ts @@ -187,6 +187,7 @@ export function createOllamaChatModel( async function complete( prompt: string | PromptSection[], + // TODO: thread AbortSignal through to fetch() for cancellation support (see interrupt-design.md) ): Promise> { const messages = typeof prompt === "string" @@ -233,6 +234,7 @@ export function createOllamaChatModel( async function completeStream( prompt: string | PromptSection[], + // TODO: thread AbortSignal through to fetch() for cancellation support (see interrupt-design.md) ): Promise>> { const messages: PromptSection[] = typeof prompt === "string" diff --git a/ts/packages/aiclient/src/openai.ts b/ts/packages/aiclient/src/openai.ts index f80832f62f..3d377011f9 100644 --- a/ts/packages/aiclient/src/openai.ts +++ b/ts/packages/aiclient/src/openai.ts @@ -553,6 +553,7 @@ function createAzureOpenAIChatModel( usageCallback?: CompleteUsageStatsCallback, jsonSchema?: CompletionJsonSchema, logFn?: (msg: any) => void, + signal?: AbortSignal, ): Promise> { verifyPromptLength(settings, prompt); @@ -564,6 +565,7 @@ function createAzureOpenAIChatModel( const params = getParams(messages, jsonSchema); const result = await callJsonApiWithPool(pool, buildRequest(params), { retryPauseMs: settings.retryPauseMs, + signal, }); if (!result.success) { return result; @@ -619,6 +621,7 @@ function createAzureOpenAIChatModel( usageCallback?: CompleteUsageStatsCallback, jsonSchema?: CompletionJsonSchema, logFn?: (msg: any) => void, + signal?: AbortSignal, ): Promise>> { verifyPromptLength(settings, prompt); @@ -647,6 +650,7 @@ function createAzureOpenAIChatModel( }); const result = await callApiWithPool(pool, buildRequest(params), { retryPauseMs: settings.retryPauseMs, + signal, }); if (!result.success) { return result; @@ -657,7 +661,11 @@ function createAzureOpenAIChatModel( return { success: true, data: (async function* () { - for await (const evt of readServerEventStream(result.data)) { + for await (const evt of readServerEventStream( + result.data, + signal, + )) { + if (signal?.aborted) break; if (evt.data === "[DONE]") { try { if (settings.enableModelRequestLogging && logFn) { diff --git a/ts/packages/aiclient/src/restClient.ts b/ts/packages/aiclient/src/restClient.ts index 43bed13e87..484821f5cf 100644 --- a/ts/packages/aiclient/src/restClient.ts +++ b/ts/packages/aiclient/src/restClient.ts @@ -46,6 +46,7 @@ export function callApi( retryPauseMs?: number, timeout?: number, throttler?: FetchThrottler, + signal?: AbortSignal, ): Promise> { const options: RequestInit = { method: "POST", @@ -56,6 +57,7 @@ export function callApi( "content-type": "application/json", ...headers, }, + signal: signal ?? null, }; return fetchWithRetry( url, @@ -86,6 +88,7 @@ export async function callJsonApi( retryPauseMs?: number, timeout?: number, throttler?: FetchThrottler, + signal?: AbortSignal, ): Promise> { const result = await callApi( headers, @@ -95,6 +98,7 @@ export async function callJsonApi( retryPauseMs, timeout, throttler, + signal, ); if (result.success) { try { @@ -199,19 +203,31 @@ export async function getBlob( */ export async function* readResponseStream( response: Response, + signal?: AbortSignal, ): AsyncIterableIterator { const reader = response.body?.getReader(); if (reader) { - const utf8Decoder = new TextDecoder("utf-8"); - const options = { stream: true }; - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; + const onAbort = () => reader.cancel(); + signal?.addEventListener("abort", onAbort, { once: true }); + try { + const utf8Decoder = new TextDecoder("utf-8"); + const options = { stream: true }; + while (true) { + if (signal?.aborted) throw toAbortError(signal.reason); + const { done, value } = await reader.read(); + if (done) { + break; + } + const text = utf8Decoder.decode(value, options); + if (text.length > 0) { + yield text; + } } - const text = utf8Decoder.decode(value, options); - if (text.length > 0) { - yield text; + } finally { + signal?.removeEventListener("abort", onAbort); + // Only cancel if the abort handler hasn't already done so. + if (!signal?.aborted) { + reader.cancel(); } } } @@ -336,12 +352,26 @@ export async function fetchWithRetry( // wait before retrying // wait at least as long as the Retry-After header, plus a back-off factor that increases with each retry // plus a random delay to avoid thundering herd - await sleep(totalWait); + const signal = options?.signal; + if (signal?.aborted) { + throw toAbortError(signal.reason); + } + await new Promise((resolve, reject) => { + const onAbort = () => { + clearTimeout(id); + reject(toAbortError(signal!.reason)); + }; + const id = setTimeout(() => { + signal?.removeEventListener("abort", onAbort); + resolve(); + }, totalWait); + signal?.addEventListener("abort", onAbort, { once: true }); + }); retryCount++; } } catch (e: any) { - if (e.name === "AbortError") { + if (isAbortError(e)) { throw e; } return error(`fetch error: ${e.cause?.message ?? e.message}`); @@ -416,10 +446,7 @@ async function fetchWithTimeout( if (externalSignal) { if (externalSignal.aborted) { clearTimeout(id); - throw ( - externalSignal.reason ?? - new DOMException("The operation was aborted.", "AbortError") - ); + throw toAbortError(externalSignal.reason); } onExternalAbort = () => controller.abort(externalSignal.reason); externalSignal.addEventListener("abort", onExternalAbort); @@ -441,10 +468,10 @@ async function fetchWithTimeout( } catch (e) { const ex = e as Error; // If the external signal caused the abort, re-throw as AbortError - if (ex.name === "AbortError" && externalSignal?.aborted) { + if (isAbortError(ex) && externalSignal?.aborted) { throw e; } - if (ex.name === "AbortError") { + if (isAbortError(ex)) { throw new Error(`fetch timeout ${timeoutMs}ms`); } throw e; @@ -483,6 +510,23 @@ function sleep(ms: number): Promise { return new Promise((resolve) => setTimeout(resolve, ms)); } +function isAbortError(e: unknown): e is Error { + return e instanceof Error && e.name === "AbortError"; +} + +/** + * Normalize an arbitrary abort reason into a real AbortError. `AbortSignal.reason` + * is whatever was passed to `controller.abort(reason)` — it may be a string, + * a plain object, or any non-Error value. Upstream code keys off `name === "AbortError"`, + * so coerce anything that isn't already an AbortError into a DOMException. + */ +function toAbortError(reason: unknown): Error { + if (isAbortError(reason)) { + return reason; + } + return new DOMException("The operation was aborted.", "AbortError"); +} + // --------------------------------------------------------------------------- // Endpoint-pool aware variants // @@ -510,6 +554,7 @@ export type BuildPoolRequest = ( export type PoolCallOptions = { overallBudgetMs?: number | undefined; retryPauseMs?: number | undefined; + signal?: AbortSignal | undefined; }; async function fetchWithPool( @@ -531,6 +576,7 @@ async function fetchWithPool( "content-type": "application/json", ...req.data.headers, }, + signal: options?.signal ?? null, }; if (method === "POST" && req.data.body !== undefined) { init.body = JSON.stringify(req.data.body); @@ -569,6 +615,12 @@ async function fetchWithPool( ) ); } + + // Respect caller-provided abort signal + if (options?.signal?.aborted) { + throw toAbortError(options.signal.reason); + } + const budgetLeft = overallBudgetMs - elapsed; const pick = pickEndpoint(pool); @@ -594,6 +646,7 @@ async function fetchWithPool( "content-type": "application/json", ...req.data.headers, }, + signal: options?.signal ?? null, }; if (method === "POST" && req.data.body !== undefined) { init.body = JSON.stringify(req.data.body); @@ -618,6 +671,11 @@ async function fetchWithPool( member.settings.throttler, ); } catch (e: any) { + // If the caller's abort signal fired, propagate immediately + // instead of rotating to another pool member. + if (isAbortError(e)) { + throw e; + } // Network error, timeout (AbortError re-thrown as "fetch timeout"), // or connection reset. Treat as transient and rotate. debugPool( diff --git a/ts/packages/aiclient/src/serverEvents.ts b/ts/packages/aiclient/src/serverEvents.ts index e79b647055..0b552bedc2 100644 --- a/ts/packages/aiclient/src/serverEvents.ts +++ b/ts/packages/aiclient/src/serverEvents.ts @@ -11,9 +11,11 @@ export type ServerEvent = { export async function* readServerEventStream( response: Response, + signal?: AbortSignal, ): AsyncIterableIterator { - const textStream = readResponseStream(response); + const textStream = readResponseStream(response, signal); for await (const message of readMessages(textStream)) { + if (signal?.aborted) break; yield readEvent(message); } } diff --git a/ts/packages/cli/src/commands/connect.ts b/ts/packages/cli/src/commands/connect.ts index db9b36b4b6..72b0e98d8f 100644 --- a/ts/packages/cli/src/commands/connect.ts +++ b/ts/packages/cli/src/commands/connect.ts @@ -401,8 +401,15 @@ export default class Connect extends Command { { showPrimaryName: false }, ), ), - async (command: string, _dispatcher: Dispatcher) => { - return activeDispatcher.processCommand(command); + async ( + command: string, + _dispatcher: Dispatcher, + clientRequestId: string, + ) => { + return activeDispatcher.processCommand( + command, + clientRequestId, + ); }, activeDispatcher, undefined, diff --git a/ts/packages/cli/src/enhancedConsole.ts b/ts/packages/cli/src/enhancedConsole.ts index 4e50d6812e..f15bd8d480 100644 --- a/ts/packages/cli/src/enhancedConsole.ts +++ b/ts/packages/cli/src/enhancedConsole.ts @@ -56,6 +56,7 @@ import { PromptRenderer, } from "./debugInterceptor.js"; import { stopAgentServer } from "@typeagent/agent-server-client"; +import { randomUUID } from "crypto"; // Track current processing state let currentSpinner: EnhancedSpinner | null = null; @@ -236,6 +237,9 @@ let terminalLayout: TerminalLayout | null = null; // Track the active request for cancellation support let currentRequestId: string | undefined; +// Client-assigned ID for the current request; available immediately before +// processCommand() returns, so we can cancel before setUserRequest() arrives. +let currentClientRequestId: string | undefined; let isProcessing = false; let lastCtrlCTime = 0; @@ -969,7 +973,7 @@ export function createEnhancedClientIO( chalk.gray("[answered by another client]"), ); } else { - displayContent(chalk.yellow("Cancelled!")); + displayContent(chalk.yellow("⚠ Cancelled")); } if (wasSpinning) { @@ -1688,7 +1692,7 @@ function initializeEnhancedConsole( dispatcherRef?: { current?: Dispatcher }, ) { process.on("SIGINT", () => { - if (isProcessing && dispatcherRef?.current && currentRequestId) { + if (isProcessing && dispatcherRef?.current) { const now = Date.now(); if (now - lastCtrlCTime < 1000) { // Double Ctrl+C — force exit @@ -1699,7 +1703,13 @@ function initializeEnhancedConsole( process.exit(0); } lastCtrlCTime = now; - dispatcherRef.current.cancelCommand(currentRequestId); + if (currentRequestId) { + dispatcherRef.current.cancelCommand(currentRequestId); + } else if (currentClientRequestId) { + dispatcherRef.current.cancelCommandByClientId( + currentClientRequestId, + ); + } return; } if (currentSpinner) { @@ -1840,16 +1850,21 @@ function startExecutionKeyListener( // stdin encoding may be set to utf8 by questionWithCompletion, // so data can be a string. Use charCodeAt for consistent handling. const code = typeof data === "string" ? data.charCodeAt(0) : data[0]; - if (data.length === 1 && code === 0x1b && currentRequestId) { + if (data.length === 1 && code === 0x1b) { // Escape key — cancel current command - dispatcher.cancelCommand(currentRequestId); + if (currentRequestId) { + dispatcher.cancelCommand(currentRequestId); + } else if (isProcessing && currentClientRequestId) { + // setUserRequest not yet received — cancel by client-assigned ID + dispatcher.cancelCommandByClientId(currentClientRequestId); + } } else if (data.length === 1 && code === 4) { // Ctrl+D — toggle debug panel expand/collapse const panel = getDebugPanel(); if (panel && panel.lineCount > 0) { panel.toggle(); } - } else if (data.length === 1 && code === 3 && currentRequestId) { + } else if (data.length === 1 && code === 3) { // Ctrl+C during raw mode — cancel or double-press exit const now = Date.now(); if (now - lastCtrlCTime < 1000) { @@ -1860,7 +1875,11 @@ function startExecutionKeyListener( process.exit(0); } lastCtrlCTime = now; - dispatcher.cancelCommand(currentRequestId); + if (currentRequestId) { + dispatcher.cancelCommand(currentRequestId); + } else if (isProcessing && currentClientRequestId) { + dispatcher.cancelCommandByClientId(currentClientRequestId); + } } }; @@ -1880,7 +1899,11 @@ function startExecutionKeyListener( */ export async function processCommandsEnhanced( interactivePrompt: string | ((context: T) => string | Promise), - processCommand: (request: string, context: T) => Promise, + processCommand: ( + request: string, + context: T, + clientRequestId: string, + ) => Promise, context: T, inputs?: string[], completionController?: CompletionController, @@ -1963,7 +1986,7 @@ export async function processCommandsEnhanced( if (isSlashCommand(request)) { try { const slashResult = await handleSlashCommand(request, (cmd) => - processCommand(cmd, context), + processCommand(cmd, context, randomUUID()), ); if (slashResult.exit) { break; @@ -1992,14 +2015,20 @@ export async function processCommandsEnhanced( isProcessing = true; currentRequestId = undefined; + currentClientRequestId = randomUUID(); const stopKeyListener = startExecutionKeyListener(dispatcherForCancel); let result: any; try { - result = await processCommand(request, context); + result = await processCommand( + request, + context, + currentClientRequestId, + ); } finally { stopKeyListener(); isProcessing = false; + currentClientRequestId = undefined; } // Stop live panel rendering but keep the buffer for post-command review @@ -2017,7 +2046,7 @@ export async function processCommandsEnhanced( } if (result?.cancelled) { - stopSpinner("warn", "Cancelled"); + stopSpinner("warn", " Cancelled"); } else { stopSpinner("success", "Complete"); } diff --git a/ts/packages/dispatcher/dispatcher/src/command/command.ts b/ts/packages/dispatcher/dispatcher/src/command/command.ts index 3536432313..195600453f 100644 --- a/ts/packages/dispatcher/dispatcher/src/command/command.ts +++ b/ts/packages/dispatcher/dispatcher/src/command/command.ts @@ -352,8 +352,8 @@ export async function processCommandNoLock( attachments, ); } catch (e: any) { - if (e.name === "AbortError") { - throw e; + if (e.name === "AbortError" || context.currentAbortSignal?.aborted) { + throw new DOMException("The operation was aborted.", "AbortError"); } context.clientIO.appendDisplay( makeClientIOMessage( @@ -442,31 +442,46 @@ export async function processCommand( attachments?: string[], options?: ProcessCommandOptions, ): Promise { - // Process one command at at time. - return context.commandLock(async () => { - const abortController = new AbortController(); - const requestIdStr = requestId.requestId; - context.activeRequests.set(requestIdStr, abortController); - beginProcessCommand( - requestId, - context, - options, - abortController.signal, + // Create the AbortController *before* acquiring the lock so that a + // cancelCommandByClientId() call that arrives while we are queued can + // already abort the controller that will drive this command. + const abortController = new AbortController(); + if (requestId.clientRequestId !== undefined) { + context.activeRequestsByClientId.set( + requestId.clientRequestId, + abortController, ); - context.clientIO.setUserRequest(requestId, originalInput); - try { - await processCommandNoLock(originalInput, context, attachments); - } catch (e: any) { - if (e.name === "AbortError") { - ensureCommandResult(context).cancelled = true; - } else { - throw e; + } + try { + // Process one command at a time. + return await context.commandLock(async () => { + const requestIdStr = requestId.requestId; + context.activeRequests.set(requestIdStr, abortController); + beginProcessCommand( + requestId, + context, + options, + abortController.signal, + ); + context.clientIO.setUserRequest(requestId, originalInput); + try { + await processCommandNoLock(originalInput, context, attachments); + } catch (e: any) { + if (e.name === "AbortError") { + ensureCommandResult(context).cancelled = true; + } else { + throw e; + } + } finally { + context.activeRequests.delete(requestIdStr); + return endProcessCommand(requestId, context); } - } finally { - context.activeRequests.delete(requestIdStr); - return endProcessCommand(requestId, context); + }); + } finally { + if (requestId.clientRequestId !== undefined) { + context.activeRequestsByClientId.delete(requestId.clientRequestId); } - }); + } } export const enum unicodeChar { diff --git a/ts/packages/dispatcher/dispatcher/src/context/commandHandlerContext.ts b/ts/packages/dispatcher/dispatcher/src/context/commandHandlerContext.ts index 301db6437a..7f9088891c 100644 --- a/ts/packages/dispatcher/dispatcher/src/context/commandHandlerContext.ts +++ b/ts/packages/dispatcher/dispatcher/src/context/commandHandlerContext.ts @@ -165,6 +165,7 @@ export type CommandHandlerContext = { currentRequestId: RequestId | undefined; currentAbortSignal: AbortSignal | undefined; activeRequests: Map; + activeRequestsByClientId: Map; noReasoning: boolean; isInsideReasoningLoop: boolean; // true while the MCP execute_action handler is dispatching a sub-action commandResult?: CommandResult | undefined; @@ -597,6 +598,7 @@ export async function initializeCommandHandlerContext( currentRequestId: undefined, currentAbortSignal: undefined, activeRequests: new Map(), + activeRequestsByClientId: new Map(), noReasoning: false, isInsideReasoningLoop: false, pendingToggleTransientAgents: [], diff --git a/ts/packages/dispatcher/dispatcher/src/dispatcher.ts b/ts/packages/dispatcher/dispatcher/src/dispatcher.ts index f8422d28c2..a2a8090684 100644 --- a/ts/packages/dispatcher/dispatcher/src/dispatcher.ts +++ b/ts/packages/dispatcher/dispatcher/src/dispatcher.ts @@ -299,6 +299,13 @@ export function createDispatcherFromContext( controller.abort(); } }, + cancelCommandByClientId(clientRequestId: unknown) { + const controller = + context.activeRequestsByClientId.get(clientRequestId); + if (controller) { + controller.abort(); + } + }, async respondToChoice(choiceId: string, response: boolean | number[]) { return context.commandLock(async () => { const pending = context.pendingChoiceRoutes.get(choiceId); diff --git a/ts/packages/dispatcher/dispatcher/src/execute/actionHandlers.ts b/ts/packages/dispatcher/dispatcher/src/execute/actionHandlers.ts index 7e3e569473..d0b9915e91 100644 --- a/ts/packages/dispatcher/dispatcher/src/execute/actionHandlers.ts +++ b/ts/packages/dispatcher/dispatcher/src/execute/actionHandlers.ts @@ -185,11 +185,17 @@ export async function executeAction( ); } } catch (e: any) { - if (e.name === "AbortError") { - throw e; + if ( + e.name === "AbortError" || + systemContext.currentAbortSignal?.aborted + ) { + throw new DOMException("The operation was aborted.", "AbortError"); } result = createActionResultFromError(e.message); } + // If the agent ran to completion but a cancel arrived while it was executing, + // discard the result and treat this as a cancellation. + systemContext.currentAbortSignal?.throwIfAborted(); actionContext.profiler?.stop(); actionContext.profiler = undefined; @@ -584,6 +590,10 @@ export function startStreamPartialAction( context.streamingActionContext = actionContextWithClose; return (name: string, value: any, delta?: string) => { + // Gap 2: Drop streaming chunks after cancellation to avoid + // dispatching to agents once the request has been aborted. + if (context.currentAbortSignal?.aborted) return; + appAgent.streamPartialAction!( actionName, name, @@ -631,8 +641,14 @@ export async function executeCommand( attachments, ); } catch (e: any) { - if (e.name === "AbortError") { - throw e; + if ( + e.name === "AbortError" || + context.currentAbortSignal?.aborted + ) { + throw new DOMException( + "The operation was aborted.", + "AbortError", + ); } displayError(`ERROR: ${e.message}`, actionContext); debugCommandExecError(e.stack); diff --git a/ts/packages/dispatcher/dispatcher/src/translation/agentTranslators.ts b/ts/packages/dispatcher/dispatcher/src/translation/agentTranslators.ts index 374d7ce1f6..62aef9e8cb 100644 --- a/ts/packages/dispatcher/dispatcher/src/translation/agentTranslators.ts +++ b/ts/packages/dispatcher/dispatcher/src/translation/agentTranslators.ts @@ -260,6 +260,7 @@ export type TypeAgentTranslator = { attachments?: CachedImageWithDetails[], cb?: IncrementalJsonValueCallBack, usageCallback?: CompleteUsageStatsCallback, + signal?: AbortSignal, ): Promise>; checkTranslate(request: string): Promise>; getSchemaName(actionName: string): string | undefined; @@ -442,6 +443,7 @@ function createTypeAgentTranslator< attachments?: CachedImageWithDetails[], cb?: IncrementalJsonValueCallBack, usageCallback?: CompleteUsageStatsCallback, + signal?: AbortSignal, ) => { // Expand the request prompt up front with the history and attachments const requestPrompt = createTypeAgentRequestPrompt( @@ -460,6 +462,7 @@ function createTypeAgentTranslator< attachments, cb, usageCallback, + signal, ); }, // No streaming, no history, no attachments. diff --git a/ts/packages/dispatcher/dispatcher/src/translation/interpretRequest.ts b/ts/packages/dispatcher/dispatcher/src/translation/interpretRequest.ts index 355b2b3b59..568d5425e6 100644 --- a/ts/packages/dispatcher/dispatcher/src/translation/interpretRequest.ts +++ b/ts/packages/dispatcher/dispatcher/src/translation/interpretRequest.ts @@ -108,7 +108,13 @@ async function interpretRequestWithActiveSchemas( ); const canUseCacheMatch = cannotUseCacheReason === undefined; const match = canUseCacheMatch - ? await matchRequest(context, request, history, activeSchemaNames) + ? await matchRequest( + context, + request, + history, + activeSchemaNames, + context.sessionContext.agentContext.currentAbortSignal, + ) : undefined; return ( diff --git a/ts/packages/dispatcher/dispatcher/src/translation/matchRequest.ts b/ts/packages/dispatcher/dispatcher/src/translation/matchRequest.ts index 55f598ebcb..0d830b8087 100644 --- a/ts/packages/dispatcher/dispatcher/src/translation/matchRequest.ts +++ b/ts/packages/dispatcher/dispatcher/src/translation/matchRequest.ts @@ -19,9 +19,12 @@ const debugConstValidation = registerDebug("typeagent:const:validation"); async function validateWildcardMatch( match: MatchResult, context: CommandHandlerContext, + signal?: AbortSignal, ) { const actions = match.match.actions; for (const { action } of actions) { + // Check abort signal before processing each action + signal?.throwIfAborted(); const schemaName = action.schemaName; if (schemaName === undefined) { continue; @@ -50,11 +53,16 @@ async function validateWildcardMatch( async function validateEntityWildcardMatch( match: MatchResult, context: CommandHandlerContext, + signal?: AbortSignal, ): Promise { if (match.entityWildcardPropertyNames.length === 0) { // No entity wildcard, nothing to validate. return true; } + + // Check abort signal before starting validation loop + signal?.throwIfAborted(); + debugConstValidation( `Validating entity wildcards: [${match.entityWildcardPropertyNames.join(", ")}] ` + `for actions: [${match.match.actions.map((a) => `${a.action.schemaName}.${a.action.actionName}`).join(", ")}]`, @@ -72,6 +80,9 @@ async function validateEntityWildcardMatch( const agents = context.agents; for (const propertyName of match.entityWildcardPropertyNames) { + // Check abort signal before each property validation + signal?.throwIfAborted(); + const canResolve = await canResolvePropertyEntity( conversationMemory, propertyName, @@ -94,21 +105,26 @@ async function validateEntityWildcardMatch( * * @param matches * @param context + * @param signal Optional AbortSignal to allow cancellation during validation * @returns the validate matched. */ async function getValidatedMatch( matches: MatchResult[], context: CommandHandlerContext, + signal?: AbortSignal, ) { for (const match of matches) { - if (!(await validateEntityWildcardMatch(match, context))) { + // Check abort signal before processing each match + signal?.throwIfAborted(); + + if (!(await validateEntityWildcardMatch(match, context, signal))) { continue; } if (match.wildcardCharCount === 0) { return match; } - if (await validateWildcardMatch(match, context)) { + if (await validateWildcardMatch(match, context, signal)) { debugConstValidation( `Wildcard match accepted: ${match.match.actions}`, ); @@ -188,6 +204,7 @@ export async function matchRequest( request: string, history?: HistoryContext, activeSchemas?: string[], + signal?: AbortSignal, ): Promise { // Bypass grammar cache for recording/reasoning-directed requests. const lower = request.trimStart().toLowerCase(); @@ -195,6 +212,9 @@ export async function matchRequest( return undefined; } + // Check abort signal before expensive grammar matching + signal?.throwIfAborted(); + const systemContext = context.sessionContext.agentContext; const agentCache = systemContext.agentCache; if (!agentCache.isEnabled()) { @@ -230,7 +250,7 @@ export async function matchRequest( const elapsedMs = performance.now() - startTime; - const match = await getValidatedMatch(matches, systemContext); + const match = await getValidatedMatch(matches, systemContext, signal); if (match === undefined) { if (matches.length > 0) { console.log( diff --git a/ts/packages/dispatcher/dispatcher/src/translation/translateRequest.ts b/ts/packages/dispatcher/dispatcher/src/translation/translateRequest.ts index 2118386300..2c35823eb2 100644 --- a/ts/packages/dispatcher/dispatcher/src/translation/translateRequest.ts +++ b/ts/packages/dispatcher/dispatcher/src/translation/translateRequest.ts @@ -449,6 +449,7 @@ async function translateWithTranslator( attachments, onProperty, usageCallback, + systemContext.currentAbortSignal, ); if (!response.success) { @@ -491,7 +492,10 @@ async function findAssistantForRequest( systemContext.promptLogger, ); - const result = await selectTranslator.translate(request); + const result = await selectTranslator.translate( + request, + systemContext.currentAbortSignal, + ); if (!result.success) { displayWarn(`Failed to switch assistant: ${result.message}`, context); return undefined; @@ -515,6 +519,7 @@ async function getNextTranslation( context: ActionContext, forceSearch: boolean, ): Promise { + context.sessionContext.agentContext.currentAbortSignal?.throwIfAborted(); let request: string; if (isAdditionalActionLookupAction(action)) { if (!forceSearch) { @@ -553,6 +558,7 @@ async function finalizeAction( let currentTranslator = translator; let currentSchemaName: string = schemaName; while (true) { + context.sessionContext.agentContext.currentAbortSignal?.throwIfAborted(); const forceSearch = currentAction !== action; // force search if we have switched once const nextTranslation = await getNextTranslation( currentAction, @@ -653,6 +659,7 @@ async function finalizeMultipleActions( const requests = action.parameters.requests; const actions: ExecutableAction[] = []; for (const request of requests) { + context.sessionContext.agentContext.currentAbortSignal?.throwIfAborted(); if (isPendingRequest(request)) { actions.push(createPendingRequestAction(request)); continue; diff --git a/ts/packages/dispatcher/dispatcher/src/translation/unknownSwitcher.ts b/ts/packages/dispatcher/dispatcher/src/translation/unknownSwitcher.ts index e8f0b6e567..eff0024507 100644 --- a/ts/packages/dispatcher/dispatcher/src/translation/unknownSwitcher.ts +++ b/ts/packages/dispatcher/dispatcher/src/translation/unknownSwitcher.ts @@ -130,33 +130,58 @@ export type AssistantSelection = { type TranslatorPartition = { names: string[]; translator: { - translate: (request: string) => Promise>; + translate: ( + request: string, + signal?: AbortSignal, + ) => Promise>; }; }; /** * Run all translator partitions in parallel and return the first non-"unknown" * result in partition order, or "unknown" if all partitions return "unknown". + * + * The provided signal is forwarded to each partition's translator so that an + * abort tears down the in-flight LLM HTTP requests, rather than just stopping + * the loop from awaiting them. */ export async function selectFromPartitions( partitions: TranslatorPartition[], request: string, + signal?: AbortSignal, ): Promise> { partitions.forEach(({ names }) => debugSwitchSearch(`Switch: searching ${names.join(", ")}`), ); // Start all translations in parallel. const promises = partitions.map(({ translator }) => - translator.translate(request).catch( - (err): Result => ({ - success: false, - message: err instanceof Error ? err.message : String(err), + translator + .translate(request, signal) + .catch((err): Result => { + // Let AbortError propagate so callers see cancellation rather + // than a generic failure result. Non-abort errors are converted + // to a failure result so a single bad partition doesn't block + // the others from being checked. + if ( + err instanceof Error && + (err.name === "AbortError" || + (err as DOMException).name === "AbortError") + ) { + throw err; + } + return { + success: false, + message: err instanceof Error ? err.message : String(err), + }; }), - ), ); // Await results in partition order; return as soon as outcome is decided. for (const promise of promises) { + // Check for cancellation before and after awaiting each partition result + // so an abort that fires mid-await is caught promptly. + signal?.throwIfAborted(); const result = await promise; + signal?.throwIfAborted(); if (!result.success) { return result; } @@ -199,10 +224,9 @@ export function loadAssistantSelectionJsonTranslator( current.push(entry); currentLength += schema.length; } - const translators = partitions.map((entries) => { - return { - names: entries.map((entry) => entry.name), - translator: createJsonTranslatorFromSchemaDef( + const translators = partitions.map((entries): TranslatorPartition => { + const innerTranslator = + createJsonTranslatorFromSchemaDef( "AllAssistantSelection", entries .map((entry) => entry.schema) @@ -217,12 +241,22 @@ export function loadAssistantSelectionJsonTranslator( ], promptLogger, }, - ), + ); + return { + names: entries.map((entry) => entry.name), + // Adapt typechat's translate(request, promptPreamble?, signal?) to + // the (request, signal?) shape selectFromPartitions expects. The + // signal is smuggled through to model.complete via the prompt's + // ModelParamSection (see jsonTranslator.ts). + translator: { + translate: (request: string, signal?: AbortSignal) => + innerTranslator.translate(request, undefined, signal), + }, }; }); return { - translate: (request: string) => - selectFromPartitions(translators, request), + translate: (request: string, signal?: AbortSignal) => + selectFromPartitions(translators, request, signal), }; } diff --git a/ts/packages/dispatcher/dispatcher/test/cancel.spec.ts b/ts/packages/dispatcher/dispatcher/test/cancel.spec.ts new file mode 100644 index 0000000000..18a05766b9 --- /dev/null +++ b/ts/packages/dispatcher/dispatcher/test/cancel.spec.ts @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/** + * Smoke test for Escape/Ctrl+C cancellation (AbortSignal wired end-to-end). + * + * Strategy: create a dispatcher whose agent handler blocks indefinitely on a + * slow Promise. Intercept setUserRequest() to capture the internal requestId, + * then immediately call cancelCommand() — simulating Escape/Ctrl+C. Verify: + * 1. processCommand() resolves (does not hang). + * 2. CommandResult.cancelled === true. + * 3. The whole round-trip completes in well under 500 ms. + */ + +import { AppAgent, AppAgentManifest } from "@typeagent/agent-sdk"; +import { AppAgentProvider } from "../src/agentProvider/agentProvider.js"; +import { getCommandInterface } from "@typeagent/agent-sdk/helpers/command"; +import { createDispatcher } from "../src/dispatcher.js"; +import type { Dispatcher } from "@typeagent/dispatcher-types"; +import type { ClientIO, RequestId } from "@typeagent/dispatcher-types"; + +// ── Slow agent that hangs for 30 s (simulates LLM network call) ───────────── + +const slowConfig: AppAgentManifest = { + emojiChar: "🐢", + description: "Slow test agent", +}; + +const slowHandlers = { + description: "Slow Command Table", + commands: { + slow: { + description: "A command that takes 30 seconds", + run: async (): Promise => { + // Never resolves within normal test time + await new Promise((resolve) => + setTimeout(resolve, 30_000), + ); + }, + }, + }, +} as const; + +const slowAgent: AppAgent = { + ...getCommandInterface(slowHandlers), +}; + +const slowAgentProvider: AppAgentProvider = { + getAppAgentNames: () => ["slow"], + getAppAgentManifest: async (name) => { + if (name !== "slow") throw new Error(`Unknown agent: ${name}`); + return slowConfig; + }, + loadAppAgent: async (name) => { + if (name !== "slow") throw new Error(`Unknown agent: ${name}`); + return slowAgent; + }, + unloadAppAgent: async () => {}, +}; + +// ── Helpers ────────────────────────────────────────────────────────────────── + +/** + * Build a ClientIO that captures the first requestId from setUserRequest() + * and resolves a promise so the test can call cancelCommand() immediately. + */ +function makeCapturingClientIO(): { + clientIO: ClientIO; + requestIdPromise: Promise; +} { + let resolveId!: (id: string) => void; + const requestIdPromise = new Promise((r) => { + resolveId = r; + }); + + const clientIO: ClientIO = { + clear: () => {}, + exit: () => process.exit(0), + shutdown: () => process.exit(0), + setUserRequest: (requestId: RequestId) => { + resolveId(requestId.requestId); + }, + setDisplayInfo: () => {}, + setDisplay: () => {}, + appendDisplay: () => {}, + appendDiagnosticData: () => {}, + setDynamicDisplay: () => {}, + question: async (_r, _m, _c, defaultId) => defaultId ?? 0, + proposeAction: async () => undefined, + notify: () => {}, + openLocalView: async () => {}, + closeLocalView: async () => {}, + requestChoice: () => {}, + requestInteraction: () => {}, + interactionResolved: () => {}, + interactionCancelled: () => {}, + takeAction: (_requestId, action) => { + throw new Error(`Action ${action} not supported`); + }, + }; + + return { clientIO, requestIdPromise }; +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +describe("Cancellation (AbortSignal smoke tests)", () => { + let dispatcher: Dispatcher; + const { clientIO, requestIdPromise } = makeCapturingClientIO(); + + beforeAll(async () => { + dispatcher = await createDispatcher("test", { + agents: { + actions: false, + schemas: false, + }, + translation: { enabled: false }, + explainer: { enabled: false }, + cache: { enabled: false }, + appAgentProviders: [slowAgentProvider], + collectCommandResult: true, + clientIO, + }); + }); + + afterAll(async () => { + if (dispatcher) { + await dispatcher.close(); + } + }); + + it("cancels a slow command quickly and sets cancelled:true", async () => { + const start = Date.now(); + + // Start the slow command — it will block for 30 s unless cancelled. + const resultPromise = dispatcher.processCommand("@slow slow"); + + // As soon as the dispatcher registers the request (setUserRequest fires), + // immediately cancel it — simulating the user pressing Escape. + const requestId = await requestIdPromise; + dispatcher.cancelCommand(requestId); + + const result = await resultPromise; + const elapsed = Date.now() - start; + + expect(result).toBeDefined(); + expect(result!.cancelled).toBe(true); + // Should return well within 500 ms — not wait 30 s. + expect(elapsed).toBeLessThan(500); + }, 5_000 /* generous 5 s jest timeout */); +}); +describe("Early cancellation via cancelCommandByClientId", () => { + let dispatcher: Dispatcher; + + // ClientIO that captures setUserRequest so the first (blocking) command + // can be cancelled once it's running, freeing the lock for the second. + function makeBlockingClientIO(): { + clientIO: ClientIO; + firstRequestIdPromise: Promise; + } { + let resolved = false; + let resolveId!: (id: string) => void; + const firstRequestIdPromise = new Promise((r) => { + resolveId = r; + }); + const clientIO: ClientIO = { + clear: () => {}, + exit: () => process.exit(0), + shutdown: () => process.exit(0), + setUserRequest: (requestId: RequestId) => { + if (!resolved) { + resolved = true; + resolveId(requestId.requestId); + } + }, + setDisplayInfo: () => {}, + setDisplay: () => {}, + appendDisplay: () => {}, + appendDiagnosticData: () => {}, + setDynamicDisplay: () => {}, + question: async (_r, _m, _c, defaultId) => defaultId ?? 0, + proposeAction: async () => undefined, + notify: () => {}, + openLocalView: async () => {}, + closeLocalView: async () => {}, + requestChoice: () => {}, + requestInteraction: () => {}, + interactionResolved: () => {}, + interactionCancelled: () => {}, + takeAction: (_requestId, action) => { + throw new Error(`Action ${action} not supported`); + }, + }; + return { clientIO, firstRequestIdPromise }; + } + + const { clientIO, firstRequestIdPromise } = makeBlockingClientIO(); + + beforeAll(async () => { + dispatcher = await createDispatcher("test-early", { + agents: { + actions: false, + schemas: false, + }, + translation: { enabled: false }, + explainer: { enabled: false }, + cache: { enabled: false }, + appAgentProviders: [slowAgentProvider], + collectCommandResult: true, + clientIO, + }); + }); + + afterAll(async () => { + if (dispatcher) { + await dispatcher.close(); + } + }); + + it("cancels a queued command by client-assigned id before setUserRequest fires", async () => { + const clientRequestId = "test-client-id-early-cancel"; + + // Start a slow command that will hold the command lock. + const firstResultPromise = dispatcher.processCommand("@slow slow"); + + // Wait for the first command to be running (setUserRequest fired). + const firstRequestId = await firstRequestIdPromise; + + // Queue a second command with a client-assigned id while the first is still running. + // It is now blocked behind the command lock — setUserRequest has not yet fired for it. + const start = Date.now(); + const secondResultPromise = dispatcher.processCommand( + "@slow slow", + clientRequestId, + ); + + // Cancel the second command by client id — before it acquires the lock. + // This is the early-cancel path: the AbortController exists but setUserRequest + // has not yet been called, so the server requestId is not yet known to the caller. + dispatcher.cancelCommandByClientId(clientRequestId); + + // Also cancel the first command so the test doesn't hang. + dispatcher.cancelCommand(firstRequestId); + + const [firstResult, secondResult] = await Promise.all([ + firstResultPromise, + secondResultPromise, + ]); + const elapsed = Date.now() - start; + + expect(firstResult?.cancelled).toBe(true); + expect(secondResult).toBeDefined(); + expect(secondResult!.cancelled).toBe(true); + // Second command should cancel quickly once the first releases the lock. + expect(elapsed).toBeLessThan(1000); + }, 10_000); +}); diff --git a/ts/packages/dispatcher/rpc/src/dispatcherClient.ts b/ts/packages/dispatcher/rpc/src/dispatcherClient.ts index 6b5557c4b1..718db5359b 100644 --- a/ts/packages/dispatcher/rpc/src/dispatcherClient.ts +++ b/ts/packages/dispatcher/rpc/src/dispatcherClient.ts @@ -64,5 +64,8 @@ export function createDispatcherRpcClient( cancelCommand(...args) { return rpc.send("cancelCommand", ...args); }, + cancelCommandByClientId(...args) { + return rpc.send("cancelCommandByClientId", ...args); + }, }; } diff --git a/ts/packages/dispatcher/rpc/src/dispatcherServer.ts b/ts/packages/dispatcher/rpc/src/dispatcherServer.ts index 22fd33b643..62d1db2ea6 100644 --- a/ts/packages/dispatcher/rpc/src/dispatcherServer.ts +++ b/ts/packages/dispatcher/rpc/src/dispatcherServer.ts @@ -17,6 +17,9 @@ export function createDispatcherRpcServer( cancelCommand(...args) { dispatcher.cancelCommand(...args); }, + cancelCommandByClientId(...args) { + dispatcher.cancelCommandByClientId(...args); + }, cancelInteraction(...args) { dispatcher.cancelInteraction(...args); }, diff --git a/ts/packages/dispatcher/rpc/src/dispatcherTypes.ts b/ts/packages/dispatcher/rpc/src/dispatcherTypes.ts index 8c3f48ee2d..3ad8f0772b 100644 --- a/ts/packages/dispatcher/rpc/src/dispatcherTypes.ts +++ b/ts/packages/dispatcher/rpc/src/dispatcherTypes.ts @@ -68,5 +68,6 @@ export type DispatcherInvokeFunctions = { export type DispatcherCallFunctions = { cancelCommand(requestId: string): void; + cancelCommandByClientId(clientRequestId: unknown): void; cancelInteraction(interactionId: string): void; }; diff --git a/ts/packages/dispatcher/rpc/test/dispatcherRpc.spec.ts b/ts/packages/dispatcher/rpc/test/dispatcherRpc.spec.ts index fb7d94b16d..c4465763fa 100644 --- a/ts/packages/dispatcher/rpc/test/dispatcherRpc.spec.ts +++ b/ts/packages/dispatcher/rpc/test/dispatcherRpc.spec.ts @@ -67,6 +67,9 @@ function makeStubDispatcher(overrides: Partial = {}): Dispatcher & { cancelCommand(...args) { calls.push({ method: "cancelCommand", args }); }, + cancelCommandByClientId(...args) { + calls.push({ method: "cancelCommandByClientId", args }); + }, async respondToInteraction(...args) { calls.push({ method: "respondToInteraction", args }); }, diff --git a/ts/packages/dispatcher/types/src/dispatcher.ts b/ts/packages/dispatcher/types/src/dispatcher.ts index 2c96f37982..2842301c50 100644 --- a/ts/packages/dispatcher/types/src/dispatcher.ts +++ b/ts/packages/dispatcher/types/src/dispatcher.ts @@ -268,4 +268,14 @@ export interface Dispatcher { * @param requestId the requestId string of the command to cancel */ cancelCommand(requestId: string): void; + + /** + * Cancel an in-flight command using the client-assigned id that was passed + * as the second argument to processCommand(). This is the early-cancel + * path: the client can call this immediately after processCommand() returns + * without waiting for setUserRequest() to deliver the server-assigned UUID. + * + * @param clientRequestId the same value passed to processCommand() as clientRequestId + */ + cancelCommandByClientId(clientRequestId: unknown): void; } diff --git a/ts/packages/shell/src/renderer/src/chat/messageGroup.ts b/ts/packages/shell/src/renderer/src/chat/messageGroup.ts index fc30e3f82f..b731d0ebce 100644 --- a/ts/packages/shell/src/renderer/src/chat/messageGroup.ts +++ b/ts/packages/shell/src/renderer/src/chat/messageGroup.ts @@ -107,10 +107,16 @@ export class MessageGroup { private requestCompleted(result: CommandResult | undefined) { this.updateMetrics(result?.metrics); if (result?.cancelled) { - this.addStatusMessage( - { message: "⚠ Cancelled", source: "shell" }, - false, - ); + const lastAgentMessage = this.getLastAgentMessage(); + if (lastAgentMessage !== undefined) { + lastAgentMessage.setMessage("⚠ Cancelled", "shell", "block"); + this.chatView.updateScroll(); + } else { + this.addStatusMessage( + { message: "⚠ Cancelled", source: "shell" }, + false, + ); + } } else if (this.statusMessage === undefined) { this.addStatusMessage( { message: "Command completed", source: "shell" }, diff --git a/ts/packages/utils/typechatUtils/src/index.ts b/ts/packages/utils/typechatUtils/src/index.ts index e6be3fd835..6d4751dd32 100644 --- a/ts/packages/utils/typechatUtils/src/index.ts +++ b/ts/packages/utils/typechatUtils/src/index.ts @@ -11,6 +11,7 @@ export { composeTranslatorSchemas, enableJsonTranslatorStreaming, TypeChatJsonTranslatorWithStreaming, + TypeChatJsonTranslatorWithSignal, createJsonTranslatorWithValidator, TypeAgentJsonValidator, JsonTranslatorOptions, diff --git a/ts/packages/utils/typechatUtils/src/jsonTranslator.ts b/ts/packages/utils/typechatUtils/src/jsonTranslator.ts index 159105c3f2..e2bdb6f794 100644 --- a/ts/packages/utils/typechatUtils/src/jsonTranslator.ts +++ b/ts/packages/utils/typechatUtils/src/jsonTranslator.ts @@ -66,16 +66,40 @@ export interface TypeChatJsonTranslatorWithStreaming attachments?: CachedImageWithDetails[] | undefined, cb?: IncrementalJsonValueCallBack, usageCallback?: CompleteUsageStatsCallback, + signal?: AbortSignal, + ) => Promise>; +} + +/** + * Variant of TypeChatJsonTranslator that exposes an optional AbortSignal on + * translate() so callers can cancel the in-flight LLM request. The signal is + * smuggled to model.complete via a ModelParamSection in the prompt array. + */ +export interface TypeChatJsonTranslatorWithSignal + extends TypeChatJsonTranslator { + translate: ( + request: string, + promptPreamble?: string | PromptSection[], + signal?: AbortSignal, ) => Promise>; } // This rely on the fact that the prompt preamble based to typechat are copied to the final prompt. // Add a internal section so we can pass information from the caller to the model.complete function. +// +// TECH DEBT: This is an intentional workaround for TypeChat's limited translate() +// signature, which only accepts (request, promptPreamble). We smuggle additional +// parameters (parser, usageCallback, signal) through the prompt array using a +// synthetic section with role: "model". This works because TypeChat copies prompt +// sections to the final prompt without inspecting "model" role entries, and our +// model.complete wrapper strips them before the actual API call. This coupling is +// fragile and should be replaced if TypeChat adds extensible translate() options. type ModelParamSection = { role: "model"; content: { parser: IncrementalJsonParser | undefined; usageCallback: CompleteUsageStatsCallback | undefined; + signal: AbortSignal | undefined; }; }; @@ -83,8 +107,13 @@ function addModelParamSection( promptPreamble?: string | PromptSection[], cb?: IncrementalJsonValueCallBack, usageCallback?: CompleteUsageStatsCallback, + signal?: AbortSignal, ) { - if (cb === undefined && usageCallback === undefined) { + if ( + cb === undefined && + usageCallback === undefined && + signal === undefined + ) { return promptPreamble; } const prompts: (PromptSection | ModelParamSection)[] = @@ -103,6 +132,7 @@ function addModelParamSection( content: { parser, usageCallback, + signal, }, }); @@ -125,6 +155,7 @@ function getModelParams( return { parser: internal[0].content.parser, usageCallback: internal[0].content.usageCallback, + signal: internal[0].content.signal, actualPrompt: newPrompt as PromptSection[], }; } @@ -149,17 +180,24 @@ export function enableJsonTranslatorStreaming( promptLogger?.logModelRequest, ); } - const { parser, usageCallback, actualPrompt } = modelParams; + const { parser, usageCallback, signal, actualPrompt } = modelParams; if (parser === undefined) { return originalComplete( actualPrompt, usageCallback, undefined, promptLogger?.logModelRequest, + signal, ); } const chunks = []; - const result = await model.completeStream(actualPrompt, usageCallback); + const result = await model.completeStream( + actualPrompt, + usageCallback, + undefined, + undefined, + signal, + ); if (!result.success) { return result; } @@ -180,11 +218,12 @@ export function enableJsonTranslatorStreaming( attachments?: CachedImageWithDetails[], cb?: IncrementalJsonValueCallBack, usageCallback?: CompleteUsageStatsCallback, + signal?: AbortSignal, ) => { await attachAttachments(attachments, promptPreamble); return originalTranslate( request, - addModelParamSection(promptPreamble, cb, usageCallback), + addModelParamSection(promptPreamble, cb, usageCallback, signal), ); }; @@ -236,7 +275,7 @@ export function createJsonTranslatorFromSchemaDef( typeName: string, schemas: string | TranslatorSchemaDef[], options?: JsonTranslatorOptions, -) { +): TypeChatJsonTranslatorWithSignal { const schema = Array.isArray(schemas) ? composeTranslatorSchemas(typeName, schemas) : schemas; @@ -262,7 +301,7 @@ export function createJsonTranslatorWithValidator( name: string, validator: TypeAgentJsonValidator, options?: JsonTranslatorOptions, -) { +): TypeChatJsonTranslatorWithSignal { const model = ai.createChatModel( options?.model, { @@ -281,17 +320,28 @@ export function createJsonTranslatorWithValidator( model.complete = async ( prompt: string | PromptSection[], usageCallback?: CompleteUsageStatsCallback, + jsonSchemaOverride?: CompletionJsonSchema, + logFnOverride?: (msg: any) => void, + signal?: AbortSignal, ) => { - debugPrompt(prompt); - const jsonSchema = validator.getJsonSchema?.(); + // Extract the smuggled signal (and any other model params) from the + // prompt sections so callers of translator.translate(request, preamble, signal) + // can propagate cancellation all the way down to the HTTP request. + const modelParams = getModelParams(prompt); + const actualPrompt = modelParams?.actualPrompt ?? prompt; + const actualUsageCallback = modelParams?.usageCallback ?? usageCallback; + const actualSignal = modelParams?.signal ?? signal; + debugPrompt(actualPrompt); + const jsonSchema = jsonSchemaOverride ?? validator.getJsonSchema?.(); if (jsonSchema !== undefined) { debugJsonSchema(jsonSchema); } return originalComplete( - prompt, - usageCallback, + actualPrompt, + actualUsageCallback, jsonSchema, - options?.promptLogger?.logModelRequest, + logFnOverride ?? options?.promptLogger?.logModelRequest, + actualSignal, ); }; @@ -300,13 +350,28 @@ export function createJsonTranslatorWithValidator( model.completeStream = async ( prompt: string | PromptSection[], usageCallback?: CompleteUsageStatsCallback, + jsonSchemaOverride?: CompletionJsonSchema, + logFnOverride?: (msg: any) => void, + signal?: AbortSignal, ) => { - debugPrompt(prompt); - const jsonSchema = validator.getJsonSchema?.(); + const modelParams = getModelParams(prompt); + const actualPrompt = modelParams?.actualPrompt ?? prompt; + const actualUsageCallback = + modelParams?.usageCallback ?? usageCallback; + const actualSignal = modelParams?.signal ?? signal; + debugPrompt(actualPrompt); + const jsonSchema = + jsonSchemaOverride ?? validator.getJsonSchema?.(); if (jsonSchema !== undefined) { debugJsonSchema(jsonSchema); } - return originalCompleteStream(prompt, usageCallback, jsonSchema); + return originalCompleteStream( + actualPrompt, + actualUsageCallback, + jsonSchema, + logFnOverride, + actualSignal, + ); }; } @@ -363,23 +428,38 @@ export function createJsonTranslatorWithValidator( translator.translate = async ( request: string, promptPreamble?: string | PromptSection[], + signal?: AbortSignal, ) => { patchStreamCallback(promptPreamble); - const result = await innerFn(request, promptPreamble); + const result = await innerFn( + request, + addModelParamSection( + promptPreamble, + undefined, + undefined, + signal, + ), + ); debugResult(result); return result; }; - return translator; + return translator as TypeChatJsonTranslatorWithSignal; } translator.translate = async ( request: string, promptPreamble?: string | PromptSection[], + signal?: AbortSignal, ) => { patchStreamCallback(promptPreamble); const result = await innerFn( request, - toPromptSections(instructions, promptPreamble), + addModelParamSection( + toPromptSections(instructions, promptPreamble), + undefined, + undefined, + signal, + ), ); debugResult(result); return result; @@ -394,7 +474,7 @@ export function createJsonTranslatorWithValidator( `Based on all available information in our chat history including images previously provided, the following is the latest user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:\n` ); }; - return translator; + return translator as TypeChatJsonTranslatorWithSignal; } /** diff --git a/ts/pnpm-lock.yaml b/ts/pnpm-lock.yaml index 51f8c44da3..94b979da1b 100644 --- a/ts/pnpm-lock.yaml +++ b/ts/pnpm-lock.yaml @@ -150,7 +150,7 @@ importers: version: 8.18.1 jest: specifier: ^29.7.0 - version: 29.7.0(@types/node@20.19.25)(ts-node@10.9.2(@types/node@20.19.25)(typescript@5.4.5)) + version: 29.7.0(@types/node@25.5.2)(ts-node@10.9.2(@types/node@25.5.2)(typescript@5.4.5)) prettier: specifier: ^3.5.3 version: 3.5.3 @@ -2619,15 +2619,24 @@ importers: specifier: ^2.11.2 version: 2.11.2(puppeteer-extra@3.3.6(puppeteer-core@24.37.5)(puppeteer@24.37.5(typescript@5.4.5))) devDependencies: + '@jest/globals': + specifier: ^29.7.0 + version: 29.7.0 '@typeagent/action-schema-compiler': specifier: workspace:* version: link:../../actionSchemaCompiler + '@types/jest': + specifier: ^29.5.7 + version: 29.5.14 action-grammar-compiler: specifier: workspace:* version: link:../../actionGrammarCompiler concurrently: specifier: ^9.1.2 version: 9.1.2 + jest: + specifier: ^29.7.0 + version: 29.7.0(@types/node@20.19.25)(ts-node@10.9.2(@types/node@20.19.25)(typescript@5.4.5)) prettier: specifier: ^3.5.3 version: 3.5.3 @@ -15039,7 +15048,7 @@ snapshots: '@anthropic-ai/claude-agent-sdk@0.2.105': dependencies: - '@anthropic-ai/sdk': 0.81.0(zod@3.25.76) + '@anthropic-ai/sdk': 0.81.0(zod@4.1.13) '@modelcontextprotocol/sdk': 1.29.0 optionalDependencies: '@img/sharp-darwin-arm64': 0.34.5 @@ -18294,7 +18303,7 @@ snapshots: json-schema-typed: 8.0.2 pkce-challenge: 5.0.1 raw-body: 3.0.2 - zod-to-json-schema: 3.25.1(zod@3.25.76) + zod-to-json-schema: 3.25.1(zod@4.1.13) transitivePeerDependencies: - supports-color