diff --git a/apps/server/package.json b/apps/server/package.json index 4f57523c..d36f42e4 100644 --- a/apps/server/package.json +++ b/apps/server/package.json @@ -26,6 +26,7 @@ "@anthropic-ai/claude-agent-sdk": "^0.2.77", "@effect/platform-node": "catalog:", "@effect/sql-sqlite-bun": "catalog:", + "@github/copilot-sdk": "^0.2.1", "@pierre/diffs": "^1.1.0-beta.16", "effect": "catalog:", "node-pty": "^1.1.0", diff --git a/apps/server/src/doctor.ts b/apps/server/src/doctor.ts index a949bc10..232f31fc 100644 --- a/apps/server/src/doctor.ts +++ b/apps/server/src/doctor.ts @@ -10,6 +10,7 @@ import { Effect } from "effect"; import { Command } from "effect/unstable/cli"; import { + checkCopilotProviderStatus, checkCodexProviderStatus, checkClaudeProviderStatus, } from "./provider/Layers/ProviderHealth"; @@ -32,6 +33,7 @@ const AUTH_LABELS: Record = { const PROVIDER_LABELS: Record = { codex: "Codex (OpenAI)", claudeAgent: "Claude (Anthropic)", + copilot: "GitHub Copilot", }; function printStatus(status: ServerProviderStatus): void { @@ -65,9 +67,12 @@ const doctorProgram = Effect.gen(function* () { console.log(""); console.log("Checking provider health..."); - const statuses = yield* Effect.all([checkCodexProviderStatus, checkClaudeProviderStatus], { - concurrency: "unbounded", - }); + const statuses = yield* Effect.all( + [checkCodexProviderStatus, checkClaudeProviderStatus, checkCopilotProviderStatus], + { + concurrency: "unbounded", + }, + ); for (const status of statuses) { printStatus(status); @@ -80,6 +85,7 @@ const doctorProgram = Effect.gen(function* () { console.log(""); console.log(" Codex: npm install -g @openai/codex && codex login"); console.log(" Claude: npm install -g @anthropic-ai/claude-code && claude auth login"); + console.log(" Copilot: npm install -g @github/copilot && copilot login"); } else if (readyCount === statuses.length) { console.log("All providers are ready."); } else { diff --git a/apps/server/src/provider/Layers/CopilotAdapter.ts b/apps/server/src/provider/Layers/CopilotAdapter.ts new file mode 100644 index 00000000..b91d0178 --- /dev/null +++ b/apps/server/src/provider/Layers/CopilotAdapter.ts @@ -0,0 +1,1159 @@ +import { + ApprovalRequestId, + type CanonicalItemType, + type CanonicalRequestType, + EventId, + type ProviderApprovalDecision, + type ProviderRuntimeEvent, + type ProviderSendTurnInput, + type ProviderSession, + type ProviderUserInputAnswers, + type RuntimeContentStreamKind, + RuntimeItemId, + RuntimeRequestId, + ThreadId, + type ThreadTokenUsageSnapshot, + TurnId, + type UserInputQuestion, +} from "@okcode/contracts"; +import { CopilotClient, type CopilotSession, type SessionEvent } from "@github/copilot-sdk"; +import { + compactNodeProcessEnv, + mergeNodeProcessEnv, + sanitizeShellEnvironment, +} from "@okcode/shared/environment"; +import { Effect, Layer, Queue, Stream } from "effect"; +import { + ProviderAdapterProcessError, + ProviderAdapterRequestError, + ProviderAdapterSessionClosedError, + ProviderAdapterSessionNotFoundError, + type ProviderAdapterError, +} from "../Errors.ts"; +import { CopilotAdapter, type CopilotAdapterShape } from "../Services/CopilotAdapter.ts"; +import type { EventNdjsonLogger } from "./EventNdjsonLogger.ts"; + +const PROVIDER = "copilot" as const; + +type CopilotPermissionRequest = + | { + readonly kind: "shell"; + readonly fullCommandText?: string; + } + | { + readonly kind: "write"; + readonly fileName?: string; + } + | { + readonly kind: "read"; + readonly path?: string; + } + | { + readonly kind: "mcp"; + readonly toolTitle?: string; + } + | { + readonly kind: "url"; + readonly url?: string; + } + | { + readonly kind: "custom-tool"; + readonly toolName?: string; + } + | { + readonly kind: string; + readonly fullCommandText?: string; + readonly fileName?: string; + readonly path?: string; + readonly toolTitle?: string; + readonly url?: string; + readonly toolName?: string; + }; + +type CopilotPermissionRequestResult = + | { readonly kind: "approved" } + | { readonly kind: "denied-interactively-by-user" }; + +interface CopilotUserInputRequest { + readonly question: string; + readonly choices?: ReadonlyArray; + readonly allowFreeform?: boolean; +} + +interface CopilotUserInputResponse { + readonly answer: string; + readonly wasFreeform: boolean; +} + +interface CopilotResumeCursor { + readonly version: 1; + readonly sessionId: string; +} + +interface PendingApproval { + readonly requestType: CanonicalRequestType; + readonly detail?: string; + readonly promise: Promise; + readonly resolve: (decision: ProviderApprovalDecision) => void; +} + +interface PendingUserInput { + readonly question: UserInputQuestion; + readonly promise: Promise; + readonly resolve: (answers: ProviderUserInputAnswers) => void; +} + +interface CopilotTurnState { + readonly turnId: TurnId; + readonly startedAt: string; + readonly items: Array; + readonly assistantItemIds: Set; + readonly toolItemIds: Set; +} + +interface CopilotSessionContext { + session: ProviderSession; + readonly client: CopilotClient; + readonly copilotSession: CopilotSession; + readonly pendingApprovals: Map; + readonly pendingUserInputs: Map; + readonly turns: Array<{ id: TurnId; items: Array }>; + turnState: CopilotTurnState | undefined; + lastUsage: ThreadTokenUsageSnapshot | undefined; + stopped: boolean; +} + +export interface CopilotAdapterLiveOptions { + readonly nativeEventLogger?: EventNdjsonLogger; +} + +function nowIsoString(): string { + return new Date().toISOString(); +} + +function toMessage(cause: unknown, fallback: string): string { + if (cause instanceof Error && cause.message.length > 0) { + return cause.message; + } + return fallback; +} + +function makeResumeCursor(sessionId: string): CopilotResumeCursor { + return { version: 1, sessionId }; +} + +function readResumeCursor(cursor: unknown): CopilotResumeCursor | undefined { + if ( + typeof cursor === "object" && + cursor !== null && + "version" in cursor && + "sessionId" in cursor && + (cursor as Record).version === 1 && + typeof (cursor as Record).sessionId === "string" + ) { + return cursor as CopilotResumeCursor; + } + return undefined; +} + +function toRequestError(threadId: ThreadId, method: string, cause: unknown): ProviderAdapterError { + const message = toMessage(cause, `${method} failed`); + if (message.toLowerCase().includes("not found")) { + return new ProviderAdapterSessionNotFoundError({ + provider: PROVIDER, + threadId, + cause, + }); + } + return new ProviderAdapterRequestError({ + provider: PROVIDER, + method, + detail: message, + cause, + }); +} + +function toProcessError(threadId: ThreadId, detail: string, cause?: unknown): ProviderAdapterError { + return new ProviderAdapterProcessError({ + provider: PROVIDER, + threadId, + detail, + ...(cause !== undefined ? { cause } : {}), + }); +} + +function mapInteractionModeToCopilotMode( + interactionMode: ProviderSendTurnInput["interactionMode"], +): "interactive" | "plan" | "autopilot" { + switch (interactionMode) { + case "plan": + return "plan"; + case "code": + return "autopilot"; + case "chat": + default: + return "interactive"; + } +} + +function mapApprovalDecisionToPermissionResult( + decision: ProviderApprovalDecision, +): CopilotPermissionRequestResult { + switch (decision) { + case "accept": + case "acceptForSession": + return { kind: "approved" }; + case "decline": + case "cancel": + default: + return { kind: "denied-interactively-by-user" }; + } +} + +function inferRequestType(request: CopilotPermissionRequest): CanonicalRequestType { + switch (request.kind) { + case "shell": + return "exec_command_approval"; + case "write": + return "apply_patch_approval"; + case "read": + return "file_read_approval"; + case "mcp": + case "custom-tool": + return "dynamic_tool_call"; + case "url": + return "unknown"; + default: + return "unknown"; + } +} + +function permissionDetail(request: CopilotPermissionRequest): string | undefined { + if (request.kind === "shell" && typeof request.fullCommandText === "string") { + return request.fullCommandText; + } + if (request.kind === "write" && typeof request.fileName === "string") { + return request.fileName; + } + if (request.kind === "read" && typeof request.path === "string") { + return request.path; + } + if (request.kind === "mcp" && typeof request.toolTitle === "string") { + return request.toolTitle; + } + if (request.kind === "url" && typeof request.url === "string") { + return request.url; + } + if (request.kind === "custom-tool" && typeof request.toolName === "string") { + return request.toolName; + } + return undefined; +} + +function inferToolItemType(toolName: string): CanonicalItemType { + const normalized = toolName.trim().toLowerCase(); + if (normalized === "shell" || normalized.includes("bash")) return "command_execution"; + if (normalized.includes("write") || normalized.includes("edit") || normalized.includes("patch")) + return "file_change"; + if (normalized.includes("read")) return "dynamic_tool_call"; + if (normalized.includes("mcp")) return "mcp_tool_call"; + if (normalized.includes("web") || normalized.includes("fetch") || normalized.includes("search")) + return "web_search"; + if (normalized.includes("image")) return "image_view"; + return "dynamic_tool_call"; +} + +function inferToolStreamKind(itemType: CanonicalItemType): RuntimeContentStreamKind { + switch (itemType) { + case "command_execution": + return "command_output"; + case "file_change": + return "file_change_output"; + default: + return "unknown"; + } +} + +function normalizeUsageFromAssistantEvent( + event: Extract, +) { + const usedTokens = (event.data.inputTokens ?? 0) + (event.data.outputTokens ?? 0); + if (usedTokens <= 0) { + return undefined; + } + return { + usedTokens, + ...(event.data.inputTokens !== undefined ? { inputTokens: event.data.inputTokens } : {}), + ...(event.data.outputTokens !== undefined ? { outputTokens: event.data.outputTokens } : {}), + ...(event.data.cacheReadTokens !== undefined + ? { cachedInputTokens: event.data.cacheReadTokens } + : {}), + ...(event.data.reasoningEffort ? { compactsAutomatically: true } : {}), + ...(event.data.duration !== undefined ? { durationMs: event.data.duration } : {}), + lastUsedTokens: usedTokens, + ...(event.data.inputTokens !== undefined ? { lastInputTokens: event.data.inputTokens } : {}), + ...(event.data.outputTokens !== undefined ? { lastOutputTokens: event.data.outputTokens } : {}), + } satisfies ThreadTokenUsageSnapshot; +} + +function buildUserInputQuestion( + requestId: string, + request: CopilotUserInputRequest, +): UserInputQuestion { + return { + id: requestId, + header: "Copilot", + question: request.question.trim() || "GitHub Copilot needs input.", + options: (request.choices ?? []).map((choice: string) => ({ + label: choice, + description: choice, + })), + }; +} + +function readCopilotProviderOptions(input: { readonly providerOptions?: unknown }) { + if (!input.providerOptions || typeof input.providerOptions !== "object") { + return {}; + } + const providerOptions = input.providerOptions as Record; + const copilot = providerOptions.copilot; + if (!copilot || typeof copilot !== "object") { + return {}; + } + const record = copilot as Record; + return { + ...(typeof record.binaryPath === "string" ? { binaryPath: record.binaryPath } : {}), + ...(typeof record.configDir === "string" ? { configDir: record.configDir } : {}), + }; +} + +function getCopilotReasoningEffort(input: ProviderSendTurnInput): string | undefined { + return input.modelOptions?.copilot?.reasoningEffort; +} + +export function makeCopilotAdapterLive(options?: CopilotAdapterLiveOptions) { + return Layer.effect(CopilotAdapter, makeCopilotAdapter(options)); +} + +const makeCopilotAdapter = (options?: CopilotAdapterLiveOptions) => + Effect.gen(function* () { + const runtimeEventQueue = yield* Queue.unbounded(); + const sessions = new Map(); + + const emitEvent = (event: ProviderRuntimeEvent) => + Queue.offer(runtimeEventQueue, event).pipe( + Effect.tap(() => + options?.nativeEventLogger ? options.nativeEventLogger.write(event, null) : Effect.void, + ), + Effect.asVoid, + ); + + const makeBase = ( + threadId: ThreadId, + extra?: { + readonly turnId?: TurnId; + readonly itemId?: string; + readonly requestId?: string; + }, + raw?: SessionEvent, + ) => ({ + eventId: EventId.makeUnsafe(crypto.randomUUID()), + provider: PROVIDER, + threadId, + createdAt: nowIsoString(), + ...(extra?.turnId ? { turnId: extra.turnId } : {}), + ...(extra?.itemId ? { itemId: RuntimeItemId.makeUnsafe(extra.itemId) } : {}), + ...(extra?.requestId ? { requestId: RuntimeRequestId.makeUnsafe(extra.requestId) } : {}), + providerRefs: {}, + ...(raw + ? { + raw: { + source: "copilot.sdk.event" as const, + messageType: raw.type, + payload: raw, + }, + } + : {}), + }); + + const getContext = (threadId: ThreadId) => { + const context = sessions.get(threadId); + if (!context) { + return Effect.fail( + new ProviderAdapterSessionNotFoundError({ + provider: PROVIDER, + threadId, + }), + ); + } + if (context.stopped) { + return Effect.fail( + new ProviderAdapterSessionClosedError({ + provider: PROVIDER, + threadId, + }), + ); + } + return Effect.succeed(context); + }; + + const ensureAssistantItem = ( + context: CopilotSessionContext, + messageId: string, + raw: SessionEvent, + ) => + Effect.gen(function* () { + const turnState = context.turnState; + if (!turnState) return; + const itemId = `assistant:${messageId}`; + if (turnState.assistantItemIds.has(itemId)) return; + turnState.assistantItemIds.add(itemId); + turnState.items.push({ itemId, itemType: "assistant_message" }); + yield* emitEvent({ + ...makeBase(context.session.threadId, { turnId: turnState.turnId, itemId }, raw), + type: "item.started", + payload: { + itemType: "assistant_message", + title: "Assistant message", + }, + }); + }); + + const ensureToolItem = ( + context: CopilotSessionContext, + toolCallId: string, + toolName: string, + raw: SessionEvent, + ) => + Effect.gen(function* () { + const turnState = context.turnState; + if (!turnState) + return { itemId: `tool:${toolCallId}`, itemType: inferToolItemType(toolName) }; + const itemId = `tool:${toolCallId}`; + const itemType = inferToolItemType(toolName); + if (!turnState.toolItemIds.has(itemId)) { + turnState.toolItemIds.add(itemId); + turnState.items.push({ itemId, itemType, toolName }); + yield* emitEvent({ + ...makeBase(context.session.threadId, { turnId: turnState.turnId, itemId }, raw), + type: "item.started", + payload: { + itemType, + title: toolName, + data: { + toolName, + }, + }, + }); + } + return { itemId, itemType }; + }); + + const completeTurn = ( + context: CopilotSessionContext, + state: "completed" | "failed" | "cancelled" | "interrupted", + raw?: SessionEvent, + errorMessage?: string, + ) => + Effect.gen(function* () { + const turnState = context.turnState; + if (!turnState) return; + context.turns.push({ id: turnState.turnId, items: [...turnState.items] }); + context.turnState = undefined; + context.session = { + ...context.session, + status: state === "failed" ? "error" : "ready", + activeTurnId: undefined, + updatedAt: nowIsoString(), + ...(errorMessage ? { lastError: errorMessage } : {}), + }; + yield* emitEvent({ + ...makeBase(context.session.threadId, { turnId: turnState.turnId }, raw), + type: "turn.completed", + payload: { + state, + ...(context.lastUsage ? { usage: context.lastUsage } : {}), + ...(errorMessage ? { errorMessage } : {}), + }, + }); + yield* emitEvent({ + ...makeBase(context.session.threadId, { turnId: turnState.turnId }, raw), + type: "session.state.changed", + payload: { + state: state === "failed" ? "error" : "ready", + ...(errorMessage ? { reason: errorMessage } : {}), + }, + }); + }); + + const handleSessionEvent = (context: CopilotSessionContext, event: SessionEvent) => + Effect.gen(function* () { + switch (event.type) { + case "assistant.message_delta": { + const turnState = context.turnState; + if (!turnState) return; + yield* ensureAssistantItem(context, event.data.messageId, event); + yield* emitEvent({ + ...makeBase( + context.session.threadId, + { turnId: turnState.turnId, itemId: `assistant:${event.data.messageId}` }, + event, + ), + type: "content.delta", + payload: { + streamKind: "assistant_text", + delta: event.data.deltaContent, + }, + }); + return; + } + case "assistant.message": { + const turnState = context.turnState; + if (!turnState) return; + const itemId = `assistant:${event.data.messageId}`; + yield* ensureAssistantItem(context, event.data.messageId, event); + if (event.data.reasoningText?.trim()) { + const reasoningItemId = `${itemId}:reasoning`; + yield* emitEvent({ + ...makeBase( + context.session.threadId, + { turnId: turnState.turnId, itemId: reasoningItemId }, + event, + ), + type: "item.started", + payload: { + itemType: "reasoning", + title: "Reasoning", + }, + }); + yield* emitEvent({ + ...makeBase( + context.session.threadId, + { turnId: turnState.turnId, itemId: reasoningItemId }, + event, + ), + type: "content.delta", + payload: { + streamKind: "reasoning_text", + delta: event.data.reasoningText, + }, + }); + yield* emitEvent({ + ...makeBase( + context.session.threadId, + { turnId: turnState.turnId, itemId: reasoningItemId }, + event, + ), + type: "item.completed", + payload: { + itemType: "reasoning", + status: "completed", + title: "Reasoning", + data: { + text: event.data.reasoningText, + }, + }, + }); + } + yield* emitEvent({ + ...makeBase(context.session.threadId, { turnId: turnState.turnId, itemId }, event), + type: "item.completed", + payload: { + itemType: "assistant_message", + status: "completed", + title: "Assistant message", + data: { + text: event.data.content, + ...(event.data.toolRequests ? { toolRequests: event.data.toolRequests } : {}), + }, + }, + }); + return; + } + case "assistant.usage": { + context.lastUsage = normalizeUsageFromAssistantEvent(event); + if (!context.lastUsage) return; + yield* emitEvent({ + ...makeBase( + context.session.threadId, + context.turnState ? { turnId: context.turnState.turnId } : undefined, + event, + ), + type: "thread.token-usage.updated", + payload: { + usage: context.lastUsage, + }, + }); + return; + } + case "tool.execution_start": { + const { itemId, itemType } = yield* ensureToolItem( + context, + event.data.toolCallId, + event.data.toolName, + event, + ); + yield* emitEvent({ + ...makeBase( + context.session.threadId, + context.turnState ? { turnId: context.turnState.turnId, itemId } : { itemId }, + event, + ), + type: "item.updated", + payload: { + itemType, + status: "inProgress", + title: event.data.toolName, + data: { + ...(event.data.arguments ? { arguments: event.data.arguments } : {}), + ...(event.data.mcpServerName ? { mcpServerName: event.data.mcpServerName } : {}), + }, + }, + }); + return; + } + case "tool.execution_partial_result": { + const turnState = context.turnState; + if (!turnState) return; + const toolItem = turnState.items.find( + (item) => + typeof item === "object" && + item !== null && + "itemId" in item && + (item as { itemId?: string }).itemId === `tool:${event.data.toolCallId}`, + ) as { itemId?: string; itemType?: CanonicalItemType } | undefined; + const itemId = toolItem?.itemId ?? `tool:${event.data.toolCallId}`; + const itemType = toolItem?.itemType ?? "dynamic_tool_call"; + yield* emitEvent({ + ...makeBase(context.session.threadId, { turnId: turnState.turnId, itemId }, event), + type: "content.delta", + payload: { + streamKind: inferToolStreamKind(itemType), + delta: event.data.partialOutput, + }, + }); + return; + } + case "tool.execution_complete": { + const turnState = context.turnState; + const toolItem = turnState?.items.find( + (item) => + typeof item === "object" && + item !== null && + "itemId" in item && + (item as { itemId?: string }).itemId === `tool:${event.data.toolCallId}`, + ) as { itemId?: string; itemType?: CanonicalItemType; toolName?: string } | undefined; + const itemId = toolItem?.itemId ?? `tool:${event.data.toolCallId}`; + const itemType = toolItem?.itemType ?? "dynamic_tool_call"; + yield* emitEvent({ + ...makeBase( + context.session.threadId, + turnState ? { turnId: turnState.turnId, itemId } : { itemId }, + event, + ), + type: "item.completed", + payload: { + itemType, + status: event.data.success ? "completed" : "failed", + title: toolItem?.toolName ?? "Tool execution", + data: { + success: event.data.success, + ...(event.data.result !== undefined ? { result: event.data.result } : {}), + ...(event.data.error ? { error: event.data.error } : {}), + }, + }, + }); + return; + } + case "session.idle": + yield* completeTurn(context, "completed", event); + return; + case "abort": + yield* completeTurn(context, "interrupted", event, event.data.reason); + return; + case "session.warning": + yield* emitEvent({ + ...makeBase( + context.session.threadId, + context.turnState ? { turnId: context.turnState.turnId } : undefined, + event, + ), + type: "runtime.warning", + payload: { + message: event.data.message, + detail: event.data, + }, + }); + return; + case "session.error": + yield* emitEvent({ + ...makeBase( + context.session.threadId, + context.turnState ? { turnId: context.turnState.turnId } : undefined, + event, + ), + type: "runtime.error", + payload: { + message: event.data.message, + class: "provider_error", + detail: event.data, + }, + }); + yield* completeTurn(context, "failed", event, event.data.message); + return; + default: + return; + } + }); + + const startSession: CopilotAdapterShape["startSession"] = (input) => + Effect.gen(function* () { + const now = nowIsoString(); + const threadId = input.threadId; + const resolvedCwd = input.cwd ?? process.cwd(); + const providerOptions = readCopilotProviderOptions(input); + const sessionEnv = sanitizeShellEnvironment( + mergeNodeProcessEnv( + process.env, + input.env ? compactNodeProcessEnv(input.env) : undefined, + ), + ); + + const client = new CopilotClient({ + ...(providerOptions.binaryPath ? { cliPath: providerOptions.binaryPath } : {}), + cwd: resolvedCwd, + env: sessionEnv, + logLevel: "error", + }); + + yield* Effect.tryPromise({ + try: () => client.start(), + catch: (cause) => + toProcessError(threadId, "Failed to start GitHub Copilot CLI client.", cause), + }); + + const pendingApprovals = new Map(); + let context: CopilotSessionContext | undefined; + const onPermissionRequest = (request: CopilotPermissionRequest) => { + if (input.runtimeMode === "full-access") { + return { kind: "approved" } satisfies CopilotPermissionRequestResult; + } + const requestId = ApprovalRequestId.makeUnsafe( + `copilot-permission:${crypto.randomUUID()}`, + ); + let resolve!: (decision: ProviderApprovalDecision) => void; + const promise = new Promise((resolvePromise) => { + resolve = resolvePromise; + }); + const detail = permissionDetail(request); + pendingApprovals.set(requestId, { + requestType: inferRequestType(request), + ...(detail ? { detail } : {}), + promise, + resolve, + }); + void Effect.runPromise( + emitEvent({ + ...makeBase( + threadId, + context?.turnState + ? { turnId: context.turnState.turnId, requestId } + : { requestId }, + ), + type: "request.opened", + payload: { + requestType: inferRequestType(request), + ...(detail ? { detail } : {}), + args: request, + }, + }), + ); + return promise.then((decision) => { + const result = mapApprovalDecisionToPermissionResult(decision); + void Effect.runPromise( + emitEvent({ + ...makeBase( + threadId, + context?.turnState + ? { turnId: context.turnState.turnId, requestId } + : { requestId }, + ), + type: "request.resolved", + payload: { + requestType: inferRequestType(request), + decision: result.kind, + resolution: result, + }, + }).pipe(Effect.ensuring(Effect.sync(() => pendingApprovals.delete(requestId)))), + ); + return result; + }); + }; + + const reasonEffort = input.modelOptions?.copilot?.reasoningEffort; + const sessionConfig = { + ...(input.model ? { model: input.model } : {}), + ...(reasonEffort ? { reasoningEffort: reasonEffort } : {}), + workingDirectory: resolvedCwd, + streaming: true, + onPermissionRequest, + ...(providerOptions.configDir ? { configDir: providerOptions.configDir } : {}), + }; + + const resumeCursor = readResumeCursor(input.resumeCursor); + const copilotSession = yield* Effect.tryPromise({ + try: () => + resumeCursor + ? client.resumeSession(resumeCursor.sessionId, { + ...sessionConfig, + }) + : client.createSession({ + ...sessionConfig, + }), + catch: (cause) => + toProcessError(threadId, "Failed to create GitHub Copilot session.", cause), + }).pipe( + Effect.tapError(() => + Effect.promise(() => + client + .stop() + .then(() => undefined) + .catch(() => undefined), + ), + ), + ); + + const session: ProviderSession = { + provider: PROVIDER, + status: "ready", + runtimeMode: input.runtimeMode, + cwd: resolvedCwd, + model: input.model, + threadId, + createdAt: now, + updatedAt: now, + resumeCursor: makeResumeCursor(copilotSession.sessionId), + }; + + context = { + session, + client, + copilotSession, + pendingApprovals, + pendingUserInputs: new Map(), + turns: [], + turnState: undefined, + lastUsage: undefined, + stopped: false, + }; + + copilotSession.on((event) => { + void Effect.runPromise(handleSessionEvent(context, event)); + }); + + copilotSession.registerUserInputHandler((request) => { + const requestId = ApprovalRequestId.makeUnsafe( + `copilot-user-input:${crypto.randomUUID()}`, + ); + let resolve!: (answers: ProviderUserInputAnswers) => void; + const promise = new Promise((resolvePromise) => { + resolve = resolvePromise; + }); + const question = buildUserInputQuestion(requestId, request); + context.pendingUserInputs.set(requestId, { + question, + promise, + resolve, + }); + void Effect.runPromise( + emitEvent({ + ...makeBase( + threadId, + context.turnState ? { turnId: context.turnState.turnId, requestId } : { requestId }, + ), + type: "user-input.requested", + payload: { + questions: [question], + }, + }), + ); + return promise.then((answers) => { + const answerValue = answers[requestId] ?? answers[question.id]; + const answer = typeof answerValue === "string" ? answerValue : ""; + void Effect.runPromise( + emitEvent({ + ...makeBase( + threadId, + context.turnState + ? { turnId: context.turnState.turnId, requestId } + : { requestId }, + ), + type: "user-input.resolved", + payload: { + answers: { + [question.id]: answer, + }, + }, + }).pipe( + Effect.ensuring(Effect.sync(() => context.pendingUserInputs.delete(requestId))), + ), + ); + return { + answer, + wasFreeform: !(request.choices ?? []).includes(answer), + } satisfies CopilotUserInputResponse; + }); + }); + + sessions.set(threadId, context); + + yield* emitEvent({ + ...makeBase(threadId), + type: "session.started", + payload: { + message: "GitHub Copilot session started.", + resume: makeResumeCursor(copilotSession.sessionId), + }, + }); + yield* emitEvent({ + ...makeBase(threadId), + type: "thread.started", + payload: { + providerThreadId: copilotSession.sessionId, + }, + }); + yield* emitEvent({ + ...makeBase(threadId), + type: "session.state.changed", + payload: { + state: "ready", + }, + }); + + return context.session; + }); + + const sendTurn: CopilotAdapterShape["sendTurn"] = (input) => + Effect.gen(function* () { + const context = yield* getContext(input.threadId); + if (context.turnState) { + return yield* new ProviderAdapterRequestError({ + provider: PROVIDER, + method: "session.send", + detail: "GitHub Copilot already has an active turn for this thread.", + }); + } + + const turnId = TurnId.makeUnsafe(crypto.randomUUID()); + context.turnState = { + turnId, + startedAt: nowIsoString(), + items: [], + assistantItemIds: new Set(), + toolItemIds: new Set(), + }; + context.lastUsage = undefined; + context.session = { + ...context.session, + status: "running", + activeTurnId: turnId, + updatedAt: nowIsoString(), + ...(input.model ? { model: input.model } : {}), + }; + + const reasoningEffort = getCopilotReasoningEffort(input); + const nextModel = input.model; + if (nextModel) { + yield* Effect.tryPromise({ + try: () => + context.copilotSession.rpc.model.switchTo({ + modelId: nextModel, + ...(reasoningEffort ? { reasoningEffort } : {}), + }), + catch: (cause) => toRequestError(input.threadId, "session.model.switchTo", cause), + }); + } + + yield* Effect.tryPromise({ + try: () => + context.copilotSession.rpc.mode.set({ + mode: mapInteractionModeToCopilotMode(input.interactionMode), + }), + catch: (cause) => toRequestError(input.threadId, "session.mode.set", cause), + }); + + yield* emitEvent({ + ...makeBase(input.threadId, { turnId }), + type: "session.state.changed", + payload: { + state: "running", + }, + }); + yield* emitEvent({ + ...makeBase(input.threadId, { turnId }), + type: "turn.started", + payload: { + ...(context.session.model ? { model: context.session.model } : {}), + ...(reasoningEffort ? { effort: reasoningEffort } : {}), + }, + }); + + yield* Effect.tryPromise({ + try: () => + context.copilotSession.send({ + prompt: input.input ?? "", + }), + catch: (cause) => toRequestError(input.threadId, "session.send", cause), + }); + + return { + threadId: input.threadId, + turnId, + resumeCursor: makeResumeCursor(context.copilotSession.sessionId), + }; + }); + + const interruptTurn: CopilotAdapterShape["interruptTurn"] = (threadId) => + Effect.gen(function* () { + const context = yield* getContext(threadId); + yield* Effect.tryPromise({ + try: () => context.copilotSession.abort(), + catch: (cause) => toRequestError(threadId, "session.abort", cause), + }); + }); + + const respondToRequest: CopilotAdapterShape["respondToRequest"] = ( + threadId, + requestId, + decision, + ) => + Effect.gen(function* () { + const context = yield* getContext(threadId); + const pending = context.pendingApprovals.get(requestId); + if (!pending) { + return yield* new ProviderAdapterRequestError({ + provider: PROVIDER, + method: "session.permissions.handlePendingPermissionRequest", + detail: `Unknown pending approval request '${requestId}'.`, + }); + } + yield* Effect.sync(() => pending.resolve(decision)); + }); + + const respondToUserInput: CopilotAdapterShape["respondToUserInput"] = ( + threadId, + requestId, + answers, + ) => + Effect.gen(function* () { + const context = yield* getContext(threadId); + const pending = context.pendingUserInputs.get(requestId); + if (!pending) { + return yield* new ProviderAdapterRequestError({ + provider: PROVIDER, + method: "session.userInput.respond", + detail: `Unknown pending user input request '${requestId}'.`, + }); + } + yield* Effect.sync(() => pending.resolve(answers)); + }); + + const stopContext = (context: CopilotSessionContext) => + Effect.gen(function* () { + if (context.stopped) return; + context.stopped = true; + for (const pending of context.pendingApprovals.values()) { + pending.resolve("cancel"); + } + for (const pending of context.pendingUserInputs.values()) { + pending.resolve({}); + } + yield* Effect.promise(() => context.copilotSession.disconnect().catch(() => undefined)); + yield* Effect.promise(() => + context.client + .stop() + .then(() => undefined) + .catch(() => undefined), + ); + yield* emitEvent({ + ...makeBase(context.session.threadId), + type: "session.exited", + payload: { + reason: "Session stopped.", + exitKind: "graceful", + }, + }); + }); + + const stopSession: CopilotAdapterShape["stopSession"] = (threadId) => + Effect.gen(function* () { + const context = yield* getContext(threadId); + yield* stopContext(context); + sessions.delete(threadId); + }); + + const listSessions: CopilotAdapterShape["listSessions"] = () => + Effect.sync(() => + [...sessions.values()] + .filter((context) => !context.stopped) + .map((context) => context.session), + ); + + const hasSession: CopilotAdapterShape["hasSession"] = (threadId) => + Effect.sync(() => { + const context = sessions.get(threadId); + return context !== undefined && !context.stopped; + }); + + const readThread: CopilotAdapterShape["readThread"] = (threadId) => + Effect.gen(function* () { + const context = yield* getContext(threadId); + const turns = [...context.turns]; + if (context.turnState) { + turns.push({ id: context.turnState.turnId, items: [...context.turnState.items] }); + } + return { + threadId, + turns, + }; + }); + + const rollbackThread: CopilotAdapterShape["rollbackThread"] = (threadId) => + Effect.fail( + new ProviderAdapterRequestError({ + provider: PROVIDER, + method: "session.rollback", + detail: `GitHub Copilot rollback is not implemented for thread '${threadId}'.`, + }), + ); + + const stopAll: CopilotAdapterShape["stopAll"] = () => + Effect.promise(async () => { + await Promise.all( + [...sessions.values()].map((context) => + Effect.runPromise(stopContext(context).pipe(Effect.asVoid)), + ), + ); + }); + + return { + provider: PROVIDER, + capabilities: { + sessionModelSwitch: "in-session", + }, + startSession, + sendTurn, + interruptTurn, + respondToRequest, + respondToUserInput, + stopSession, + listSessions, + hasSession, + readThread, + rollbackThread, + stopAll, + streamEvents: Stream.fromQueue(runtimeEventQueue), + } satisfies CopilotAdapterShape; + }); + +export const CopilotAdapterLive = makeCopilotAdapterLive(); diff --git a/apps/server/src/provider/Layers/ProviderAdapterRegistry.test.ts b/apps/server/src/provider/Layers/ProviderAdapterRegistry.test.ts index 2702b336..9d8939d1 100644 --- a/apps/server/src/provider/Layers/ProviderAdapterRegistry.test.ts +++ b/apps/server/src/provider/Layers/ProviderAdapterRegistry.test.ts @@ -5,6 +5,7 @@ import { assertFailure } from "@effect/vitest/utils"; import { Effect, Layer, Stream } from "effect"; import { ClaudeAdapter, ClaudeAdapterShape } from "../Services/ClaudeAdapter.ts"; +import { CopilotAdapter, CopilotAdapterShape } from "../Services/CopilotAdapter.ts"; import { CodexAdapter, CodexAdapterShape } from "../Services/CodexAdapter.ts"; import { OpenClawAdapter, OpenClawAdapterShape } from "../Services/OpenClawAdapter.ts"; import { ProviderAdapterRegistry } from "../Services/ProviderAdapterRegistry.ts"; @@ -63,6 +64,23 @@ const fakeOpenClawAdapter: OpenClawAdapterShape = { streamEvents: Stream.empty, }; +const fakeCopilotAdapter: CopilotAdapterShape = { + provider: "copilot", + capabilities: { sessionModelSwitch: "in-session" }, + startSession: vi.fn(), + sendTurn: vi.fn(), + interruptTurn: vi.fn(), + respondToRequest: vi.fn(), + respondToUserInput: vi.fn(), + stopSession: vi.fn(), + listSessions: vi.fn(), + hasSession: vi.fn(), + readThread: vi.fn(), + rollbackThread: vi.fn(), + stopAll: vi.fn(), + streamEvents: Stream.empty, +}; + const layer = it.layer( Layer.mergeAll( Layer.provide( @@ -71,6 +89,7 @@ const layer = it.layer( Layer.succeed(CodexAdapter, fakeCodexAdapter), Layer.succeed(ClaudeAdapter, fakeClaudeAdapter), Layer.succeed(OpenClawAdapter, fakeOpenClawAdapter), + Layer.succeed(CopilotAdapter, fakeCopilotAdapter), ), ), NodeServices.layer, @@ -87,7 +106,7 @@ layer("ProviderAdapterRegistryLive", (it) => { assert.equal(claude, fakeClaudeAdapter); const providers = yield* registry.listProviders(); - assert.deepEqual(providers, ["codex", "claudeAgent", "openclaw"]); + assert.deepEqual(providers, ["codex", "claudeAgent", "openclaw", "copilot"]); }), ); diff --git a/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts b/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts index cb4dd06d..257219a3 100644 --- a/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts +++ b/apps/server/src/provider/Layers/ProviderAdapterRegistry.ts @@ -16,6 +16,7 @@ import { type ProviderAdapterRegistryShape, } from "../Services/ProviderAdapterRegistry.ts"; import { ClaudeAdapter } from "../Services/ClaudeAdapter.ts"; +import { CopilotAdapter } from "../Services/CopilotAdapter.ts"; import { CodexAdapter } from "../Services/CodexAdapter.ts"; import { OpenClawAdapter } from "../Services/OpenClawAdapter.ts"; @@ -28,7 +29,12 @@ const makeProviderAdapterRegistry = (options?: ProviderAdapterRegistryLiveOption const adapters = options?.adapters !== undefined ? options.adapters - : [yield* CodexAdapter, yield* ClaudeAdapter, yield* OpenClawAdapter]; + : [ + yield* CodexAdapter, + yield* ClaudeAdapter, + yield* OpenClawAdapter, + yield* CopilotAdapter, + ]; const byProvider = new Map(adapters.map((adapter) => [adapter.provider, adapter])); const getByProvider: ProviderAdapterRegistryShape["getByProvider"] = (provider) => { diff --git a/apps/server/src/provider/Layers/ProviderHealth.ts b/apps/server/src/provider/Layers/ProviderHealth.ts index 2b24becb..c7e9f793 100644 --- a/apps/server/src/provider/Layers/ProviderHealth.ts +++ b/apps/server/src/provider/Layers/ProviderHealth.ts @@ -9,6 +9,7 @@ * @module ProviderHealthLive */ import * as OS from "node:os"; +import { CopilotClient } from "@github/copilot-sdk"; import type { ServerProviderAuthStatus, ServerProviderStatus, @@ -38,11 +39,20 @@ import { ProviderHealth, type ProviderHealthShape } from "../Services/ProviderHe const DEFAULT_TIMEOUT_MS = 4_000; const CODEX_PROVIDER = "codex" as const; const CLAUDE_AGENT_PROVIDER = "claudeAgent" as const; +const COPILOT_PROVIDER = "copilot" as const; class OpenClawHealthProbeError extends Data.TaggedError("OpenClawHealthProbeError")<{ cause: unknown; }> {} +class CopilotHealthProbeError extends Data.TaggedError("CopilotHealthProbeError")<{ + cause: unknown; +}> {} + +function formatHealthProbeCause(cause: unknown): string { + return cause instanceof Error ? cause.message : String(cause); +} + // ── Pure helpers ──────────────────────────────────────────────────── export interface CommandResult { @@ -298,6 +308,98 @@ const runClaudeCommand = (args: ReadonlyArray) => return { stdout, stderr, code: exitCode } satisfies CommandResult; }).pipe(Effect.scoped); +export const checkCopilotProviderStatus: Effect.Effect = + Effect.gen(function* () { + const checkedAt = new Date().toISOString(); + const client = new CopilotClient({ logLevel: "error" }); + + const started = yield* Effect.tryPromise({ + try: () => client.start(), + catch: (cause) => new CopilotHealthProbeError({ cause }), + }).pipe(Effect.timeoutOption(DEFAULT_TIMEOUT_MS), Effect.result); + + if (Result.isFailure(started)) { + const error = started.failure; + return { + provider: COPILOT_PROVIDER, + status: "error" as const, + available: false, + authStatus: "unknown" as const, + checkedAt, + message: + error instanceof CopilotHealthProbeError + ? `Failed to start GitHub Copilot CLI: ${formatHealthProbeCause(error.cause)}.` + : "Failed to start GitHub Copilot CLI.", + } satisfies ServerProviderStatus; + } + + if (Option.isNone(started.success)) { + yield* Effect.promise(() => + client + .forceStop() + .then(() => undefined) + .catch(() => undefined), + ); + return { + provider: COPILOT_PROVIDER, + status: "error" as const, + available: false, + authStatus: "unknown" as const, + checkedAt, + message: "GitHub Copilot CLI timed out while starting.", + } satisfies ServerProviderStatus; + } + + const authResult = yield* Effect.tryPromise({ + try: () => client.getAuthStatus(), + catch: (cause) => new CopilotHealthProbeError({ cause }), + }).pipe(Effect.timeoutOption(DEFAULT_TIMEOUT_MS), Effect.result); + yield* Effect.promise(() => + client + .stop() + .then(() => undefined) + .catch(() => undefined), + ); + + if (Result.isFailure(authResult)) { + const error = authResult.failure; + return { + provider: COPILOT_PROVIDER, + status: "warning" as const, + available: true, + authStatus: "unknown" as const, + checkedAt, + message: + error instanceof CopilotHealthProbeError + ? `Could not verify GitHub Copilot authentication status: ${formatHealthProbeCause(error.cause)}.` + : "Could not verify GitHub Copilot authentication status.", + } satisfies ServerProviderStatus; + } + + if (Option.isNone(authResult.success)) { + return { + provider: COPILOT_PROVIDER, + status: "warning" as const, + available: true, + authStatus: "unknown" as const, + checkedAt, + message: "Could not verify GitHub Copilot authentication status. Timed out while checking.", + } satisfies ServerProviderStatus; + } + + const authStatus = authResult.success.value; + return { + provider: COPILOT_PROVIDER, + status: authStatus.isAuthenticated ? ("ready" as const) : ("error" as const), + available: true, + authStatus: authStatus.isAuthenticated + ? ("authenticated" as const) + : ("unauthenticated" as const), + checkedAt, + ...(authStatus.statusMessage ? { message: authStatus.statusMessage } : {}), + } satisfies ServerProviderStatus; + }); + // ── Health check ──────────────────────────────────────────────────── export const checkCodexProviderStatus: Effect.Effect< @@ -687,7 +789,12 @@ export const ProviderHealthLive = Layer.effect( ProviderHealth, Effect.gen(function* () { const statusesFiber = yield* Effect.all( - [checkCodexProviderStatus, checkClaudeProviderStatus, checkOpenClawProviderStatus], + [ + checkCodexProviderStatus, + checkClaudeProviderStatus, + checkCopilotProviderStatus, + checkOpenClawProviderStatus, + ], { concurrency: "unbounded", }, diff --git a/apps/server/src/provider/Services/CopilotAdapter.ts b/apps/server/src/provider/Services/CopilotAdapter.ts new file mode 100644 index 00000000..2677e4ca --- /dev/null +++ b/apps/server/src/provider/Services/CopilotAdapter.ts @@ -0,0 +1,12 @@ +import { ServiceMap } from "effect"; + +import type { ProviderAdapterError } from "../Errors.ts"; +import type { ProviderAdapterShape } from "./ProviderAdapter.ts"; + +export interface CopilotAdapterShape extends ProviderAdapterShape { + readonly provider: "copilot"; +} + +export class CopilotAdapter extends ServiceMap.Service()( + "okcode/provider/Services/CopilotAdapter", +) {} diff --git a/apps/server/src/serverLayers.ts b/apps/server/src/serverLayers.ts index b7ee3bb6..7c3a0578 100644 --- a/apps/server/src/serverLayers.ts +++ b/apps/server/src/serverLayers.ts @@ -18,6 +18,7 @@ import { ProviderRuntimeIngestionLive } from "./orchestration/Layers/ProviderRun import { RuntimeReceiptBusLive } from "./orchestration/Layers/RuntimeReceiptBus"; import { ProviderUnsupportedError } from "./provider/Errors"; import { makeClaudeAdapterLive } from "./provider/Layers/ClaudeAdapter"; +import { makeCopilotAdapterLive } from "./provider/Layers/CopilotAdapter"; import { makeCodexAdapterLive } from "./provider/Layers/CodexAdapter"; import { makeOpenClawAdapterLive } from "./provider/Layers/OpenClawAdapter"; import { ProviderAdapterRegistryLive } from "./provider/Layers/ProviderAdapterRegistry"; @@ -95,10 +96,14 @@ export function makeServerProviderLayer(): Layer.Layer< const openclawAdapterLayer = makeOpenClawAdapterLive( nativeEventLogger ? { nativeEventLogger } : undefined, ); + const copilotAdapterLayer = makeCopilotAdapterLive( + nativeEventLogger ? { nativeEventLogger } : undefined, + ); const adapterRegistryLayer = ProviderAdapterRegistryLive.pipe( Layer.provide(codexAdapterLayer), Layer.provide(claudeAdapterLayer), Layer.provide(openclawAdapterLayer), + Layer.provide(copilotAdapterLayer), Layer.provideMerge(providerSessionDirectoryLayer), ); return makeProviderServiceLive( diff --git a/apps/server/src/sme/Layers/SmeChatServiceLive.ts b/apps/server/src/sme/Layers/SmeChatServiceLive.ts index 4e627115..0788b46b 100644 --- a/apps/server/src/sme/Layers/SmeChatServiceLive.ts +++ b/apps/server/src/sme/Layers/SmeChatServiceLive.ts @@ -192,6 +192,14 @@ const makeSmeChatService = (options: SmeChatServiceLiveOptions = {}) => new SmeChatError("validateSetup", "Failed to validate Codex setup.", cause), }); + case "copilot": + return { + ok: false, + severity: "warning" as const, + message: "GitHub Copilot is not available in SME Chat yet.", + resolvedAuthMethod: "auto" as const, + }; + case "openclaw": return validateOpenClawSetup({ authMethod: conversation.authMethod as Extract< @@ -499,18 +507,25 @@ const makeSmeChatService = (options: SmeChatServiceLiveOptions = {}) => abortSignal: controller.signal, }).pipe(Effect.ensuring(clearInterrupt(input.conversationId))); }) - : sendSmeViaProviderRuntime({ - providerService, - provider: conv.provider, - conversationId: input.conversationId, - assistantMessageId, - model: conv.model, - compiledPrompt, - ...(input.providerOptions ? { providerOptions: input.providerOptions } : {}), - ...(onEvent ? { onEvent } : {}), - setInterruptEffect: (interrupt) => setInterrupt(input.conversationId, interrupt), - clearInterruptEffect: clearInterrupt(input.conversationId), - }); + : conv.provider === "codex" || conv.provider === "openclaw" + ? sendSmeViaProviderRuntime({ + providerService, + provider: conv.provider, + conversationId: input.conversationId, + assistantMessageId, + model: conv.model, + compiledPrompt, + ...(input.providerOptions ? { providerOptions: input.providerOptions } : {}), + ...(onEvent ? { onEvent } : {}), + setInterruptEffect: (interrupt) => setInterrupt(input.conversationId, interrupt), + clearInterruptEffect: clearInterrupt(input.conversationId), + }) + : Effect.fail( + new SmeChatError( + "sendMessage:validate", + "GitHub Copilot is not available in SME Chat yet.", + ), + ); const responseText = yield* sendEffect.pipe( Effect.mapError((cause) => diff --git a/apps/server/src/sme/authValidation.ts b/apps/server/src/sme/authValidation.ts index 235cdaa5..52c430c6 100644 --- a/apps/server/src/sme/authValidation.ts +++ b/apps/server/src/sme/authValidation.ts @@ -63,6 +63,8 @@ export function getAllowedSmeAuthMethods(provider: ProviderKind): readonly SmeAu switch (provider) { case "claudeAgent": return ["auto", "apiKey", "authToken"]; + case "copilot": + return ["auto"]; case "codex": return ["auto", "chatgpt", "apiKey", "customProvider"]; case "openclaw": @@ -74,6 +76,8 @@ export function getDefaultSmeAuthMethod(provider: ProviderKind): SmeAuthMethod { switch (provider) { case "claudeAgent": return "apiKey"; + case "copilot": + return "auto"; case "codex": return "chatgpt"; case "openclaw": diff --git a/apps/web/src/appSettings.ts b/apps/web/src/appSettings.ts index a7381818..56a6b446 100644 --- a/apps/web/src/appSettings.ts +++ b/apps/web/src/appSettings.ts @@ -36,7 +36,11 @@ export const DEFAULT_SIDEBAR_THREAD_SORT_ORDER: SidebarThreadSortOrder = "update export const PrReviewRequestChangesTone = Schema.Literals(["warning", "brand", "neutral"]); export type PrReviewRequestChangesTone = typeof PrReviewRequestChangesTone.Type; export const DEFAULT_PR_REVIEW_REQUEST_CHANGES_TONE: PrReviewRequestChangesTone = "warning"; -type CustomModelSettingsKey = "customCodexModels" | "customClaudeModels" | "customOpenClawModels"; +type CustomModelSettingsKey = + | "customCodexModels" + | "customClaudeModels" + | "customOpenClawModels" + | "customCopilotModels"; export type ProviderCustomModelConfig = { provider: ProviderKind; settingsKey: CustomModelSettingsKey; @@ -51,6 +55,7 @@ const BUILT_IN_MODEL_SLUGS_BY_PROVIDER: Record codex: new Set(getModelOptions("codex").map((option) => option.slug)), claudeAgent: new Set(getModelOptions("claudeAgent").map((option) => option.slug)), openclaw: new Set(getModelOptions("openclaw").map((option) => option.slug)), + copilot: new Set(getModelOptions("copilot").map((option) => option.slug)), }; const withDefaults = @@ -68,6 +73,8 @@ const withDefaults = export const AppSettingsSchema = Schema.Struct({ claudeBinaryPath: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), + copilotBinaryPath: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), + copilotConfigDir: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), codexBinaryPath: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), codexHomePath: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), backgroundImageUrl: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), @@ -102,6 +109,7 @@ export const AppSettingsSchema = Schema.Struct({ codeViewerAutosave: Schema.Boolean.pipe(withDefaults(() => false)), customCodexModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), customClaudeModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), + customCopilotModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), customOpenClawModels: Schema.Array(Schema.String).pipe(withDefaults(() => [])), openclawGatewayUrl: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), openclawPassword: Schema.String.check(Schema.isMaxLength(4096)).pipe(withDefaults(() => "")), @@ -135,6 +143,15 @@ const PROVIDER_CUSTOM_MODEL_CONFIG: Record store.setPromptEnhancementState, ); - const setComposerDraftProvider = useComposerDraftStore((store) => store.setProvider); - const setComposerDraftModel = useComposerDraftStore((store) => store.setModel); + const setComposerDraftProviderAndModel = useComposerDraftStore( + (store) => store.setProviderAndModel, + ); const setComposerDraftRuntimeMode = useComposerDraftStore((store) => store.setRuntimeMode); const setComposerDraftInteractionMode = useComposerDraftStore( (store) => store.setInteractionMode, @@ -4251,8 +4252,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { return; } const resolvedModel = resolveAppModelSelection(provider, customModelsByProvider, model); - setComposerDraftProvider(activeThread.id, provider); - setComposerDraftModel(activeThread.id, resolvedModel); + setComposerDraftProviderAndModel(activeThread.id, provider, resolvedModel); setStickyComposerModel(resolvedModel); scheduleComposerFocus(); }, @@ -4260,8 +4260,7 @@ export default function ChatView({ threadId, onMinimize }: ChatViewProps) { activeThread, lockedProvider, scheduleComposerFocus, - setComposerDraftModel, - setComposerDraftProvider, + setComposerDraftProviderAndModel, setStickyComposerModel, customModelsByProvider, ], diff --git a/apps/web/src/components/chat/ProviderModelPicker.browser.tsx b/apps/web/src/components/chat/ProviderModelPicker.browser.tsx index 4ebba720..828973af 100644 --- a/apps/web/src/components/chat/ProviderModelPicker.browser.tsx +++ b/apps/web/src/components/chat/ProviderModelPicker.browser.tsx @@ -15,6 +15,7 @@ const MODEL_OPTIONS_BY_PROVIDER = { { slug: "gpt-5-codex", name: "GPT-5 Codex" }, { slug: "gpt-5.3-codex", name: "GPT-5.3 Codex" }, ], + copilot: [{ slug: "claude-sonnet-4-5", name: "Claude Sonnet 4.5" }], openclaw: [], } as const satisfies Record>; @@ -51,7 +52,7 @@ describe("ProviderModelPicker", () => { document.body.innerHTML = ""; }); - it("shows both Codex and Anthropic model groups when provider switching is allowed", async () => { + it("shows providers on the left and swaps models in the right pane", async () => { const mounted = await mountPicker({ provider: "claudeAgent", model: "claude-opus-4-6", @@ -63,11 +64,21 @@ describe("ProviderModelPicker", () => { await vi.waitFor(() => { const text = document.body.textContent ?? ""; + expect(text).toContain("Providers"); + expect(text).toContain("Models"); expect(text).toContain("Codex"); expect(text).toContain("Anthropic"); - expect(text).toContain("GPT-5 Codex"); expect(text).toContain("Claude Sonnet 4.6"); expect(text).toContain("Claude Haiku 4.5"); + expect(text).not.toContain("GPT-5 Codex"); + }); + + await page.getByRole("button", { name: /Codex/ }).click(); + + await vi.waitFor(() => { + const text = document.body.textContent ?? ""; + expect(text).toContain("GPT-5 Codex"); + expect(text).toContain("GPT-5.3 Codex"); }); } finally { await mounted.cleanup(); @@ -86,6 +97,7 @@ describe("ProviderModelPicker", () => { await vi.waitFor(() => { const text = document.body.textContent ?? ""; + expect(text).toContain("Locked"); expect(text).toContain("Claude Sonnet 4.6"); expect(text).toContain("Claude Haiku 4.5"); expect(text).not.toContain("Codex"); @@ -99,17 +111,15 @@ describe("ProviderModelPicker", () => { const mounted = await mountPicker({ provider: "claudeAgent", model: "claude-opus-4-6", - lockedProvider: "claudeAgent", + lockedProvider: null, }); try { await page.getByRole("button").click(); - await page.getByRole("menuitemradio", { name: "Claude Sonnet 4.6" }).click(); + await page.getByRole("button", { name: /Codex/ }).click(); + await page.getByRole("button", { name: "GPT-5 Codex" }).click(); - expect(mounted.onProviderModelChange).toHaveBeenCalledWith( - "claudeAgent", - "claude-sonnet-4-6", - ); + expect(mounted.onProviderModelChange).toHaveBeenCalledWith("codex", "gpt-5-codex"); } finally { await mounted.cleanup(); } diff --git a/apps/web/src/components/chat/ProviderModelPicker.tsx b/apps/web/src/components/chat/ProviderModelPicker.tsx index 1b1fc8db..7da07eec 100644 --- a/apps/web/src/components/chat/ProviderModelPicker.tsx +++ b/apps/web/src/components/chat/ProviderModelPicker.tsx @@ -1,21 +1,20 @@ import { type ModelSlug, type ProviderKind } from "@okcode/contracts"; import { resolveSelectableModel } from "@okcode/shared/model"; -import { memo, useState } from "react"; +import { memo, useCallback, useRef, useState } from "react"; import { type ProviderPickerKind, PROVIDER_OPTIONS } from "../../session-logic"; -import { ChevronDownIcon } from "lucide-react"; +import { CheckIcon, ChevronDownIcon } from "lucide-react"; import { Button } from "../ui/button"; +import { Menu, MenuPopup, MenuTrigger } from "../ui/menu"; import { - Menu, - MenuGroup, - MenuGroupLabel, - MenuItem, - MenuPopup, - MenuRadioGroup, - MenuRadioItem, - MenuSeparator as MenuDivider, - MenuTrigger, -} from "../ui/menu"; -import { ClaudeAI, CursorIcon, Gemini, Icon, OpenAI, OpenClawIcon, OpenCodeIcon } from "../Icons"; + ClaudeAI, + CursorIcon, + Gemini, + GitHubIcon, + Icon, + OpenAI, + OpenClawIcon, + OpenCodeIcon, +} from "../Icons"; import { cn } from "~/lib/utils"; function isAvailableProviderOption(option: (typeof PROVIDER_OPTIONS)[number]): option is { @@ -30,6 +29,7 @@ const PROVIDER_ICON_BY_PROVIDER: Record = { codex: OpenAI, claudeAgent: ClaudeAI, openclaw: OpenClawIcon, + copilot: GitHubIcon, cursor: CursorIcon, }; @@ -46,6 +46,7 @@ function providerIconClassName( ): string { if (provider === "claudeAgent") return "text-[#d97757]"; if (provider === "openclaw") return "text-[#6cb4ee]"; + if (provider === "copilot") return "text-white/85"; return fallbackClassName; } @@ -64,11 +65,26 @@ export const ProviderModelPicker = memo(function ProviderModelPicker(props: { onProviderModelChange: (provider: ProviderKind, model: ModelSlug) => void; }) { const [isMenuOpen, setIsMenuOpen] = useState(false); + const [activeMenuProvider, setActiveMenuProvider] = useState( + props.lockedProvider ?? props.provider, + ); + const menuJustOpenedRef = useRef(false); + const setActiveMenuProviderFromInteraction = useCallback((provider: ProviderKind) => { + if (menuJustOpenedRef.current) return; + setActiveMenuProvider(provider); + }, []); const activeProvider = props.lockedProvider ?? props.provider; const selectedProviderOptions = props.modelOptionsByProvider[activeProvider]; const selectedModelLabel = selectedProviderOptions.find((option) => option.slug === props.model)?.name ?? props.model; const ProviderIcon = PROVIDER_ICON_BY_PROVIDER[activeProvider]; + const previewProvider = props.lockedProvider ?? activeMenuProvider; + const previewProviderOptions = props.modelOptionsByProvider[previewProvider]; + const PreviewProviderIcon = PROVIDER_ICON_BY_PROVIDER[previewProvider]; + const providerList = + props.lockedProvider === null + ? AVAILABLE_PROVIDER_OPTIONS + : AVAILABLE_PROVIDER_OPTIONS.filter((option) => option.value === props.lockedProvider); const handleModelChange = (provider: ProviderKind, value: string) => { if (props.disabled) return; if (!value) return; @@ -90,6 +106,10 @@ export const ProviderModelPicker = memo(function ProviderModelPicker(props: { setIsMenuOpen(false); return; } + if (open) { + setActiveMenuProvider(props.lockedProvider ?? props.provider); + menuJustOpenedRef.current = true; + } setIsMenuOpen(open); }} > @@ -124,92 +144,157 @@ export const ProviderModelPicker = memo(function ProviderModelPicker(props: {