diff --git a/.changeset/add-step-context.md b/.changeset/add-step-context.md new file mode 100644 index 000000000000..c57994d15a6b --- /dev/null +++ b/.changeset/add-step-context.md @@ -0,0 +1,17 @@ +--- +"@cloudflare/workflows-shared": minor +--- + +Adds step context with attempt count to step.do() callbacks. + +Workflow step callbacks now receive a context object containing the current attempt number (1-indexed). +This allows developers to access which retry attempt is currently executing. + +Example: + +```ts +await step.do("my-step", async (ctx) => { + // ctx.attempt is 1 on first try, 2 on first retry, etc. + console.log(`Attempt ${ctx.attempt}`); +}); +``` diff --git a/.changeset/short-sloths-bake.md b/.changeset/short-sloths-bake.md new file mode 100644 index 000000000000..f9b0cc4a5d62 --- /dev/null +++ b/.changeset/short-sloths-bake.md @@ -0,0 +1,7 @@ +--- +"@cloudflare/workers-shared": patch +--- + +fix: Normalize backslash characters in `/cdn-cgi` paths + +Requests containing backslash characters in `/cdn-cgi` paths are now redirected to their normalized equivalents with forward slashes. This ensures consistent URL handling across different browsers and HTTP clients. diff --git a/.changeset/silver-coins-take.md b/.changeset/silver-coins-take.md new file mode 100644 index 000000000000..cdd6c5ee7c74 --- /dev/null +++ b/.changeset/silver-coins-take.md @@ -0,0 +1,17 @@ +--- +"wrangler": minor +--- + +feat(hyperdrive): add MySQL SSL mode and Custom CA support + +Hyperdrive now supports MySQL-specific SSL modes (`REQUIRED`, `VERIFY_CA`, `VERIFY_IDENTITY`) alongside the existing PostgreSQL modes. The `--sslmode` flag now validates the provided value based on the database scheme (PostgreSQL or MySQL) and enforces appropriate CA certificate requirements for each. + +**Usage:** + +```sh +# MySQL with CA verification +wrangler hyperdrive create my-config --connection-string="mysql://user:pass@host:3306/db" --sslmode=VERIFY_CA --ca-certificate-id= + +# PostgreSQL (unchanged) +wrangler hyperdrive create my-config --connection-string="postgres://user:pass@host:5432/db" --sslmode=verify-full --ca-certificate-id= +``` diff --git a/.github/workflows/codeowners.yml b/.github/workflows/codeowners.yml index 5e4a24ae5518..28438abb1ebf 100644 --- a/.github/workflows/codeowners.yml +++ b/.github/workflows/codeowners.yml @@ -1,6 +1,8 @@ name: "Code Owners" -# Re-evaluate when PRs are opened/updated, and when reviews are submitted/dismissed. +# Re-evaluate when PRs are opened/updated. +# When reviews are submitted/dismissed, the separate rerun_codeowners.yml workflow +# re-runs this check (rather than creating a second check context). # Using pull_request_target (not pull_request) so the workflow has access to secrets # for fork PRs. This is safe because: # - The checkout is the BASE branch (ownership rules come from the protected branch) @@ -9,8 +11,6 @@ name: "Code Owners" on: pull_request_target: types: [opened, reopened, synchronize, ready_for_review, labeled, unlabeled] - pull_request_review: - types: [submitted, dismissed] concurrency: group: codeowners-${{ github.event.pull_request.number }} @@ -27,16 +27,19 @@ jobs: runs-on: ubuntu-latest steps: - name: "Checkout Base Branch" + if: github.event.pull_request.head.ref != 'changeset-release/main' uses: actions/checkout@v4 with: fetch-depth: 0 - name: "Fetch PR Head (for diff computation)" + if: github.event.pull_request.head.ref != 'changeset-release/main' run: git fetch origin +refs/pull/${{ github.event.pull_request.number }}/head env: GITHUB_TOKEN: "${{ secrets.CODEOWNERS_GITHUB_PAT }}" - name: "Codeowners Plus" + if: github.event.pull_request.head.ref != 'changeset-release/main' uses: multimediallc/codeowners-plus@ff02aa993a92e8efe01642916d0877beb9439e9f # v1.9.0 with: github-token: "${{ secrets.CODEOWNERS_GITHUB_PAT }}" diff --git a/.github/workflows/rerun_codeowners.yml b/.github/workflows/rerun_codeowners.yml new file mode 100644 index 000000000000..3aa3f6817756 --- /dev/null +++ b/.github/workflows/rerun_codeowners.yml @@ -0,0 +1,40 @@ +name: "Rerun Code Owners" + +# When a review is submitted or dismissed, re-run the Code Owners check from the +# main codeowners.yml workflow (triggered by pull_request_target) rather than +# running the check again under a separate event context. This avoids duplicate +# "Code Owners" checks appearing on the PR. +on: + pull_request_review: + types: [submitted, dismissed] + +permissions: {} + +jobs: + rerun-codeowners: + name: "Rerun Codeowners Plus" + runs-on: ubuntu-latest + permissions: + actions: write + checks: read + steps: + - name: "Re-run Codeowners Check" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + REPO: ${{ github.repository }} + run: | + # Find the "Run Codeowners Plus" check run for the PR head commit + # and extract the Actions job ID from its details URL. + # Using the commit SHA (not branch name) so this works for fork PRs + # where the branch name doesn't exist in the base repository. + job_id=$(gh api "repos/${REPO}/commits/${HEAD_SHA}/check-runs?check_name=Run+Codeowners+Plus" \ + --jq '(.check_runs[0].details_url // "") | split("/") | last | split("?") | first' || true) + + if [ -n "$job_id" ]; then + gh api "repos/${REPO}/actions/jobs/${job_id}/rerun" --method POST \ + || echo "Job may already be running" + echo "Re-triggered 'Run Codeowners Plus' (job ${job_id})" + else + echo "Check 'Run Codeowners Plus' not found for SHA ${HEAD_SHA}" + fi diff --git a/packages/workers-shared/router-worker/src/analytics.ts b/packages/workers-shared/router-worker/src/analytics.ts index c8f82b9c7737..c5535bbe9286 100644 --- a/packages/workers-shared/router-worker/src/analytics.ts +++ b/packages/workers-shared/router-worker/src/analytics.ts @@ -55,6 +55,8 @@ type Data = { abuseMitigationURLHost?: string; // blob7 - XSS detection href parameter value xssDetectionImageHref?: string; + // blob8 - cdn-cgi backslash bypass attempt URL + cdnCgiBackslashBypassUrl?: string; }; export class Analytics { @@ -110,6 +112,7 @@ export class Analytics { this.data.coloRegion, // blob5 this.data.abuseMitigationURLHost, // blob6 this.data.xssDetectionImageHref, // blob7 + this.data.cdnCgiBackslashBypassUrl?.substring(0, 256), // blob8 - trim to 256 bytes ], }); } diff --git a/packages/workers-shared/router-worker/src/worker.ts b/packages/workers-shared/router-worker/src/worker.ts index c6fd766ce2ad..5461dd3717b1 100644 --- a/packages/workers-shared/router-worker/src/worker.ts +++ b/packages/workers-shared/router-worker/src/worker.ts @@ -1,5 +1,6 @@ import { generateStaticRoutingRuleMatcher } from "../../asset-worker/src/utils/rules-engine"; import { PerformanceTimer } from "../../utils/performance"; +import { TemporaryRedirectResponse } from "../../utils/responses"; import { setupSentry } from "../../utils/sentry"; import { mockJaegerBinding } from "../../utils/tracing"; import { Analytics, DISPATCH_TYPE, STATIC_ROUTING_DECISION } from "./analytics"; @@ -82,6 +83,19 @@ export default { }); } + // Handle /cdn-cgi\... backslash bypass attempts + // - in production if pathname starts with `/cdn-cgi/` then it bypassed the external + // routing and so must have actually started with `/cdn-cgi\`. + // - in local dev it is possible for pathname to start with `/cdn-cgi/` + // even if it doesn't start with `/cdn-cgi\` so we also check the raw URL for that. + if ( + url.pathname.startsWith("/cdn-cgi/") && + request.url.includes("/cdn-cgi\\") + ) { + analytics.setData({ cdnCgiBackslashBypassUrl: request.url }); + return new TemporaryRedirectResponse(url.href); + } + const routeToUserWorker = async ({ asset, }: { diff --git a/packages/workers-shared/router-worker/tests/index.test.ts b/packages/workers-shared/router-worker/tests/index.test.ts index 8a3bca146cd8..bd075200abd2 100644 --- a/packages/workers-shared/router-worker/tests/index.test.ts +++ b/packages/workers-shared/router-worker/tests/index.test.ts @@ -509,6 +509,323 @@ describe("unit tests", async () => { ); }); + describe( + String.raw`Handling /cdn-cgi\ backslash bypass with redirect`, + () => { + const backslashBypassCases = [ + { + description: String.raw`/cdn-cgi\image bypass`, + rawUrl: String.raw`https://example.com/cdn-cgi\image/q=75/https://evil.com/ssrf`, + expectedLocation: + "https://example.com/cdn-cgi/image/q=75/https://evil.com/ssrf", + }, + { + description: String.raw`/cdn-cgi\_next_cache bypass`, + rawUrl: String.raw`https://example.com/cdn-cgi\_next_cache/some-data`, + expectedLocation: "https://example.com/cdn-cgi/_next_cache/some-data", + }, + { + description: String.raw`/cdn-cgi\ bypass with arbitrary subpath`, + rawUrl: String.raw`https://example.com/cdn-cgi\something-else/path`, + expectedLocation: "https://example.com/cdn-cgi/something-else/path", + }, + { + description: String.raw`/cdn-cgi\ bypass with query params`, + rawUrl: String.raw`https://example.com/cdn-cgi\something-else/path?value=/cdn-cgi/param=foo`, + expectedLocation: + "https://example.com/cdn-cgi/something-else/path?value=/cdn-cgi/param=foo", + }, + ]; + + it.for(backslashBypassCases)( + "redirects $description to normalized URL when invoke_user_worker_ahead_of_assets is true", + async ({ rawUrl, expectedLocation }, { expect }) => { + const request = new Request(rawUrl); + // In production, raw backslashes in URLs are preserved in request.url by the + // Workers runtime. The Request constructor normalizes backslashes to forward + // slashes, so we use Object.defineProperty to simulate production behavior. + Object.defineProperty(request, "url", { + value: rawUrl, + configurable: true, + }); + + const env = { + CONFIG: { + has_user_worker: true, + invoke_user_worker_ahead_of_assets: true, + }, + USER_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach user worker as it should be redirected by the router worker" + ); + }, + }, + } as Env; + const ctx = createExecutionContext(); + + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(307); + expect(response.headers.get("Location")).toBe(expectedLocation); + } + ); + + it.for(backslashBypassCases)( + "redirects $description to normalized URL when invoke_user_worker_ahead_of_assets is false and no asset matches", + async ({ rawUrl, expectedLocation }, { expect }) => { + const request = new Request(rawUrl); + Object.defineProperty(request, "url", { + value: rawUrl, + configurable: true, + }); + + const env = { + CONFIG: { + has_user_worker: true, + invoke_user_worker_ahead_of_assets: false, + }, + USER_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach user worker as it should be redirected by the router worker" + ); + }, + }, + ASSET_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach asset worker as it should be redirected by the router worker" + ); + }, + async unstable_canFetch(_: Request): Promise { + return false; + }, + }, + } as Env; + const ctx = createExecutionContext(); + + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(307); + expect(response.headers.get("Location")).toBe(expectedLocation); + } + ); + + it.for(backslashBypassCases)( + "redirects $description to normalized URL when invoke_user_worker_ahead_of_assets is false even if asset exists", + async ({ rawUrl, expectedLocation }, { expect }) => { + const request = new Request(rawUrl); + Object.defineProperty(request, "url", { + value: rawUrl, + configurable: true, + }); + + const env = { + CONFIG: { + has_user_worker: true, + invoke_user_worker_ahead_of_assets: false, + }, + USER_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach user worker as it should be redirected by the router worker" + ); + }, + }, + ASSET_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach asset worker as it should be redirected by the router worker" + ); + }, + async unstable_canFetch(_: Request): Promise { + return true; + }, + }, + } as Env; + const ctx = createExecutionContext(); + + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(307); + expect(response.headers.get("Location")).toBe(expectedLocation); + } + ); + + const nonInterferenceCases = [ + { + description: + "does not interfere with legitimate /cdn-cgi/ forward-slash requests", + url: "https://example.com/cdn-cgi/image/q=75/https://other.com/image.jpg", + userWorkerResponse: { + body: "image data", + headers: { "content-type": "image/jpeg" }, + status: 200, + }, + expectedStatus: 200, + expectedBody: "image data", + }, + { + description: + "does not interfere with escaped backslashes /cdn-cgi%5C requests", + url: "https://example.com/cdn-cgi%5Cimage/q=75/https://other.com/image.jpg", + userWorkerResponse: { + body: "image data", + headers: { "content-type": "image/jpeg" }, + status: 200, + }, + expectedStatus: 200, + expectedBody: "image data", + }, + { + description: + "does not interfere with escaped forward slashes /cdn-cgi%2F requests", + url: "https://example.com/cdn-cgi%2Fimage/q=75/https://other.com/image.jpg", + userWorkerResponse: { + body: "image data", + headers: { "content-type": "image/jpeg" }, + status: 200, + }, + expectedStatus: 200, + expectedBody: "image data", + }, + { + description: "does not interfere with non-cdn-cgi requests", + url: "https://example.com/some-page", + userWorkerResponse: { + body: "page content", + headers: { "content-type": "text/html" }, + status: 200, + }, + expectedStatus: 200, + expectedBody: "page content", + }, + { + description: String.raw`does not interfere with non-cdn-cgi requests with /cdn-cgi\ in query`, + url: String.raw`https://example.com/some-page?redirect=/cdn-cgi\image`, + userWorkerResponse: { + body: "page content", + headers: { "content-type": "text/html" }, + status: 200, + }, + expectedStatus: 200, + expectedBody: "page content", + }, + ]; + + it.for(nonInterferenceCases)( + "$description when invoke_user_worker_ahead_of_assets is true", + async ( + { url, userWorkerResponse, expectedStatus, expectedBody }, + { expect } + ) => { + const request = new Request(url); + const env = { + CONFIG: { + has_user_worker: true, + invoke_user_worker_ahead_of_assets: true, + }, + USER_WORKER: { + async fetch(userWorkerRequest: Request): Promise { + const response = new Response(userWorkerResponse.body, { + status: userWorkerResponse.status, + headers: userWorkerResponse.headers, + }); + Object.defineProperty(response, "url", { + value: userWorkerRequest.url, + configurable: true, + }); + return response; + }, + }, + } as Env; + const ctx = createExecutionContext(); + + const response = await worker.fetch(request, env, ctx); + expect(response.url).toBe(url); + expect(response.status).toBe(expectedStatus); + expect(await response.text()).toBe(expectedBody); + } + ); + + it.for(nonInterferenceCases)( + "$description when invoke_user_worker_ahead_of_assets is false and no asset matches", + async ( + { url, userWorkerResponse, expectedStatus, expectedBody }, + { expect } + ) => { + const request = new Request(url); + const env = { + CONFIG: { + has_user_worker: true, + invoke_user_worker_ahead_of_assets: false, + }, + USER_WORKER: { + async fetch(userWorkerRequest: Request): Promise { + const response = new Response(userWorkerResponse.body, { + status: userWorkerResponse.status, + headers: userWorkerResponse.headers, + }); + Object.defineProperty(response, "url", { + value: userWorkerRequest.url, + configurable: true, + }); + return response; + }, + }, + ASSET_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach asset worker as asset does not exist" + ); + }, + async unstable_canFetch(_: Request): Promise { + return false; + }, + }, + } as Env; + const ctx = createExecutionContext(); + + const response = await worker.fetch(request, env, ctx); + expect(response.url).toBe(url); + expect(response.status).toBe(expectedStatus); + expect(await response.text()).toBe(expectedBody); + } + ); + + it.for(nonInterferenceCases)( + "$description when invoke_user_worker_ahead_of_assets is false even if asset exists", + async ({ url }, { expect }) => { + const request = new Request(url); + const env = { + CONFIG: { + has_user_worker: true, + invoke_user_worker_ahead_of_assets: false, + }, + USER_WORKER: { + async fetch(_: Request): Promise { + return new Response( + "should not reach user worker as it should be handled by asset worker" + ); + }, + }, + ASSET_WORKER: { + async fetch(_: Request): Promise { + return new Response("hello from asset worker"); + }, + async unstable_canFetch(_: Request): Promise { + return true; + }, + }, + } as Env; + const ctx = createExecutionContext(); + + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + expect(await response.text()).toBe("hello from asset worker"); + } + ); + } + ); + describe("free tier limiting", () => { it("returns fetch from asset worker for assets", async ({ expect }) => { const request = new Request("https://example.com/asset"); diff --git a/packages/workflows-shared/src/context.ts b/packages/workflows-shared/src/context.ts index 71a2460ac422..ab2a145d9119 100644 --- a/packages/workflows-shared/src/context.ts +++ b/packages/workflows-shared/src/context.ts @@ -42,6 +42,10 @@ export type StepState = { attemptedCount: number; }; +export type WorkflowStepContext = { + attempt: number; +}; + export class Context extends RpcTarget { #engine: Engine; #state: DurableObjectState; @@ -64,25 +68,30 @@ export class Context extends RpcTarget { return val; } - do(name: string, callback: () => Promise): Promise; + do( + name: string, + callback: (ctx: WorkflowStepContext) => Promise + ): Promise; do( name: string, config: WorkflowStepConfig, - callback: () => Promise + callback: (ctx: WorkflowStepContext) => Promise ): Promise; async do( name: string, - configOrCallback: WorkflowStepConfig | (() => Promise), - callback?: () => Promise + configOrCallback: + | WorkflowStepConfig + | ((ctx: WorkflowStepContext) => Promise), + callback?: (ctx: WorkflowStepContext) => Promise ): Promise { - let closure, stepConfig; + let closure: (ctx: WorkflowStepContext) => Promise, stepConfig; // If a user passes in a config, we'd like it to be the second arg so the callback is always last if (callback) { closure = callback; stepConfig = configOrCallback as WorkflowStepConfig; } else { - closure = configOrCallback as () => Promise; + closure = configOrCallback as (ctx: WorkflowStepContext) => Promise; stepConfig = {}; } @@ -126,7 +135,11 @@ export class Context extends RpcTarget { const stepNameWithCounter = `${name}-${count}`; const stepStateKey = `${cacheKey}-metadata`; - const maybeMap = await this.#state.storage.get([valueKey, configKey]); + const maybeMap = await this.#state.storage.get([ + valueKey, + configKey, + errorKey, + ]); // Check cache const maybeResult = maybeMap.get(valueKey); @@ -202,7 +215,7 @@ export class Context extends RpcTarget { } const doWrapper = async ( - doWrapperClosure: () => Promise + doWrapperClosure: (ctx: WorkflowStepContext) => Promise ): Promise => { const stepState = ((await this.#state.storage.get( stepStateKey @@ -324,7 +337,10 @@ export class Context extends RpcTarget { await this.#state.storage.delete(`replace-result-${valueKey}`); // if there is a timeout to be forced we dont want to race with closure } else { - result = await Promise.race([doWrapperClosure(), timeoutPromise()]); + result = await Promise.race([ + doWrapperClosure({ attempt: stepState.attemptedCount }), + timeoutPromise(), + ]); } // if we reach here, means that the clouse ran successfully and we can remove the timeout from the PQ diff --git a/packages/workflows-shared/tests/binding.test.ts b/packages/workflows-shared/tests/binding.test.ts index e1e76fc6127e..852440783f56 100644 --- a/packages/workflows-shared/tests/binding.test.ts +++ b/packages/workflows-shared/tests/binding.test.ts @@ -5,33 +5,7 @@ import { } from "cloudflare:test"; import { describe, it, vi } from "vitest"; import { WorkflowBinding } from "../src/binding"; -import type { Engine } from "../src/engine"; -import type { ProvidedEnv } from "cloudflare:test"; -import type { WorkflowEvent, WorkflowStep } from "cloudflare:workers"; - -async function setWorkflowEntrypoint( - stub: DurableObjectStub, - callback: (event: unknown, step: WorkflowStep) => Promise -) { - const ctx = createExecutionContext(); - await runInDurableObject(stub, (instance) => { - // @ts-expect-error this is only a stub for WorkflowEntrypoint - instance.env.USER_WORKFLOW = new (class { - constructor( - // eslint-disable-next-line @typescript-eslint/no-shadow - protected ctx: ExecutionContext, - // eslint-disable-next-line @typescript-eslint/no-shadow - protected env: ProvidedEnv - ) {} - public async run( - event: Readonly>, - step: WorkflowStep - ): Promise { - return await callback(event, step); - } - })(ctx, env); - }); -} +import { setWorkflowEntrypoint } from "./utils"; describe("WorkflowBinding", () => { it("should not call dispose when sending an event to an instance", async ({ diff --git a/packages/workflows-shared/tests/context.test.ts b/packages/workflows-shared/tests/context.test.ts new file mode 100644 index 000000000000..01a4c5b1f44f --- /dev/null +++ b/packages/workflows-shared/tests/context.test.ts @@ -0,0 +1,48 @@ +import { describe, it } from "vitest"; +import { runWorkflow } from "./utils"; + +describe("Context", () => { + it("should provide attempt count 1 on first successful attempt", async ({ + expect, + }) => { + let receivedAttempt: number | undefined; + + await runWorkflow("MOCK-INSTANCE-ID", async (_event, step) => { + // TODO: remove after types are updated + // @ts-expect-error WorkflowStep types + const result = await step.do("a successful step", async (ctx) => { + receivedAttempt = ctx.attempt; + return "success"; + }); + return result; + }); + + expect(receivedAttempt).toBe(1); + }); + + it("should provide attempt count to callback", async ({ expect }) => { + const receivedAttempts: number[] = []; + + await runWorkflow("MOCK-INSTANCE-ID", async (_event, step) => { + const result = await step.do( + "retrying step", + { + retries: { + limit: 2, + delay: 0, + }, + }, + // TODO: remove after types are updated + // @ts-expect-error WorkflowStep types + async (ctx) => { + receivedAttempts.push(ctx.attempt); + throw new Error(`Throwing`); + } + ); + return result; + }); + + // Should have received attempts 1, 2, and 3 + expect(receivedAttempts).toEqual([1, 2, 3]); + }); +}); diff --git a/packages/workflows-shared/tests/engine.test.ts b/packages/workflows-shared/tests/engine.test.ts index ef93206ff449..cadc57342eeb 100644 --- a/packages/workflows-shared/tests/engine.test.ts +++ b/packages/workflows-shared/tests/engine.test.ts @@ -1,84 +1,14 @@ -import { - createExecutionContext, - env, - runInDurableObject, -} from "cloudflare:test"; +import { env, runInDurableObject } from "cloudflare:test"; import { NonRetryableError } from "cloudflare:workflows"; import { describe, it, vi } from "vitest"; import { DEFAULT_STEP_LIMIT, InstanceEvent, InstanceStatus } from "../src"; +import { runWorkflow, runWorkflowDefer, setWorkflowEntrypoint } from "./utils"; import type { DatabaseInstance, DatabaseVersion, DatabaseWorkflow, - Engine, EngineLogs, } from "../src/engine"; -import type { ProvidedEnv } from "cloudflare:test"; -import type { WorkflowEvent, WorkflowStep } from "cloudflare:workers"; - -async function setWorkflowEntrypoint( - stub: DurableObjectStub, - callback: (event: unknown, step: WorkflowStep) => Promise -) { - const ctx = createExecutionContext(); - await runInDurableObject(stub, (instance) => { - // @ts-expect-error this is only a stub for WorkflowEntrypoint - instance.env.USER_WORKFLOW = new (class { - constructor( - // eslint-disable-next-line @typescript-eslint/no-shadow - protected ctx: ExecutionContext, - // eslint-disable-next-line @typescript-eslint/no-shadow - protected env: ProvidedEnv - ) {} - public async run( - event: Readonly>, - step: WorkflowStep - ): Promise { - return await callback(event, step); - } - })(ctx, env); - }); -} - -async function runWorkflow( - instanceId: string, - callback: (event: unknown, step: WorkflowStep) => Promise -): Promise> { - const engineId = env.ENGINE.idFromName(instanceId); - const engineStub = env.ENGINE.get(engineId); - - await setWorkflowEntrypoint(engineStub, callback); - - await engineStub.init( - 12346, - {} as DatabaseWorkflow, - {} as DatabaseVersion, - {} as DatabaseInstance, - { payload: {}, timestamp: new Date(), instanceId: "some-instance-id" } - ); - - return engineStub; -} - -async function runWorkflowDefer( - instanceId: string, - callback: (event: unknown, step: WorkflowStep) => Promise -): Promise> { - const engineId = env.ENGINE.idFromName(instanceId); - const engineStub = env.ENGINE.get(engineId); - - await setWorkflowEntrypoint(engineStub, callback); - - void engineStub.init( - 12346, - {} as DatabaseWorkflow, - {} as DatabaseVersion, - {} as DatabaseInstance, - { payload: {}, timestamp: new Date(), instanceId: "some-instance-id" } - ); - - return engineStub; -} describe("Engine", () => { it("should not retry after NonRetryableError is thrown", async ({ diff --git a/packages/workflows-shared/tests/utils.ts b/packages/workflows-shared/tests/utils.ts new file mode 100644 index 000000000000..4e58933d5b0c --- /dev/null +++ b/packages/workflows-shared/tests/utils.ts @@ -0,0 +1,77 @@ +import { + createExecutionContext, + env, + runInDurableObject, +} from "cloudflare:test"; +import type { + DatabaseInstance, + DatabaseVersion, + DatabaseWorkflow, + Engine, +} from "../src/engine"; +import type { ProvidedEnv } from "cloudflare:test"; +import type { WorkflowEvent, WorkflowStep } from "cloudflare:workers"; + +export async function setWorkflowEntrypoint( + stub: DurableObjectStub, + callback: (event: unknown, step: WorkflowStep) => Promise +) { + const ctx = createExecutionContext(); + await runInDurableObject(stub, (instance) => { + // @ts-expect-error this is only a stub for WorkflowEntrypoint + instance.env.USER_WORKFLOW = new (class { + constructor( + // eslint-disable-next-line @typescript-eslint/no-shadow + protected ctx: ExecutionContext, + // eslint-disable-next-line @typescript-eslint/no-shadow + protected env: ProvidedEnv + ) {} + public async run( + event: Readonly>, + step: WorkflowStep + ): Promise { + return await callback(event, step); + } + })(ctx, env); + }); +} + +export async function runWorkflow( + instanceId: string, + callback: (event: unknown, step: WorkflowStep) => Promise +): Promise> { + const engineId = env.ENGINE.idFromName(instanceId); + const engineStub = env.ENGINE.get(engineId); + + await setWorkflowEntrypoint(engineStub, callback); + + await engineStub.init( + 12346, + {} as DatabaseWorkflow, + {} as DatabaseVersion, + {} as DatabaseInstance, + { payload: {}, timestamp: new Date(), instanceId: "some-instance-id" } + ); + + return engineStub; +} + +export async function runWorkflowDefer( + instanceId: string, + callback: (event: unknown, step: WorkflowStep) => Promise +): Promise> { + const engineId = env.ENGINE.idFromName(instanceId); + const engineStub = env.ENGINE.get(engineId); + + await setWorkflowEntrypoint(engineStub, callback); + + void engineStub.init( + 12346, + {} as DatabaseWorkflow, + {} as DatabaseVersion, + {} as DatabaseInstance, + { payload: {}, timestamp: new Date(), instanceId: "some-instance-id" } + ); + + return engineStub; +} diff --git a/packages/wrangler/src/__tests__/hyperdrive.test.ts b/packages/wrangler/src/__tests__/hyperdrive.test.ts index ab4580419813..95b62c665847 100644 --- a/packages/wrangler/src/__tests__/hyperdrive.test.ts +++ b/packages/wrangler/src/__tests__/hyperdrive.test.ts @@ -829,65 +829,353 @@ describe("hyperdrive commands", () => { `); }); - it("should error on create hyperdrive with mtls config sslmode=require and CA flag set", async ({ + it("should allow create hyperdrive with mtls config sslmode=require and CA flag set", async ({ expect, }) => { - await expect(() => - runWrangler( - "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --ca-certificate-id=1234 --sslmode=require" - ) - ).rejects.toThrow(); - expect(std.err).toMatchInlineSnapshot(` - "X [ERROR] CA not allowed when sslmode = 'require' is set + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --ca-certificate-id=1234 --sslmode=require" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "1234", + "sslmode": "require", + }, + "name": "test123", + "origin": { + "database": "neondb", + "host": "example.com", + "password": "password", + "port": 1234, + "scheme": "postgresql", + "user": "test", + }, + } + `); + }); - " + it("should allow create hyperdrive with mtls config sslmode=verify-ca missing CA", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --mtls-certificate-id=1234 --sslmode=verify-ca" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "1234", + "sslmode": "verify-ca", + }, + "name": "test123", + "origin": { + "database": "neondb", + "host": "example.com", + "password": "password", + "port": 1234, + "scheme": "postgresql", + "user": "test", + }, + } `); }); - it("should error on create hyperdrive with mtls config sslmode=verify-ca missing CA", async ({ + it("should allow create hyperdrive with mtls config sslmode=verify-full missing CA", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --mtls-certificate-id=1234 --sslmode=verify-full" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "1234", + "sslmode": "verify-full", + }, + "name": "test123", + "origin": { + "database": "neondb", + "host": "example.com", + "password": "password", + "port": 1234, + "scheme": "postgresql", + "user": "test", + }, + } + `); + }); + + it("should error on create hyperdrive with invalid sslmode", async ({ expect, }) => { await expect(() => runWrangler( - "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --mtls-certificate-id=1234 --sslmode=verify-ca" + "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --mtls-certificate-id=1234 --sslmode=random" ) ).rejects.toThrow(); expect(std.err).toMatchInlineSnapshot(` - "X [ERROR] CA required when sslmode = 'verify-ca' or 'verify-full' is set + "X [ERROR] Invalid values: + + Argument: sslmode, Given: "random", Choices: "require", "verify-ca", "verify-full", "REQUIRED", + "VERIFY_CA", "VERIFY_IDENTITY" " `); }); - it("should error on create hyperdrive with mtls config sslmode=verify-full missing CA", async ({ + it("should successfully create a MySQL hyperdrive with mtls config and sslmode=VERIFY_IDENTITY", async ({ expect, }) => { - await expect(() => - runWrangler( - "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --mtls-certificate-id=1234 --sslmode=verify-full" - ) - ).rejects.toThrow(); - expect(std.err).toMatchInlineSnapshot(` - "X [ERROR] CA required when sslmode = 'verify-ca' or 'verify-full' is set + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --ca-certificate-id=12345 --mtls-certificate-id=1234 --sslmode=VERIFY_IDENTITY" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "12345", + "mtls_certificate_id": "1234", + "sslmode": "VERIFY_IDENTITY", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` + " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Creating 'test123' + ✅ Created new Hyperdrive MySQL config: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + To access your new Hyperdrive Config in your Worker, add the following snippet to your configuration file: + { + "hyperdrive": [ + { + "binding": "HYPERDRIVE", + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + } + ] + }" + `); + }); + it("should successfully create a MySQL hyperdrive with mtls config and sslmode=VERIFY_CA", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --ca-certificate-id=12345 --sslmode=VERIFY_CA" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "12345", + "sslmode": "VERIFY_CA", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Creating 'test123' + ✅ Created new Hyperdrive MySQL config: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + To access your new Hyperdrive Config in your Worker, add the following snippet to your configuration file: + { + "hyperdrive": [ + { + "binding": "HYPERDRIVE", + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + } + ] + }" `); }); - it("should error on create hyperdrive with mtls config sslmode=random", async ({ + it("should successfully create a MySQL hyperdrive with sslmode=REQUIRED", async ({ expect, }) => { - await expect(() => - runWrangler( - "hyperdrive create test123 --host=example.com --database=neondb --user=test --password=password --port=1234 --mtls-certificate-id=1234 --sslmode=random" - ) - ).rejects.toThrow(); - expect(std.err).toMatchInlineSnapshot(` - "X [ERROR] Invalid values: + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --mtls-certificate-id=1234 --sslmode=REQUIRED" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "1234", + "sslmode": "REQUIRED", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` + " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Creating 'test123' + ✅ Created new Hyperdrive MySQL config: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + To access your new Hyperdrive Config in your Worker, add the following snippet to your configuration file: + { + "hyperdrive": [ + { + "binding": "HYPERDRIVE", + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + } + ] + }" + `); + }); - Argument: sslmode, Given: "random", Choices: "require", "verify-ca", "verify-full" + it("should accept MySQL sslmode in lowercase", async ({ expect }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --mtls-certificate-id=1234 --sslmode=required" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "1234", + "sslmode": "REQUIRED", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + }); - " + it("should allow create MySQL hyperdrive with sslmode=REQUIRED and CA flag set", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --ca-certificate-id=1234 --sslmode=REQUIRED" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "1234", + "sslmode": "REQUIRED", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + }); + + it("should allow create MySQL hyperdrive with sslmode=VERIFY_CA missing CA", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --mtls-certificate-id=1234 --sslmode=VERIFY_CA" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "1234", + "sslmode": "VERIFY_CA", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + }); + + it("should allow create MySQL hyperdrive with sslmode=VERIFY_IDENTITY missing CA", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --mtls-certificate-id=1234 --sslmode=VERIFY_IDENTITY" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "1234", + "sslmode": "VERIFY_IDENTITY", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } + `); + }); + + it("should allow create MySQL hyperdrive with PostgreSQL sslmode value", async ({ + expect, + }) => { + const reqProm = mockHyperdriveCreate(); + await runWrangler( + "hyperdrive create test123 --host=example.com --database=mydb --user=test --password=password --port=3306 --origin-scheme=mysql --sslmode=verify-full" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "sslmode": "verify-full", + }, + "name": "test123", + "origin": { + "database": "mydb", + "host": "example.com", + "password": "password", + "port": 3306, + "scheme": "mysql", + "user": "test", + }, + } `); }); @@ -1358,6 +1646,166 @@ describe("hyperdrive commands", () => { }" `); }); + + it("should handle updating a PostgreSQL hyperdrive config's SSL settings without re-specifying origin (verify-ca)", async ({ + expect, + }) => { + const reqProm = mockHyperdriveUpdate(); + await runWrangler( + "hyperdrive update xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx --sslmode=verify-ca --ca-certificate-id=abc123" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "abc123", + "sslmode": "verify-ca", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` + " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Updating 'xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx' + ✅ Updated xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx Hyperdrive config + { + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "test123", + "origin": { + "scheme": "postgresql", + "host": "example.com", + "port": 5432, + "database": "neondb", + "user": "test" + }, + "mtls": { + "ca_certificate_id": "abc123", + "sslmode": "verify-ca" + }, + "origin_connection_limit": 25 + }" + `); + }); + + it("should handle updating a MySQL hyperdrive config's SSL settings without re-specifying origin (VERIFY_CA)", async ({ + expect, + }) => { + const reqProm = mockHyperdriveUpdate(defaultMysqlConfig); + await runWrangler( + "hyperdrive update xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx --sslmode=VERIFY_CA --ca-certificate-id=abc123" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "abc123", + "sslmode": "VERIFY_CA", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` + " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Updating 'xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx' + ✅ Updated xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx Hyperdrive config + { + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "test-mysql", + "origin": { + "scheme": "mysql", + "host": "mysql.example.com", + "port": 3306, + "database": "mydb", + "user": "test" + }, + "mtls": { + "ca_certificate_id": "abc123", + "sslmode": "VERIFY_CA" + }, + "origin_connection_limit": 25 + }" + `); + }); + + it("should handle updating a MySQL hyperdrive config's SSL settings without re-specifying origin (VERIFY_IDENTITY)", async ({ + expect, + }) => { + const reqProm = mockHyperdriveUpdate(defaultMysqlConfig); + await runWrangler( + "hyperdrive update xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx --sslmode=VERIFY_IDENTITY --ca-certificate-id=abc123" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "ca_certificate_id": "abc123", + "sslmode": "VERIFY_IDENTITY", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` + " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Updating 'xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx' + ✅ Updated xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx Hyperdrive config + { + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "test-mysql", + "origin": { + "scheme": "mysql", + "host": "mysql.example.com", + "port": 3306, + "database": "mydb", + "user": "test" + }, + "mtls": { + "ca_certificate_id": "abc123", + "sslmode": "VERIFY_IDENTITY" + }, + "origin_connection_limit": 25 + }" + `); + }); + + it("should handle updating a MySQL hyperdrive config's SSL settings without re-specifying origin (REQUIRED)", async ({ + expect, + }) => { + const reqProm = mockHyperdriveUpdate(defaultMysqlConfig); + await runWrangler( + "hyperdrive update xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx --sslmode=REQUIRED --mtls-certificate-id=cert123" + ); + await expect(reqProm).resolves.toMatchInlineSnapshot(` + { + "mtls": { + "mtls_certificate_id": "cert123", + "sslmode": "REQUIRED", + }, + } + `); + expect(std.out).toMatchInlineSnapshot(` + " + ⛅️ wrangler x.x.x + ────────────────── + 🚧 Updating 'xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx' + ✅ Updated xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx Hyperdrive config + { + "id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "test-mysql", + "origin": { + "scheme": "mysql", + "host": "mysql.example.com", + "port": 3306, + "database": "mydb", + "user": "test" + }, + "mtls": { + "mtls_certificate_id": "cert123", + "sslmode": "REQUIRED" + }, + "origin_connection_limit": 25 + }" + `); + }); }); const defaultConfig: HyperdriveConfig = { @@ -1373,6 +1821,19 @@ const defaultConfig: HyperdriveConfig = { origin_connection_limit: 25, }; +const defaultMysqlConfig: HyperdriveConfig = { + id: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + name: "test-mysql", + origin: { + scheme: "mysql", + host: "mysql.example.com", + port: 3306, + database: "mydb", + user: "test", + }, + origin_connection_limit: 25, +}; + /** Create a mock handler for Hyperdrive API */ function mockHyperdriveGetListOrDelete() { msw.use( @@ -1438,13 +1899,16 @@ function mockHyperdriveGetListOrDelete() { } /** Create a mock handler for Hyperdrive API */ -function mockHyperdriveUpdate(): Promise { +function mockHyperdriveUpdate( + configOverride?: HyperdriveConfig +): Promise { + const mockConfig = configOverride ?? defaultConfig; return new Promise((resolve) => { msw.use( http.get( "*/accounts/:accountId/hyperdrive/configs/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", () => { - return HttpResponse.json(createFetchResult(defaultConfig, true)); + return HttpResponse.json(createFetchResult(mockConfig, true)); }, { once: true } ), @@ -1455,7 +1919,7 @@ function mockHyperdriveUpdate(): Promise { resolve(reqBody); - let origin = defaultConfig.origin; + let origin = mockConfig.origin; if (reqBody.origin) { const { password: _, @@ -1474,7 +1938,7 @@ function mockHyperdriveUpdate(): Promise { delete origin.port; } } - const mtls = defaultConfig.mtls; + const mtls = mockConfig.mtls; if (mtls && reqBody.mtls) { mtls.ca_certificate_id = reqBody.mtls.ca_certificate_id; mtls.mtls_certificate_id = reqBody.mtls.mtls_certificate_id; @@ -1484,13 +1948,13 @@ function mockHyperdriveUpdate(): Promise { createFetchResult( { id: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - name: reqBody.name ?? defaultConfig.name, + name: reqBody.name ?? mockConfig.name, origin, - caching: reqBody.caching ?? defaultConfig.caching, + caching: reqBody.caching ?? mockConfig.caching, mtls: reqBody.mtls, origin_connection_limit: reqBody.origin_connection_limit ?? - defaultConfig.origin_connection_limit, + mockConfig.origin_connection_limit, }, true ) diff --git a/packages/wrangler/src/hyperdrive/client.ts b/packages/wrangler/src/hyperdrive/client.ts index c5adbf456a3f..35388c1d4749 100644 --- a/packages/wrangler/src/hyperdrive/client.ts +++ b/packages/wrangler/src/hyperdrive/client.ts @@ -93,7 +93,8 @@ export type Mtls = { sslmode?: string; }; -export const Sslmode = ["require", "verify-ca", "verify-full"]; +export const PostgresSslmode = ["require", "verify-ca", "verify-full"]; +export const MySqlSslmode = ["REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY"]; export async function createConfig( config: Config, diff --git a/packages/wrangler/src/hyperdrive/index.ts b/packages/wrangler/src/hyperdrive/index.ts index 350aa3423c46..cb4e126ddc29 100644 --- a/packages/wrangler/src/hyperdrive/index.ts +++ b/packages/wrangler/src/hyperdrive/index.ts @@ -1,5 +1,6 @@ import { UserError } from "@cloudflare/workers-utils"; import { createNamespace } from "../core/create-command"; +import { MySqlSslmode, PostgresSslmode } from "./client"; import type { CachingOptions, Mtls, @@ -20,6 +21,14 @@ export const hyperdriveNamespace = createNamespace({ }, }); +function normalizeMysqlSslmode(sslmode: string): string { + const mysqlSslmode = MySqlSslmode.find( + (mode) => mode.toLowerCase() === sslmode.toLowerCase() + ); + + return mysqlSslmode ?? sslmode; +} + export const upsertOptions = ( defaultOriginScheme: string | undefined = undefined ) => @@ -130,8 +139,9 @@ export const upsertOptions = ( }, sslmode: { type: "string", - choices: ["require", "verify-ca", "verify-full"], - description: "Sets CA sslmode for connecting to database.", + coerce: normalizeMysqlSslmode, + choices: [...PostgresSslmode, ...MySqlSslmode], + description: `Sets sslmode for connecting to database. For PostgreSQL: '${PostgresSslmode.join(", ")}'. For MySQL: '${MySqlSslmode.join(", ")}'.`, }, "origin-connection-limit": { type: "number", @@ -157,10 +167,10 @@ export function getOriginFromArgs< if ( url.port === "" && - (url.protocol == "postgresql:" || url.protocol == "postgres:") + (url.protocol == "postgresql:" || url.protocol === "postgres:") ) { url.port = "5432"; - } else if (url.port === "" && url.protocol == "mysql:") { + } else if (url.port === "" && url.protocol === "mysql:") { url.port = "3306"; } @@ -243,7 +253,7 @@ export function getOriginFromArgs< ); } - if (!args.originHost || args.originHost == "") { + if (!args.originHost || args.originHost === "") { throw new UserError( "You must provide an origin hostname for the database" ); @@ -315,22 +325,21 @@ export function getMtlsFromArgs( const mtls = { ca_certificate_id: args.caCertificateId, mtls_certificate_id: args.mtlsCertificateId, - sslmode: args.sslmode, + sslmode: args.sslmode ? normalizeMysqlSslmode(args.sslmode) : undefined, }; if (JSON.stringify(mtls) === "{}") { return undefined; } else { - if (mtls.sslmode == "require" && mtls.ca_certificate_id?.trim()) { - throw new UserError("CA not allowed when sslmode = 'require' is set"); - } - if ( - (mtls.sslmode == "verify-ca" || mtls.sslmode == "verify-full") && - !mtls.ca_certificate_id?.trim() + mtls.sslmode && + !PostgresSslmode.includes(mtls.sslmode) && + !MySqlSslmode.includes(mtls.sslmode) ) { throw new UserError( - "CA required when sslmode = 'verify-ca' or 'verify-full' is set" + `Invalid sslmode '${mtls.sslmode}'. Valid options are:\n` + + `- PostgreSQL: ${PostgresSslmode.join(", ")}\n` + + `- MySQL: ${MySqlSslmode.join(", ")}` ); } return mtls;