diff --git a/src/browser/components/Settings/sections/ProvidersSection.tsx b/src/browser/components/Settings/sections/ProvidersSection.tsx index f6a3ac7eec..d70a8b5314 100644 --- a/src/browser/components/Settings/sections/ProvidersSection.tsx +++ b/src/browser/components/Settings/sections/ProvidersSection.tsx @@ -836,14 +836,18 @@ export function ProvidersSection() { setCopilotLoginError(null); }; - // Cancel any in-flight Copilot login if the component unmounts + // Cancel any in-flight Copilot login if the component unmounts. + // Use a ref for api so this only fires on true unmount, not on api identity + // changes (e.g. reconnection), which would spuriously cancel active flows. + const apiRef = useRef(api); + apiRef.current = api; useEffect(() => { return () => { - if (copilotFlowIdRef.current && api) { - void api.copilotOauth.cancelDeviceFlow({ flowId: copilotFlowIdRef.current }); + if (copilotFlowIdRef.current && apiRef.current) { + void apiRef.current.copilotOauth.cancelDeviceFlow({ flowId: copilotFlowIdRef.current }); } }; - }, [api]); + }, []); const clearCopilotCredentials = () => { if (!api) return; @@ -868,9 +872,21 @@ export function ProvidersSection() { return; } + // Best-effort: cancel any in-progress flow before starting a new one. + if (copilotFlowIdRef.current) { + void api.copilotOauth.cancelDeviceFlow({ flowId: copilotFlowIdRef.current }); + copilotFlowIdRef.current = null; + setCopilotFlowId(null); + } + const startResult = await api.copilotOauth.startDeviceFlow(); - if (attempt !== copilotLoginAttemptRef.current) return; + if (attempt !== copilotLoginAttemptRef.current) { + if (startResult.success) { + void api.copilotOauth.cancelDeviceFlow({ flowId: startResult.data.flowId }); + } + return; + } if (!startResult.success) { setCopilotLoginStatus("error"); diff --git a/src/node/services/codexOauthService.ts b/src/node/services/codexOauthService.ts index a0e2fcebc3..0483b2ca0d 100644 --- a/src/node/services/codexOauthService.ts +++ b/src/node/services/codexOauthService.ts @@ -1,5 +1,4 @@ import * as crypto from "crypto"; -import * as http from "http"; import type { Result } from "@/common/types/result"; import { Err, Ok } from "@/common/types/result"; import { @@ -24,26 +23,14 @@ import { parseCodexOauthAuth, type CodexOauthAuth, } from "@/node/utils/codexOauthAuth"; +import { createDeferred } from "@/node/utils/oauthUtils"; +import { startLoopbackServer } from "@/node/utils/oauthLoopbackServer"; +import { OAuthFlowManager } from "@/node/utils/oauthFlowManager"; const DEFAULT_DESKTOP_TIMEOUT_MS = 5 * 60 * 1000; const DEFAULT_DEVICE_TIMEOUT_MS = 15 * 60 * 1000; const COMPLETED_FLOW_TTL_MS = 60 * 1000; -interface DesktopFlow { - flowId: string; - authorizeUrl: string; - redirectUri: string; - codeVerifier: string; - - server: http.Server; - timeout: ReturnType; - cleanupTimeout: ReturnType | null; - - resultPromise: Promise>; - resolveResult: (result: Result) => void; - settled: boolean; -} - interface DeviceFlow { flowId: string; deviceAuthId: string; @@ -63,20 +50,6 @@ interface DeviceFlow { settled: boolean; } -function closeServer(server: http.Server): Promise { - return new Promise((resolve) => { - server.close(() => resolve()); - }); -} - -function createDeferred(): { promise: Promise; resolve: (value: T) => void } { - let resolve!: (value: T) => void; - const promise = new Promise((res) => { - resolve = res; - }); - return { promise, resolve }; -} - function sha256Base64Url(value: string): string { return crypto.createHash("sha256").update(value).digest().toString("base64url"); } @@ -89,17 +62,6 @@ function isPlainObject(value: unknown): value is Record { return Boolean(value) && typeof value === "object" && !Array.isArray(value); } -function isLoopbackAddress(address: string | undefined): boolean { - if (!address) return false; - - // Node may normalize IPv4 loopback to an IPv6-mapped address. - if (address === "::ffff:127.0.0.1") { - return true; - } - - return address === "127.0.0.1" || address === "::1"; -} - function parseOptionalNumber(value: unknown): number | null { if (typeof value === "number" && Number.isFinite(value)) { return value; @@ -133,7 +95,7 @@ function isInvalidGrantError(errorText: string): boolean { } export class CodexOauthService { - private readonly desktopFlows = new Map(); + private readonly desktopFlows = new OAuthFlowManager(); private readonly deviceFlows = new Map(); private readonly refreshMutex = new AsyncMutex(); @@ -159,81 +121,75 @@ export class CodexOauthService { const codeVerifier = randomBase64Url(); const codeChallenge = sha256Base64Url(codeVerifier); - - const { promise: resultPromise, resolve: resolveResult } = - createDeferred>(); - const redirectUri = CODEX_OAUTH_BROWSER_REDIRECT_URI; - const server = http.createServer((req, res) => { - if (!isLoopbackAddress(req.socket.remoteAddress)) { - res.statusCode = 403; - res.end("Forbidden"); - return; - } - - const reqUrl = req.url ?? "/"; - const url = new URL(reqUrl, "http://localhost"); - - if (req.method !== "GET" || url.pathname !== "/auth/callback") { - res.statusCode = 404; - res.end("Not found"); - return; - } - - const state = url.searchParams.get("state"); - if (!state || state !== flowId) { - res.statusCode = 400; - res.setHeader("Content-Type", "text/html"); - res.end("

Invalid OAuth state

"); - return; - } - - const code = url.searchParams.get("code"); - const error = url.searchParams.get("error"); - const errorDescription = url.searchParams.get("error_description") ?? undefined; - - void this.handleDesktopCallback({ - flowId, - code, - error, - errorDescription, - res, - }); - }); - + let loopback: Awaited>; try { - await new Promise((resolve, reject) => { - server.once("error", reject); - server.listen(1455, "localhost", () => resolve()); + loopback = await startLoopbackServer({ + port: 1455, + host: "localhost", + callbackPath: "/auth/callback", + validateLoopback: true, + expectedState: flowId, + deferSuccessResponse: true, }); } catch (error) { const message = error instanceof Error ? error.message : String(error); return Err(`Failed to start OAuth callback listener: ${message}`); } + const resultDeferred = createDeferred>(); + + this.desktopFlows.register(flowId, { + server: loopback.server, + resultDeferred, + // Keep server-side timeout tied to flow lifetime so abandoned flows + // (e.g. callers that never invoke waitForDesktopFlow) still self-clean. + timeoutHandle: setTimeout(() => { + void this.desktopFlows.finish(flowId, Err("Timed out waiting for OAuth callback")); + }, DEFAULT_DESKTOP_TIMEOUT_MS), + }); + const authorizeUrl = buildCodexAuthorizeUrl({ redirectUri, state: flowId, codeChallenge, }); - const timeout = setTimeout(() => { - void this.finishDesktopFlow(flowId, Err("Timed out waiting for OAuth callback")); - }, DEFAULT_DESKTOP_TIMEOUT_MS); + // Background task: wait for the loopback callback, exchange code for tokens, + // then finish the flow. Races against resultDeferred (which resolves on + // cancel/timeout) so the task exits cleanly if the flow is cancelled. + void (async () => { + const callbackResult = await Promise.race([ + loopback.result, + resultDeferred.promise.then(() => null), + ]); - this.desktopFlows.set(flowId, { - flowId, - authorizeUrl, - redirectUri, - codeVerifier, - server, - timeout, - cleanupTimeout: null, - resultPromise, - resolveResult, - settled: false, - }); + // null means the flow was finished externally (cancel/timeout). + if (!callbackResult) return; + + if (!callbackResult.success) { + await this.desktopFlows.finish(flowId, Err(callbackResult.error)); + return; + } + + const exchangeResult = await this.handleDesktopCallbackAndExchange({ + flowId, + redirectUri, + codeVerifier, + code: callbackResult.data.code, + error: null, + errorDescription: undefined, + }); + + if (exchangeResult.success) { + loopback.sendSuccessResponse(); + } else { + loopback.sendFailureResponse(exchangeResult.error); + } + + await this.desktopFlows.finish(flowId, exchangeResult); + })(); log.debug(`[Codex OAuth] Desktop flow started (flowId=${flowId})`); @@ -244,40 +200,14 @@ export class CodexOauthService { flowId: string, opts?: { timeoutMs?: number } ): Promise> { - const flow = this.desktopFlows.get(flowId); - if (!flow) { - return Err("OAuth flow not found"); - } - - const timeoutMs = opts?.timeoutMs ?? DEFAULT_DESKTOP_TIMEOUT_MS; - - let timeoutHandle: ReturnType | null = null; - const timeoutPromise = new Promise>((resolve) => { - timeoutHandle = setTimeout(() => { - resolve(Err("Timed out waiting for OAuth callback")); - }, timeoutMs); - }); - - const result = await Promise.race([flow.resultPromise, timeoutPromise]); - - if (timeoutHandle !== null) { - clearTimeout(timeoutHandle); - } - - if (!result.success) { - // Ensure listener is closed on timeout/errors. - void this.finishDesktopFlow(flowId, result); - } - - return result; + return this.desktopFlows.waitFor(flowId, opts?.timeoutMs ?? DEFAULT_DESKTOP_TIMEOUT_MS); } async cancelDesktopFlow(flowId: string): Promise { - const flow = this.desktopFlows.get(flowId); - if (!flow) return; - - log.debug(`[Codex OAuth] Desktop flow cancelled (flowId=${flowId})`); - await this.finishDesktopFlow(flowId, Err("OAuth flow cancelled")); + if (this.desktopFlows.has(flowId)) { + log.debug(`[Codex OAuth] Desktop flow cancelled (flowId=${flowId})`); + } + await this.desktopFlows.cancel(flowId); } async startDeviceFlow(): Promise< @@ -414,19 +344,11 @@ export class CodexOauthService { } async dispose(): Promise { - const desktopIds = [...this.desktopFlows.keys()]; - await Promise.all(desktopIds.map((id) => this.finishDesktopFlow(id, Err("App shutting down")))); + await this.desktopFlows.shutdownAll(); const deviceIds = [...this.deviceFlows.keys()]; await Promise.all(deviceIds.map((id) => this.finishDeviceFlow(id, Err("App shutting down")))); - for (const flow of this.desktopFlows.values()) { - clearTimeout(flow.timeout); - if (flow.cleanupTimeout !== null) { - clearTimeout(flow.cleanupTimeout); - } - } - for (const flow of this.deviceFlows.values()) { clearTimeout(flow.timeout); if (flow.cleanupTimeout !== null) { @@ -434,7 +356,6 @@ export class CodexOauthService { } } - this.desktopFlows.clear(); this.deviceFlows.clear(); } @@ -458,67 +379,6 @@ export class CodexOauthService { return result; } - private async handleDesktopCallback(input: { - flowId: string; - code: string | null; - error: string | null; - errorDescription?: string; - res: http.ServerResponse; - }): Promise { - const flow = this.desktopFlows.get(input.flowId); - if (!flow || flow.settled) { - input.res.statusCode = 409; - input.res.setHeader("Content-Type", "text/html"); - input.res.end("

