diff --git a/client/src/App.tsx b/client/src/App.tsx index 59d15ba06..b32e89e9f 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -124,6 +124,50 @@ const cloneToolParams = ( } }; +const createPollingCancelledError = (): Error | DOMException => { + if (typeof DOMException !== "undefined") { + return new DOMException("Polling cancelled", "AbortError"); + } + + const error = new Error("Polling cancelled"); + error.name = "AbortError"; + return error; +}; + +const isAbortError = (error: unknown): boolean => { + return ( + typeof error === "object" && + error !== null && + "name" in error && + (error as { name?: string }).name === "AbortError" + ); +}; + +const waitForNextTaskPoll = ( + pollInterval: number, + signal: AbortSignal, +): Promise => { + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(createPollingCancelledError()); + return; + } + + function onAbort() { + window.clearTimeout(timeoutId); + signal.removeEventListener("abort", onAbort); + reject(createPollingCancelledError()); + } + + const timeoutId = window.setTimeout(() => { + signal.removeEventListener("abort", onAbort); + resolve(); + }, pollInterval); + + signal.addEventListener("abort", onAbort, { once: true }); + }); +}; + const filterReservedMetadata = ( metadata: Record, ): Record => { @@ -304,6 +348,7 @@ const App = () => { const [selectedTool, setSelectedTool] = useState(null); const [selectedTask, setSelectedTask] = useState(null); const [isPollingTask, setIsPollingTask] = useState(false); + const pollingAbortControllerRef = useRef(null); const [nextResourceCursor, setNextResourceCursor] = useState< string | undefined >(); @@ -318,6 +363,16 @@ const App = () => { const progressTokenRef = useRef(0); const prefilledAppsToolCallIdRef = useRef(0); + const cancelToolPolling = useCallback(() => { + pollingAbortControllerRef.current?.abort(); + }, []); + + useEffect(() => { + return () => { + pollingAbortControllerRef.current?.abort(); + }; + }, []); + const [activeTab, setActiveTab] = useState(() => { const hash = window.location.hash.slice(1); const initialTab = hash || "resources"; @@ -839,9 +894,10 @@ const App = () => { request: ClientRequest, schema: T, tabKey?: keyof typeof errors, + options?: Parameters[2], ): Promise> => { try { - const response = await makeRequest(request, schema); + const response = await makeRequest(request, schema, options); if (tabKey !== undefined) { clearError(tabKey); } @@ -1070,6 +1126,8 @@ const App = () => { if (runAsTask && isTaskResult(response)) { const taskId = response.task.taskId; const pollInterval = response.task.pollInterval; + const abortController = new AbortController(); + pollingAbortControllerRef.current = abortController; // Set polling state BEFORE setting tool result for proper UI update setIsPollingTask(true); // Safely extract any _meta from the original response (if present) @@ -1095,100 +1153,130 @@ const App = () => { // Polling loop let taskCompleted = false; - while (!taskCompleted) { - try { - // Wait for 1 second before polling - await new Promise((resolve) => setTimeout(resolve, pollInterval)); - - const taskStatus = await sendMCPRequest( - { - method: "tasks/get", - params: { taskId }, - }, - GetTaskResultSchema, - ); + try { + while (!taskCompleted && !abortController.signal.aborted) { + try { + await waitForNextTaskPoll(pollInterval, abortController.signal); - if ( - taskStatus.status === "completed" || - taskStatus.status === "failed" || - taskStatus.status === "cancelled" - ) { - taskCompleted = true; - console.log( - `Polling complete for task ${taskId}: ${taskStatus.status}`, + const taskStatus = await sendMCPRequest( + { + method: "tasks/get", + params: { taskId }, + }, + GetTaskResultSchema, + undefined, + { signal: abortController.signal }, ); - if (taskStatus.status === "completed") { - console.log(`Fetching result for task ${taskId}`); - const result = await sendMCPRequest( - { - method: "tasks/result", - params: { taskId }, - }, - CompatibilityCallToolResultSchema, + if ( + taskStatus.status === "completed" || + taskStatus.status === "failed" || + taskStatus.status === "cancelled" + ) { + taskCompleted = true; + console.log( + `Polling complete for task ${taskId}: ${taskStatus.status}`, ); - console.log(`Result received for task ${taskId}:`, result); - latestToolResult = result as CompatibilityCallToolResult; - setToolResult(latestToolResult); - // Refresh tasks list to show completed state - void listTasks(); + if (taskStatus.status === "completed") { + console.log(`Fetching result for task ${taskId}`); + const result = await sendMCPRequest( + { + method: "tasks/result", + params: { taskId }, + }, + CompatibilityCallToolResultSchema, + undefined, + { signal: abortController.signal }, + ); + console.log(`Result received for task ${taskId}:`, result); + latestToolResult = result as CompatibilityCallToolResult; + setToolResult(latestToolResult); + + // Refresh tasks list to show completed state + void listTasks(); + } else { + latestToolResult = { + content: [ + { + type: "text", + text: `Task ${taskStatus.status}: ${taskStatus.statusMessage || "No additional information"}`, + }, + ], + isError: true, + }; + setToolResult(latestToolResult); + // Refresh tasks list to show failed/cancelled state + void listTasks(); + } } else { + // Update status message while polling + // Safely extract any _meta from the original response (if present) + const pollingResponseMeta = + response && + typeof response === "object" && + "_meta" in (response as Record) + ? ((response as { _meta?: Record }) + ._meta ?? {}) + : undefined; latestToolResult = { content: [ { type: "text", - text: `Task ${taskStatus.status}: ${taskStatus.statusMessage || "No additional information"}`, + text: `Task status: ${taskStatus.status}${taskStatus.statusMessage ? ` - ${taskStatus.statusMessage}` : ""}. Polling...`, }, ], - isError: true, + _meta: { + ...(pollingResponseMeta || {}), + "io.modelcontextprotocol/related-task": { taskId }, + }, }; setToolResult(latestToolResult); - // Refresh tasks list to show failed/cancelled state + // Refresh tasks list to show progress void listTasks(); } - } else { - // Update status message while polling - // Safely extract any _meta from the original response (if present) - const pollingResponseMeta = - response && - typeof response === "object" && - "_meta" in (response as Record) - ? ((response as { _meta?: Record })._meta ?? - {}) - : undefined; + } catch (pollingError) { + if ( + abortController.signal.aborted || + isAbortError(pollingError) + ) { + latestToolResult = { + content: [ + { + type: "text", + text: `Polling cancelled for task ${taskId}. The task may still be running on the server.`, + }, + ], + _meta: { + ...(initialResponseMeta || {}), + "io.modelcontextprotocol/related-task": { taskId }, + }, + }; + setToolResult(latestToolResult); + taskCompleted = true; + break; + } + + console.error("Error polling task status:", pollingError); latestToolResult = { content: [ { type: "text", - text: `Task status: ${taskStatus.status}${taskStatus.statusMessage ? ` - ${taskStatus.statusMessage}` : ""}. Polling...`, + text: `Error polling task status: ${pollingError instanceof Error ? pollingError.message : String(pollingError)}`, }, ], - _meta: { - ...(pollingResponseMeta || {}), - "io.modelcontextprotocol/related-task": { taskId }, - }, + isError: true, }; setToolResult(latestToolResult); - // Refresh tasks list to show progress - void listTasks(); + taskCompleted = true; } - } catch (pollingError) { - console.error("Error polling task status:", pollingError); - latestToolResult = { - content: [ - { - type: "text", - text: `Error polling task status: ${pollingError instanceof Error ? pollingError.message : String(pollingError)}`, - }, - ], - isError: true, - }; - setToolResult(latestToolResult); - taskCompleted = true; + } + } finally { + setIsPollingTask(false); + if (pollingAbortControllerRef.current === abortController) { + pollingAbortControllerRef.current = null; } } - setIsPollingTask(false); // Clear any validation errors since tool execution completed setErrors((prev) => ({ ...prev, tools: null })); return latestToolResult; @@ -1587,6 +1675,7 @@ const App = () => { }} toolResult={toolResult} isPollingTask={isPollingTask} + cancelPolling={cancelToolPolling} nextCursor={nextToolCursor} error={errors.tools} resourceContent={resourceContentMap} diff --git a/client/src/__tests__/App.toolsAppsPrefill.test.tsx b/client/src/__tests__/App.toolsAppsPrefill.test.tsx index 9239f29db..4e81c65e5 100644 --- a/client/src/__tests__/App.toolsAppsPrefill.test.tsx +++ b/client/src/__tests__/App.toolsAppsPrefill.test.tsx @@ -1,4 +1,10 @@ -import { fireEvent, render, screen, waitFor } from "@testing-library/react"; +import { + act, + fireEvent, + render, + screen, + waitFor, +} from "@testing-library/react"; import "@testing-library/jest-dom"; import App from "../App"; import { useConnection } from "../lib/hooks/useConnection"; @@ -56,6 +62,7 @@ jest.mock("../utils/configUtils", () => ({ getInitialArgs: jest.fn(() => ""), initializeInspectorConfig: jest.fn(() => ({})), saveInspectorConfig: jest.fn(), + getMCPTaskTtl: jest.fn(() => 60000), })); jest.mock("../lib/hooks/useDraggablePane", () => ({ @@ -135,6 +142,9 @@ jest.mock("../components/ToolsTab", () => ({ default: ({ listTools, callTool, + toolResult, + isPollingTask, + cancelPolling, }: { listTools: () => void; callTool: ( @@ -143,6 +153,11 @@ jest.mock("../components/ToolsTab", () => ({ metadata?: Record, runAsTask?: boolean, ) => Promise; + toolResult?: { + content?: Array<{ type: string; text: string }>; + } | null; + isPollingTask?: boolean; + cancelPolling?: () => void; }) => (
+ + {isPollingTask && ( + + )} +
{JSON.stringify(toolResult ?? null)}
), })); @@ -281,4 +310,100 @@ describe("App - Tools to Apps prefilled handoff", () => { expect(screen.getByTestId("apps-prefilled")).toHaveTextContent("null"); }); }); + + it("can cancel task polling from ToolsTab", async () => { + const makeRequest = jest.fn( + async (request: { method: string; params?: { name?: string } }) => { + if (request.method === "tools/call") { + return { + task: { + taskId: "task-1", + status: "pending", + pollInterval: 10000, + }, + }; + } + + if (request.method === "tasks/get") { + return { + taskId: "task-1", + status: "running", + pollInterval: 10000, + lastUpdatedAt: new Date().toISOString(), + }; + } + + throw new Error(`Unexpected method: ${request.method}`); + }, + ); + + mockUseConnection.mockReturnValue({ + connectionStatus: "connected", + serverCapabilities: { + tools: { listChanged: true }, + tasks: { requests: { tools: { call: true } } }, + }, + serverImplementation: null, + mcpClient: { + request: jest.fn(), + notification: jest.fn(), + close: jest.fn(), + } as unknown as Client, + requestHistory: [], + clearRequestHistory: jest.fn(), + makeRequest, + cancelTask: jest.fn(), + listTasks: jest.fn(), + sendNotification: jest.fn(), + handleCompletion: jest.fn(), + completionsSupported: false, + connect: jest.fn(), + disconnect: jest.fn(), + } as ReturnType); + + render(); + + await act(async () => { + fireEvent.click( + screen.getByRole("button", { name: /mock run task tool/i }), + ); + await Promise.resolve(); + }); + + await waitFor(() => { + expect( + makeRequest.mock.calls.some( + ([request]) => request.method === "tools/call", + ), + ).toBe(true); + }); + + await waitFor(() => { + expect(screen.getByTestId("tools-result")).toHaveTextContent( + "Task created: task-1", + ); + }); + + await waitFor(() => { + expect( + screen.getByRole("button", { name: /mock cancel polling/i }), + ).toBeInTheDocument(); + }); + + fireEvent.click( + screen.getByRole("button", { name: /mock cancel polling/i }), + ); + + await waitFor(() => { + expect(screen.getByTestId("tools-result")).toHaveTextContent( + "Polling cancelled for task task-1", + ); + }); + + expect( + makeRequest.mock.calls.some( + ([request]) => request.method === "tasks/get", + ), + ).toBe(false); + }); }); diff --git a/client/src/components/ToolsTab.tsx b/client/src/components/ToolsTab.tsx index febea1d8f..a8d14e06b 100644 --- a/client/src/components/ToolsTab.tsx +++ b/client/src/components/ToolsTab.tsx @@ -35,6 +35,7 @@ import { AlertCircle, Copy, CheckCheck, + XCircle, } from "lucide-react"; import { useEffect, useState, useRef } from "react"; import ListPane from "./ListPane"; @@ -173,6 +174,7 @@ const ToolsTab = ({ setSelectedTool, toolResult, isPollingTask, + cancelPolling, nextCursor, error, resourceContent, @@ -192,6 +194,7 @@ const ToolsTab = ({ setSelectedTool: (tool: Tool | null) => void; toolResult: CompatibilityCallToolResult | null; isPollingTask?: boolean; + cancelPolling?: () => void; nextCursor: ListToolsResult["nextCursor"]; error: string | null; resourceContent: Record; @@ -855,6 +858,12 @@ const ToolsTab = ({ )} + {isPollingTask && cancelPolling && ( + + )}