diff --git a/cmd/common/compile.go b/cmd/common/compile.go index 2456ec77..d500e6e9 100644 --- a/cmd/common/compile.go +++ b/cmd/common/compile.go @@ -1,6 +1,7 @@ package common import ( + "context" "errors" "fmt" "os" @@ -33,7 +34,7 @@ type WorkflowCompileOptions struct { } // getBuildCmd returns a single step that builds the workflow and returns the WASM bytes. -func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCompileOptions) (func() ([]byte, error), error) { +func getBuildCmd(ctx context.Context, workflowRootFolder, mainFile, language string, opts WorkflowCompileOptions) (func() ([]byte, error), error) { tmpPath := filepath.Join(workflowRootFolder, ".cre_build_tmp.wasm") switch language { case constants.WorkflowLanguageTypeScript: @@ -41,7 +42,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if opts.SkipTypeChecks { args = append(args, SkipTypeChecksFlag) } - cmd := exec.Command("bun", args...) + cmd := exec.CommandContext(ctx, "bun", args...) cmd.Dir = workflowRootFolder return func() ([]byte, error) { out, err := cmd.CombinedOutput() @@ -67,7 +68,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if opts.StripSymbols { ldflags = "-buildid= -w -s" } - cmd := exec.Command( + cmd := exec.CommandContext(ctx, "go", "build", "-o", tmpPath, "-trimpath", @@ -92,7 +93,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if err != nil { return nil, err } - makeCmd := exec.Command("make", "build") + makeCmd := exec.CommandContext(ctx, "make", "build") makeCmd.Dir = makeRoot builtPath := filepath.Join(makeRoot, defaultWasmOutput) return func() ([]byte, error) { @@ -108,7 +109,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if opts.StripSymbols { ldflags = "-buildid= -w -s" } - cmd := exec.Command( + cmd := exec.CommandContext(ctx, "go", "build", "-o", tmpPath, "-trimpath", @@ -135,7 +136,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom // opts.StripSymbols: for Go builds, true strips debug symbols (deploy); false keeps them (simulate). // opts.SkipTypeChecks: for TypeScript, passes SkipTypeChecksFlag to cre-compile. // For custom Makefile WASM builds, StripSymbols and SkipTypeChecks have no effect. -func CompileWorkflowToWasm(workflowPath string, opts WorkflowCompileOptions) ([]byte, error) { +func CompileWorkflowToWasm(ctx context.Context, workflowPath string, opts WorkflowCompileOptions) ([]byte, error) { workflowRootFolder, workflowMainFile, err := WorkflowPathRootAndMain(workflowPath) if err != nil { return nil, fmt.Errorf("workflow path: %w", err) @@ -167,7 +168,7 @@ func CompileWorkflowToWasm(workflowPath string, opts WorkflowCompileOptions) ([] return nil, fmt.Errorf("unsupported workflow language for file %s", workflowMainFile) } - buildStep, err := getBuildCmd(workflowRootFolder, workflowMainFile, language, opts) + buildStep, err := getBuildCmd(ctx, workflowRootFolder, workflowMainFile, language, opts) if err != nil { return nil, err } diff --git a/cmd/common/compile_test.go b/cmd/common/compile_test.go index fdc7dc3d..fa4aa42c 100644 --- a/cmd/common/compile_test.go +++ b/cmd/common/compile_test.go @@ -2,6 +2,7 @@ package common import ( "bytes" + "context" "io" "os" "os/exec" @@ -47,21 +48,21 @@ func TestFindMakefileRoot(t *testing.T) { func TestCompileWorkflowToWasm_Go_Success(t *testing.T) { t.Run("basic_workflow", func(t *testing.T) { path := deployTestdataPath("basic_workflow", "main.go") - wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) }) t.Run("configless_workflow", func(t *testing.T) { path := deployTestdataPath("configless_workflow", "main.go") - wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) }) t.Run("missing_go_mod", func(t *testing.T) { path := deployTestdataPath("missing_go_mod", "main.go") - wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) }) @@ -69,7 +70,7 @@ func TestCompileWorkflowToWasm_Go_Success(t *testing.T) { func TestCompileWorkflowToWasm_Go_Malformed_Fails(t *testing.T) { path := deployTestdataPath("malformed_workflow", "main.go") - _, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + _, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.Error(t, err) assert.Contains(t, err.Error(), "failed to compile workflow") assert.Contains(t, err.Error(), "undefined: sdk.RemovedFunctionThatFailsCompilation") @@ -80,7 +81,7 @@ func TestCompileWorkflowToWasm_Wasm_Success(t *testing.T) { _ = os.Remove(wasmPath) t.Cleanup(func() { _ = os.Remove(wasmPath) }) - wasm, err := CompileWorkflowToWasm(wasmPath, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), wasmPath, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) @@ -96,14 +97,14 @@ func TestCompileWorkflowToWasm_Wasm_Fails(t *testing.T) { wasmPath := filepath.Join(wasmDir, "workflow.wasm") require.NoError(t, os.WriteFile(wasmPath, []byte("not really wasm"), 0600)) - _, err := CompileWorkflowToWasm(wasmPath, WorkflowCompileOptions{StripSymbols: true}) + _, err := CompileWorkflowToWasm(context.Background(), wasmPath, WorkflowCompileOptions{StripSymbols: true}) require.Error(t, err) assert.Contains(t, err.Error(), "no Makefile found") }) t.Run("make_build_fails", func(t *testing.T) { path := deployTestdataPath("wasm_make_fails", "wasm", "workflow.wasm") - _, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + _, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.Error(t, err) assert.Contains(t, err.Error(), "failed to compile workflow") assert.Contains(t, err.Error(), "build output:") @@ -138,7 +139,7 @@ func TestCompileWorkflowToWasm_TS_Success(t *testing.T) { "include": ["main.ts"] } `), 0600)) - wasm, err := CompileWorkflowToWasm(mainPath, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), mainPath, WorkflowCompileOptions{StripSymbols: true}) if err != nil { t.Skipf("TS compile failed (published cre-sdk may lack full layout): %v", err) } diff --git a/cmd/common/fetch.go b/cmd/common/fetch.go index 5f8ee4f4..bc5b69e0 100644 --- a/cmd/common/fetch.go +++ b/cmd/common/fetch.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "io" "net/http" @@ -28,8 +29,13 @@ func IsURL(s string) bool { } // FetchURL performs an HTTP GET and returns the response body bytes. -func FetchURL(url string) ([]byte, error) { - resp, err := http.Get(url) //nolint:gosec,noctx +func FetchURL(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("HTTP GET %s: %w", url, err) + } + + resp, err := http.DefaultClient.Do(req) //nolint:gosec if err != nil { return nil, fmt.Errorf("HTTP GET %s: %w", url, err) } diff --git a/cmd/common/fetch_test.go b/cmd/common/fetch_test.go index 10a6ade6..0291f3dd 100644 --- a/cmd/common/fetch_test.go +++ b/cmd/common/fetch_test.go @@ -1,6 +1,7 @@ package common import ( + "context" "net/http" "net/http/httptest" "testing" @@ -42,7 +43,7 @@ func TestFetchURL(t *testing.T) { })) defer srv.Close() - data, err := FetchURL(srv.URL) + data, err := FetchURL(context.Background(), srv.URL) require.NoError(t, err) assert.Equal(t, body, data) }) @@ -53,13 +54,13 @@ func TestFetchURL(t *testing.T) { })) defer srv.Close() - _, err := FetchURL(srv.URL) + _, err := FetchURL(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "returned status 404") }) t.Run("unreachable host", func(t *testing.T) { - _, err := FetchURL("http://127.0.0.1:1") + _, err := FetchURL(context.Background(), "http://127.0.0.1:1") require.Error(t, err) }) } diff --git a/cmd/workflow/build/build.go b/cmd/workflow/build/build.go index f92f6973..63b79edd 100644 --- a/cmd/workflow/build/build.go +++ b/cmd/workflow/build/build.go @@ -1,6 +1,7 @@ package build import ( + "context" "fmt" "os" "path/filepath" @@ -26,7 +27,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { outputPath, _ := cmd.Flags().GetString("output") skipTypeChecks, _ := cmd.Flags().GetBool(cmdcommon.SkipTypeChecksCLIFlag) - return execute(args[0], outputPath, skipTypeChecks) + return execute(cmd.Context(), args[0], outputPath, skipTypeChecks) }, } buildCmd.Flags().StringP("output", "o", "", "Output file path for the compiled WASM binary (default: /binary.wasm)") @@ -34,7 +35,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { return buildCmd } -func execute(workflowFolder, outputPath string, skipTypeChecks bool) error { +func execute(ctx context.Context, workflowFolder, outputPath string, skipTypeChecks bool) error { workflowDir, err := filepath.Abs(workflowFolder) if err != nil { return fmt.Errorf("resolve workflow folder: %w", err) @@ -60,7 +61,7 @@ func execute(workflowFolder, outputPath string, skipTypeChecks bool) error { outputPath = cmdcommon.EnsureWasmExtension(outputPath) ui.Dim("Compiling workflow...") - wasmBytes, err := cmdcommon.CompileWorkflowToWasm(resolvedPath, cmdcommon.WorkflowCompileOptions{ + wasmBytes, err := cmdcommon.CompileWorkflowToWasm(ctx, resolvedPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: skipTypeChecks, }) diff --git a/cmd/workflow/deploy/artifacts.go b/cmd/workflow/deploy/artifacts.go index 070b1421..7f7d6803 100644 --- a/cmd/workflow/deploy/artifacts.go +++ b/cmd/workflow/deploy/artifacts.go @@ -9,6 +9,10 @@ import ( ) func (h *handler) uploadArtifacts() error { + if err := h.executionContext().Err(); err != nil { + return err + } + if h.workflowArtifact == nil { return fmt.Errorf("workflowArtifact is nil") } @@ -46,7 +50,7 @@ func (h *handler) uploadArtifacts() error { if !binaryFromURL { ui.Success(fmt.Sprintf("Loaded binary from: %s", h.inputs.OutputPath)) binaryResp, err := storageClient.UploadArtifactWithRetriesAndGetURL( - workflowID, storageclient.ArtifactTypeBinary, binaryData, "application/octet-stream") + h.executionContext(), workflowID, storageclient.ArtifactTypeBinary, binaryData, "application/octet-stream") if err != nil { return fmt.Errorf("uploading binary artifact: %w", err) } @@ -59,7 +63,7 @@ func (h *handler) uploadArtifacts() error { ui.Success(fmt.Sprintf("Loaded config from: %s", h.inputs.ConfigPath)) var err error configURL, err = storageClient.UploadArtifactWithRetriesAndGetURL( - workflowID, storageclient.ArtifactTypeConfig, configData, "text/plain") + h.executionContext(), workflowID, storageclient.ArtifactTypeConfig, configData, "text/plain") if err != nil { return fmt.Errorf("uploading config artifact: %w", err) } diff --git a/cmd/workflow/deploy/artifacts_test.go b/cmd/workflow/deploy/artifacts_test.go index 24833d9c..1d18a759 100644 --- a/cmd/workflow/deploy/artifacts_test.go +++ b/cmd/workflow/deploy/artifacts_test.go @@ -71,7 +71,7 @@ func TestUpload_SuccessAndErrorCases(t *testing.T) { simulatedEnvironment := chainsim.NewSimulatedEnvironment(t) ctx, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - h := newHandler(ctx, buf) + h := newTestHandler(ctx, buf) h.inputs.WorkflowOwner = chainsim.TestAddress h.inputs.WorkflowName = "test_workflow" h.inputs.DonFamily = "test_label" @@ -147,7 +147,7 @@ func TestUploadArtifactToStorageService_OriginError(t *testing.T) { simulatedEnvironment := chainsim.NewSimulatedEnvironment(t) runtimeContext, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - h := newHandler(runtimeContext, buf) + h := newTestHandler(runtimeContext, buf) h.inputs.WorkflowOwner = chainsim.TestAddress h.inputs.WorkflowName = "test_workflow" h.inputs.DonFamily = "test_label" @@ -187,7 +187,7 @@ func TestUploadArtifactToStorageService_AlreadyExistsError(t *testing.T) { simulatedEnvironment := chainsim.NewSimulatedEnvironment(t) runtimeContext, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - h := newHandler(runtimeContext, buf) + h := newTestHandler(runtimeContext, buf) h.inputs.WorkflowOwner = chainsim.TestAddress h.inputs.WorkflowName = "test_workflow" h.inputs.DonFamily = "test_label" @@ -255,7 +255,7 @@ func TestUpload_UsesResolvedWorkflowOwnerForPresignedUrls(t *testing.T) { simulatedEnvironment := chainsim.NewSimulatedEnvironment(t) t.Cleanup(simulatedEnvironment.Close) ctx, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - h := newHandler(ctx, buf) + h := newTestHandler(ctx, buf) h.inputs.WorkflowOwner = "0x2222222222222222222222222222222222222222" h.inputs.WorkflowName = "test_workflow" h.inputs.DonFamily = "test_label" diff --git a/cmd/workflow/deploy/auto_link.go b/cmd/workflow/deploy/auto_link.go index 7e393dc3..fbcadb0d 100644 --- a/cmd/workflow/deploy/auto_link.go +++ b/cmd/workflow/deploy/auto_link.go @@ -1,7 +1,6 @@ package deploy import ( - "context" "fmt" "strings" "time" @@ -25,7 +24,7 @@ const ( func (h *handler) ensureOwnerLinkedOrFail(onChain *settings.OnChainRegistry) error { ownerAddr := common.HexToAddress(h.inputs.WorkflowOwner) - linked, err := h.wrc.IsOwnerLinked(h.execCtx, ownerAddr) + linked, err := h.wrc.IsOwnerLinked(h.executionContext(), ownerAddr) if err != nil { return fmt.Errorf("failed to check owner link status: %w", err) } @@ -66,7 +65,7 @@ func (h *handler) ensureOwnerLinkedOrFail(onChain *settings.OnChainRegistry) err func (h *handler) autoLinkMSIGAndExit(onChain *settings.OnChainRegistry) (halt bool, err error) { ownerAddr := common.HexToAddress(h.inputs.WorkflowOwner) - linked, err := h.wrc.IsOwnerLinked(h.execCtx, ownerAddr) + linked, err := h.wrc.IsOwnerLinked(h.executionContext(), ownerAddr) if err != nil { return false, fmt.Errorf("failed to check owner link status: %w", err) } @@ -107,7 +106,7 @@ func (h *handler) tryAutoLink(onChain *settings.OnChainRegistry) error { EnvironmentSet: h.environmentSet, } - return linkkey.Exec(h.execCtx, rtx, linkkey.Inputs{ + return linkkey.Exec(h.executionContext(), rtx, linkkey.Inputs{ WorkflowOwner: h.inputs.WorkflowOwner, WorkflowRegistryContractAddress: onChain.Address(), WorkflowOwnerLabel: h.inputs.OwnerLabel, @@ -137,7 +136,7 @@ func (h *handler) checkLinkStatusViaGraphQL(ownerAddr common.Address) (bool, err } gql := graphqlclient.New(h.credentials, h.environmentSet, h.log) - if err := gql.Execute(context.Background(), req, &resp); err != nil { + if err := gql.Execute(h.executionContext(), req, &resp); err != nil { return false, fmt.Errorf("GraphQL query failed: %w", err) } @@ -181,7 +180,12 @@ func (h *handler) waitForBackendLinkProcessing(ownerAddr common.Address) error { ui.Line() // Wait for 3 block confirmations before polling - time.Sleep(initialBlockWait) + ctx := h.executionContext() + select { + case <-time.After(initialBlockWait): + case <-ctx.Done(): + return ctx.Err() + } err := retry.Do( func() error { @@ -199,6 +203,7 @@ func (h *handler) waitForBackendLinkProcessing(ownerAddr common.Address) error { retry.Delay(retryDelay), retry.DelayType(retry.FixedDelay), // Use fixed 3s delay between retries retry.LastErrorOnly(true), + retry.Context(ctx), retry.OnRetry(func(n uint, err error) { h.log.Debug().Uint("attempt", n+1).Uint("maxAttempts", maxAttempts).Err(err).Msg("Retrying link status check") ui.Dim(fmt.Sprintf(" Waiting for verification... (attempt %d/%d)", n+1, maxAttempts)) diff --git a/cmd/workflow/deploy/auto_link_test.go b/cmd/workflow/deploy/auto_link_test.go index e192ccfa..aa4d8b3f 100644 --- a/cmd/workflow/deploy/auto_link_test.go +++ b/cmd/workflow/deploy/auto_link_test.go @@ -158,7 +158,7 @@ func TestCheckLinkStatusViaGraphQL(t *testing.T) { AuthType: credentials.AuthTypeApiKey, IsValidated: true, } - h := newHandler(ctx, nil) + h := newTestHandler(ctx, nil) h.inputs.WorkflowOwner = tt.ownerAddress h.environmentSet.GraphQLURL = server.URL + "/graphql" @@ -329,7 +329,7 @@ func TestWaitForBackendLinkProcessing(t *testing.T) { AuthType: credentials.AuthTypeApiKey, IsValidated: true, } - h := newHandler(ctx, nil) + h := newTestHandler(ctx, nil) h.inputs.WorkflowOwner = tt.ownerAddress h.environmentSet.GraphQLURL = server.URL + "/graphql" diff --git a/cmd/workflow/deploy/compile.go b/cmd/workflow/deploy/compile.go index ecb3c064..a3dda9c5 100644 --- a/cmd/workflow/deploy/compile.go +++ b/cmd/workflow/deploy/compile.go @@ -67,7 +67,7 @@ func (h *handler) Compile() error { h.runtimeContext.Workflow.Language = cmdcommon.GetWorkflowLanguage(workflowMainFile) } - wasmFile, err = cmdcommon.CompileWorkflowToWasm(resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasmFile, err = cmdcommon.CompileWorkflowToWasm(h.executionContext(), resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: h.inputs.SkipTypeChecks, }) diff --git a/cmd/workflow/deploy/compile_test.go b/cmd/workflow/deploy/compile_test.go index 149ac19a..ffd29c61 100644 --- a/cmd/workflow/deploy/compile_test.go +++ b/cmd/workflow/deploy/compile_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "encoding/base64" "errors" "io" @@ -254,7 +255,7 @@ func createTestSettings(workflowOwnerAddress, workflowOwnerType, workflowName, w func runCompile(simulatedEnvironment *chainsim.SimulatedEnvironment, inputs Inputs, ownerType string) error { ctx, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - handler := newHandler(ctx, buf) + handler := newTestHandler(ctx, buf) ctx.Settings = createTestSettings( inputs.WorkflowOwner, @@ -266,8 +267,7 @@ func runCompile(simulatedEnvironment *chainsim.SimulatedEnvironment, inputs Inpu handler.settings = ctx.Settings handler.inputs = inputs - err := handler.ValidateInputs() - if err != nil { + if err := handler.ValidateInputs(); err != nil { return err } @@ -286,7 +286,7 @@ func outputPathWithExtensions(path string) string { // file content equals CompileWorkflowToWasm(workflowPath) + brotli + base64. func assertCompileOutputMatchesUnderlying(t *testing.T, simulatedEnvironment *chainsim.SimulatedEnvironment, inputs Inputs, ownerType string) { t.Helper() - wasm, err := cmdcommon.CompileWorkflowToWasm(inputs.WorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasm, err := cmdcommon.CompileWorkflowToWasm(context.Background(), inputs.WorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: inputs.SkipTypeChecks, }) @@ -416,7 +416,7 @@ func TestCompileWithWasmPath(t *testing.T) { defer simulatedEnvironment.Close() ctx, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - handler := newHandler(ctx, buf) + handler := newTestHandler(ctx, buf) ctx.Settings = createTestSettings( chainsim.TestAddress, constants.WorkflowOwnerTypeEOA, diff --git a/cmd/workflow/deploy/deploy.go b/cmd/workflow/deploy/deploy.go index 34f68f9b..87e7dd28 100644 --- a/cmd/workflow/deploy/deploy.go +++ b/cmd/workflow/deploy/deploy.go @@ -133,6 +133,15 @@ func newHandler(ctx *runtime.Context, stdin io.Reader) *handler { return &h } +// executionContext returns the context from Execute(), or context.Background() +// when handler methods are invoked directly in unit tests. +func (h *handler) executionContext() context.Context { + if h.execCtx != nil { + return h.execCtx + } + return context.Background() +} + func (h *handler) ResolveInputs(v *viper.Viper) (Inputs, error) { var configURL *string if v.IsSet("config-url") { @@ -231,6 +240,10 @@ func (h *handler) Execute(ctx context.Context) error { return err } + if err := h.executionContext().Err(); err != nil { + return err + } + if err := adapter.RunPreDeployChecks(); err != nil { if errors.Is(err, errDeployHalted) { return nil @@ -274,6 +287,10 @@ func (h *handler) Execute(ctx context.Context) error { // Artifact upload is deferred to the deploy service so it runs after any // existing-workflow update confirmation. func (h *handler) prepareArtifacts() error { + if err := h.executionContext().Err(); err != nil { + return err + } + workflowcommon.DisplayWorkflowDetails( h.settings, h.runtimeContext, @@ -285,7 +302,7 @@ func (h *handler) prepareArtifacts() error { if cmdcommon.IsURL(h.inputs.WasmPath) { h.inputs.BinaryURL = h.inputs.WasmPath ui.Dim("Fetching binary from URL for workflow ID computation...") - fetched, err := cmdcommon.FetchURL(h.inputs.WasmPath) + fetched, err := cmdcommon.FetchURL(h.executionContext(), h.inputs.WasmPath) if err != nil { return fmt.Errorf("failed to fetch binary from URL: %w", err) } @@ -302,7 +319,7 @@ func (h *handler) prepareArtifacts() error { h.inputs.ConfigURL = &url h.inputs.ConfigPath = "" ui.Dim("Fetching config from URL for workflow ID computation...") - fetched, err := cmdcommon.FetchURL(url) + fetched, err := cmdcommon.FetchURL(h.executionContext(), url) if err != nil { return fmt.Errorf("failed to fetch config from URL: %w", err) } diff --git a/cmd/workflow/deploy/private_registry_test.go b/cmd/workflow/deploy/private_registry_test.go index db0ed61b..e7f839a6 100644 --- a/cmd/workflow/deploy/private_registry_test.go +++ b/cmd/workflow/deploy/private_registry_test.go @@ -1,7 +1,6 @@ package deploy import ( - "context" "encoding/base64" "encoding/hex" "encoding/json" @@ -303,7 +302,7 @@ func TestCheckWorkflowExists_PrivateRegistry(t *testing.T) { defer simulatedEnvironment.Close() ctx, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - h := newHandler(ctx, buf) + h := newTestHandler(ctx, buf) h.credentials = makeAPIKeyCredentials(t) gqlServer := newAssertGQLServer(t, func(t *testing.T, req deployMockGraphQLRequest) (int, map[string]any) { @@ -313,7 +312,6 @@ func TestCheckWorkflowExists_PrivateRegistry(t *testing.T) { defer gqlServer.Close() h.environmentSet.GraphQLURL = gqlServer.URL - h.execCtx = context.Background() strategy := newPrivateRegistryDeployStrategy(h) exists, status, err := strategy.CheckWorkflowExists("", "jnowak-workflow-test-v5", "", tt.workflowID) diff --git a/cmd/workflow/deploy/register.go b/cmd/workflow/deploy/register.go index 29c6ac67..b65aad3a 100644 --- a/cmd/workflow/deploy/register.go +++ b/cmd/workflow/deploy/register.go @@ -57,7 +57,7 @@ func (h *handler) handleUpsert(params client.RegisterWorkflowV2Parameters, onCha workflowName := h.inputs.WorkflowName workflowTag := h.inputs.WorkflowTag h.log.Debug().Interface("Workflow parameters", params).Msg("Registering workflow...") - txOut, err := h.wrc.UpsertWorkflow(h.execCtx, params) + txOut, err := h.wrc.UpsertWorkflow(h.executionContext(), params) if err != nil { return fmt.Errorf("failed to register workflow: %w", err) } diff --git a/cmd/workflow/deploy/register_test.go b/cmd/workflow/deploy/register_test.go index b039aaf5..be843bd7 100644 --- a/cmd/workflow/deploy/register_test.go +++ b/cmd/workflow/deploy/register_test.go @@ -46,7 +46,7 @@ func TestWorkflowUpsert(t *testing.T) { defer simulatedEnvironment.Close() ctx, buf := simulatedEnvironment.NewRuntimeContextWithBufferedOutput() - handler := newHandler(ctx, buf) + handler := newTestHandler(ctx, buf) wrc, err := handler.clientFactory.NewWorkflowRegistryV2Client(context.Background()) require.NoError(t, err) @@ -56,15 +56,12 @@ func TestWorkflowUpsert(t *testing.T) { err = handler.ValidateInputs() require.NoError(t, err) - wfArt := workflowArtifact{ + handler.workflowArtifact = &workflowArtifact{ BinaryData: []byte("0x1234"), ConfigData: []byte("config"), WorkflowID: "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", } - handler.workflowArtifact = &wfArt - handler.execCtx = context.Background() - onChain, err := settings.AsOnChain(ctx.ResolvedRegistry, "test") require.NoError(t, err) err = handler.upsert(onChain) diff --git a/cmd/workflow/deploy/registry_deploy_strategy_onchain.go b/cmd/workflow/deploy/registry_deploy_strategy_onchain.go index 032de838..1ed63bce 100644 --- a/cmd/workflow/deploy/registry_deploy_strategy_onchain.go +++ b/cmd/workflow/deploy/registry_deploy_strategy_onchain.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "fmt" "sync" @@ -33,7 +34,7 @@ func newOnchainRegistryDeployStrategy(h *handler) (*onchainRegistryDeployStrateg a.wg.Add(1) go func() { defer a.wg.Done() - wrc, err := h.clientFactory.NewWorkflowRegistryV2Client(h.execCtx) + wrc, err := h.clientFactory.NewWorkflowRegistryV2Client(h.executionContext()) if err != nil { a.initErr = fmt.Errorf("failed to create workflow registry client: %w", err) return @@ -44,10 +45,27 @@ func newOnchainRegistryDeployStrategy(h *handler) (*onchainRegistryDeployStrateg return a, nil } +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + func (a *onchainRegistryDeployStrategy) RunPreDeployChecks() error { h := a.h - a.wg.Wait() + if err := waitWithContext(a.h.executionContext(), &a.wg); err != nil { + return err + } if a.initErr != nil { return a.initErr } @@ -72,7 +90,7 @@ func (a *onchainRegistryDeployStrategy) RunPreDeployChecks() error { } func (a *onchainRegistryDeployStrategy) CheckWorkflowExists(workflowOwner, workflowName, workflowTag, workflowID string) (bool, *uint8, error) { - workflow, err := a.wrc.GetWorkflow(a.h.execCtx, common.HexToAddress(workflowOwner), workflowName, workflowTag) + workflow, err := a.wrc.GetWorkflow(a.h.executionContext(), common.HexToAddress(workflowOwner), workflowName, workflowTag) if err != nil { return false, nil, err } @@ -92,7 +110,7 @@ func (a *onchainRegistryDeployStrategy) Upsert() error { h := a.h if err := checkUserDonLimitBeforeDeploy( - h.execCtx, + h.executionContext(), a.wrc, a.wrc, common.HexToAddress(h.inputs.WorkflowOwner), diff --git a/cmd/workflow/deploy/registry_deploy_strategy_private.go b/cmd/workflow/deploy/registry_deploy_strategy_private.go index 9909800e..2a77d2b9 100644 --- a/cmd/workflow/deploy/registry_deploy_strategy_private.go +++ b/cmd/workflow/deploy/registry_deploy_strategy_private.go @@ -34,7 +34,7 @@ func (a *privateRegistryDeployStrategy) RunPreDeployChecks() error { func (a *privateRegistryDeployStrategy) CheckWorkflowExists(_, workflowName, _, workflowID string) (bool, *uint8, error) { a.ensureClient() - workflow, err := a.prc.GetWorkflowByName(a.h.execCtx, workflowName) + workflow, err := a.prc.GetWorkflowByName(a.h.executionContext(), workflowName) if err == nil { if workflow.WorkflowID == workflowID { return true, offchainStatusToUint8(workflow.Status), fmt.Errorf("workflow with id %s is already registered and unchanged; re-deployment skipped: %w", workflowID, errWorkflowUnchanged) @@ -57,7 +57,7 @@ func (a *privateRegistryDeployStrategy) Upsert() error { ui.Line() ui.Dim(fmt.Sprintf("Registering workflow in private registry (workflowID: %s)...", input.WorkflowID)) - result, err := a.prc.UpsertWorkflowInRegistry(a.h.execCtx, input) + result, err := a.prc.UpsertWorkflowInRegistry(a.h.executionContext(), input) if err != nil { return fmt.Errorf("failed to register workflow in private registry: %w", err) } diff --git a/cmd/workflow/deploy/test_helpers_test.go b/cmd/workflow/deploy/test_helpers_test.go new file mode 100644 index 00000000..e321e3c6 --- /dev/null +++ b/cmd/workflow/deploy/test_helpers_test.go @@ -0,0 +1,17 @@ +package deploy + +import ( + "context" + "io" + + "github.com/smartcontractkit/cre-cli/internal/runtime" +) + +// newTestHandler returns a handler suitable for unit tests that call handler +// methods directly instead of going through Execute(). It pre-sets execCtx so +// cancellation-aware code paths behave like a normal CLI invocation. +func newTestHandler(ctx *runtime.Context, stdin io.Reader) *handler { + h := newHandler(ctx, stdin) + h.execCtx = context.Background() + return h +} diff --git a/cmd/workflow/hash/hash.go b/cmd/workflow/hash/hash.go index 49533870..7efdb434 100644 --- a/cmd/workflow/hash/hash.go +++ b/cmd/workflow/hash/hash.go @@ -1,6 +1,7 @@ package hash import ( + "context" "fmt" "os" "strings" @@ -62,7 +63,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { DerivedOwner: runtimeContext.DerivedWorkflowOwner, } - return Execute(inputs) + return Execute(cmd.Context(), inputs) }, } @@ -81,8 +82,8 @@ func New(runtimeContext *runtime.Context) *cobra.Command { return hashCmd } -func Execute(inputs Inputs) error { - rawBinary, err := loadBinary(inputs.WasmPath, inputs.WorkflowPath, inputs.SkipTypeChecks) +func Execute(ctx context.Context, inputs Inputs) error { + rawBinary, err := loadBinary(ctx, inputs.WasmPath, inputs.WorkflowPath, inputs.SkipTypeChecks) if err != nil { return err } @@ -92,7 +93,7 @@ func Execute(inputs Inputs) error { return fmt.Errorf("failed to compress binary: %w", err) } - config, err := loadConfig(inputs.ConfigPath) + config, err := loadConfig(ctx, inputs.ConfigPath) if err != nil { return err } @@ -190,11 +191,11 @@ func isPrivateRegistryID(deploymentRegistry string) bool { return strings.EqualFold(deploymentRegistry, "private") } -func loadBinary(wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) ([]byte, error) { +func loadBinary(ctx context.Context, wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) ([]byte, error) { if wasmFlag != "" { if cmdcommon.IsURL(wasmFlag) { ui.Dim("Fetching WASM binary from URL...") - data, err := cmdcommon.FetchURL(wasmFlag) + data, err := cmdcommon.FetchURL(ctx, wasmFlag) if err != nil { return nil, fmt.Errorf("failed to fetch WASM from URL: %w", err) } @@ -221,7 +222,7 @@ func loadBinary(wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) spinner := ui.NewSpinner() spinner.Start("Compiling workflow...") - wasmBytes, err := cmdcommon.CompileWorkflowToWasm(resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasmBytes, err := cmdcommon.CompileWorkflowToWasm(ctx, resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: skipTypeChecks, }) @@ -235,13 +236,13 @@ func loadBinary(wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) return wasmBytes, nil } -func loadConfig(configPath string) ([]byte, error) { +func loadConfig(ctx context.Context, configPath string) ([]byte, error) { if configPath == "" { return nil, nil } if cmdcommon.IsURL(configPath) { ui.Dim("Fetching config from URL...") - data, err := cmdcommon.FetchURL(configPath) + data, err := cmdcommon.FetchURL(ctx, configPath) if err != nil { return nil, fmt.Errorf("failed to fetch config from URL: %w", err) } diff --git a/cmd/workflow/hash/hash_test.go b/cmd/workflow/hash/hash_test.go index 0ed08a82..14809402 100644 --- a/cmd/workflow/hash/hash_test.go +++ b/cmd/workflow/hash/hash_test.go @@ -1,6 +1,7 @@ package hash import ( + "context" "crypto/sha256" "encoding/hex" "io" @@ -80,7 +81,7 @@ func TestExecute_WithForUser(t *testing.T) { WorkflowName: "test-workflow", } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -94,7 +95,7 @@ func TestExecute_WithoutForUser_UsesPrivateKey(t *testing.T) { PrivateKey: testPrivateKey, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -107,7 +108,7 @@ func TestExecute_WithoutForUser_NoKey_Errors(t *testing.T) { WorkflowName: "test-workflow", } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.Error(t, err) assert.Contains(t, err.Error(), "--public_key") } @@ -173,7 +174,7 @@ func TestExecute_HashesAreDeterministic(t *testing.T) { "workflow ID should start with version byte 00") // Running Execute should succeed (hashes are printed via ui, verified above) - err = Execute(inputs) + err = Execute(context.Background(), inputs) require.NoError(t, err) } @@ -187,7 +188,7 @@ func TestExecute_EmptyConfig(t *testing.T) { WorkflowName: "test-workflow", } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -201,7 +202,7 @@ func TestExecute_OffChainRequiresPublicKey(t *testing.T) { RegistryType: settings.RegistryTypeOffChain, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.Error(t, err) assert.Contains(t, err.Error(), "--public_key") } @@ -218,7 +219,7 @@ func TestExecute_OffChainUsesPublicKey(t *testing.T) { DerivedOwner: testDerivedOwner, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -233,7 +234,7 @@ func TestExecute_OffChainUsesDerivedOwner(t *testing.T) { DerivedOwner: testDerivedOwner, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } diff --git a/cmd/workflow/simulate/simulate.go b/cmd/workflow/simulate/simulate.go index 7d0ca445..fe95ed2b 100644 --- a/cmd/workflow/simulate/simulate.go +++ b/cmd/workflow/simulate/simulate.go @@ -90,7 +90,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { if err != nil { return err } - return handler.Execute(inputs) + return handler.Execute(cmd.Context(), inputs) }, } @@ -252,14 +252,14 @@ func (h *handler) ValidateInputs(inputs Inputs) error { return nil } -func (h *handler) Execute(inputs Inputs) error { +func (h *handler) Execute(ctx context.Context, inputs Inputs) error { var wasmFileBinary []byte var err error if inputs.WasmPath != "" { if cmdcommon.IsURL(inputs.WasmPath) { ui.Dim("Fetching WASM binary from URL...") - wasmFileBinary, err = cmdcommon.FetchURL(inputs.WasmPath) + wasmFileBinary, err = cmdcommon.FetchURL(ctx, inputs.WasmPath) if err != nil { return fmt.Errorf("failed to fetch WASM from URL: %w", err) } @@ -298,7 +298,7 @@ func (h *handler) Execute(inputs Inputs) error { spinner := ui.NewSpinner() spinner.Start("Compiling workflow...") - wasmFileBinary, err = cmdcommon.CompileWorkflowToWasm(resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasmFileBinary, err = cmdcommon.CompileWorkflowToWasm(ctx, resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: false, SkipTypeChecks: inputs.SkipTypeChecks, }) @@ -343,7 +343,7 @@ func (h *handler) Execute(inputs Inputs) error { var config []byte if cmdcommon.IsURL(inputs.ConfigPath) { ui.Dim("Fetching config from URL...") - config, err = cmdcommon.FetchURL(inputs.ConfigPath) + config, err = cmdcommon.FetchURL(ctx, inputs.ConfigPath) if err != nil { return fmt.Errorf("failed to fetch config from URL: %w", err) } diff --git a/cmd/workflow/simulate/simulate_test.go b/cmd/workflow/simulate/simulate_test.go index 847d7ae2..4879b8f3 100644 --- a/cmd/workflow/simulate/simulate_test.go +++ b/cmd/workflow/simulate/simulate_test.go @@ -98,7 +98,7 @@ func TestBlankWorkflowSimulation(t *testing.T) { require.NoError(t, err) // Execute the simulation. We expect this to compile the workflow and run the simulator successfully. - err = handler.Execute(inputs) + err = handler.Execute(context.Background(), inputs) require.NoError(t, err, "Execute should not return an error") } diff --git a/internal/client/storageclient/storageclient.go b/internal/client/storageclient/storageclient.go index de16d790..9a798684 100644 --- a/internal/client/storageclient/storageclient.go +++ b/internal/client/storageclient/storageclient.go @@ -69,15 +69,15 @@ func (c *Client) SetHTTPTimeout(timeout time.Duration) { c.httpTimeout = timeout } -func (c *Client) CreateServiceContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers +func (c *Client) CreateServiceContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers } -func (c *Client) CreateHttpContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), c.httpTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers +func (c *Client) CreateHttpContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, c.httpTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers } -func (c *Client) GeneratePostUrlForArtifact(workflowId string, artifactType ArtifactType, content []byte) (GeneratePresignedPostUrlForArtifactResponse, error) { +func (c *Client) GeneratePostUrlForArtifact(ctx context.Context, workflowId string, artifactType ArtifactType, content []byte) (GeneratePresignedPostUrlForArtifactResponse, error) { const mutation = ` mutation GeneratePresignedPostUrlForArtifact($artifact: GeneratePresignedPostUrlRequest!) { generatePresignedPostUrlForArtifact(artifact: $artifact) { @@ -102,7 +102,7 @@ mutation GeneratePresignedPostUrlForArtifact($artifact: GeneratePresignedPostUrl GeneratePresignedPostUrlForArtifact GeneratePresignedPostUrlForArtifactResponse `json:"generatePresignedPostUrlForArtifact"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + ctx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() if err := c.graphql. @@ -116,7 +116,7 @@ mutation GeneratePresignedPostUrlForArtifact($artifact: GeneratePresignedPostUrl return container.GeneratePresignedPostUrlForArtifact, nil } -func (c *Client) GenerateUnsignedGetUrlForArtifact(workflowId string, artifactType ArtifactType) (GenerateUnsignedGetUrlForArtifactResponse, error) { +func (c *Client) GenerateUnsignedGetUrlForArtifact(ctx context.Context, workflowId string, artifactType ArtifactType) (GenerateUnsignedGetUrlForArtifactResponse, error) { const mutation = ` mutation GenerateUnsignedGetUrlForArtifact($artifact: GenerateUnsignedGetUrlRequest!) { generateUnsignedGetUrlForArtifact(artifact: $artifact) { @@ -134,7 +134,7 @@ mutation GenerateUnsignedGetUrlForArtifact($artifact: GenerateUnsignedGetUrlRequ GenerateUnsignedGetUrlForArtifact GenerateUnsignedGetUrlForArtifactResponse `json:"generateUnsignedGetUrlForArtifact"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + ctx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() if err := c.graphql. @@ -154,7 +154,7 @@ func calculateContentHash(content []byte) string { return contentHash } -func (c *Client) UploadToOrigin(g GeneratePresignedPostUrlForArtifactResponse, content []byte, contentType string) error { +func (c *Client) UploadToOrigin(ctx context.Context, g GeneratePresignedPostUrlForArtifactResponse, content []byte, contentType string) error { c.log.Debug().Str("URL", g.PresignedPostURL).Msg("Uploading content to origin") var b bytes.Buffer @@ -197,7 +197,7 @@ func (c *Client) UploadToOrigin(g GeneratePresignedPostUrlForArtifactResponse, c return err } - ctx, cancel := c.CreateHttpContextWithTimeout() + ctx, cancel := c.CreateHttpContextWithTimeout(ctx) defer cancel() httpReq, err := http.NewRequestWithContext(ctx, "POST", g.PresignedPostURL, &b) @@ -231,10 +231,14 @@ func (c *Client) UploadToOrigin(g GeneratePresignedPostUrlForArtifactResponse, c } func (c *Client) UploadArtifactWithRetriesAndGetURL( + ctx context.Context, workflowID string, artifactType ArtifactType, content []byte, contentType string) (GenerateUnsignedGetUrlForArtifactResponse, error) { + if err := ctx.Err(); err != nil { + return GenerateUnsignedGetUrlForArtifactResponse{}, err + } if len(workflowID) == 0 { return GenerateUnsignedGetUrlForArtifactResponse{}, fmt.Errorf("workflowID is empty") } @@ -251,7 +255,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( err := retry.Do( func() error { var err error - g, err = c.GeneratePostUrlForArtifact(workflowID, artifactType, content) + g, err = c.GeneratePostUrlForArtifact(ctx, workflowID, artifactType, content) if err != nil { if strings.Contains(err.Error(), "already exists") { shouldUpload = false @@ -264,6 +268,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( }, retry.Attempts(3), retry.LastErrorOnly(true), + retry.Context(ctx), ) if err != nil { c.log.Error().Err(err).Msg("Failed to generate presigned post URL for artifact") @@ -276,10 +281,11 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( if shouldUpload { err = retry.Do( func() error { - return c.UploadToOrigin(g, content, contentType) + return c.UploadToOrigin(ctx, g, content, contentType) }, retry.Attempts(3), retry.LastErrorOnly(true), + retry.Context(ctx), ) if err != nil { c.log.Error().Err(err).Msg("Failed to upload content to origin") @@ -290,7 +296,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( var g2 GenerateUnsignedGetUrlForArtifactResponse err = retry.Do( func() error { - g2, err = c.GenerateUnsignedGetUrlForArtifact(workflowID, artifactType) + g2, err = c.GenerateUnsignedGetUrlForArtifact(ctx, workflowID, artifactType) if err != nil { return fmt.Errorf("generate unsigned get url: %w", err) } @@ -298,6 +304,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( }, retry.Attempts(3), retry.LastErrorOnly(true), + retry.Context(ctx), ) if err != nil { c.log.Error().Err(err).Msg("Failed to generate unsigned get URL for artifact")