OAuth flow already completed

"); - return; - } - - log.debug(`[Codex OAuth] Desktop callback received (flowId=${input.flowId})`); - - const result = await this.handleDesktopCallbackAndExchange({ - flowId: input.flowId, - redirectUri: flow.redirectUri, - codeVerifier: flow.codeVerifier, - code: input.code, - error: input.error, - errorDescription: input.errorDescription, - }); - - const title = result.success ? "Login complete" : "Login failed"; - const description = result.success - ? "You can return to Mux. You may now close this tab." - : escapeHtml(result.error); - - input.res.setHeader("Content-Type", "text/html"); - if (!result.success) { - input.res.statusCode = 400; - } - - input.res.end(` - - - - - - ${title} - - -

${title}

-

${description}

- - -`); - - await this.finishDesktopFlow(input.flowId, result); - } - private async handleDesktopCallbackAndExchange(input: { flowId: string; redirectUri: string; @@ -845,28 +705,6 @@ export class CodexOauthService { } } - private async finishDesktopFlow(flowId: string, result: Result): Promise { - const flow = this.desktopFlows.get(flowId); - if (!flow || flow.settled) return; - - flow.settled = true; - clearTimeout(flow.timeout); - - try { - flow.resolveResult(result); - await closeServer(flow.server); - } catch (error) { - log.debug("[Codex OAuth] Failed to close OAuth callback listener:", error); - } finally { - if (flow.cleanupTimeout !== null) { - clearTimeout(flow.cleanupTimeout); - } - flow.cleanupTimeout = setTimeout(() => { - this.desktopFlows.delete(flowId); - }, COMPLETED_FLOW_TTL_MS); - } - } - private finishDeviceFlow(flowId: string, result: Result): Promise { const flow = this.deviceFlows.get(flowId); if (!flow || flow.settled) { @@ -920,12 +758,3 @@ async function sleepWithAbort(ms: number, signal: AbortSignal): Promise { signal.addEventListener("abort", onAbort); }); } - -function escapeHtml(input: string): string { - return input - .replaceAll("&", "&") - .replaceAll("<", "<") - .replaceAll(">", ">") - .replaceAll('"', """) - .replaceAll("'", "'"); -} diff --git a/src/node/services/copilotOauthService.ts b/src/node/services/copilotOauthService.ts index 1b95da5817..f7d40966f9 100644 --- a/src/node/services/copilotOauthService.ts +++ b/src/node/services/copilotOauthService.ts @@ -4,6 +4,7 @@ import { Err, Ok } from "@/common/types/result"; import type { ProviderService } from "@/node/services/providerService"; import type { WindowService } from "@/node/services/windowService"; import { log } from "@/node/services/log"; +import { createDeferred } from "@/node/utils/oauthUtils"; const GITHUB_COPILOT_CLIENT_ID = "Ov23liCVKFN3jOo9R7HS"; const SCOPE = "read:user"; @@ -71,11 +72,8 @@ export class CopilotOauthService { return Err("Invalid response from GitHub device code endpoint"); } - // Create deferred promise - let resolveResult!: (result: Result) => void; - const resultPromise = new Promise>((resolve) => { - resolveResult = resolve; - }); + const { promise: resultPromise, resolve: resolveResult } = + createDeferred>(); const timeout = setTimeout(() => { void this.finishFlow(flowId, Err("Timed out waiting for GitHub authorization")); @@ -146,6 +144,9 @@ export class CopilotOauthService { const flow = this.flows.get(flowId); if (!flow) return; + // Skip if the flow already completed (e.g. unmount cleanup after success) + if (flow.cancelled) return; + log.debug(`Copilot OAuth device flow cancelled (flowId=${flowId})`); this.finishFlow(flowId, Err("Device flow cancelled")); } diff --git a/src/node/services/mcpOauthService.ts b/src/node/services/mcpOauthService.ts index e7105f23a3..f104f37e09 100644 --- a/src/node/services/mcpOauthService.ts +++ b/src/node/services/mcpOauthService.ts @@ -21,6 +21,7 @@ import type { } from "@/common/types/mcpOauth"; import { stripTrailingSlashes } from "@/node/utils/pathUtils"; import { MutexMap } from "@/node/utils/concurrency/mutexMap"; +import { closeServer, createDeferred, renderOAuthCallbackHtml } from "@/node/utils/oauthUtils"; const DEFAULT_DESKTOP_TIMEOUT_MS = 5 * 60 * 1000; const DEFAULT_SERVER_TIMEOUT_MS = 10 * 60 * 1000; @@ -114,20 +115,6 @@ interface DesktopFlow extends OAuthFlowBase { type ServerFlow = OAuthFlowBase; -function closeServer(server: http.Server): Promise { - return new Promise((resolve) => { - server.close(() => resolve()); - }); -} - -function createDeferred(): { promise: Promise; resolve: (value: T) => void } { - let resolve!: (value: T) => void; - const promise = new Promise((res) => { - resolve = res; - }); - return { promise, resolve }; -} - function isPlainObject(value: unknown): value is Record { return Boolean(value) && typeof value === "object" && !Array.isArray(value); } @@ -1132,29 +1119,20 @@ export class McpOauthService { errorDescription: input.errorDescription, }); - const title = result.success ? "Login complete" : "Login failed"; - const description = result.success - ? "You can return to Mux. You may now close this tab." - : escapeHtml(result.error); - input.res.setHeader("Content-Type", "text/html"); if (!result.success) { input.res.statusCode = 400; } - input.res.end(` - - - - - - ${title} - - -

${title}

-

${description}

