diff --git a/apps/code/src/main/services/handoff/schemas.ts b/apps/code/src/main/services/handoff/schemas.ts index f9675b6a6..290a818b8 100644 --- a/apps/code/src/main/services/handoff/schemas.ts +++ b/apps/code/src/main/services/handoff/schemas.ts @@ -14,9 +14,14 @@ const handoffApiInput = handoffBaseInput.extend({ teamId: z.number(), }); +export const handoffErrorCodeSchema = z.enum(["github_authorization_required"]); + +export type HandoffErrorCode = z.infer; + const handoffBaseResult = z.object({ success: z.boolean(), error: z.string().optional(), + code: handoffErrorCodeSchema.optional(), }); export const handoffPreflightInput = handoffApiInput; diff --git a/apps/code/src/main/services/handoff/service.test.ts b/apps/code/src/main/services/handoff/service.test.ts index 94bd62800..e2d624153 100644 --- a/apps/code/src/main/services/handoff/service.test.ts +++ b/apps/code/src/main/services/handoff/service.test.ts @@ -65,7 +65,7 @@ vi.mock("@main/di/tokens", () => ({ })); import type { HandoffPreflightInput } from "./schemas"; -import { HandoffService } from "./service"; +import { extractHandoffErrorCode, HandoffService } from "./service"; const DEFAULT_LOCAL_GIT_STATE = { head: "abc123", @@ -172,3 +172,20 @@ describe("HandoffService.preflight", () => { expect(result.localTreeDirty).toBe(false); }); }); + +describe("extractHandoffErrorCode", () => { + it("detects GitHub authorization failures in backend error payloads", () => { + const message = + 'Failed request: [400] {"type":"validation_error","code":"github_authorization_required","detail":"Link a GitHub account"}'; + + expect(extractHandoffErrorCode(message)).toBe( + "github_authorization_required", + ); + }); + + it("ignores unrelated failures", () => { + expect(extractHandoffErrorCode("Failed request: [500] boom")).toBe( + undefined, + ); + }); +}); diff --git a/apps/code/src/main/services/handoff/service.ts b/apps/code/src/main/services/handoff/service.ts index f7f3d3b18..9cb07a6b0 100644 --- a/apps/code/src/main/services/handoff/service.ts +++ b/apps/code/src/main/services/handoff/service.ts @@ -35,6 +35,7 @@ import { type HandoffToCloudSagaDeps, } from "./handoff-to-cloud-saga"; import { + type HandoffErrorCode, HandoffEvent, type HandoffExecuteInput, type HandoffExecuteResult, @@ -49,6 +50,18 @@ import { const log = logger.scope("handoff"); const CONTINUE_DIVERGENCE_BUTTON = 1; +const GITHUB_AUTHORIZATION_REQUIRED_CODE = "github_authorization_required"; +const GITHUB_AUTHORIZATION_REQUIRED_MESSAGE = + "Connect GitHub in your browser, then retry Continue in cloud."; + +export function extractHandoffErrorCode( + message: string | undefined, +): HandoffErrorCode | undefined { + if (message?.includes(GITHUB_AUTHORIZATION_REQUIRED_CODE)) { + return GITHUB_AUTHORIZATION_REQUIRED_CODE; + } + return undefined; +} @injectable() export class HandoffService extends TypedEventEmitter { @@ -350,9 +363,14 @@ export class HandoffService extends TypedEventEmitter { failedStep: result.failedStep, }); deps.onProgress("failed", result.error ?? "Handoff to cloud failed"); + const code = extractHandoffErrorCode(result.error); return { success: false, - error: `Handoff to cloud failed at step '${result.failedStep}': ${result.error}`, + code, + error: + code === GITHUB_AUTHORIZATION_REQUIRED_CODE + ? GITHUB_AUTHORIZATION_REQUIRED_MESSAGE + : `Handoff to cloud failed at step '${result.failedStep}': ${result.error}`, }; } diff --git a/apps/code/src/renderer/api/posthogClient.ts b/apps/code/src/renderer/api/posthogClient.ts index e6c697614..b2003f114 100644 --- a/apps/code/src/renderer/api/posthogClient.ts +++ b/apps/code/src/renderer/api/posthogClient.ts @@ -70,6 +70,19 @@ export type McpInstallationTool = Schemas.MCPServerInstallationTool; export type Evaluation = Schemas.Evaluation; +export interface UserGitHubIntegration { + id: string; + kind: "github"; + installation_id: string; + repository_selection?: string | null; + account?: { + type?: string | null; + name?: string | null; + } | null; + uses_shared_installation?: boolean; + created_at?: string; +} + export interface SignalSourceConfig { id: string; source_product: @@ -153,7 +166,6 @@ interface CloudRunOptions { prAuthorshipMode?: PrAuthorshipMode; runSource?: CloudRunSource; signalReportId?: string; - githubUserToken?: string; initialPermissionMode?: PermissionMode; } @@ -230,9 +242,6 @@ function buildCloudRunRequestBody( if (options?.signalReportId) { body.signal_report_id = options.signalReportId; } - if (options?.githubUserToken) { - body.github_user_token = options.githubUserToken; - } if (options?.initialPermissionMode) { body.initial_permission_mode = options.initialPermissionMode; } @@ -563,7 +572,7 @@ export class PostHogAPIClient { */ async startGithubUserIntegrationConnect(teamId?: number): Promise<{ install_url: string; - connect_flow?: "oauth_authorize" | "app_install"; + connect_flow?: "oauth_authorize" | "oauth_discover" | "app_install"; }> { const id = teamId ?? (await this.getTeamId()); const urlPath = `/api/users/@me/integrations/github/start/`; @@ -588,10 +597,31 @@ export class PostHogAPIClient { } return (await response.json()) as { install_url: string; - connect_flow?: "oauth_authorize" | "app_install"; + connect_flow?: "oauth_authorize" | "oauth_discover" | "app_install"; }; } + async getGithubUserIntegrations(): Promise { + const urlPath = `/api/users/@me/integrations/`; + const url = new URL(`${this.api.baseUrl}${urlPath}`); + const response = await this.api.fetcher.fetch({ + method: "get", + url, + path: urlPath, + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch personal GitHub integrations: ${response.statusText}`, + ); + } + + const data = (await response.json()) as { + results?: UserGitHubIntegration[]; + }; + return data.results ?? []; + } + async switchOrganization(orgId: string): Promise { await this.api.patch("/api/users/{uuid}/", { path: { uuid: "@me" }, @@ -837,17 +867,19 @@ export class PostHogAPIClient { > > & { github_integration?: number | null; + github_user_integration?: string | null; /** POST-only: `SignalReportTask.relationship` to create when linking to `signal_report`. */ signal_report_task_relationship?: SignalReportTaskRelationship; }, ) { const teamId = await this.getTeamId(); + const { origin_product: originProduct, ...taskOptions } = options; const data = await this.api.post(`/api/projects/{project_id}/tasks/`, { path: { project_id: teamId.toString() }, body: { - origin_product: "user_created", - ...options, + ...taskOptions, + origin_product: originProduct ?? "user_created", } as unknown as Schemas.Task, }); @@ -883,6 +915,7 @@ export class PostHogAPIClient { json_schema: task.json_schema, origin_product: task.origin_product, github_integration: task.github_integration, + github_user_integration: task.github_user_integration, }); } @@ -1438,6 +1471,45 @@ export class PostHogAPIClient { }; } + async getGithubUserBranchesPage( + installationId: string | number, + repo: string, + offset: number, + limit: number, + search?: string, + ): Promise<{ + branches: string[]; + defaultBranch: string | null; + hasMore: boolean; + }> { + const urlPath = `/api/users/@me/integrations/github/${installationId}/branches/`; + const url = new URL(`${this.api.baseUrl}${urlPath}`); + url.searchParams.set("repo", repo); + url.searchParams.set("offset", String(offset)); + url.searchParams.set("limit", String(limit)); + if (search?.trim()) { + url.searchParams.set("search", search.trim()); + } + const response = await this.api.fetcher.fetch({ + method: "get", + url, + path: urlPath, + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch personal GitHub branches: ${response.statusText}`, + ); + } + + const data = await response.json(); + return { + branches: data.branches ?? data.results ?? data ?? [], + defaultBranch: data.default_branch ?? null, + hasMore: data.has_more ?? false, + }; + } + async getGithubRepositories( integrationId: string | number, ): Promise { @@ -1497,6 +1569,63 @@ export class PostHogAPIClient { }; } + async getGithubUserRepositories( + installationId: string | number, + ): Promise { + const repositories: string[] = []; + let offset = 0; + + while (true) { + const page = await this.getGithubUserRepositoriesPage( + installationId, + offset, + 500, + ); + repositories.push(...page.repositories); + + if (!page.hasMore) { + return repositories; + } + + offset += page.repositories.length; + } + } + + async getGithubUserRepositoriesPage( + installationId: string | number, + offset: number, + limit: number, + search?: string, + ): Promise<{ + repositories: string[]; + hasMore: boolean; + }> { + const urlPath = `/api/users/@me/integrations/github/${installationId}/repos/`; + const url = new URL(`${this.api.baseUrl}${urlPath}`); + url.searchParams.set("offset", String(offset)); + url.searchParams.set("limit", String(limit)); + if (search?.trim()) { + url.searchParams.set("search", search.trim()); + } + const response = await this.api.fetcher.fetch({ + method: "get", + url, + path: urlPath, + }); + + if (!response.ok) { + throw new Error( + `Failed to fetch personal GitHub repositories: ${response.statusText}`, + ); + } + + const data = await response.json(); + return { + repositories: this.normalizeGithubRepositories(data), + hasMore: data.has_more ?? false, + }; + } + async refreshGithubRepositories( integrationId: string | number, ): Promise { @@ -1520,6 +1649,27 @@ export class PostHogAPIClient { return this.normalizeGithubRepositories(data); } + async refreshGithubUserRepositories( + installationId: string | number, + ): Promise { + const urlPath = `/api/users/@me/integrations/github/${installationId}/repos/refresh/`; + const url = new URL(`${this.api.baseUrl}${urlPath}`); + const response = await this.api.fetcher.fetch({ + method: "post", + url, + path: urlPath, + }); + + if (!response.ok) { + throw new Error( + `Failed to refresh personal GitHub repositories: ${response.statusText}`, + ); + } + + const data = await response.json(); + return this.normalizeGithubRepositories(data); + } + private normalizeGithubRepositories(data: unknown): string[] { const repos = (data as { repositories?: unknown[] }).repositories ?? diff --git a/apps/code/src/renderer/features/inbox/components/list/GitHubConnectionBanner.tsx b/apps/code/src/renderer/features/inbox/components/list/GitHubConnectionBanner.tsx index af78919c9..9213526a0 100644 --- a/apps/code/src/renderer/features/inbox/components/list/GitHubConnectionBanner.tsx +++ b/apps/code/src/renderer/features/inbox/components/list/GitHubConnectionBanner.tsx @@ -2,7 +2,7 @@ import { Button } from "@components/ui/Button"; import { useOptionalAuthenticatedClient } from "@features/auth/hooks/authClient"; import { useAuthStateValue } from "@features/auth/hooks/authQueries"; import { useAuthenticatedQuery } from "@hooks/useAuthenticatedQuery"; -import { useRepositoryIntegration } from "@hooks/useIntegrations"; +import { useUserRepositoryIntegration } from "@hooks/useIntegrations"; import { ArrowSquareOutIcon, GithubLogoIcon, @@ -22,7 +22,6 @@ async function openUrlInBrowser(url: string): Promise { } } -/** Uses project-scoped integrations (see useRepositoryIntegration), not session `current_team`. */ export function GitHubConnectionBanner() { const { data: githubLogin, isLoading: loginLoading } = useAuthenticatedQuery( ["github_login"], @@ -30,7 +29,7 @@ export function GitHubConnectionBanner() { { staleTime: 5 * 60 * 1000 }, ); const { hasGithubIntegration: hasGithubForProject } = - useRepositoryIntegration(); + useUserRepositoryIntegration(); const apiClient = useOptionalAuthenticatedClient(); const projectId = useAuthStateValue((s) => s.projectId); const cloudRegion = useAuthStateValue((s) => s.cloudRegion); @@ -50,6 +49,9 @@ export function GitHubConnectionBanner() { void queryClient.invalidateQueries({ queryKey: ["integrations", "list"], }); + void queryClient.invalidateQueries({ + queryKey: ["user-github-integrations"], + }); } }; window.addEventListener("focus", onFocus); diff --git a/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts b/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts index 5cffeddf7..2a931737a 100644 --- a/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts +++ b/apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts @@ -61,7 +61,7 @@ export const useInboxCloudTaskStore = create()( workspaceMode: "cloud", githubIntegrationId: params.githubIntegrationId, repository: selectedRepo, - cloudPrAuthorshipMode: "user", + cloudPrAuthorshipMode: "bot", cloudRunSource: "signal_report", signalReportId: params.reportId, }); diff --git a/apps/code/src/renderer/features/onboarding/components/GitIntegrationStep.tsx b/apps/code/src/renderer/features/onboarding/components/GitIntegrationStep.tsx index 97f1e2748..a1548d96e 100644 --- a/apps/code/src/renderer/features/onboarding/components/GitIntegrationStep.tsx +++ b/apps/code/src/renderer/features/onboarding/components/GitIntegrationStep.tsx @@ -2,9 +2,11 @@ import { useOptionalAuthenticatedClient } from "@features/auth/hooks/authClient" import { useSelectProjectMutation } from "@features/auth/hooks/authMutations"; import { useAuthStateValue } from "@features/auth/hooks/authQueries"; import { FolderPicker } from "@features/folder-picker/components/FolderPicker"; -import { useIntegrationSelectors } from "@features/integrations/stores/integrationStore"; import { useOnboardingStore } from "@features/onboarding/stores/onboardingStore"; -import { useRepositoryIntegration } from "@hooks/useIntegrations"; +import { + useUserGithubIntegrations, + useUserRepositoryIntegration, +} from "@hooks/useIntegrations"; import { ArrowLeft, ArrowRight, @@ -85,10 +87,13 @@ export function GitIntegrationStep({ [projects, selectedProjectId], ); - const hasGitIntegration = selectedProject?.hasGithubIntegration ?? false; - const { repositories, isLoadingRepos } = useRepositoryIntegration(); - const { githubIntegrations } = useIntegrationSelectors(); - const githubIntegration = githubIntegrations[0] ?? null; + const { + data: githubUserIntegrations = [], + isLoading: githubUserIntegrationsLoading, + } = useUserGithubIntegrations(); + const hasGitIntegration = githubUserIntegrations.length > 0; + const { repositories, isLoadingRepos } = useUserRepositoryIntegration(); + const githubIntegration = githubUserIntegrations[0] ?? null; const alternativeConnectedProject = useMemo(() => { if (hasGitIntegration) return null; @@ -135,11 +140,17 @@ export function GitIntegrationStep({ (projectId: number | null) => { if (projectId === null) { void queryClient.invalidateQueries({ queryKey: ["integrations"] }); + void queryClient.invalidateQueries({ + queryKey: ["user-github-integrations"], + }); return; } void queryClient.invalidateQueries({ queryKey: ["integrations", projectId], }); + void queryClient.invalidateQueries({ + queryKey: ["user-github-integrations"], + }); }, [queryClient], ); @@ -170,10 +181,13 @@ export function GitIntegrationStep({ setTimedOut(false); setConnectingGithub(true); try { - await trpcClient.githubIntegration.startFlow.mutate({ - region: cloudRegion, - projectId: selectedProjectId, - }); + const res = + await client.startGithubUserIntegrationConnect(selectedProjectId); + const installUrl = res.install_url?.trim() ?? ""; + if (!installUrl) { + throw new Error("GitHub connection did not return a URL"); + } + await trpcClient.os.openExternal.mutate({ url: installUrl }); // Dev-only fallback: GitHub returns via posthog-code-dev:// while the // browser flow may not always surface the same path; poll integrations. @@ -182,6 +196,9 @@ export function GitIntegrationStep({ void queryClient.invalidateQueries({ queryKey: ["integrations", selectedProjectId], }); + void queryClient.invalidateQueries({ + queryKey: ["user-github-integrations"], + }); }, POLL_INTERVAL_MS); } @@ -191,8 +208,13 @@ export function GitIntegrationStep({ setConnectingGithub(false); setTimedOut(true); }, POLL_TIMEOUT_MS); - } catch { + } catch (error) { setConnectingGithub(false); + toast.error( + error instanceof Error + ? error.message + : "Failed to start GitHub connection", + ); } }; @@ -391,7 +413,7 @@ export function GitIntegrationStep({ Connect GitHub - {isLoading ? ( + {isLoading || githubUserIntegrationsLoading ? ( ) : hasGitIntegration ? ( @@ -421,17 +443,8 @@ export function GitIntegrationStep({ variant="soft" color="gray" onClick={() => { - const config = githubIntegration?.config as - | { - installation_id?: number; - account?: { - name?: string; - type?: string; - }; - } - | undefined; - const id = config?.installation_id; - const account = config?.account; + const id = githubIntegration?.installation_id; + const account = githubIntegration?.account; const url = id ? account?.type === "Organization" && account.name @@ -452,6 +465,9 @@ export function GitIntegrationStep({ queryClient.invalidateQueries({ queryKey: ["integrations"], }); + queryClient.invalidateQueries({ + queryKey: ["user-github-integrations"], + }); }} > @@ -459,7 +475,7 @@ export function GitIntegrationStep({ - ) : !isLoading ? ( + ) : !isLoading && !githubUserIntegrationsLoading ? ( {timedOut diff --git a/apps/code/src/renderer/features/sessions/service/service.test.ts b/apps/code/src/renderer/features/sessions/service/service.test.ts index e64e4c4a5..69f7ff7ad 100644 --- a/apps/code/src/renderer/features/sessions/service/service.test.ts +++ b/apps/code/src/renderer/features/sessions/service/service.test.ts @@ -43,6 +43,15 @@ const mockTrpcFs = vi.hoisted(() => ({ readFileAsBase64: { query: vi.fn() }, })); +const mockTrpcHandoff = vi.hoisted(() => ({ + preflightToCloud: { query: vi.fn() }, + executeToCloud: { mutate: vi.fn() }, +})); + +const mockTrpcOs = vi.hoisted(() => ({ + openExternal: { mutate: vi.fn() }, +})); + vi.mock("@renderer/trpc/client", () => ({ trpcClient: { agent: mockTrpcAgent, @@ -50,6 +59,8 @@ vi.mock("@renderer/trpc/client", () => ({ logs: mockTrpcLogs, cloudTask: mockTrpcCloudTask, fs: mockTrpcFs, + handoff: mockTrpcHandoff, + os: mockTrpcOs, }, })); @@ -109,6 +120,7 @@ const mockAuthenticatedClient = vi.hoisted(() => ({ finalizeTaskRunArtifactUploads: vi.fn(), prepareTaskStagedArtifactUploads: vi.fn(), finalizeTaskStagedArtifactUploads: vi.fn(), + startGithubUserIntegrationConnect: vi.fn(), })); type MockAuthenticatedClient = typeof mockAuthenticatedClient; @@ -231,11 +243,12 @@ vi.mock("@utils/notifications", () => ({ notifyPromptComplete: vi.fn(), })); vi.mock("@renderer/utils/toast", () => ({ - toast: { error: vi.fn() }, + toast: { error: vi.fn(), info: vi.fn() }, })); vi.mock("@utils/queryClient", () => ({ queryClient: { invalidateQueries: vi.fn(), + refetchQueries: vi.fn(), setQueriesData: vi.fn(), }, })); @@ -282,6 +295,7 @@ vi.mock("@utils/session", async () => { }; }); +import { toast } from "@renderer/utils/toast"; import { getSessionService, resetSessionService } from "./service"; // --- Test Fixtures --- @@ -350,6 +364,14 @@ describe("SessionService", () => { unsubscribe: vi.fn(), }); mockTrpcFs.readFileAsBase64.query.mockResolvedValue(null); + mockTrpcHandoff.preflightToCloud.query.mockResolvedValue({ + canHandoff: true, + }); + mockTrpcHandoff.executeToCloud.mutate.mockResolvedValue({ + success: true, + logEntryCount: 0, + }); + mockTrpcOs.openExternal.mutate.mockResolvedValue(undefined); mockAuthenticatedClient.prepareTaskRunArtifactUploads.mockResolvedValue([]); mockAuthenticatedClient.finalizeTaskRunArtifactUploads.mockResolvedValue( [], @@ -360,6 +382,12 @@ describe("SessionService", () => { mockAuthenticatedClient.finalizeTaskStagedArtifactUploads.mockResolvedValue( [], ); + mockAuthenticatedClient.startGithubUserIntegrationConnect.mockResolvedValue( + { + install_url: "https://github.com/login/oauth/authorize", + connect_flow: "oauth_authorize", + }, + ); }); describe("singleton management", () => { @@ -2391,6 +2419,43 @@ describe("SessionService", () => { }); }); + describe("handoffToCloud", () => { + it("starts GitHub reauth when cloud handoff needs user authorization", async () => { + const service = getSessionService(); + mockSessionStoreSetters.getSessionByTaskId.mockReturnValue( + createMockSession(), + ); + mockTrpcHandoff.executeToCloud.mutate.mockResolvedValue({ + success: false, + code: "github_authorization_required", + error: "Connect GitHub in your browser, then retry Continue in cloud.", + }); + + await service.handoffToCloud("task-123", "/repo/path"); + + expect( + mockAuthenticatedClient.startGithubUserIntegrationConnect, + ).toHaveBeenCalledWith(123); + expect(mockTrpcOs.openExternal.mutate).toHaveBeenCalledWith({ + url: "https://github.com/login/oauth/authorize", + }); + expect(toast.info).toHaveBeenCalledWith( + "Connect GitHub to continue in cloud", + "Complete the authorization in your browser, then click Continue again.", + ); + expect(toast.error).not.toHaveBeenCalledWith( + expect.stringContaining("github_authorization_required"), + ); + expect(mockSessionStoreSetters.updateSession).toHaveBeenCalledWith( + "run-123", + { + handoffInProgress: false, + status: "disconnected", + }, + ); + }); + }); + describe("automatic local recovery", () => { it("reconnects automatically after a subscription error", async () => { vi.useFakeTimers(); diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index eda53cbbb..5dee66e3a 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -39,7 +39,6 @@ import { DEFAULT_GATEWAY_MODEL } from "@posthog/agent/gateway-models"; import { getIsOnline } from "@renderer/stores/connectivityStore"; import { trpc } from "@renderer/trpc"; import { trpcClient } from "@renderer/trpc/client"; -import { getGhUserTokenOrThrow } from "@renderer/utils/github"; import { toast } from "@renderer/utils/toast"; import { type CloudTaskPermissionRequestUpdate, @@ -93,6 +92,16 @@ const LOCAL_SESSION_RECOVERY_MESSAGE = "Lost connection to the agent. Reconnecting…"; const LOCAL_SESSION_RECOVERY_FAILED_MESSAGE = "Connecting to to the agent has been lost. Retry, or start a new session."; +const GITHUB_AUTHORIZATION_REQUIRED_CODE = "github_authorization_required"; + +class GitHubAuthorizationRequiredForCloudHandoffError extends Error { + constructor( + message = "Connect GitHub before continuing this task in cloud.", + ) { + super(message); + this.name = "GitHubAuthorizationRequiredForCloudHandoffError"; + } +} /** * Build default configOptions for cloud sessions so the mode switcher @@ -1733,11 +1742,10 @@ export class SessionService { transport.filePaths, ); - const [previousRun, task] = await Promise.all([ - authCredentials.client.getTaskRun(session.taskId, session.taskRunId), - authCredentials.client.getTask(session.taskId), - ]); - const hasGitHubRepo = !!task.repository && !!task.github_integration; + const previousRun = await authCredentials.client.getTaskRun( + session.taskId, + session.taskRunId, + ); const previousState = previousRun.state as Record; const previousOutput = (previousRun.output ?? {}) as Record< string, @@ -1757,10 +1765,6 @@ export class SessionService { : null) ?? session.cloudBranch; const prAuthorshipMode = this.getCloudPrAuthorshipMode(previousState); - const githubUserToken = - prAuthorshipMode === "user" && hasGitHubRepo - ? await getGhUserTokenOrThrow() - : undefined; log.info("Creating resume run for terminal cloud task", { taskId: session.taskId, @@ -1790,7 +1794,6 @@ export class SessionService { typeof previousState.signal_report_id === "string" ? previousState.signal_report_id : undefined, - githubUserToken, }, ); const newRun = updatedTask.latest_run; @@ -2780,6 +2783,11 @@ export class SessionService { localGitState: preflight.localGitState, }); if (!result.success) { + if (result.code === GITHUB_AUTHORIZATION_REQUIRED_CODE) { + throw new GitHubAuthorizationRequiredForCloudHandoffError( + result.error, + ); + } throw new Error(result.error ?? "Handoff to cloud failed"); } @@ -2803,9 +2811,13 @@ export class SessionService { log.info("Local-to-cloud handoff complete", { taskId, runId }); } catch (err) { log.error("Handoff to cloud failed", { taskId, err }); - toast.error( - err instanceof Error ? err.message : "Handoff to cloud failed", - ); + if (err instanceof GitHubAuthorizationRequiredForCloudHandoffError) { + await this.startGithubReauthForCloudHandoff(auth.projectId); + } else { + toast.error( + err instanceof Error ? err.message : "Handoff to cloud failed", + ); + } this.subscribeToChannel(runId); sessionStoreSetters.updateSession(runId, { handoffInProgress: false, @@ -2814,6 +2826,40 @@ export class SessionService { } } + private async startGithubReauthForCloudHandoff( + projectId: number, + ): Promise { + const client = await getAuthenticatedClient(); + if (!client) { + toast.error("Sign in before connecting GitHub."); + return; + } + + try { + const { install_url: installUrl } = + await client.startGithubUserIntegrationConnect(projectId); + const url = installUrl?.trim(); + if (!url) { + toast.error( + "GitHub connection did not return a URL. Please try again.", + ); + return; + } + + await trpcClient.os.openExternal.mutate({ url }); + toast.info( + "Connect GitHub to continue in cloud", + "Complete the authorization in your browser, then click Continue again.", + ); + } catch (error) { + toast.error( + error instanceof Error + ? error.message + : "Failed to start GitHub connection", + ); + } + } + private async getHandoffAuth(): Promise<{ apiHost: string; projectId: number; diff --git a/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx b/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx index 96bf46f13..1ebc10ec4 100644 --- a/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx +++ b/apps/code/src/renderer/features/task-detail/components/TaskInput.tsx @@ -24,9 +24,9 @@ import { useSettingsStore } from "@features/settings/stores/settingsStore"; import { useAutoFocusOnTyping } from "@hooks/useAutoFocusOnTyping"; import { useConnectivity } from "@hooks/useConnectivity"; import { - useGithubBranches, - useGithubRepositories, - useRepositoryIntegration, + useUserGithubBranches, + useUserGithubRepositories, + useUserRepositoryIntegration, } from "@hooks/useIntegrations"; import { X } from "@phosphor-icons/react"; import { ButtonGroup } from "@posthog/quill"; @@ -177,17 +177,18 @@ export function TaskInput({ const { repositories, - getIntegrationIdForRepo, + getInstallationIdForRepo, + getUserIntegrationIdForRepo, isLoadingRepos, isRefreshingRepos, refreshRepositories, - } = useRepositoryIntegration(); + } = useUserRepositoryIntegration(); const { repositories: visibleCloudRepositories, isPending: cloudRepositoriesLoading, hasMore: cloudRepositoriesHasMore, loadMore: loadMoreCloudRepositories, - } = useGithubRepositories(cloudRepoSearchQuery, isCloudRepoPickerOpen); + } = useUserGithubRepositories(cloudRepoSearchQuery, isCloudRepoPickerOpen); const [selectedRepository, setSelectedRepository] = useState( () => initialCloudRepository?.toLowerCase() ?? @@ -202,8 +203,11 @@ export function TaskInput({ const { currentBranch, branchLoading, defaultBranch } = useGitQueries(selectedDirectory); - const selectedIntegrationId = selectedCloudRepository - ? getIntegrationIdForRepo(selectedCloudRepository) + const selectedGithubUserIntegrationId = selectedCloudRepository + ? getUserIntegrationIdForRepo(selectedCloudRepository) + : undefined; + const selectedInstallationId = selectedCloudRepository + ? getInstallationIdForRepo(selectedCloudRepository) : undefined; const { @@ -214,8 +218,8 @@ export function TaskInput({ hasMore: cloudBranchesHasMore, loadMore: loadMoreCloudBranches, refresh: refreshCloudBranches, - } = useGithubBranches( - selectedIntegrationId, + } = useUserGithubBranches( + selectedInstallationId, selectedCloudRepository, cloudBranchSearchQuery, isCloudBranchPickerOpen, @@ -428,7 +432,7 @@ export function TaskInput({ editorRef, selectedDirectory, selectedRepository: selectedCloudRepository, - githubIntegrationId: selectedIntegrationId, + githubUserIntegrationId: selectedGithubUserIntegrationId, workspaceMode: effectiveWorkspaceMode, branch: branchForTaskCreation, editorIsEmpty, diff --git a/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts b/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts index 8d45858db..e7f36507d 100644 --- a/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts +++ b/apps/code/src/renderer/features/task-detail/hooks/useTaskCreation.ts @@ -28,6 +28,7 @@ interface UseTaskCreationOptions { selectedDirectory: string; selectedRepository?: string | null; githubIntegrationId?: number; + githubUserIntegrationId?: string; workspaceMode: WorkspaceMode; branch?: string | null; editorIsEmpty: boolean; @@ -53,6 +54,7 @@ function prepareTaskInput( selectedDirectory: string; selectedRepository?: string | null; githubIntegrationId?: number; + githubUserIntegrationId?: string; workspaceMode: WorkspaceMode; branch?: string | null; executionMode?: ExecutionMode; @@ -81,6 +83,7 @@ function prepareTaskInput( ? options.selectedRepository : undefined, githubIntegrationId: options.githubIntegrationId, + githubUserIntegrationId: options.githubUserIntegrationId, workspaceMode: options.workspaceMode, branch: options.branch, executionMode: options.executionMode, @@ -118,6 +121,7 @@ export function useTaskCreation({ selectedDirectory, selectedRepository, githubIntegrationId, + githubUserIntegrationId, workspaceMode, branch, editorIsEmpty, @@ -167,6 +171,7 @@ export function useTaskCreation({ selectedDirectory, selectedRepository, githubIntegrationId, + githubUserIntegrationId, workspaceMode, branch, executionMode, @@ -219,6 +224,7 @@ export function useTaskCreation({ selectedDirectory, selectedRepository, githubIntegrationId, + githubUserIntegrationId, workspaceMode, branch, executionMode, diff --git a/apps/code/src/renderer/hooks/useIntegrations.ts b/apps/code/src/renderer/hooks/useIntegrations.ts index 620a412f0..cb8137cbe 100644 --- a/apps/code/src/renderer/hooks/useIntegrations.ts +++ b/apps/code/src/renderer/hooks/useIntegrations.ts @@ -5,6 +5,7 @@ import { useIntegrationSelectors, useIntegrationStore, } from "@features/integrations/stores/integrationStore"; +import type { UserGitHubIntegration } from "@renderer/api/posthogClient"; import { useQueries, useQueryClient } from "@tanstack/react-query"; import { useCallback, @@ -33,6 +34,38 @@ const integrationKeys = { [...integrationKeys.all, "branches", integrationId, repo, search] as const, }; +const userGithubIntegrationKeys = { + all: ["user-github-integrations"] as const, + list: () => [...userGithubIntegrationKeys.all, "list"] as const, + repositories: (installationId?: string) => + [...userGithubIntegrationKeys.all, "repositories", installationId] as const, + repositoryPicker: ( + installationId?: string, + search?: string, + limit?: number, + ) => + [ + ...userGithubIntegrationKeys.all, + "repository-picker", + installationId, + search, + limit, + ] as const, + branches: (installationId?: string, repo?: string | null, search?: string) => + [ + ...userGithubIntegrationKeys.all, + "branches", + installationId, + repo, + search, + ] as const, +}; + +interface UserRepositoryIntegrationRef { + userIntegrationId: string; + installationId: string; +} + export function useIntegrations() { const setIntegrations = useIntegrationStore((state) => state.setIntegrations); @@ -82,6 +115,57 @@ function useAllGithubRepositories(githubIntegrations: Integration[]) { }); } +export function useUserGithubIntegrations() { + return useAuthenticatedQuery(userGithubIntegrationKeys.list(), (client) => + client.getGithubUserIntegrations(), + ); +} + +function useAllUserGithubRepositories( + githubIntegrations: UserGitHubIntegration[], +) { + const client = useOptionalAuthenticatedClient(); + + return useQueries({ + queries: githubIntegrations.map((integration) => ({ + queryKey: userGithubIntegrationKeys.repositories( + integration.installation_id, + ), + queryFn: async () => { + if (!client) throw new Error("Not authenticated"); + const repos = await client.getGithubUserRepositories( + integration.installation_id, + ); + return { + userIntegrationId: integration.id, + installationId: integration.installation_id, + repos, + }; + }, + enabled: !!client, + staleTime: 5 * 60 * 1000, + meta: AUTH_SCOPED_QUERY_META, + })), + combine: (results) => { + const map: Record = {}; + let pending = false; + for (const result of results) { + if (result.isPending) pending = true; + if (!result.data) continue; + for (const repo of result.data.repos ?? []) { + if (!(repo in map)) { + map[repo] = { + userIntegrationId: result.data.userIntegrationId, + installationId: result.data.installationId, + }; + } + } + } + return { repositoryMap: map, isPending: pending }; + }, + }); +} + const REPOSITORIES_PAGE_SIZE = 50; const BRANCHES_FIRST_PAGE_SIZE = 50; const BRANCHES_PAGE_SIZE = 100; @@ -168,6 +252,94 @@ export function useGithubRepositories( }; } +export function useUserGithubRepositories( + search?: string, + enabled: boolean = true, +) { + const client = useOptionalAuthenticatedClient(); + const { data: githubIntegrations = [] } = useUserGithubIntegrations(); + const deferredSearch = useDeferredValue(search?.trim() ?? ""); + const [requestedLimit, setRequestedLimit] = useState(REPOSITORIES_PAGE_SIZE); + const queryEnabled = enabled && !!client && githubIntegrations.length > 0; + + useEffect(() => { + setRequestedLimit(REPOSITORIES_PAGE_SIZE); + }, []); + + const { repositoryMap, isPending, isRefreshing, hasMore } = useQueries({ + queries: githubIntegrations.map((integration) => ({ + queryKey: userGithubIntegrationKeys.repositoryPicker( + integration.installation_id, + deferredSearch, + requestedLimit, + ), + queryFn: async () => { + if (!client) throw new Error("Not authenticated"); + + const page = await client.getGithubUserRepositoriesPage( + integration.installation_id, + 0, + requestedLimit, + deferredSearch, + ); + + return { + userIntegrationId: integration.id, + installationId: integration.installation_id, + ...page, + }; + }, + enabled: queryEnabled, + staleTime: 5 * 60 * 1000, + meta: AUTH_SCOPED_QUERY_META, + })), + combine: (results) => { + const map: Record = {}; + let pending = false; + let refreshing = false; + let hasMoreResults = false; + + for (const result of results) { + if (result.isPending) pending = true; + if (result.isRefetching) refreshing = true; + if (!result.data) continue; + + if (result.data.hasMore) { + hasMoreResults = true; + } + + for (const repo of result.data.repositories ?? []) { + if (!(repo in map)) { + map[repo] = { + userIntegrationId: result.data.userIntegrationId, + installationId: result.data.installationId, + }; + } + } + } + + return { + repositoryMap: map, + isPending: pending, + isRefreshing: refreshing, + hasMore: hasMoreResults, + }; + }, + }); + + const loadMore = useCallback(() => { + setRequestedLimit((currentLimit) => currentLimit + REPOSITORIES_PAGE_SIZE); + }, []); + + return { + repositories: Object.keys(repositoryMap), + isPending: queryEnabled ? isPending : false, + isRefreshing: queryEnabled ? isRefreshing : false, + hasMore, + loadMore, + }; +} + interface GithubBranchesPage { branches: string[]; defaultBranch: string | null; @@ -242,6 +414,150 @@ export function useGithubBranches( }; } +export function useUserGithubBranches( + installationId?: string, + repo?: string | null, + search?: string, + enabled: boolean = true, +) { + const deferredSearch = useDeferredValue(search?.trim() ?? ""); + const queryEnabled = enabled && !!installationId && !!repo; + + const query = useAuthenticatedInfiniteQuery( + userGithubIntegrationKeys.branches(installationId, repo, deferredSearch), + async (client, offset) => { + if (!installationId || !repo) { + return { branches: [], defaultBranch: null, hasMore: false }; + } + const pageSize = + offset === 0 ? BRANCHES_FIRST_PAGE_SIZE : BRANCHES_PAGE_SIZE; + return await client.getGithubUserBranchesPage( + installationId, + repo, + offset, + pageSize, + deferredSearch, + ); + }, + { + enabled: queryEnabled, + initialPageParam: 0, + getNextPageParam: (lastPage, allPages) => { + if (!lastPage.hasMore) return undefined; + return allPages.reduce((n, p) => n + p.branches.length, 0); + }, + }, + ); + + const data = useMemo(() => { + if (!query.data?.pages.length) { + return { branches: [] as string[], defaultBranch: null }; + } + return { + branches: query.data.pages.flatMap((p) => p.branches), + defaultBranch: query.data.pages[0]?.defaultBranch ?? null, + }; + }, [query.data?.pages]); + + const loadMore = useCallback(() => { + if (!query.hasNextPage || query.isFetchingNextPage) { + return; + } + + void query.fetchNextPage(); + }, [query.fetchNextPage, query.hasNextPage, query.isFetchingNextPage]); + + const refresh = useCallback(async () => { + await query.refetch(); + }, [query.refetch]); + + return { + data, + isPending: queryEnabled ? query.isPending : false, + isRefreshing: queryEnabled ? query.isRefetching : false, + isFetchingMore: query.isFetchingNextPage, + hasMore: query.hasNextPage ?? false, + loadMore, + refresh, + }; +} + +export function useUserRepositoryIntegration() { + const client = useOptionalAuthenticatedClient(); + const queryClient = useQueryClient(); + const { data: githubIntegrations = [], isPending: integrationsPending } = + useUserGithubIntegrations(); + const [isRefreshingRepos, setIsRefreshingRepos] = useState(false); + + const { repositoryMap, isPending: reposPending } = + useAllUserGithubRepositories(githubIntegrations); + + const repositories = useMemo( + () => Object.keys(repositoryMap), + [repositoryMap], + ); + + const getUserIntegrationIdForRepo = useCallback( + (repoKey: string) => + repositoryMap[repoKey?.toLowerCase()]?.userIntegrationId, + [repositoryMap], + ); + + const getInstallationIdForRepo = useCallback( + (repoKey: string) => repositoryMap[repoKey?.toLowerCase()]?.installationId, + [repositoryMap], + ); + + const isRepoInIntegration = useCallback( + (repoKey: string) => !repoKey || repoKey.toLowerCase() in repositoryMap, + [repositoryMap], + ); + + const refreshRepositories = useCallback(async () => { + if (!githubIntegrations.length || !client) { + return; + } + + setIsRefreshingRepos(true); + + try { + await Promise.all( + githubIntegrations.map((integration) => + client.refreshGithubUserRepositories(integration.installation_id), + ), + ); + + await Promise.all( + githubIntegrations.map((integration) => + queryClient.refetchQueries({ + queryKey: userGithubIntegrationKeys.repositories( + integration.installation_id, + ), + exact: true, + }), + ), + ); + + await queryClient.refetchQueries({ + queryKey: [...userGithubIntegrationKeys.all, "repository-picker"], + }); + } finally { + setIsRefreshingRepos(false); + } + }, [client, githubIntegrations, queryClient]); + + return { + repositories, + getUserIntegrationIdForRepo, + getInstallationIdForRepo, + isRepoInIntegration, + isLoadingRepos: integrationsPending || reposPending, + isRefreshingRepos, + refreshRepositories, + hasGithubIntegration: githubIntegrations.length > 0, + }; +} + export function useRepositoryIntegration() { const client = useOptionalAuthenticatedClient(); const queryClient = useQueryClient(); diff --git a/apps/code/src/renderer/sagas/task/task-creation.test.ts b/apps/code/src/renderer/sagas/task/task-creation.test.ts index ed7449989..ae214465e 100644 --- a/apps/code/src/renderer/sagas/task/task-creation.test.ts +++ b/apps/code/src/renderer/sagas/task/task-creation.test.ts @@ -166,7 +166,6 @@ describe("TaskCreationSaga", () => { prAuthorshipMode: "bot", runSource: "manual", signalReportId: undefined, - githubUserToken: undefined, initialPermissionMode: "auto", }); expect(startTaskRunMock).toHaveBeenCalledWith("task-123", "run-123", { @@ -351,7 +350,6 @@ describe("TaskCreationSaga", () => { prAuthorshipMode: "bot", runSource: "manual", signalReportId: undefined, - githubUserToken: undefined, initialPermissionMode: "auto", }); expect(startTaskRunMock).toHaveBeenCalledWith("task-123", "run-123", { @@ -369,4 +367,99 @@ describe("TaskCreationSaga", () => { onTaskReady.mock.invocationCallOrder[0], ); }); + + it("uses the selected user GitHub integration for cloud task creation", async () => { + const createdTask = createTask({ + github_user_integration: "user-integration-123", + }); + const startedTask = createTask({ latest_run: createRun() }); + const createTaskMock = vi.fn().mockResolvedValue(createdTask); + const createTaskRunMock = vi.fn().mockResolvedValue(createRun()); + const startTaskRunMock = vi.fn().mockResolvedValue(startedTask); + + const saga = new TaskCreationSaga({ + posthogClient: { + createTask: createTaskMock, + deleteTask: vi.fn(), + getTask: vi.fn(), + createTaskRun: createTaskRunMock, + startTaskRun: startTaskRunMock, + sendRunCommand: vi.fn(), + updateTask: vi.fn(), + } as never, + }); + + const result = await saga.run({ + content: "Ship the fix", + repository: "posthog/posthog", + workspaceMode: "cloud", + branch: "main", + githubUserIntegrationId: "user-integration-123", + }); + + expect(result.success).toBe(true); + expect(createTaskMock).toHaveBeenCalledWith( + expect.objectContaining({ + repository: "posthog/posthog", + github_user_integration: "user-integration-123", + github_integration: undefined, + }), + ); + expect(createTaskRunMock).toHaveBeenCalledWith( + "task-123", + expect.objectContaining({ + prAuthorshipMode: "user", + runSource: "manual", + }), + ); + }); + + it("uses user authorship for repo-less cloud tasks with a selected user GitHub integration", async () => { + const createdTask = createTask({ + repository: null, + github_user_integration: "user-integration-123", + }); + const startedTask = createTask({ + repository: null, + latest_run: createRun(), + }); + const createTaskMock = vi.fn().mockResolvedValue(createdTask); + const createTaskRunMock = vi.fn().mockResolvedValue(createRun()); + const startTaskRunMock = vi.fn().mockResolvedValue(startedTask); + + const saga = new TaskCreationSaga({ + posthogClient: { + createTask: createTaskMock, + deleteTask: vi.fn(), + getTask: vi.fn(), + createTaskRun: createTaskRunMock, + startTaskRun: startTaskRunMock, + sendRunCommand: vi.fn(), + updateTask: vi.fn(), + } as never, + }); + + const result = await saga.run({ + content: "Clone the private repo", + workspaceMode: "cloud", + branch: "main", + githubUserIntegrationId: "user-integration-123", + }); + + expect(result.success).toBe(true); + expect(createTaskMock).toHaveBeenCalledWith( + expect.objectContaining({ + repository: undefined, + github_user_integration: "user-integration-123", + github_integration: undefined, + }), + ); + expect(createTaskRunMock).toHaveBeenCalledWith( + "task-123", + expect.objectContaining({ + prAuthorshipMode: "user", + runSource: "manual", + }), + ); + }); }); diff --git a/apps/code/src/renderer/sagas/task/task-creation.ts b/apps/code/src/renderer/sagas/task/task-creation.ts index bea69991b..523066bf0 100644 --- a/apps/code/src/renderer/sagas/task/task-creation.ts +++ b/apps/code/src/renderer/sagas/task/task-creation.ts @@ -26,7 +26,6 @@ import { type Task, } from "@shared/types"; import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud"; -import { getGhUserTokenOrThrow } from "@utils/github"; import { logger } from "@utils/logger"; import { getCachedTask, queryClient } from "@utils/queryClient"; @@ -83,6 +82,7 @@ export interface TaskCreationInput { workspaceMode?: WorkspaceMode; branch?: string | null; githubIntegrationId?: number; + githubUserIntegrationId?: string; executionMode?: ExecutionMode; adapter?: "claude" | "codex"; model?: string; @@ -274,14 +274,11 @@ export class TaskCreationSaga extends Saga< task = await this.step({ name: "cloud_run", execute: async () => { - const hasGitHubRepo = !!task.repository && !!task.github_integration; + const hasUserGitHubIntegration = + !!input.githubUserIntegrationId || !!task.github_user_integration; const prAuthorshipMode = - input.cloudPrAuthorshipMode ?? (hasGitHubRepo ? "user" : "bot"); - let githubUserToken: string | undefined; - - if (prAuthorshipMode === "user" && hasGitHubRepo) { - githubUserToken = await getGhUserTokenOrThrow(); - } + input.cloudPrAuthorshipMode ?? + (hasUserGitHubIntegration ? "user" : "bot"); const transport = (input.content || input.filePaths?.length) && @@ -299,7 +296,6 @@ export class TaskCreationSaga extends Saga< prAuthorshipMode, runSource: input.cloudRunSource ?? "manual", signalReportId: input.signalReportId, - githubUserToken, initialPermissionMode: input.adapter ? (input.executionMode ?? (input.adapter === "codex" ? "auto" : "plan")) @@ -453,10 +449,18 @@ export class TaskCreationSaga extends Saga< description: input.taskDescription ?? input.content ?? "", repository: repository ?? undefined, github_integration: - input.workspaceMode === "cloud" + input.workspaceMode === "cloud" && + input.cloudRunSource === "signal_report" ? input.githubIntegrationId : undefined, - origin_product: input.signalReportId ? "signal_report" : undefined, + github_user_integration: + input.workspaceMode === "cloud" && + input.cloudRunSource !== "signal_report" + ? input.githubUserIntegrationId + : undefined, + origin_product: input.signalReportId + ? "signal_report" + : "user_created", signal_report: input.signalReportId ?? undefined, signal_report_task_relationship: input.signalReportId ? SIGNAL_REPORT_TASK_IMPLEMENTATION_RELATIONSHIP diff --git a/apps/code/src/renderer/utils/github.ts b/apps/code/src/renderer/utils/github.ts deleted file mode 100644 index 721cf619d..000000000 --- a/apps/code/src/renderer/utils/github.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { trpcClient } from "@renderer/trpc"; - -export async function getGhUserTokenOrThrow(): Promise { - const tokenResult = await trpcClient.git.getGhAuthToken.query(); - if (!tokenResult.success || !tokenResult.token) { - throw new Error( - tokenResult.error || - "Authenticate GitHub CLI with `gh auth login` before starting a cloud task.", - ); - } - return tokenResult.token; -} diff --git a/apps/code/src/shared/types.ts b/apps/code/src/shared/types.ts index 100987023..dc5cd119f 100644 --- a/apps/code/src/shared/types.ts +++ b/apps/code/src/shared/types.ts @@ -46,6 +46,7 @@ export interface Task { origin_product: string; repository?: string | null; // Format: "organization/repository" (e.g., "posthog/posthog-js") github_integration?: number | null; + github_user_integration?: string | null; json_schema?: Record | null; signal_report?: string | null; internal?: boolean;