- -`); + input.res.end( + renderOAuthCallbackHtml({ + title: result.success ? "Login complete" : "Login failed", + message: result.success + ? "You can return to Mux. You may now close this tab." + : result.error, + success: result.success, + }) + ); await this.finishDesktopFlow(input.flowId, result); } @@ -1481,12 +1459,3 @@ export class McpOauthService { }); } } - -function escapeHtml(input: string): string { - return input - .replaceAll("&", "&") - .replaceAll("<", "<") - .replaceAll(">", ">") - .replaceAll('"', """) - .replaceAll("'", "'"); -} diff --git a/src/node/services/muxGatewayOauthService.test.ts b/src/node/services/muxGatewayOauthService.test.ts new file mode 100644 index 0000000000..1ca525012c --- /dev/null +++ b/src/node/services/muxGatewayOauthService.test.ts @@ -0,0 +1,233 @@ +import http from "node:http"; +import { describe, it, expect, beforeEach, afterEach } from "bun:test"; + +import type { Result } from "@/common/types/result"; +import { Ok } from "@/common/types/result"; +import { + MUX_GATEWAY_AUTHORIZE_URL, + MUX_GATEWAY_EXCHANGE_URL, +} from "@/common/constants/muxGatewayOAuth"; +import type { ProviderService } from "@/node/services/providerService"; +import type { WindowService } from "@/node/services/windowService"; +import { MuxGatewayOauthService } from "./muxGatewayOauthService"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/** Simple GET helper that returns { status, body }. */ +async function httpGet(url: string): Promise<{ status: number; body: string }> { + return new Promise((resolve, reject) => { + http + .get(url, (res) => { + let body = ""; + res.on("data", (chunk: Buffer | Uint8Array) => { + body += Buffer.from(chunk).toString(); + }); + res.on("end", () => resolve({ status: res.statusCode ?? 0, body })); + }) + .on("error", reject); + }); +} + +/** Build a mock JSON response. */ +function jsonResponse(body: Record, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +/** + * Helper to mock globalThis.fetch without needing the `preconnect` property. + */ +function mockFetch( + fn: (input: RequestInfo | URL, init?: RequestInit) => Response | Promise +): void { + globalThis.fetch = Object.assign(fn, { + preconnect: (_url: string | URL) => { + // no-op in tests + }, + }) as typeof fetch; +} + +interface MockDeps { + setConfigCalls: Array<{ provider: string; keyPath: string[]; value: string }>; + setConfigResult: Result; + focusCalls: number; +} + +function createMockDeps(): MockDeps { + return { + setConfigCalls: [], + setConfigResult: Ok(undefined), + focusCalls: 0, + }; +} + +function createMockProviderService(deps: MockDeps): Pick { + return { + setConfig: (provider: string, keyPath: string[], value: string): Result => { + deps.setConfigCalls.push({ provider, keyPath, value }); + return deps.setConfigResult; + }, + }; +} + +function createMockWindowService(deps: MockDeps): Pick { + return { + focusMainWindow: () => { + deps.focusCalls++; + }, + }; +} + +function createService(deps: MockDeps): MuxGatewayOauthService { + return new MuxGatewayOauthService( + createMockProviderService(deps) as ProviderService, + createMockWindowService(deps) as WindowService + ); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("MuxGatewayOauthService", () => { + let deps: MockDeps; + let service: MuxGatewayOauthService; + const originalFetch = globalThis.fetch; + + beforeEach(() => { + deps = createMockDeps(); + service = createService(deps); + }); + + afterEach(async () => { + globalThis.fetch = originalFetch; + await service.dispose(); + }); + + describe("startDesktopFlow", () => { + it("returns flowId, authorizeUrl, and redirectUri", async () => { + const result = await service.startDesktopFlow(); + expect(result.success).toBe(true); + if (!result.success) return; + + expect(result.data.flowId).toBeTruthy(); + expect(result.data.redirectUri).toMatch(/^http:\/\/127\.0\.0\.1:\d+\/callback$/); + + const authorizeUrl = new URL(result.data.authorizeUrl); + expect(`${authorizeUrl.origin}${authorizeUrl.pathname}`).toBe(MUX_GATEWAY_AUTHORIZE_URL); + expect(authorizeUrl.searchParams.get("response_type")).toBe("code"); + expect(authorizeUrl.searchParams.get("state")).toBe(result.data.flowId); + expect(authorizeUrl.searchParams.get("redirect_uri")).toBe(result.data.redirectUri); + + await service.cancelDesktopFlow(result.data.flowId); + }); + }); + + describe("desktop callback flow", () => { + it("callback with code + successful exchange resolves waitFor success and renders success HTML", async () => { + let capturedUrl = ""; + let capturedBody = ""; + + mockFetch((input, init) => { + capturedUrl = + typeof input === "string" ? input : input instanceof URL ? input.toString() : input.url; + capturedBody = + typeof init?.body === "string" + ? init.body + : init?.body instanceof URLSearchParams + ? init.body.toString() + : ""; + return jsonResponse({ access_token: "gateway-token" }); + }); + + const startResult = await service.startDesktopFlow(); + expect(startResult.success).toBe(true); + if (!startResult.success) return; + + const flowId = startResult.data.flowId; + const callbackUrl = `${startResult.data.redirectUri}?state=${flowId}&code=ok-code`; + + const waitPromise = service.waitForDesktopFlow(flowId, { timeoutMs: 5000 }); + const callbackResponsePromise = httpGet(callbackUrl); + + const [waitResult, callbackResponse] = await Promise.all([ + waitPromise, + callbackResponsePromise, + ]); + + expect(waitResult).toEqual(Ok(undefined)); + expect(callbackResponse.status).toBe(200); + expect(callbackResponse.body).toContain("Login complete"); + + expect(capturedUrl).toBe(MUX_GATEWAY_EXCHANGE_URL); + expect(capturedBody).toContain("code=ok-code"); + + expect(deps.setConfigCalls).toEqual([ + { + provider: "mux-gateway", + keyPath: ["couponCode"], + value: "gateway-token", + }, + ]); + expect(deps.focusCalls).toBe(1); + }); + + it("callback with code + failed exchange resolves waitFor error and renders failure HTML", async () => { + let releaseExchange!: () => void; + const exchangeStarted = new Promise((resolveStarted) => { + const exchangeBlocked = new Promise((resolveBlocked) => { + releaseExchange = () => resolveBlocked(); + }); + + mockFetch(async () => { + resolveStarted(); + await exchangeBlocked; + return new Response("upstream exploded", { status: 500 }); + }); + }); + + const startResult = await service.startDesktopFlow(); + expect(startResult.success).toBe(true); + if (!startResult.success) return; + + const flowId = startResult.data.flowId; + const callbackUrl = `${startResult.data.redirectUri}?state=${flowId}&code=bad-code`; + + const waitPromise = service.waitForDesktopFlow(flowId, { timeoutMs: 5000 }); + const callbackResponsePromise = httpGet(callbackUrl); + + await exchangeStarted; + + const callbackState = await Promise.race([ + callbackResponsePromise.then(() => "settled" as const), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 100)), + ]); + expect(callbackState).toBe("pending"); + + releaseExchange(); + + const [waitResult, callbackResponse] = await Promise.all([ + waitPromise, + callbackResponsePromise, + ]); + + expect(waitResult.success).toBe(false); + if (!waitResult.success) { + expect(waitResult.error).toContain("Mux Gateway exchange failed (500)"); + } + + expect(callbackResponse.status).toBe(400); + expect(callbackResponse.body).toContain("Login failed"); + expect(callbackResponse.body).toContain( + "Mux Gateway exchange failed (500): upstream exploded" + ); + + expect(deps.setConfigCalls).toHaveLength(0); + expect(deps.focusCalls).toBe(0); + }); + }); +}); diff --git a/src/node/services/muxGatewayOauthService.ts b/src/node/services/muxGatewayOauthService.ts index 1b52ebb80e..7d08e9c991 100644 --- a/src/node/services/muxGatewayOauthService.ts +++ b/src/node/services/muxGatewayOauthService.ts @@ -1,5 +1,4 @@ import * as crypto from "crypto"; -import * as http from "http"; import type { Result } from "@/common/types/result"; import { Err, Ok } from "@/common/types/result"; import { @@ -10,44 +9,20 @@ import { import type { ProviderService } from "@/node/services/providerService"; import type { WindowService } from "@/node/services/windowService"; import { log } from "@/node/services/log"; +import { createDeferred, renderOAuthCallbackHtml } from "@/node/utils/oauthUtils"; +import { startLoopbackServer } from "@/node/utils/oauthLoopbackServer"; +import { OAuthFlowManager } from "@/node/utils/oauthFlowManager"; const DEFAULT_DESKTOP_TIMEOUT_MS = 5 * 60 * 1000; const DEFAULT_SERVER_TIMEOUT_MS = 10 * 60 * 1000; -const COMPLETED_DESKTOP_FLOW_TTL_MS = 60 * 1000; - -interface DesktopFlow { - flowId: string; - authorizeUrl: string; - redirectUri: string; - server: http.Server; - timeout: ReturnType; - cleanupTimeout: ReturnType | null; - resultPromise: Promise>; - resolveResult: (result: Result) => void; - settled: boolean; -} interface ServerFlow { state: string; expiresAtMs: number; } -function closeServer(server: http.Server): Promise { - return new Promise((resolve) => { - server.close(() => resolve()); - }); -} - -function createDeferred(): { promise: Promise; resolve: (value: T) => void } { - let resolve!: (value: T) => void; - const promise = new Promise((res) => { - resolve = res; - }); - return { promise, resolve }; -} - export class MuxGatewayOauthService { - private readonly desktopFlows = new Map(); + private readonly desktopFlows = new OAuthFlowManager(); private readonly serverFlows = new Map(); constructor( @@ -59,118 +34,91 @@ export class MuxGatewayOauthService { Result<{ flowId: string; authorizeUrl: string; redirectUri: string }, string> > { const flowId = crypto.randomUUID(); + const resultDeferred = createDeferred>(); - const { promise: resultPromise, resolve: resolveResult } = - createDeferred>(); - - const server = http.createServer((req, res) => { - const reqUrl = req.url ?? "/"; - const url = new URL(reqUrl, "http://localhost"); - - if (req.method !== "GET" || url.pathname !== "/callback") { - res.statusCode = 404; - res.end("Not found"); - return; - } - - const state = url.searchParams.get("state"); - if (!state || state !== flowId) { - res.statusCode = 400; - res.setHeader("Content-Type", "text/html"); - res.end("

Invalid OAuth state

"); - return; - } - - const code = url.searchParams.get("code"); - const error = url.searchParams.get("error"); - const errorDescription = url.searchParams.get("error_description") ?? undefined; - - void this.handleDesktopCallback({ - flowId, - code, - error, - errorDescription, - res, - }); - }); - + let loopback; try { - await new Promise((resolve, reject) => { - server.once("error", reject); - server.listen(0, "127.0.0.1", () => resolve()); + loopback = await startLoopbackServer({ + expectedState: flowId, + deferSuccessResponse: true, + renderHtml: (r) => + renderOAuthCallbackHtml({ + title: r.success ? "Login complete" : "Login failed", + message: r.success + ? "You can return to Mux. You may now close this tab." + : (r.error ?? "Unknown error"), + success: r.success, + extraHead: + '\n ', + }), }); } catch (error) { const message = error instanceof Error ? error.message : String(error); return Err(`Failed to start OAuth callback listener: ${message}`); } - const address = server.address(); - if (!address || typeof address === "string") { - return Err("Failed to determine OAuth callback listener port"); - } + const authorizeUrl = buildAuthorizeUrl({ redirectUri: loopback.redirectUri, state: flowId }); - const redirectUri = `http://127.0.0.1:${address.port}/callback`; - const authorizeUrl = buildAuthorizeUrl({ redirectUri, state: flowId }); - - const timeout = setTimeout(() => { - void this.finishDesktopFlow(flowId, Err("Timed out waiting for OAuth callback")); - }, DEFAULT_DESKTOP_TIMEOUT_MS); - - this.desktopFlows.set(flowId, { - flowId, - authorizeUrl, - redirectUri, - server, - timeout, - cleanupTimeout: null, - resultPromise, - resolveResult, - settled: false, + this.desktopFlows.register(flowId, { + server: loopback.server, + resultDeferred, + // Keep server-side timeout tied to flow lifetime so abandoned flows + // (e.g. callers that never invoke waitForDesktopFlow) still self-clean. + timeoutHandle: setTimeout(() => { + void this.desktopFlows.finish(flowId, Err("Timed out waiting for OAuth callback")); + }, DEFAULT_DESKTOP_TIMEOUT_MS), }); + // Background task: await loopback callback, do token exchange, finish flow. + // Race against resultDeferred so that if the flow is cancelled/timed out + // externally, this task exits cleanly instead of dangling on loopback.result. + void (async () => { + const callbackOrDone = await Promise.race([ + loopback.result, + resultDeferred.promise.then((): null => null), + ]); + + // Flow was already finished externally (timeout or cancel). + if (callbackOrDone === null) return; + + log.debug(`Mux Gateway OAuth callback received (flowId=${flowId})`); + + let result: Result; + if (callbackOrDone.success) { + result = await this.handleCallbackAndExchange({ + state: callbackOrDone.data.state, + code: callbackOrDone.data.code, + error: null, + }); + + if (result.success) { + loopback.sendSuccessResponse(); + } else { + loopback.sendFailureResponse(result.error); + } + } else { + result = Err(`Mux Gateway OAuth error: ${callbackOrDone.error}`); + } + + await this.desktopFlows.finish(flowId, result); + })(); + log.debug(`Mux Gateway OAuth desktop flow started (flowId=${flowId})`); - return Ok({ flowId, authorizeUrl, redirectUri }); + return Ok({ flowId, authorizeUrl, redirectUri: loopback.redirectUri }); } async waitForDesktopFlow( flowId: string, opts?: { timeoutMs?: number } ): Promise> { - const flow = this.desktopFlows.get(flowId); - if (!flow) { - return Err("OAuth flow not found"); - } - - const timeoutMs = opts?.timeoutMs ?? DEFAULT_DESKTOP_TIMEOUT_MS; - - let timeoutHandle: ReturnType | null = null; - const timeoutPromise = new Promise>((resolve) => { - timeoutHandle = setTimeout(() => { - resolve(Err("Timed out waiting for OAuth callback")); - }, timeoutMs); - }); - - const result = await Promise.race([flow.resultPromise, timeoutPromise]); - - if (timeoutHandle !== null) { - clearTimeout(timeoutHandle); - } - - if (!result.success) { - // Ensure listener is closed on timeout/errors. - void this.finishDesktopFlow(flowId, result); - } - - return result; + return this.desktopFlows.waitFor(flowId, opts?.timeoutMs ?? DEFAULT_DESKTOP_TIMEOUT_MS); } async cancelDesktopFlow(flowId: string): Promise { - const flow = this.desktopFlows.get(flowId); - if (!flow) return; - + if (!this.desktopFlows.has(flowId)) return; log.debug(`Mux Gateway OAuth desktop flow cancelled (flowId=${flowId})`); - await this.finishDesktopFlow(flowId, Err("OAuth flow cancelled")); + await this.desktopFlows.cancel(flowId); } startServerFlow(input: { redirectUri: string }): { authorizeUrl: string; state: string } { @@ -228,118 +176,10 @@ export class MuxGatewayOauthService { } async dispose(): Promise { - // Best-effort: cancel all in-flight flows. - const flowIds = [...this.desktopFlows.keys()]; - await Promise.all(flowIds.map((id) => this.finishDesktopFlow(id, Err("App shutting down")))); - - for (const flow of this.desktopFlows.values()) { - clearTimeout(flow.timeout); - if (flow.cleanupTimeout !== null) { - clearTimeout(flow.cleanupTimeout); - } - } - - this.desktopFlows.clear(); + await this.desktopFlows.shutdownAll(); this.serverFlows.clear(); } - private async handleDesktopCallback(input: { - flowId: string; - code: string | null; - error: string | null; - errorDescription?: string; - res: http.ServerResponse; - }): Promise { - const flow = this.desktopFlows.get(input.flowId); - if (!flow || flow.settled) { - input.res.statusCode = 409; - input.res.setHeader("Content-Type", "text/html"); - input.res.end("

OAuth flow already completed

"); - return; - } - - log.debug(`Mux Gateway OAuth callback received (flowId=${input.flowId})`); - - const result = await this.handleCallbackAndExchange({ - state: input.flowId, - code: input.code, - error: input.error, - errorDescription: input.errorDescription, - }); - - const title = result.success ? "Login complete" : "Login failed"; - const description = result.success - ? "You can return to Mux. You may now close this tab." - : escapeHtml(result.error); - - const html = ` - - - - - - - ${title} - - - -
- - -
-
-
-

${title}

-

${description}

- ${ - result.success - ? '

Mux should now be in the foreground. You can close this tab.

' - : '

You can close this tab.

' - } -
-
-
-
- - - -`; - - input.res.setHeader("Content-Type", "text/html"); - if (!result.success) { - input.res.statusCode = 400; - } - - input.res.end(html); - - await this.finishDesktopFlow(input.flowId, result); - } - private async handleCallbackAndExchange(input: { state: string; code: string | null; @@ -406,38 +246,4 @@ export class MuxGatewayOauthService { return Err(`Mux Gateway exchange failed: ${message}`); } } - - private async finishDesktopFlow(flowId: string, result: Result): Promise { - const flow = this.desktopFlows.get(flowId); - if (!flow || flow.settled) return; - - flow.settled = true; - clearTimeout(flow.timeout); - - try { - flow.resolveResult(result); - - // Stop accepting new connections. - await closeServer(flow.server); - } catch (error) { - log.debug("Failed to close OAuth callback listener:", error); - } finally { - // Keep the completed flow around briefly so callers can still await the result. - if (flow.cleanupTimeout !== null) { - clearTimeout(flow.cleanupTimeout); - } - flow.cleanupTimeout = setTimeout(() => { - this.desktopFlows.delete(flowId); - }, COMPLETED_DESKTOP_FLOW_TTL_MS); - } - } -} - -function escapeHtml(input: string): string { - return input - .replaceAll("&", "&") - .replaceAll("<", "<") - .replaceAll(">", ">") - .replaceAll('"', """) - .replaceAll("'", "'"); } diff --git a/src/node/services/muxGovernorOauthService.test.ts b/src/node/services/muxGovernorOauthService.test.ts new file mode 100644 index 0000000000..77c9d258b7 --- /dev/null +++ b/src/node/services/muxGovernorOauthService.test.ts @@ -0,0 +1,254 @@ +import http from "node:http"; +import { describe, it, expect, beforeEach, afterEach } from "bun:test"; + +import type { Result } from "@/common/types/result"; +import { Ok } from "@/common/types/result"; +import type { ProjectsConfig } from "@/common/types/project"; +import type { Config } from "@/node/config"; +import type { PolicyService } from "@/node/services/policyService"; +import type { WindowService } from "@/node/services/windowService"; +import { MuxGovernorOauthService } from "./muxGovernorOauthService"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/** Simple GET helper that returns { status, body }. */ +async function httpGet(url: string): Promise<{ status: number; body: string }> { + return new Promise((resolve, reject) => { + http + .get(url, (res) => { + let body = ""; + res.on("data", (chunk: Buffer | Uint8Array) => { + body += Buffer.from(chunk).toString(); + }); + res.on("end", () => resolve({ status: res.statusCode ?? 0, body })); + }) + .on("error", reject); + }); +} + +/** Build a mock JSON response. */ +function jsonResponse(body: Record, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +/** + * Helper to mock globalThis.fetch without needing the `preconnect` property. + */ +function mockFetch( + fn: (input: RequestInfo | URL, init?: RequestInit) => Response | Promise +): void { + globalThis.fetch = Object.assign(fn, { + preconnect: (_url: string | URL) => { + // no-op in tests + }, + }) as typeof fetch; +} + +interface MockDeps { + configState: ProjectsConfig; + editConfigCalls: number; + focusCalls: number; + refreshCalls: number; + refreshResult: Result; +} + +function createMockDeps(): MockDeps { + return { + configState: { projects: new Map() }, + editConfigCalls: 0, + focusCalls: 0, + refreshCalls: 0, + refreshResult: Ok(undefined), + }; +} + +function createMockConfig(deps: MockDeps): Pick { + return { + editConfig: (fn: (config: ProjectsConfig) => ProjectsConfig) => { + deps.configState = fn(deps.configState); + deps.editConfigCalls++; + return Promise.resolve(); + }, + }; +} + +function createMockWindowService(deps: MockDeps): Pick { + return { + focusMainWindow: () => { + deps.focusCalls++; + }, + }; +} + +function createMockPolicyService(deps: MockDeps): Pick { + return { + refreshNow: () => { + deps.refreshCalls++; + return Promise.resolve(deps.refreshResult); + }, + }; +} + +function createService(deps: MockDeps): MuxGovernorOauthService { + return new MuxGovernorOauthService( + createMockConfig(deps) as Config, + createMockWindowService(deps) as WindowService, + createMockPolicyService(deps) as PolicyService + ); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("MuxGovernorOauthService", () => { + let deps: MockDeps; + let service: MuxGovernorOauthService; + const originalFetch = globalThis.fetch; + + beforeEach(() => { + deps = createMockDeps(); + service = createService(deps); + }); + + afterEach(async () => { + globalThis.fetch = originalFetch; + await service.dispose(); + }); + + describe("startDesktopFlow", () => { + it("returns flowId, authorizeUrl, and redirectUri", async () => { + const result = await service.startDesktopFlow({ + governorOrigin: "https://governor.example.com/admin?from=test", + }); + expect(result.success).toBe(true); + if (!result.success) return; + + expect(result.data.flowId).toBeTruthy(); + expect(result.data.redirectUri).toMatch(/^http:\/\/127\.0\.0\.1:\d+\/callback$/); + + const authorizeUrl = new URL(result.data.authorizeUrl); + expect(`${authorizeUrl.origin}${authorizeUrl.pathname}`).toBe( + "https://governor.example.com/oauth2/authorize" + ); + expect(authorizeUrl.searchParams.get("response_type")).toBe("code"); + expect(authorizeUrl.searchParams.get("state")).toBe(result.data.flowId); + expect(authorizeUrl.searchParams.get("redirect_uri")).toBe(result.data.redirectUri); + + await service.cancelDesktopFlow(result.data.flowId); + }); + }); + + describe("desktop callback flow", () => { + it("callback with code + successful exchange resolves waitFor success and renders success HTML", async () => { + let capturedUrl = ""; + let capturedBody = ""; + + mockFetch((input, init) => { + capturedUrl = + typeof input === "string" ? input : input instanceof URL ? input.toString() : input.url; + capturedBody = + typeof init?.body === "string" + ? init.body + : init?.body instanceof URLSearchParams + ? init.body.toString() + : ""; + return jsonResponse({ access_token: "governor-token" }); + }); + + const startResult = await service.startDesktopFlow({ + governorOrigin: "https://governor.example.com", + }); + expect(startResult.success).toBe(true); + if (!startResult.success) return; + + const flowId = startResult.data.flowId; + const callbackUrl = `${startResult.data.redirectUri}?state=${flowId}&code=ok-code`; + + const waitPromise = service.waitForDesktopFlow(flowId, { timeoutMs: 5000 }); + const callbackResponsePromise = httpGet(callbackUrl); + + const [waitResult, callbackResponse] = await Promise.all([ + waitPromise, + callbackResponsePromise, + ]); + + expect(waitResult).toEqual(Ok(undefined)); + expect(callbackResponse.status).toBe(200); + expect(callbackResponse.body).toContain("Enrollment complete"); + + expect(capturedUrl).toBe("https://governor.example.com/api/v1/oauth2/exchange"); + expect(capturedBody).toContain("code=ok-code"); + + expect(deps.editConfigCalls).toBe(1); + expect(deps.configState.muxGovernorUrl).toBe("https://governor.example.com"); + expect(deps.configState.muxGovernorToken).toBe("governor-token"); + expect(deps.focusCalls).toBe(1); + expect(deps.refreshCalls).toBe(1); + }); + + it("callback with code + failed exchange resolves waitFor error and renders failure HTML", async () => { + let releaseExchange!: () => void; + const exchangeStarted = new Promise((resolveStarted) => { + const exchangeBlocked = new Promise((resolveBlocked) => { + releaseExchange = () => resolveBlocked(); + }); + + mockFetch(async () => { + resolveStarted(); + await exchangeBlocked; + return new Response("governor unavailable", { status: 502 }); + }); + }); + + const startResult = await service.startDesktopFlow({ + governorOrigin: "https://governor.example.com", + }); + expect(startResult.success).toBe(true); + if (!startResult.success) return; + + const flowId = startResult.data.flowId; + const callbackUrl = `${startResult.data.redirectUri}?state=${flowId}&code=bad-code`; + + const waitPromise = service.waitForDesktopFlow(flowId, { timeoutMs: 5000 }); + const callbackResponsePromise = httpGet(callbackUrl); + + await exchangeStarted; + + const callbackState = await Promise.race([ + callbackResponsePromise.then(() => "settled" as const), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 100)), + ]); + expect(callbackState).toBe("pending"); + + releaseExchange(); + + const [waitResult, callbackResponse] = await Promise.all([ + waitPromise, + callbackResponsePromise, + ]); + + expect(waitResult.success).toBe(false); + if (!waitResult.success) { + expect(waitResult.error).toContain("Mux Governor exchange failed (502)"); + } + + expect(callbackResponse.status).toBe(400); + expect(callbackResponse.body).toContain("Enrollment failed"); + expect(callbackResponse.body).toContain( + "Mux Governor exchange failed (502): governor unavailable" + ); + + expect(deps.editConfigCalls).toBe(0); + expect(deps.focusCalls).toBe(0); + expect(deps.refreshCalls).toBe(0); + expect(deps.configState.muxGovernorUrl).toBeUndefined(); + expect(deps.configState.muxGovernorToken).toBeUndefined(); + }); + }); +}); diff --git a/src/node/services/muxGovernorOauthService.ts b/src/node/services/muxGovernorOauthService.ts index 95282cd47b..3964128f21 100644 --- a/src/node/services/muxGovernorOauthService.ts +++ b/src/node/services/muxGovernorOauthService.ts @@ -7,7 +7,6 @@ */ import * as crypto from "crypto"; -import * as http from "http"; import type { Result } from "@/common/types/result"; import { Err, Ok } from "@/common/types/result"; import { @@ -20,23 +19,12 @@ import type { Config } from "@/node/config"; import type { PolicyService } from "@/node/services/policyService"; import type { WindowService } from "@/node/services/windowService"; import { log } from "@/node/services/log"; +import { createDeferred, renderOAuthCallbackHtml } from "@/node/utils/oauthUtils"; +import { startLoopbackServer } from "@/node/utils/oauthLoopbackServer"; +import { OAuthFlowManager } from "@/node/utils/oauthFlowManager"; const DEFAULT_DESKTOP_TIMEOUT_MS = 5 * 60 * 1000; const DEFAULT_SERVER_TIMEOUT_MS = 10 * 60 * 1000; -const COMPLETED_DESKTOP_FLOW_TTL_MS = 60 * 1000; - -interface DesktopFlow { - flowId: string; - governorOrigin: string; - authorizeUrl: string; - redirectUri: string; - server: http.Server; - timeout: ReturnType; - cleanupTimeout: ReturnType | null; - resultPromise: Promise>; - resolveResult: (result: Result) => void; - settled: boolean; -} interface ServerFlow { state: string; @@ -44,22 +32,8 @@ interface ServerFlow { expiresAtMs: number; } -function closeServer(server: http.Server): Promise { - return new Promise((resolve) => { - server.close(() => resolve()); - }); -} - -function createDeferred(): { promise: Promise; resolve: (value: T) => void } { - let resolve!: (value: T) => void; - const promise = new Promise((res) => { - resolve = res; - }); - return { promise, resolve }; -} - export class MuxGovernorOauthService { - private readonly desktopFlows = new Map(); + private readonly desktopFlows = new OAuthFlowManager(); private readonly serverFlows = new Map(); constructor( @@ -82,125 +56,95 @@ export class MuxGovernorOauthService { const flowId = crypto.randomUUID(); - const { promise: resultPromise, resolve: resolveResult } = - createDeferred>(); - - const server = http.createServer((req, res) => { - const reqUrl = req.url ?? "/"; - const url = new URL(reqUrl, "http://localhost"); - - if (req.method !== "GET" || url.pathname !== "/callback") { - res.statusCode = 404; - res.end("Not found"); - return; - } - - const state = url.searchParams.get("state"); - if (!state || state !== flowId) { - res.statusCode = 400; - res.setHeader("Content-Type", "text/html"); - res.end("

Invalid OAuth state

"); - return; - } - - const code = url.searchParams.get("code"); - const error = url.searchParams.get("error"); - const errorDescription = url.searchParams.get("error_description") ?? undefined; - - void this.handleDesktopCallback({ - flowId, - governorOrigin, - code, - error, - errorDescription, - res, - }); - }); - + let loopback: Awaited>; try { - await new Promise((resolve, reject) => { - server.once("error", reject); - server.listen(0, "127.0.0.1", () => resolve()); + loopback = await startLoopbackServer({ + expectedState: flowId, + deferSuccessResponse: true, + renderHtml: (r) => + renderOAuthCallbackHtml({ + title: r.success ? "Enrollment complete" : "Enrollment failed", + message: r.success + ? "You can return to Mux. You may now close this tab." + : (r.error ?? "Unknown error"), + success: r.success, + }), }); } catch (error) { const message = error instanceof Error ? error.message : String(error); return Err(`Failed to start OAuth callback listener: ${message}`); } - const address = server.address(); - if (!address || typeof address === "string") { - return Err("Failed to determine OAuth callback listener port"); - } - - const redirectUri = `http://127.0.0.1:${address.port}/callback`; const authorizeUrl = buildGovernorAuthorizeUrl({ governorOrigin, - redirectUri, + redirectUri: loopback.redirectUri, state: flowId, }); - const timeout = setTimeout(() => { - void this.finishDesktopFlow(flowId, Err("Timed out waiting for OAuth callback")); - }, DEFAULT_DESKTOP_TIMEOUT_MS); + const resultDeferred = createDeferred>(); - this.desktopFlows.set(flowId, { - flowId, - governorOrigin, - authorizeUrl, - redirectUri, - server, - timeout, - cleanupTimeout: null, - resultPromise, - resolveResult, - settled: false, + this.desktopFlows.register(flowId, { + server: loopback.server, + resultDeferred, + // Keep server-side timeout tied to flow lifetime so abandoned flows + // (e.g. callers that never invoke waitForDesktopFlow) still self-clean. + timeoutHandle: setTimeout(() => { + void this.desktopFlows.finish(flowId, Err("Timed out waiting for OAuth callback")); + }, DEFAULT_DESKTOP_TIMEOUT_MS), }); + // Background task: await loopback callback, do token exchange, finish flow. + // Race against resultDeferred so that if the flow is cancelled/timed out + // externally, this task exits cleanly instead of dangling on loopback.result. + void (async () => { + const callbackOrDone = await Promise.race([ + loopback.result, + resultDeferred.promise.then((): null => null), + ]); + + // Flow was already finished externally (timeout or cancel). + if (callbackOrDone === null) return; + + let result: Result; + if (callbackOrDone.success) { + result = await this.handleCallbackAndExchange({ + state: flowId, + governorOrigin, + code: callbackOrDone.data.code, + error: null, + }); + } else { + result = Err(`Mux Governor OAuth error: ${callbackOrDone.error}`); + } + + // Render the final browser response based on exchange outcome. + if (result.success) { + loopback.sendSuccessResponse(); + } else { + loopback.sendFailureResponse(result.error); + } + + await this.desktopFlows.finish(flowId, result); + })(); + log.debug( `Mux Governor OAuth desktop flow started (flowId=${flowId}, origin=${governorOrigin})` ); - return Ok({ flowId, authorizeUrl, redirectUri }); + return Ok({ flowId, authorizeUrl, redirectUri: loopback.redirectUri }); } async waitForDesktopFlow( flowId: string, opts?: { timeoutMs?: number } ): Promise> { - const flow = this.desktopFlows.get(flowId); - if (!flow) { - return Err("OAuth flow not found"); - } - - const timeoutMs = opts?.timeoutMs ?? DEFAULT_DESKTOP_TIMEOUT_MS; - - let timeoutHandle: ReturnType | null = null; - const timeoutPromise = new Promise>((resolve) => { - timeoutHandle = setTimeout(() => { - resolve(Err("Timed out waiting for OAuth callback")); - }, timeoutMs); - }); - - const result = await Promise.race([flow.resultPromise, timeoutPromise]); - - if (timeoutHandle !== null) { - clearTimeout(timeoutHandle); - } - - if (!result.success) { - // Ensure listener is closed on timeout/errors. - void this.finishDesktopFlow(flowId, result); - } - - return result; + return this.desktopFlows.waitFor(flowId, opts?.timeoutMs ?? DEFAULT_DESKTOP_TIMEOUT_MS); } async cancelDesktopFlow(flowId: string): Promise { - const flow = this.desktopFlows.get(flowId); - if (!flow) return; - + if (!this.desktopFlows.has(flowId)) return; log.debug(`Mux Governor OAuth desktop flow cancelled (flowId=${flowId})`); - await this.finishDesktopFlow(flowId, Err("OAuth flow cancelled")); + await this.desktopFlows.cancel(flowId); } startServerFlow(input: { @@ -278,94 +222,10 @@ export class MuxGovernorOauthService { } async dispose(): Promise { - // Best-effort: cancel all in-flight flows. - const flowIds = [...this.desktopFlows.keys()]; - await Promise.all(flowIds.map((id) => this.finishDesktopFlow(id, Err("App shutting down")))); - - for (const flow of this.desktopFlows.values()) { - clearTimeout(flow.timeout); - if (flow.cleanupTimeout !== null) { - clearTimeout(flow.cleanupTimeout); - } - } - - this.desktopFlows.clear(); + await this.desktopFlows.shutdownAll(); this.serverFlows.clear(); } - private async handleDesktopCallback(input: { - flowId: string; - governorOrigin: string; - code: string | null; - error: string | null; - errorDescription?: string; - res: http.ServerResponse; - }): Promise { - const flow = this.desktopFlows.get(input.flowId); - if (!flow || flow.settled) { - input.res.statusCode = 409; - input.res.setHeader("Content-Type", "text/html"); - input.res.end("

OAuth flow already completed

"); - return; - } - - log.debug(`Mux Governor OAuth callback received (flowId=${input.flowId})`); - - const result = await this.handleCallbackAndExchange({ - state: input.flowId, - governorOrigin: input.governorOrigin, - code: input.code, - error: input.error, - errorDescription: input.errorDescription, - }); - - const title = result.success ? "Enrollment complete" : "Enrollment failed"; - const description = result.success - ? "You can return to Mux. You may now close this tab." - : escapeHtml(result.error); - - const html = ` - - - - - - ${title} - - - -

${title}

-

${description}

- ${ - result.success - ? '

Mux should now be in the foreground. You can close this tab.

' - : '

You can close this tab.

' - } - - -`; - - input.res.setHeader("Content-Type", "text/html"); - if (!result.success) { - input.res.statusCode = 400; - } - - input.res.end(html); - - await this.finishDesktopFlow(input.flowId, result); - } - private async handleCallbackAndExchange(input: { state: string; governorOrigin: string; @@ -447,38 +307,4 @@ export class MuxGovernorOauthService { return Err(`Mux Governor exchange failed: ${message}`); } } - - private async finishDesktopFlow(flowId: string, result: Result): Promise { - const flow = this.desktopFlows.get(flowId); - if (!flow || flow.settled) return; - - flow.settled = true; - clearTimeout(flow.timeout); - - try { - flow.resolveResult(result); - - // Stop accepting new connections. - await closeServer(flow.server); - } catch (error) { - log.debug("Failed to close OAuth callback listener:", error); - } finally { - // Keep the completed flow around briefly so callers can still await the result. - if (flow.cleanupTimeout !== null) { - clearTimeout(flow.cleanupTimeout); - } - flow.cleanupTimeout = setTimeout(() => { - this.desktopFlows.delete(flowId); - }, COMPLETED_DESKTOP_FLOW_TTL_MS); - } - } -} - -function escapeHtml(input: string): string { - return input - .replaceAll("&", "&") - .replaceAll("<", "<") - .replaceAll(">", ">") - .replaceAll('"', """) - .replaceAll("'", "'"); } diff --git a/src/node/utils/oauthFlowManager.test.ts b/src/node/utils/oauthFlowManager.test.ts new file mode 100644 index 0000000000..e4a89b41c9 --- /dev/null +++ b/src/node/utils/oauthFlowManager.test.ts @@ -0,0 +1,293 @@ +import type http from "node:http"; +import { describe, it, expect, beforeEach } from "bun:test"; +import type { Result } from "@/common/types/result"; +import { Ok, Err } from "@/common/types/result"; +import { createDeferred } from "@/node/utils/oauthUtils"; +import type { OAuthFlowEntry } from "./oauthFlowManager"; +import { OAuthFlowManager } from "./oauthFlowManager"; + +// --------------------------------------------------------------------------- +// Mock http.Server +// --------------------------------------------------------------------------- + +/** + * Minimal mock that satisfies the `http.Server` contract used by + * OAuthFlowManager: only `close()` is called (via `closeServer`). + */ +function createMockServer(): http.Server { + const mock = { + close: (cb?: (err?: Error) => void) => { + if (cb) cb(); + return mock; + }, + } as unknown as http.Server; + return mock; +} + +/** Create a fresh OAuthFlowEntry backed by a mock server. */ +function createFlowEntry(server?: http.Server | null): OAuthFlowEntry { + return { + server: server === undefined ? createMockServer() : server, + resultDeferred: createDeferred>(), + timeoutHandle: null, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("OAuthFlowManager", () => { + let manager: OAuthFlowManager; + + beforeEach(() => { + manager = new OAuthFlowManager(); + }); + + // ----------------------------------------------------------------------- + // register / get / has + // ----------------------------------------------------------------------- + + describe("register / get / has", () => { + it("registers and retrieves a flow entry", () => { + const entry = createFlowEntry(); + manager.register("f1", entry); + + expect(manager.has("f1")).toBe(true); + expect(manager.get("f1")).toBe(entry); + }); + + it("cleans up the previous active flow when registering a duplicate flowId", async () => { + let serverClosed = false; + const mockServer = { + close: (cb?: (err?: Error) => void) => { + serverClosed = true; + if (cb) cb(); + return mockServer; + }, + } as unknown as http.Server; + + const oldEntry = createFlowEntry(mockServer); + + let timeoutFired = false; + oldEntry.timeoutHandle = setTimeout(() => { + timeoutFired = true; + }, 10); + + manager.register("f1", oldEntry); + + const newEntry = createFlowEntry(null); + manager.register("f1", newEntry); + + const oldResult = await oldEntry.resultDeferred.promise; + expect(oldResult.success).toBe(false); + if (!oldResult.success) { + expect(oldResult.error).toContain("replaced"); + } + + await new Promise((resolve) => setTimeout(resolve, 25)); + expect(timeoutFired).toBe(false); + expect(serverClosed).toBe(true); + + expect(manager.get("f1")).toBe(newEntry); + }); + + it("returns undefined for unregistered flows", () => { + expect(manager.has("nope")).toBe(false); + expect(manager.get("nope")).toBeUndefined(); + }); + }); + + // ----------------------------------------------------------------------- + // waitFor + // ----------------------------------------------------------------------- + + describe("waitFor", () => { + it("resolves when the deferred resolves with Ok", async () => { + const entry = createFlowEntry(); + manager.register("f1", entry); + + // Resolve the deferred in the background. + setTimeout(() => entry.resultDeferred.resolve(Ok(undefined)), 5); + + const result = await manager.waitFor("f1", 5_000); + expect(result.success).toBe(true); + }); + + it("times out when the deferred does not resolve in time", async () => { + const entry = createFlowEntry(); + manager.register("f1", entry); + + // Use a very short timeout so the test is fast. + const result = await manager.waitFor("f1", 20); + + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error).toContain("Timed out"); + } + }); + + it("returns Err when flow ID is not found", async () => { + const result = await manager.waitFor("missing", 1_000); + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error).toContain("not found"); + } + }); + + it("returns the completed result for late waiters", async () => { + const entry = createFlowEntry(); + manager.register("f1", entry); + + // Simulate the common race where the OAuth callback finishes before the + // frontend begins waiting. + const finishPromise = manager.finish("f1", Ok(undefined)); + + const result = await manager.waitFor("f1", 1_000); + expect(result).toEqual(Ok(undefined)); + expect(manager.has("f1")).toBe(false); + + await finishPromise; + }); + + it("expires completed results after a short TTL", async () => { + const ttlMs = 30; + const ttlManager = new OAuthFlowManager(ttlMs); + + const entry = createFlowEntry(); + ttlManager.register("f1", entry); + + await ttlManager.finish("f1", Ok(undefined)); + + const withinTtl = await ttlManager.waitFor("f1", 1_000); + expect(withinTtl).toEqual(Ok(undefined)); + expect(ttlManager.has("f1")).toBe(false); + + await new Promise((resolve) => setTimeout(resolve, ttlMs + 50)); + + const afterTtl = await ttlManager.waitFor("f1", 1_000); + expect(afterTtl.success).toBe(false); + if (!afterTtl.success) { + expect(afterTtl.error).toContain("not found"); + } + }); + }); + + // ----------------------------------------------------------------------- + // cancel + // ----------------------------------------------------------------------- + + describe("cancel", () => { + it("resolves the deferred with Err and removes the flow", async () => { + const entry = createFlowEntry(); + manager.register("f1", entry); + + await manager.cancel("f1"); + + const result = await entry.resultDeferred.promise; + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error).toContain("cancelled"); + } + expect(manager.has("f1")).toBe(false); + }); + + it("is a no-op for non-existent flows", async () => { + // Should not throw. + await manager.cancel("nope"); + }); + }); + + // ----------------------------------------------------------------------- + // finish + // ----------------------------------------------------------------------- + + describe("finish", () => { + it("resolves the deferred, removes the flow, and closes the server", async () => { + let serverClosed = false; + const mockServer = { + close: (cb?: (err?: Error) => void) => { + serverClosed = true; + if (cb) cb(); + return mockServer; + }, + } as unknown as http.Server; + + const entry = createFlowEntry(mockServer); + manager.register("f1", entry); + + await manager.finish("f1", Ok(undefined)); + + const result = await entry.resultDeferred.promise; + expect(result.success).toBe(true); + expect(manager.has("f1")).toBe(false); + expect(serverClosed).toBe(true); + }); + + it("is idempotent — second call is a no-op", async () => { + const entry = createFlowEntry(); + manager.register("f1", entry); + + await manager.finish("f1", Ok(undefined)); + // Second call should not throw. + await manager.finish("f1", Err("should be ignored")); + + const result = await entry.resultDeferred.promise; + // Should still be the first result. + expect(result.success).toBe(true); + }); + + it("works when server is null (device-code flows)", async () => { + const entry = createFlowEntry(null); + manager.register("f1", entry); + + await manager.finish("f1", Ok(undefined)); + + const result = await entry.resultDeferred.promise; + expect(result.success).toBe(true); + expect(manager.has("f1")).toBe(false); + }); + + it("clears the timeout handle", async () => { + const entry = createFlowEntry(); + // Simulate a stored timeout handle. + // eslint-disable-next-line @typescript-eslint/no-empty-function + entry.timeoutHandle = setTimeout(() => {}, 60_000); + manager.register("f1", entry); + + await manager.finish("f1", Ok(undefined)); + + expect(manager.has("f1")).toBe(false); + }); + }); + + // ----------------------------------------------------------------------- + // shutdownAll + // ----------------------------------------------------------------------- + + describe("shutdownAll", () => { + it("finishes all registered flows", async () => { + const entry1 = createFlowEntry(); + const entry2 = createFlowEntry(); + manager.register("f1", entry1); + manager.register("f2", entry2); + + await manager.shutdownAll(); + + expect(manager.has("f1")).toBe(false); + expect(manager.has("f2")).toBe(false); + + const r1 = await entry1.resultDeferred.promise; + const r2 = await entry2.resultDeferred.promise; + expect(r1.success).toBe(false); + expect(r2.success).toBe(false); + if (!r1.success) expect(r1.error).toContain("shutting down"); + if (!r2.success) expect(r2.error).toContain("shutting down"); + }); + + it("is a no-op when there are no flows", async () => { + // Should not throw. + await manager.shutdownAll(); + }); + }); +}); diff --git a/src/node/utils/oauthFlowManager.ts b/src/node/utils/oauthFlowManager.ts new file mode 100644 index 0000000000..a4a07df398 --- /dev/null +++ b/src/node/utils/oauthFlowManager.ts @@ -0,0 +1,214 @@ +import type http from "node:http"; +import type { Result } from "@/common/types/result"; +import { Err } from "@/common/types/result"; +import { closeServer } from "@/node/utils/oauthUtils"; +import { log } from "@/node/services/log"; + +/** + * Shared desktop OAuth flow lifecycle manager. + * + * Four OAuth services (Gateway, Governor, Codex, MCP) track in-flight desktop + * flows with an identical `Map` + `waitFor`/`cancel`/ + * `finish`/`shutdownAll` pattern. This class extracts that shared lifecycle + * so each service can delegate flow bookkeeping here. + */ + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface OAuthFlowEntry { + /** The loopback HTTP server (null for device-code flows). */ + server: http.Server | null; + /** Deferred that resolves with the final flow result. */ + resultDeferred: { + promise: Promise>; + resolve: (value: Result) => void; + }; + /** Handle for the server-side timeout (set at registration time by the caller). */ + timeoutHandle: ReturnType | null; +} + +interface CompletedOAuthFlowEntry { + result: Result; + cleanupTimeout: ReturnType | null; +} + +const DEFAULT_COMPLETED_FLOW_TTL_MS = 60_000; + +// --------------------------------------------------------------------------- +// Manager +// --------------------------------------------------------------------------- + +export class OAuthFlowManager { + private readonly flows = new Map(); + private readonly completed = new Map(); + + constructor(private readonly completedFlowTtlMs = DEFAULT_COMPLETED_FLOW_TTL_MS) {} + + /** Register a new in-flight flow. */ + register(flowId: string, entry: OAuthFlowEntry): void { + // If a flow ID is re-used before the completed-flow TTL expires, ensure the + // old cleanup timer can't delete the new result. + this.clearCompleted(flowId); + + const existing = this.flows.get(flowId); + if (existing) { + // Defensive: avoid silently overwriting an active flow entry. + // + // This can happen if a provider accidentally re-uses a flow ID, or if a + // stale in-flight start attempt races a newer one. Best-effort cleanup to + // avoid leaking timeouts, deferred promises, or loopback servers. + log.debug(`[OAuthFlowManager] Duplicate register — replacing active flow (flowId=${flowId})`); + + if (existing.timeoutHandle !== null) { + clearTimeout(existing.timeoutHandle); + } + + try { + existing.resultDeferred.resolve(Err("OAuth flow replaced")); + + if (existing.server) { + // Stop accepting new connections. + void closeServer(existing.server); + } + } catch (error) { + log.debug("Failed to clean up replaced OAuth flow:", error); + } + } + + this.flows.set(flowId, entry); + } + + private clearCompleted(flowId: string): void { + const existing = this.completed.get(flowId); + if (!existing) return; + + if (existing.cleanupTimeout !== null) { + clearTimeout(existing.cleanupTimeout); + } + + this.completed.delete(flowId); + } + + /** Get a flow entry by ID, or undefined if not found. */ + get(flowId: string): OAuthFlowEntry | undefined { + return this.flows.get(flowId); + } + + /** Check whether a flow exists. */ + has(flowId: string): boolean { + return this.flows.has(flowId); + } + + /** + * Wait for a flow to complete with a caller-facing timeout race. + * + * Mirrors the `waitForDesktopFlow` pattern shared across all four services: + * - Creates a local timeout promise for this wait call only. + * - Races that local timeout against `flow.resultDeferred.promise`. + * - Registration-time timeout remains on `flow.timeoutHandle` and is cleared in `finish`. + * - On any error result (caller timeout or flow error), calls `finish` for shared cleanup. + */ + async waitFor(flowId: string, timeoutMs: number): Promise> { + const flow = this.flows.get(flowId); + if (!flow) { + const completed = this.completed.get(flowId); + if (completed) { + return completed.result; + } + + return Err("OAuth flow not found"); + } + + let timeoutHandle: ReturnType | null = null; + const timeoutPromise = new Promise>((resolve) => { + timeoutHandle = setTimeout(() => { + resolve(Err("Timed out waiting for OAuth callback")); + }, timeoutMs); + }); + + const result = await Promise.race([flow.resultDeferred.promise, timeoutPromise]); + + if (timeoutHandle !== null) { + clearTimeout(timeoutHandle); + } + + if (!result.success) { + // Ensure listener is closed on timeout/errors. + void this.finish(flowId, result); + } + + return result; + } + + /** + * Cancel a flow — resolves the deferred with an error and cleans up. + * + * Mirrors the `cancelDesktopFlow` pattern. + */ + async cancel(flowId: string): Promise { + const flow = this.flows.get(flowId); + if (!flow) return; + + await this.finish(flowId, Err("OAuth flow cancelled")); + } + + /** + * Finish a flow: resolve the deferred, clear the timeout, close the server, + * and remove the entry from the map. + * + * Idempotent — no-op if the flow was already removed. Mirrors the + * `finishDesktopFlow` pattern. + */ + async finish(flowId: string, result: Result): Promise { + const flow = this.flows.get(flowId); + if (!flow) return; + + // Remove from map first to make re-entrant calls no-ops. + this.flows.delete(flowId); + + if (flow.timeoutHandle !== null) { + clearTimeout(flow.timeoutHandle); + } + + // Preserve the final result briefly so late waiters can still retrieve it. + // + // Old per-service DesktopFlow implementations kept completed flows around for + // ~60s (cleanupTimeout) to avoid a race where the OAuth callback finishes + // before the frontend begins `waitFor`-ing. + this.clearCompleted(flowId); + const cleanupTimeout = setTimeout(() => { + this.completed.delete(flowId); + }, this.completedFlowTtlMs); + + // Don't keep the node process alive just to delete old completed entries. + if (typeof cleanupTimeout !== "number") { + cleanupTimeout.unref?.(); + } + + this.completed.set(flowId, { result, cleanupTimeout }); + + try { + flow.resultDeferred.resolve(result); + + if (flow.server) { + // Stop accepting new connections. + await closeServer(flow.server); + } + } catch (error) { + log.debug("Failed to close OAuth callback listener:", error); + } + } + + /** + * Shut down all active flows — resolves each with an error. + * + * Mirrors the `dispose` pattern where services iterate all flows + * and finish them with `Err("App shutting down")`. + */ + async shutdownAll(): Promise { + const flowIds = [...this.flows.keys()]; + await Promise.all(flowIds.map((id) => this.finish(id, Err("App shutting down")))); + } +} diff --git a/src/node/utils/oauthLoopbackServer.test.ts b/src/node/utils/oauthLoopbackServer.test.ts new file mode 100644 index 0000000000..d349aad54a --- /dev/null +++ b/src/node/utils/oauthLoopbackServer.test.ts @@ -0,0 +1,308 @@ +import http from "node:http"; +import { describe, it, expect, afterEach } from "bun:test"; +import type { LoopbackServer } from "./oauthLoopbackServer"; +import { startLoopbackServer } from "./oauthLoopbackServer"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/** Extract the port from a redirectUri like http://127.0.0.1:12345/callback */ +function portFromUri(uri: string): number { + return new URL(uri).port ? Number(new URL(uri).port) : 80; +} + +/** Simple GET helper that returns { status, body }. */ +async function httpGet(url: string): Promise<{ status: number; body: string }> { + return new Promise((resolve, reject) => { + http + .get(url, (res) => { + let body = ""; + res.on("data", (chunk: Buffer) => { + body += chunk.toString(); + }); + res.on("end", () => resolve({ status: res.statusCode ?? 0, body })); + }) + .on("error", reject); + }); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("startLoopbackServer", () => { + let loopback: LoopbackServer | undefined; + + afterEach(async () => { + // Ensure the server is always cleaned up. + if (loopback?.server.listening) { + await loopback.close(); + } + loopback = undefined; + }); + + it("starts a server and provides a redirectUri with the listening port", async () => { + loopback = await startLoopbackServer({ expectedState: "s1" }); + + expect(loopback.redirectUri).toMatch(/^http:\/\/127\.0\.0\.1:\d+\/callback$/); + const port = portFromUri(loopback.redirectUri); + expect(port).toBeGreaterThan(0); + expect(loopback.server.listening).toBe(true); + }); + + it("honors host when building redirectUri", async () => { + loopback = await startLoopbackServer({ expectedState: "s1", host: "localhost" }); + + const redirectUrl = new URL(loopback.redirectUri); + expect(redirectUrl.hostname).toBe("localhost"); + expect(loopback.redirectUri).toMatch(/^http:\/\/localhost:\d+\/callback$/); + }); + + it("returns 403 for non-loopback requests when validateLoopback is enabled", async () => { + loopback = await startLoopbackServer({ expectedState: "s1", validateLoopback: true }); + + // `validateLoopback` is based on `req.socket.remoteAddress`. It's tricky to + // force a non-loopback remoteAddress with bun's http client, so invoke the + // request handler directly with a mocked request/response. + const handler = loopback.server.listeners("request")[0] as unknown as ( + req: http.IncomingMessage, + res: http.ServerResponse + ) => void; + + let body = ""; + const headers: Record = {}; + const res = { + statusCode: 0, + setHeader: (name: string, value: unknown) => { + headers[name] = value; + }, + end: (chunk?: unknown) => { + if (typeof chunk === "string") { + body = chunk; + return; + } + if (Buffer.isBuffer(chunk)) { + body = chunk.toString(); + return; + } + body = ""; + }, + }; + + const req = { + method: "GET", + url: "/callback?state=s1&code=c1", + socket: { remoteAddress: "192.0.2.1" }, + }; + + handler(req as unknown as http.IncomingMessage, res as unknown as http.ServerResponse); + + expect(res.statusCode).toBe(403); + expect(body).toContain("Forbidden"); + + const resultState = await Promise.race([ + loopback.result.then(() => "settled" as const), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)), + ]); + expect(resultState).toBe("pending"); + + await loopback.cancel(); + }); + + it("resolves with Ok({code, state}) on a valid callback", async () => { + loopback = await startLoopbackServer({ expectedState: "state123" }); + + const callbackUrl = `${loopback.redirectUri}?state=state123&code=authcode456`; + const res = await httpGet(callbackUrl); + + expect(res.status).toBe(200); + expect(res.body).toContain(""); + + const result = await loopback.result; + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.code).toBe("authcode456"); + expect(result.data.state).toBe("state123"); + } + }); + + it("returns 400 on state mismatch but keeps waiting for a valid callback", async () => { + loopback = await startLoopbackServer({ expectedState: "good" }); + + const badCallbackUrl = `${loopback.redirectUri}?state=bad&code=c`; + const badRes = await httpGet(badCallbackUrl); + + expect(badRes.status).toBe(400); + expect(badRes.body).toContain("Invalid OAuth state"); + + const resultState = await Promise.race([ + loopback.result.then(() => "settled" as const), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)), + ]); + expect(resultState).toBe("pending"); + + const goodCallbackUrl = `${loopback.redirectUri}?state=good&code=authcode456`; + const goodRes = await httpGet(goodCallbackUrl); + + expect(goodRes.status).toBe(200); + + const result = await loopback.result; + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.code).toBe("authcode456"); + expect(result.data.state).toBe("good"); + } + }); + + it("resolves with Err when provider returns an error param", async () => { + loopback = await startLoopbackServer({ expectedState: "s1" }); + + const callbackUrl = `${loopback.redirectUri}?state=s1&error=access_denied&error_description=User+denied`; + const res = await httpGet(callbackUrl); + + expect(res.status).toBe(400); + + const result = await loopback.result; + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error).toContain("access_denied"); + expect(result.error).toContain("User denied"); + } + }); + + it("resolves with Err when code is missing", async () => { + loopback = await startLoopbackServer({ expectedState: "s1" }); + + const callbackUrl = `${loopback.redirectUri}?state=s1`; + const res = await httpGet(callbackUrl); + + expect(res.status).toBe(400); + + const result = await loopback.result; + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error).toContain("Missing authorization code"); + } + }); + + it("returns 404 for the wrong path", async () => { + loopback = await startLoopbackServer({ expectedState: "s1" }); + + const port = portFromUri(loopback.redirectUri); + const res = await httpGet(`http://127.0.0.1:${port}/wrong`); + + expect(res.status).toBe(404); + expect(res.body).toContain("Not found"); + }); + + it("cancel resolves result with Err and closes the server", async () => { + loopback = await startLoopbackServer({ expectedState: "s1" }); + expect(loopback.server.listening).toBe(true); + + await loopback.cancel(); + + const result = await loopback.result; + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error).toContain("cancelled"); + } + expect(loopback.server.listening).toBe(false); + }); + + it("uses a custom callbackPath", async () => { + loopback = await startLoopbackServer({ + expectedState: "s1", + callbackPath: "/oauth/done", + }); + + expect(loopback.redirectUri).toContain("/oauth/done"); + + const callbackUrl = `${loopback.redirectUri}?state=s1&code=c1`; + const res = await httpGet(callbackUrl); + expect(res.status).toBe(200); + + const result = await loopback.result; + expect(result.success).toBe(true); + }); + + it("defers success response until sendSuccessResponse is called", async () => { + loopback = await startLoopbackServer({ + expectedState: "s1", + deferSuccessResponse: true, + }); + + const callbackUrl = `${loopback.redirectUri}?state=s1&code=c1`; + const responsePromise = httpGet(callbackUrl); + + const callbackResult = await loopback.result; + expect(callbackResult.success).toBe(true); + + const responseState = await Promise.race([ + responsePromise.then(() => "settled" as const), + new Promise<"pending">((resolve) => setTimeout(() => resolve("pending"), 25)), + ]); + expect(responseState).toBe("pending"); + + loopback.sendSuccessResponse(); + + const response = await responsePromise; + expect(response.status).toBe(200); + expect(response.body).toContain("Login complete"); + }); + + it("can send deferred failure response after receiving a valid code", async () => { + loopback = await startLoopbackServer({ + expectedState: "s1", + deferSuccessResponse: true, + }); + + const callbackUrl = `${loopback.redirectUri}?state=s1&code=c1`; + const responsePromise = httpGet(callbackUrl); + + const callbackResult = await loopback.result; + expect(callbackResult.success).toBe(true); + + loopback.sendFailureResponse("Token exchange failed"); + + const response = await responsePromise; + expect(response.status).toBe(400); + expect(response.body).toContain("Token exchange failed"); + }); + + it("sends a cancellation response if deferred callback is still pending during close", async () => { + loopback = await startLoopbackServer({ + expectedState: "s1", + deferSuccessResponse: true, + }); + + const callbackUrl = `${loopback.redirectUri}?state=s1&code=c1`; + const responsePromise = httpGet(callbackUrl); + + const callbackResult = await loopback.result; + expect(callbackResult.success).toBe(true); + + await loopback.close(); + + const response = await responsePromise; + expect(response.status).toBe(400); + expect(response.body).toContain("OAuth flow cancelled"); + }); + + it("uses a custom renderHtml function", async () => { + const customHtml = "Custom!"; + loopback = await startLoopbackServer({ + expectedState: "s1", + renderHtml: () => customHtml, + }); + + const callbackUrl = `${loopback.redirectUri}?state=s1&code=c1`; + const res = await httpGet(callbackUrl); + + expect(res.status).toBe(200); + expect(res.body).toBe(customHtml); + + const result = await loopback.result; + expect(result.success).toBe(true); + }); +}); diff --git a/src/node/utils/oauthLoopbackServer.ts b/src/node/utils/oauthLoopbackServer.ts new file mode 100644 index 0000000000..a1a3291a91 --- /dev/null +++ b/src/node/utils/oauthLoopbackServer.ts @@ -0,0 +1,294 @@ +import http from "node:http"; +import type { Result } from "@/common/types/result"; +import { Err, Ok } from "@/common/types/result"; +import { closeServer, createDeferred, renderOAuthCallbackHtml } from "@/node/utils/oauthUtils"; + +/** + * Shared loopback OAuth callback server. + * + * Four OAuth services (Gateway, Governor, Codex, MCP) spin up a local HTTP + * server to receive the authorization code redirect. The pattern is identical + * across all four — this module extracts that into a single reusable utility. + */ + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface LoopbackServerOptions { + /** Port to listen on. 0 = random (default). Codex uses 1455. */ + port?: number; + /** Host to bind to. Default: "127.0.0.1" */ + host?: string; + /** Path to match for the OAuth callback. Default: "/callback" */ + callbackPath?: string; + /** Expected state parameter value for CSRF validation. */ + expectedState: string; + /** Whether to validate that remoteAddress is a loopback address. Default: false. Codex uses true. */ + validateLoopback?: boolean; + /** Custom HTML renderer. If not provided, uses renderOAuthCallbackHtml with generic branding. */ + renderHtml?: (result: { success: boolean; error?: string }) => string; + /** + * Defer success-page rendering until the caller finishes token exchange. + * + * When true, a valid callback resolves `result` but does not immediately + * respond to the browser tab. The caller must later call + * `sendSuccessResponse()` or `sendFailureResponse()`. + */ + deferSuccessResponse?: boolean; +} + +export interface LoopbackCallbackResult { + code: string; + state: string; +} + +export interface LoopbackServer { + /** The full redirect URI (http://{host}:{port}{callbackPath}). */ + redirectUri: string; + /** The underlying HTTP server (needed by OAuthFlowManager for cleanup). */ + server: http.Server; + /** + * Resolves when a valid callback is received (state matches `expectedState`). + * + * Note: Requests with an invalid/missing `state` respond 400 but are ignored + * so unrelated localhost probes/scanners can't break an in-flight OAuth flow. + */ + result: Promise>; + /** Cancel and close the server. */ + cancel: () => Promise; + /** Close the server. */ + close: () => Promise; + /** Send a deferred success page to the callback browser tab. */ + sendSuccessResponse: () => void; + /** Send a deferred failure page to the callback browser tab. */ + sendFailureResponse: (error: string) => void; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/** + * Check whether an address string is a loopback address. + * Node may normalize IPv4 loopback to an IPv6-mapped address. + * + * Extracted from codexOauthService.ts where validateLoopback is used. + */ +function isLoopbackAddress(address: string | undefined): boolean { + if (!address) return false; + + // Node may normalize IPv4 loopback to an IPv6-mapped address. + if (address === "::ffff:127.0.0.1") { + return true; + } + + return address === "127.0.0.1" || address === "::1"; +} + +/** + * Format a `server.listen()` host value for use inside a URL. + * + * - IPv6 literals must be wrapped in brackets (e.g. "::1" -> "[::1]") + * - Zone identifiers must be percent-encoded ("%" -> "%25") + */ +function hostForRedirectUri(rawHost: string): string { + const host = rawHost.startsWith("[") && rawHost.endsWith("]") ? rawHost.slice(1, -1) : rawHost; + + if (host.includes(":")) { + // RFC 6874: zone identifiers use "%25" in URIs. + return `[${host.replaceAll("%", "%25")}]`; + } + + return host; +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +/** + * Start a loopback HTTP server to receive an OAuth authorization code callback. + * + * Pattern extracted from the `http.createServer` blocks in Gateway, Governor, + * Codex, and MCP OAuth services. The server: + * + * 1. Optionally validates the remote address is loopback (Codex). + * 2. Matches only GET requests on `callbackPath`. + * 3. Validates the `state` query parameter against `expectedState`. + * 4. Extracts `code` from the query string. + * 5. Responds with HTML (success or error) unless success is deferred. + * 6. Resolves the result deferred — the caller then performs token exchange + * and calls `close()`. + * + * The server does NOT close itself after responding — the caller decides when + * to close (matching the existing pattern where services call `closeServer` + * after token exchange). + */ +export async function startLoopbackServer(options: LoopbackServerOptions): Promise { + const port = options.port ?? 0; + const host = options.host ?? "127.0.0.1"; + // `server.listen()` expects IPv6 hosts without brackets, but callers may pass + // "[::1]" since that's what URLs require. + const listenHost = host.startsWith("[") && host.endsWith("]") ? host.slice(1, -1) : host; + const callbackPath = options.callbackPath ?? "/callback"; + const validateLoopback = options.validateLoopback ?? false; + const deferSuccessResponse = options.deferSuccessResponse ?? false; + + const deferred = createDeferred>(); + let deferredSuccessResponse: http.ServerResponse | null = null; + + const render = + options.renderHtml ?? + ((r: { success: boolean; error?: string }) => + renderOAuthCallbackHtml({ + title: r.success ? "Login complete" : "Login failed", + message: r.success + ? "You can return to Mux. You may now close this tab." + : (r.error ?? "Unknown error"), + success: r.success, + })); + + const clearDeferredSuccessResponse = (response: http.ServerResponse) => { + if (deferredSuccessResponse === response) { + deferredSuccessResponse = null; + } + }; + + const sendDeferredResponse = (result: { success: boolean; error?: string }) => { + const pendingResponse = deferredSuccessResponse; + if (!pendingResponse || pendingResponse.writableEnded || pendingResponse.destroyed) { + deferredSuccessResponse = null; + return; + } + + deferredSuccessResponse = null; + pendingResponse.statusCode = result.success ? 200 : 400; + pendingResponse.setHeader("Content-Type", "text/html"); + pendingResponse.end(render(result)); + }; + + const sendSuccessResponse = () => { + sendDeferredResponse({ success: true }); + }; + + const sendFailureResponse = (error: string) => { + sendDeferredResponse({ success: false, error }); + }; + + const sendCancelledResponseIfPending = () => { + sendFailureResponse("OAuth flow cancelled"); + }; + + const server = http.createServer((req, res) => { + // Optionally reject non-loopback connections (Codex sets validateLoopback: true). + if (validateLoopback && !isLoopbackAddress(req.socket.remoteAddress)) { + res.statusCode = 403; + res.end("Forbidden"); + return; + } + + const reqUrl = req.url ?? "/"; + const url = new URL(reqUrl, "http://localhost"); + + if (req.method !== "GET" || url.pathname !== callbackPath) { + res.statusCode = 404; + res.end("Not found"); + return; + } + + const state = url.searchParams.get("state"); + if (!state || state !== options.expectedState) { + const errorMessage = "Invalid OAuth state"; + res.statusCode = 400; + res.setHeader("Content-Type", "text/html"); + res.end(render({ success: false, error: errorMessage })); + + // Intentionally ignore invalid state callbacks. Unrelated localhost + // probes/scanners can hit this endpoint and shouldn't break an in-flight + // OAuth flow. + return; + } + + const code = url.searchParams.get("code"); + const error = url.searchParams.get("error"); + const errorDescription = url.searchParams.get("error_description") ?? undefined; + + if (error) { + const errorMessage = errorDescription ? `${error}: ${errorDescription}` : error; + res.setHeader("Content-Type", "text/html"); + res.statusCode = 400; + res.end(render({ success: false, error: errorMessage })); + deferred.resolve(Err(errorMessage)); + return; + } + + if (!code) { + const errorMessage = "Missing authorization code"; + res.setHeader("Content-Type", "text/html"); + res.statusCode = 400; + res.end(render({ success: false, error: errorMessage })); + deferred.resolve(Err(errorMessage)); + return; + } + + if (deferSuccessResponse) { + if (deferredSuccessResponse && !deferredSuccessResponse.writableEnded) { + res.statusCode = 409; + res.setHeader("Content-Type", "text/html"); + res.end(render({ success: false, error: "OAuth callback already received" })); + return; + } + + deferredSuccessResponse = res; + res.once("close", () => clearDeferredSuccessResponse(res)); + deferred.resolve(Ok({ code, state })); + return; + } + + res.setHeader("Content-Type", "text/html"); + res.statusCode = 200; + res.end(render({ success: true })); + deferred.resolve(Ok({ code, state })); + }); + + // Ensure pending deferred browser responses are completed before server close, + // even when callers close via the raw `server` handle (e.g. OAuthFlowManager). + const originalClose = server.close.bind(server); + server.close = ((callback?: (err?: Error) => void) => { + sendCancelledResponseIfPending(); + return originalClose(callback); + }) as typeof server.close; + + // Listen on the specified host/port — mirrors the existing + // `server.listen(port, host, () => resolve())` pattern. + await new Promise((resolve, reject) => { + server.once("error", reject); + server.listen(port, listenHost, () => resolve()); + }); + + const address = server.address(); + if (!address || typeof address === "string") { + await closeServer(server); + throw new Error("Failed to determine OAuth callback listener port"); + } + + const redirectUri = `http://${hostForRedirectUri(listenHost)}:${address.port}${callbackPath}`; + + return { + redirectUri, + server, + result: deferred.promise, + cancel: async () => { + deferred.resolve(Err("OAuth flow cancelled")); + sendCancelledResponseIfPending(); + await closeServer(server); + }, + close: async () => { + sendCancelledResponseIfPending(); + await closeServer(server); + }, + sendSuccessResponse, + sendFailureResponse, + }; +} diff --git a/src/node/utils/oauthUtils.test.ts b/src/node/utils/oauthUtils.test.ts new file mode 100644 index 0000000000..5e6c1cff3c --- /dev/null +++ b/src/node/utils/oauthUtils.test.ts @@ -0,0 +1,195 @@ +import http from "node:http"; +import { describe, it, expect, afterEach } from "bun:test"; +import { createDeferred, closeServer, escapeHtml, renderOAuthCallbackHtml } from "./oauthUtils"; + +// --------------------------------------------------------------------------- +// createDeferred +// --------------------------------------------------------------------------- + +describe("createDeferred", () => { + it("resolves with the value passed to resolve()", async () => { + const d = createDeferred(); + d.resolve("hello"); + expect(await d.promise).toBe("hello"); + }); + + it("works with numeric types", async () => { + const d = createDeferred(); + d.resolve(42); + expect(await d.promise).toBe(42); + }); + + it("works with object types", async () => { + const d = createDeferred<{ ok: boolean }>(); + d.resolve({ ok: true }); + expect(await d.promise).toEqual({ ok: true }); + }); + + it("can be resolved asynchronously", async () => { + const d = createDeferred(); + setTimeout(() => d.resolve("async"), 5); + expect(await d.promise).toBe("async"); + }); +}); + +// --------------------------------------------------------------------------- +// closeServer +// --------------------------------------------------------------------------- + +describe("closeServer", () => { + let server: http.Server | undefined; + + afterEach(() => { + // Safety net in case a test fails before closing. + if (server?.listening) { + server.close(); + } + server = undefined; + }); + + it("closes a listening HTTP server", async () => { + server = http.createServer(); + await new Promise((resolve) => server!.listen(0, "127.0.0.1", resolve)); + expect(server.listening).toBe(true); + + await closeServer(server); + expect(server.listening).toBe(false); + }); +}); + +// --------------------------------------------------------------------------- +// escapeHtml +// --------------------------------------------------------------------------- + +describe("escapeHtml", () => { + it("escapes ampersand", () => { + expect(escapeHtml("a&b")).toBe("a&b"); + }); + + it("escapes less-than", () => { + expect(escapeHtml("a { + expect(escapeHtml("a>b")).toBe("a>b"); + }); + + it("escapes double quote", () => { + expect(escapeHtml('a"b')).toBe("a"b"); + }); + + it("escapes single quote", () => { + expect(escapeHtml("a'b")).toBe("a'b"); + }); + + it("escapes all special chars together", () => { + expect(escapeHtml(`&<>"'`)).toBe("&<>"'"); + }); + + it("returns plain strings unchanged", () => { + expect(escapeHtml("hello world")).toBe("hello world"); + }); +}); + +// --------------------------------------------------------------------------- +// renderOAuthCallbackHtml +// --------------------------------------------------------------------------- + +describe("renderOAuthCallbackHtml", () => { + it("renders a valid HTML page with the given title and message", () => { + const html = renderOAuthCallbackHtml({ + title: "Test Title", + message: "Test message", + success: true, + }); + expect(html).toContain(""); + expect(html).toContain("Test Title"); + expect(html).toContain("

Test Title

"); + expect(html).toContain("

Test message

"); + }); + + it("includes an auto-close script when success is true", () => { + const html = renderOAuthCallbackHtml({ + title: "Done", + message: "All good", + success: true, + }); + expect(html).toContain("window.close()"); + expect(html).toContain("const ok = true;"); + }); + + it("does NOT auto-close when success is false", () => { + const html = renderOAuthCallbackHtml({ + title: "Failed", + message: "Something went wrong", + success: false, + }); + expect(html).toContain("const ok = false;"); + // The script tag is present but guarded by `if (!ok) return;` + expect(html).toContain("if (!ok) return;"); + }); + + it("escapes title to prevent XSS", () => { + const html = renderOAuthCallbackHtml({ + title: '', + message: "ok", + success: true, + }); + expect(html).not.toContain(" + +`; +}