diff --git a/cmd/generate-bindings/bindings/README.md b/cmd/generate-bindings/evm/README.md similarity index 99% rename from cmd/generate-bindings/bindings/README.md rename to cmd/generate-bindings/evm/README.md index 6513b27a..da346939 100644 --- a/cmd/generate-bindings/bindings/README.md +++ b/cmd/generate-bindings/evm/README.md @@ -32,7 +32,7 @@ that lets you generate Go bindings for your smart contracts using a custom templ ### Programmatic API ```go -import "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" +import "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" func main() { err := bindings.GenerateBindings( diff --git a/cmd/generate-bindings/bindings/abigen/FORK_METADATA.md b/cmd/generate-bindings/evm/abigen/FORK_METADATA.md similarity index 100% rename from cmd/generate-bindings/bindings/abigen/FORK_METADATA.md rename to cmd/generate-bindings/evm/abigen/FORK_METADATA.md diff --git a/cmd/generate-bindings/bindings/abigen/bind.go b/cmd/generate-bindings/evm/abigen/bind.go similarity index 100% rename from cmd/generate-bindings/bindings/abigen/bind.go rename to cmd/generate-bindings/evm/abigen/bind.go diff --git a/cmd/generate-bindings/bindings/abigen/bindv2.go b/cmd/generate-bindings/evm/abigen/bindv2.go similarity index 100% rename from cmd/generate-bindings/bindings/abigen/bindv2.go rename to cmd/generate-bindings/evm/abigen/bindv2.go diff --git a/cmd/generate-bindings/bindings/abigen/source.go.tpl b/cmd/generate-bindings/evm/abigen/source.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/abigen/source.go.tpl rename to cmd/generate-bindings/evm/abigen/source.go.tpl diff --git a/cmd/generate-bindings/bindings/abigen/source2.go.tpl b/cmd/generate-bindings/evm/abigen/source2.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/abigen/source2.go.tpl rename to cmd/generate-bindings/evm/abigen/source2.go.tpl diff --git a/cmd/generate-bindings/bindings/abigen/template.go b/cmd/generate-bindings/evm/abigen/template.go similarity index 100% rename from cmd/generate-bindings/bindings/abigen/template.go rename to cmd/generate-bindings/evm/abigen/template.go diff --git a/cmd/generate-bindings/bindings/bindgen.go b/cmd/generate-bindings/evm/bindgen.go similarity index 98% rename from cmd/generate-bindings/bindings/bindgen.go rename to cmd/generate-bindings/evm/bindgen.go index 7b7478b4..f79e3717 100644 --- a/cmd/generate-bindings/bindings/bindgen.go +++ b/cmd/generate-bindings/evm/bindgen.go @@ -1,4 +1,4 @@ -package bindings +package evm import ( _ "embed" @@ -11,7 +11,7 @@ import ( "github.com/ethereum/go-ethereum/common/compiler" "github.com/ethereum/go-ethereum/crypto" - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings/abigen" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm/abigen" ) //go:embed sourcecre.go.tpl diff --git a/cmd/generate-bindings/bindings/bindings_test.go b/cmd/generate-bindings/evm/bindings_test.go similarity index 99% rename from cmd/generate-bindings/bindings/bindings_test.go rename to cmd/generate-bindings/evm/bindings_test.go index de225b36..ab558459 100644 --- a/cmd/generate-bindings/bindings/bindings_test.go +++ b/cmd/generate-bindings/evm/bindings_test.go @@ -1,4 +1,4 @@ -package bindings_test +package evm_test import ( "context" @@ -20,7 +20,7 @@ import ( "github.com/smartcontractkit/cre-sdk-go/cre/testutils" consensusmock "github.com/smartcontractkit/cre-sdk-go/internal_testing/capabilities/consensus/mock" - datastorage "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings/testdata" + datastorage "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm/testdata" ) const anyChainSelector = uint64(1337) diff --git a/cmd/generate-bindings/evm/evm.go b/cmd/generate-bindings/evm/evm.go new file mode 100644 index 00000000..f0850b40 --- /dev/null +++ b/cmd/generate-bindings/evm/evm.go @@ -0,0 +1,469 @@ +package evm + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/internal/runtime" + "github.com/smartcontractkit/cre-cli/internal/ui" + "github.com/smartcontractkit/cre-cli/internal/validation" +) + +type Inputs struct { + ProjectRoot string `validate:"required,dir" cli:"--project-root"` + GoLang bool + TypeScript bool + AbiPath string `validate:"required,path_read" cli:"--abi"` + PkgName string `validate:"required" cli:"--pkg"` + GoOutPath string // contracts/evm/src/generated — set when GoLang is true + TSOutPath string // contracts/evm/ts/generated — set when TypeScript is true +} + +func New(runtimeContext *runtime.Context) *cobra.Command { + generateBindingsCmd := &cobra.Command{ + Use: "evm", + Short: "Generate bindings from contract ABI", + Long: `This command generates bindings from contract ABI files. +Supports EVM chain family with Go and TypeScript languages. +The target language is auto-detected from project files, or can be +specified explicitly with --language. +Each contract gets its own package subdirectory to avoid naming conflicts. +For example, IERC20.abi generates bindings in generated/ierc20/ package. + +Both raw ABI files (*.abi) and JSON artifact files (*.json) are supported. +For JSON files the ABI is read from the top-level "abi" field.`, + Example: " cre generate-bindings evm", + RunE: func(cmd *cobra.Command, args []string) error { + handler := newHandler(runtimeContext) + + inputs, err := handler.ResolveInputs(runtimeContext.Viper) + if err != nil { + return err + } + if err := handler.ValidateInputs(inputs); err != nil { + return err + } + return handler.Execute(inputs) + }, + } + + generateBindingsCmd.Flags().StringP("project-root", "p", "", "Path to project root directory (defaults to current directory)") + generateBindingsCmd.Flags().StringP("language", "l", "", "Target language: go, typescript (auto-detected from project files when omitted)") + generateBindingsCmd.Flags().StringP("abi", "a", "", "Path to ABI directory (defaults to contracts/evm/src/abi/). Supports *.abi and *.json files") + generateBindingsCmd.Flags().StringP("pkg", "k", "bindings", "Base package name (each contract gets its own subdirectory)") + + return generateBindingsCmd +} + +type handler struct { + log *zerolog.Logger + validated bool +} + +func newHandler(ctx *runtime.Context) *handler { + return &handler{ + log: ctx.Logger, + validated: false, + } +} + +func detectLanguages(projectRoot string) (goLang, typescript bool) { + _ = filepath.WalkDir(projectRoot, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + if d.Name() == "node_modules" || d.Name() == ".git" { + return filepath.SkipDir + } + return nil + } + base := filepath.Base(path) + if strings.HasSuffix(base, ".go") { + goLang = true + } + if strings.HasSuffix(base, ".ts") && !strings.HasSuffix(base, ".d.ts") { + typescript = true + } + return nil + }) + return goLang, typescript +} + +func (h *handler) ResolveInputs(v *viper.Viper) (Inputs, error) { + currentDir, err := os.Getwd() + if err != nil { + return Inputs{}, fmt.Errorf("failed to get current working directory: %w", err) + } + + projectRoot := v.GetString("project-root") + if projectRoot == "" { + projectRoot = currentDir + } + + contractsPath := filepath.Join(projectRoot, "contracts") + if _, err := os.Stat(contractsPath); err != nil { + return Inputs{}, fmt.Errorf("contracts folder not found in project root: %s", contractsPath) + } + + // Resolve languages: --language flag takes precedence, else auto-detect + var goLang, typescript bool + langFlag := strings.ToLower(strings.TrimSpace(v.GetString("language"))) + switch langFlag { + case "": + goLang, typescript = detectLanguages(projectRoot) + if !goLang && !typescript { + return Inputs{}, fmt.Errorf("no target language detected (use --language go or --language typescript, or ensure project contains .go or .ts files)") + } + case constants.WorkflowLanguageGolang: + goLang = true + case constants.WorkflowLanguageTypeScript: + typescript = true + default: + return Inputs{}, fmt.Errorf("unsupported language %q (supported: go, typescript)", langFlag) + } + + // Unified ABI path for both languages: contracts/evm/src/abi + abiPath := v.GetString("abi") + if abiPath == "" { + abiPath = filepath.Join(projectRoot, "contracts", "evm", "src", "abi") + } + + pkgName := v.GetString("pkg") + + // Separate output paths: Go uses src/, TS uses ts/ (typescript convention) + var goOutPath, tsOutPath string + if goLang { + goOutPath = filepath.Join(projectRoot, "contracts", "evm", "src", "generated") + } + if typescript { + tsOutPath = filepath.Join(projectRoot, "contracts", "evm", "ts", "generated") + } + + return Inputs{ + ProjectRoot: projectRoot, + GoLang: goLang, + TypeScript: typescript, + AbiPath: abiPath, + PkgName: pkgName, + GoOutPath: goOutPath, + TSOutPath: tsOutPath, + }, nil +} + +// findAbiFiles returns all supported ABI files (*.abi and *.json) found in dir. +func findAbiFiles(dir string) ([]string, error) { + abiFiles, err := filepath.Glob(filepath.Join(dir, "*.abi")) + if err != nil { + return nil, err + } + jsonFiles, err := filepath.Glob(filepath.Join(dir, "*.json")) + if err != nil { + return nil, err + } + all := append(abiFiles, jsonFiles...) + sort.Strings(all) + return all, nil +} + +// contractNameFromFile returns the contract name by stripping the .abi or .json +// extension from the base filename. +func contractNameFromFile(path string) string { + name := filepath.Base(path) + ext := filepath.Ext(name) + if ext != "" { + name = name[:len(name)-len(ext)] + } + return name +} + +func (h *handler) ValidateInputs(inputs Inputs) error { + validate, err := validation.NewValidator() + if err != nil { + return fmt.Errorf("failed to initialize validator: %w", err) + } + + if err = validate.Struct(inputs); err != nil { + return validate.ParseValidationErrors(err) + } + + if _, err := os.Stat(inputs.AbiPath); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("ABI path does not exist: %s", inputs.AbiPath) + } + return fmt.Errorf("failed to access ABI path: %w", err) + } + + if info, err := os.Stat(inputs.AbiPath); err == nil && info.IsDir() { + files, err := findAbiFiles(inputs.AbiPath) + if err != nil { + return fmt.Errorf("failed to check for ABI files in directory: %w", err) + } + if len(files) == 0 { + return fmt.Errorf("no *.abi or *.json files found in directory: %s", inputs.AbiPath) + } + } + + if inputs.GoLang && inputs.GoOutPath == "" { + return fmt.Errorf("go output path is required when language is go") + } + if inputs.TypeScript && inputs.TSOutPath == "" { + return fmt.Errorf("typescript output path is required when language is typescript") + } + + h.validated = true + return nil +} + +// contractNameToPackage converts contract names to valid Go package names +// Examples: IERC20 -> ierc20, ReserveManager -> reserve_manager, IReserveManager -> ireserve_manager +func contractNameToPackage(contractName string) string { + if contractName == "" { + return "" + } + + var result []rune + runes := []rune(contractName) + + for i, r := range runes { + if r >= 'A' && r <= 'Z' { + lower := r - 'A' + 'a' + + // Add underscore before uppercase letters, but not: + // - At the beginning (i == 0) + // - If the previous character was also uppercase and this is followed by lowercase (e.g., "ERC" in "ERC20") + // - If this is part of a sequence of uppercase letters at the beginning (e.g., "IERC20" -> "ierc20") + if i > 0 { + prevIsUpper := runes[i-1] >= 'A' && runes[i-1] <= 'Z' + nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z' + + if !prevIsUpper || (prevIsUpper && nextIsLower && i > 1) { + result = append(result, '_') + } + } + + result = append(result, lower) + } else { + result = append(result, r) + } + } + + return string(result) +} + +func (h *handler) processAbiDirectory(inputs Inputs) error { + files, err := findAbiFiles(inputs.AbiPath) + if err != nil { + return fmt.Errorf("failed to find ABI files: %w", err) + } + + if len(files) == 0 { + return fmt.Errorf("no *.abi or *.json files found in directory: %s", inputs.AbiPath) + } + + // Detect duplicate contract names across extensions (e.g. Foo.abi and Foo.json) + contractNames := make(map[string]string) // contract name -> originating file + for _, f := range files { + name := contractNameFromFile(f) + if prev, exists := contractNames[name]; exists { + return fmt.Errorf("duplicate contract name %q: found in both %s and %s", name, filepath.Base(prev), filepath.Base(f)) + } + contractNames[name] = f + } + + if inputs.GoLang { + packageNames := make(map[string]bool) + for _, abiFile := range files { + contractName := contractNameFromFile(abiFile) + packageName := contractNameToPackage(contractName) + if _, exists := packageNames[packageName]; exists { + return fmt.Errorf("package name collision: multiple contracts would generate the same package name '%s' (contracts are converted to snake_case for package names). Please rename one of your contract files to avoid this conflict", packageName) + } + packageNames[packageName] = true + } + } + + var generatedContracts []string + + for _, abiFile := range files { + contractName := contractNameFromFile(abiFile) + + if inputs.TypeScript { + outputFile := filepath.Join(inputs.TSOutPath, contractName+".ts") + ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) + + err = GenerateBindingsTS( + abiFile, + contractName, + outputFile, + ) + if err != nil { + return fmt.Errorf("failed to generate TypeScript bindings for %s: %w", contractName, err) + } + generatedContracts = append(generatedContracts, contractName) + } + + if inputs.GoLang { + packageName := contractNameToPackage(contractName) + + contractOutDir := filepath.Join(inputs.GoOutPath, packageName) + if err := os.MkdirAll(contractOutDir, 0o755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + outputFile := filepath.Join(contractOutDir, contractName+".go") + ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) + + err = GenerateBindings( + "", + abiFile, + packageName, + contractName, + outputFile, + ) + if err != nil { + return fmt.Errorf("failed to generate bindings for %s: %w", contractName, err) + } + } + } + + // Generate barrel index.ts for TypeScript + if inputs.TypeScript && len(generatedContracts) > 0 { + indexPath := filepath.Join(inputs.TSOutPath, "index.ts") + var indexContent string + indexContent += "// Code generated — DO NOT EDIT.\n" + for _, name := range generatedContracts { + indexContent += fmt.Sprintf("export * from './%s'\n", name) + indexContent += fmt.Sprintf("export * from './%s_mock'\n", name) + } + if err := os.WriteFile(indexPath, []byte(indexContent), 0o600); err != nil { + return fmt.Errorf("failed to write index.ts: %w", err) + } + } + + return nil +} + +func (h *handler) processSingleAbi(inputs Inputs) error { + contractName := contractNameFromFile(inputs.AbiPath) + + if inputs.TypeScript { + outputFile := filepath.Join(inputs.TSOutPath, contractName+".ts") + ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) + + if err := GenerateBindingsTS( + inputs.AbiPath, + contractName, + outputFile, + ); err != nil { + return err + } + } + + if inputs.GoLang { + packageName := contractNameToPackage(contractName) + + contractOutDir := filepath.Join(inputs.GoOutPath, packageName) + if err := os.MkdirAll(contractOutDir, 0o755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + outputFile := filepath.Join(contractOutDir, contractName+".go") + ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) + + if err := GenerateBindings( + "", + inputs.AbiPath, + packageName, + contractName, + outputFile, + ); err != nil { + return err + } + } + + return nil +} + +func (h *handler) Execute(inputs Inputs) error { + langs := []string{} + if inputs.GoLang { + langs = append(langs, "go") + } + if inputs.TypeScript { + langs = append(langs, "typescript") + } + ui.Dim(fmt.Sprintf("Project: %s, Chain: evm, Languages: %v", inputs.ProjectRoot, langs)) + + if inputs.GoLang { + if err := os.MkdirAll(inputs.GoOutPath, 0o755); err != nil { + return fmt.Errorf("failed to create Go output directory: %w", err) + } + } + if inputs.TypeScript { + if err := os.MkdirAll(inputs.TSOutPath, 0o755); err != nil { + return fmt.Errorf("failed to create TypeScript output directory: %w", err) + } + } + + info, err := os.Stat(inputs.AbiPath) + if err != nil { + return fmt.Errorf("failed to access ABI path: %w", err) + } + + if info.IsDir() { + if err := h.processAbiDirectory(inputs); err != nil { + return err + } + } else { + if err := h.processSingleAbi(inputs); err != nil { + return err + } + } + + if inputs.GoLang { + spinner := ui.NewSpinner() + spinner.Start("Installing dependencies...") + + err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go@"+constants.SdkVersion) + if err != nil { + spinner.Stop() + return err + } + err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm@"+constants.EVMCapabilitiesVersion) + if err != nil { + spinner.Stop() + return err + } + if err = runCommand(inputs.ProjectRoot, "go", "mod", "tidy"); err != nil { + spinner.Stop() + return err + } + + spinner.Stop() + } + + ui.Success("Bindings generated successfully") + return nil +} + +func runCommand(dir string, command string, args ...string) error { + cmd := exec.Command(command, args...) + cmd.Dir = dir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run %s: %w", command, err) + } + + return nil +} diff --git a/cmd/generate-bindings/generate-bindings_test.go b/cmd/generate-bindings/evm/evm_test.go similarity index 85% rename from cmd/generate-bindings/generate-bindings_test.go rename to cmd/generate-bindings/evm/evm_test.go index e7411fb4..559b1c49 100644 --- a/cmd/generate-bindings/generate-bindings_test.go +++ b/cmd/generate-bindings/evm/evm_test.go @@ -1,4 +1,4 @@ -package generatebindings +package evm import ( "fmt" @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" "github.com/smartcontractkit/cre-cli/internal/runtime" ) @@ -43,12 +42,10 @@ func TestContractNameToPackage(t *testing.T) { } func TestResolveInputs_DefaultFallbacks(t *testing.T) { - // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - // Create required contracts directory and go.mod contractsDir := filepath.Join(tempDir, "contracts") err = os.MkdirAll(contractsDir, 0755) require.NoError(t, err) @@ -57,7 +54,6 @@ func TestResolveInputs_DefaultFallbacks(t *testing.T) { err = os.WriteFile(goModPath, []byte("module test/contracts\n\ngo 1.20\n"), 0600) require.NoError(t, err) - // Change to temp directory originalDir, err := os.Getwd() require.NoError(t, err) defer func() { @@ -76,14 +72,12 @@ func TestResolveInputs_DefaultFallbacks(t *testing.T) { v.Set("language", "go") v.Set("pkg", "bindings") - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) - // Use filepath.EvalSymlinks to handle macOS /var vs /private/var symlink issues expectedRoot, _ := filepath.EvalSymlinks(tempDir) actualRoot, _ := filepath.EvalSymlinks(inputs.ProjectRoot) assert.Equal(t, expectedRoot, actualRoot) - assert.Equal(t, "evm", inputs.ChainFamily) assert.True(t, inputs.GoLang) expectedAbi, _ := filepath.EvalSymlinks(filepath.Join(tempDir, "contracts", "evm", "src", "abi")) actualAbi, _ := filepath.EvalSymlinks(inputs.AbiPath) @@ -117,7 +111,7 @@ func TestResolveInputs_TypeScriptDefaults(t *testing.T) { v.Set("language", "typescript") v.Set("pkg", "bindings") - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) expectedRoot, _ := filepath.EvalSymlinks(tempDir) @@ -125,12 +119,10 @@ func TestResolveInputs_TypeScriptDefaults(t *testing.T) { assert.Equal(t, expectedRoot, actualRoot) assert.True(t, inputs.TypeScript) - // ABI path: contracts/evm/src/abi expectedAbi, _ := filepath.EvalSymlinks(filepath.Join(tempDir, "contracts", "evm", "src", "abi")) actualAbi, _ := filepath.EvalSymlinks(inputs.AbiPath) assert.Equal(t, expectedAbi, actualAbi) - // TS output path: contracts/evm/ts/generated expectedTSOut, _ := filepath.EvalSymlinks(filepath.Join(tempDir, "contracts", "evm", "ts", "generated")) actualTSOut, _ := filepath.EvalSymlinks(inputs.TSOutPath) assert.Equal(t, expectedTSOut, actualTSOut) @@ -157,7 +149,7 @@ func TestAutoDetect_GoOnly(t *testing.T) { handler := newHandler(runtimeCtx) v := viper.New() - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) assert.True(t, inputs.GoLang, "Go should be auto-detected") @@ -186,7 +178,7 @@ func TestAutoDetect_TypeScriptOnly(t *testing.T) { handler := newHandler(runtimeCtx) v := viper.New() - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) assert.False(t, inputs.GoLang, "Go should not be detected") @@ -220,7 +212,7 @@ func TestAutoDetect_Both(t *testing.T) { handler := newHandler(runtimeCtx) v := viper.New() - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) assert.True(t, inputs.GoLang, "Go should be auto-detected") @@ -247,7 +239,7 @@ func TestExplicitGoFlag(t *testing.T) { v := viper.New() v.Set("language", "go") - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) assert.True(t, inputs.GoLang) @@ -277,7 +269,7 @@ func TestExplicitTypeScriptFlag(t *testing.T) { v := viper.New() v.Set("language", "typescript") - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) assert.False(t, inputs.GoLang) @@ -310,7 +302,7 @@ func TestAutoDetectBothLanguages(t *testing.T) { handler := newHandler(runtimeCtx) v := viper.New() - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) assert.True(t, inputs.GoLang) @@ -340,18 +332,15 @@ func TestOutputPathsSeparation(t *testing.T) { handler := newHandler(runtimeCtx) v := viper.New() - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) - // Go path must contain src/generated assert.Contains(t, inputs.GoOutPath, "src", "Go output path should contain src") assert.Contains(t, inputs.GoOutPath, "generated", "Go output path should contain generated") - // TS path must contain ts/generated assert.Contains(t, inputs.TSOutPath, "ts", "TS output path should contain ts") assert.Contains(t, inputs.TSOutPath, "generated", "TS output path should contain generated") - // Paths must be different assert.NotEqual(t, inputs.GoOutPath, inputs.TSOutPath, "Go and TS output paths must be different") } @@ -384,7 +373,7 @@ func TestEndToEnd_TypeScriptGeneration(t *testing.T) { v := viper.New() v.Set("language", "typescript") v.Set("pkg", "bindings") - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) require.NoError(t, handler.ValidateInputs(inputs)) require.NoError(t, handler.Execute(inputs)) @@ -397,9 +386,7 @@ func TestEndToEnd_TypeScriptGeneration(t *testing.T) { require.FileExists(t, filepath.Join(tsOutDir, "index.ts")) } -// command should run in projectRoot which contains contracts directory -func TestResolveInputs_CustomProjectRoot(t *testing.T) { - // Create a temporary directory for testing +func TestResolveEvmInputs_CustomProjectRoot(t *testing.T) { tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) @@ -407,27 +394,23 @@ func TestResolveInputs_CustomProjectRoot(t *testing.T) { runtimeCtx := &runtime.Context{} handler := newHandler(runtimeCtx) - // Test with custom project root v := viper.New() v.Set("project-root", tempDir) v.Set("language", "go") v.Set("pkg", "bindings") - _, err = handler.ResolveInputs([]string{"evm"}, v) + _, err = handler.ResolveInputs(v) require.Error(t, err) expectedErrMsg := fmt.Sprintf("contracts folder not found in project root: %s", tempDir) require.Contains(t, err.Error(), expectedErrMsg) } -// Empty project root should default to current directory, and this should contain contracts and go.mod -func TestResolveInputs_EmptyProjectRoot(t *testing.T) { - // Create a temporary directory for testing +func TestResolveEvmInputs_EmptyProjectRoot(t *testing.T) { tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - // Create required contracts directory and go.mod contractsDir := filepath.Join(tempDir, "contracts") err = os.MkdirAll(contractsDir, 0755) require.NoError(t, err) @@ -436,7 +419,6 @@ func TestResolveInputs_EmptyProjectRoot(t *testing.T) { err = os.WriteFile(goModPath, []byte("module test/contracts\n\ngo 1.20\n"), 0600) require.NoError(t, err) - // Change to temp directory originalDir, err := os.Getwd() require.NoError(t, err) defer func() { @@ -451,20 +433,17 @@ func TestResolveInputs_EmptyProjectRoot(t *testing.T) { runtimeCtx := &runtime.Context{} handler := newHandler(runtimeCtx) - // Test with empty project root (should use current directory) v := viper.New() v.Set("project-root", "") v.Set("language", "go") v.Set("pkg", "bindings") - inputs, err := handler.ResolveInputs([]string{"evm"}, v) + inputs, err := handler.ResolveInputs(v) require.NoError(t, err) - // Use filepath.EvalSymlinks to handle macOS /var vs /private/var symlink issues expectedRoot, _ := filepath.EvalSymlinks(tempDir) actualRoot, _ := filepath.EvalSymlinks(inputs.ProjectRoot) assert.Equal(t, expectedRoot, actualRoot) - assert.Equal(t, "evm", inputs.ChainFamily) assert.True(t, inputs.GoLang) expectedAbi, _ := filepath.EvalSymlinks(filepath.Join(tempDir, "contracts", "evm", "src", "abi")) actualAbi, _ := filepath.EvalSymlinks(inputs.AbiPath) @@ -475,32 +454,11 @@ func TestResolveInputs_EmptyProjectRoot(t *testing.T) { assert.Equal(t, expectedGoOut, actualGoOut) } -func TestValidateInputs_RequiredChainFamily(t *testing.T) { - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - - // Test validation with missing chain family - inputs := Inputs{ - ProjectRoot: "/tmp", - ChainFamily: "", // Missing required field - GoLang: true, - AbiPath: "/tmp/abi", - PkgName: "bindings", - GoOutPath: "/tmp/out", - } - - err := handler.ValidateInputs(inputs) - require.Error(t, err) - assert.Contains(t, err.Error(), "chain-family") -} - func TestValidateInputs_ValidInputs(t *testing.T) { - // Create a temporary directory for testing tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - // Create a valid ABI file abiContent := `[{"type":"function","name":"test","inputs":[],"outputs":[]}]` abiFile := filepath.Join(tempDir, "test.abi") err = os.WriteFile(abiFile, []byte(abiContent), 0600) @@ -509,10 +467,8 @@ func TestValidateInputs_ValidInputs(t *testing.T) { runtimeCtx := &runtime.Context{} handler := newHandler(runtimeCtx) - // Test validation with valid inputs (using single file) inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiFile, PkgName: "bindings", @@ -523,7 +479,6 @@ func TestValidateInputs_ValidInputs(t *testing.T) { require.NoError(t, err) assert.True(t, handler.validated) - // Test validation with directory containing .abi files abiDir := filepath.Join(tempDir, "abi") err = os.MkdirAll(abiDir, 0755) require.NoError(t, err) @@ -535,7 +490,6 @@ func TestValidateInputs_ValidInputs(t *testing.T) { require.NoError(t, err) assert.True(t, handler.validated) - // Test validation with directory containing .abi files for TypeScript (unified extension) abiDir2 := filepath.Join(tempDir, "abi_ts") err = os.MkdirAll(abiDir2, 0755) require.NoError(t, err) @@ -544,7 +498,6 @@ func TestValidateInputs_ValidInputs(t *testing.T) { tsInputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", TypeScript: true, AbiPath: abiDir2, PkgName: "bindings", @@ -555,7 +508,6 @@ func TestValidateInputs_ValidInputs(t *testing.T) { require.NoError(t, err) assert.True(t, handler2.validated) - // Test validation with directory containing only .json files abiDir3 := filepath.Join(tempDir, "abi_json") err = os.MkdirAll(abiDir3, 0755) require.NoError(t, err) @@ -565,7 +517,6 @@ func TestValidateInputs_ValidInputs(t *testing.T) { jsonInputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiDir3, PkgName: "bindings", @@ -577,36 +528,11 @@ func TestValidateInputs_ValidInputs(t *testing.T) { assert.True(t, handler3.validated) } -func TestValidateInputs_InvalidChainFamily(t *testing.T) { - // Create a temporary directory for testing - tempDir, err := os.MkdirTemp("", "generate-bindings-test") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - runtimeCtx := &runtime.Context{} - handler := newHandler(runtimeCtx) - - // Test validation with invalid chain family - inputs := Inputs{ - ProjectRoot: tempDir, - ChainFamily: "solana", // No longer supported - GoLang: true, - AbiPath: tempDir, - PkgName: "bindings", - GoOutPath: tempDir, - } - - err = handler.ValidateInputs(inputs) - require.Error(t, err) - assert.Contains(t, err.Error(), "chain-family") -} - func TestValidateInputs_NoLanguageSpecified(t *testing.T) { tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - // Create contracts dir but no .go or .ts files for auto-detect contractsDir := filepath.Join(tempDir, "contracts") err = os.MkdirAll(contractsDir, 0755) require.NoError(t, err) @@ -619,21 +545,18 @@ func TestValidateInputs_NoLanguageSpecified(t *testing.T) { runtimeCtx := &runtime.Context{} handler := newHandler(runtimeCtx) - // ResolveInputs should error when no --language and nothing detected v := viper.New() - _, err = handler.ResolveInputs([]string{"evm"}, v) + _, err = handler.ResolveInputs(v) require.Error(t, err) assert.Contains(t, err.Error(), "no target language") } -func TestValidateInputs_NonExistentDirectory(t *testing.T) { +func TestValidateEvmInputs_NonExistentDirectory(t *testing.T) { runtimeCtx := &runtime.Context{} handler := newHandler(runtimeCtx) - // Test validation with non-existent directory inputs := Inputs{ ProjectRoot: "/non/existent/path", - ChainFamily: "evm", GoLang: true, AbiPath: "/non/existent/abi", PkgName: "bindings", @@ -646,7 +569,6 @@ func TestValidateInputs_NonExistentDirectory(t *testing.T) { } func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { - // Create a temporary directory structure tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) @@ -657,7 +579,6 @@ func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { err = os.MkdirAll(abiDir, 0755) require.NoError(t, err) - // Create mock ABI files (both .abi and .json formats) abiContent := `[{"type":"function","name":"test","inputs":[],"outputs":[]}]` jsonContent := `{"abi":[{"type":"function","name":"test","inputs":[],"outputs":[]}]}` err = os.WriteFile(filepath.Join(abiDir, "Contract1.abi"), []byte(abiContent), 0600) @@ -667,7 +588,6 @@ func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { err = os.WriteFile(filepath.Join(abiDir, "Contract3.json"), []byte(jsonContent), 0600) require.NoError(t, err) - // Create a mock logger to prevent nil pointer dereference logger := zerolog.New(os.Stderr).With().Timestamp().Logger() runtimeCtx := &runtime.Context{ Logger: &logger, @@ -676,25 +596,19 @@ func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiDir, PkgName: "bindings", GoOutPath: outDir, } - // This test will fail because it tries to call the actual bindings.GenerateBindings - // but it tests the directory processing logic err = handler.processAbiDirectory(inputs) - // We expect an error because the bindings package requires actual ABI format - // but we can check that it created the expected directory structure if err == nil { t.Log("Unexpectedly succeeded - bindings generation worked with mock ABI") } else { assert.Contains(t, err.Error(), "Contract1") } - // Verify that per-contract directories were created contract1Dir := filepath.Join(outDir, "contract1") contract2Dir := filepath.Join(outDir, "contract2") contract3Dir := filepath.Join(outDir, "contract3") @@ -704,7 +618,6 @@ func TestProcessAbiDirectory_MultipleFiles(t *testing.T) { } func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { - // Create a temporary directory structure tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) @@ -715,7 +628,6 @@ func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { err = os.MkdirAll(abiDir, 0755) require.NoError(t, err) - // Create mock ABI files with different naming patterns (both .abi and .json) abiContent := `[{"type":"function","name":"test","inputs":[],"outputs":[]}]` jsonContent := `{"abi":[{"type":"function","name":"test","inputs":[],"outputs":[]}]}` testCases := []struct { @@ -737,7 +649,6 @@ func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { require.NoError(t, err) } - // Create a mock logger logger := zerolog.New(os.Stderr).With().Timestamp().Logger() runtimeCtx := &runtime.Context{ Logger: &logger, @@ -746,20 +657,17 @@ func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiDir, PkgName: "bindings", GoOutPath: outDir, } - // Try to process - the mock ABI content might actually work err = handler.processAbiDirectory(inputs) if err != nil { t.Logf("Expected error occurred: %v", err) } - // Verify that per-contract directories were created with correct names for _, tc := range testCases { contractDir := filepath.Join(outDir, tc.expectedPackage) assert.DirExists(t, contractDir, "Expected directory %s to be created", contractDir) @@ -767,7 +675,6 @@ func TestProcessAbiDirectory_CreatesPerContractDirectories(t *testing.T) { } func TestProcessAbiDirectory_NoAbiFiles(t *testing.T) { - // Create a temporary directory structure tempDir, err := os.MkdirTemp("", "generate-bindings-test") require.NoError(t, err) defer os.RemoveAll(tempDir) @@ -786,7 +693,6 @@ func TestProcessAbiDirectory_NoAbiFiles(t *testing.T) { inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiDir, PkgName: "bindings", @@ -814,7 +720,6 @@ func TestProcessAbiDirectory_NoAbiFiles_TypeScript(t *testing.T) { inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", TypeScript: true, AbiPath: abiDir, PkgName: "bindings", @@ -839,8 +744,6 @@ func TestProcessAbiDirectory_PackageNameCollision(t *testing.T) { abiContent := `[{"type":"function","name":"test","inputs":[],"outputs":[]}]` - // "TestContract" -> "test_contract" - // "test_contract" -> "test_contract" err = os.WriteFile(filepath.Join(abiDir, "TestContract.abi"), []byte(abiContent), 0600) require.NoError(t, err) err = os.WriteFile(filepath.Join(abiDir, "test_contract.abi"), []byte(abiContent), 0600) @@ -854,7 +757,6 @@ func TestProcessAbiDirectory_PackageNameCollision(t *testing.T) { inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiDir, PkgName: "bindings", @@ -862,7 +764,6 @@ func TestProcessAbiDirectory_PackageNameCollision(t *testing.T) { } err = handler.processAbiDirectory(inputs) - fmt.Println(err.Error()) require.Error(t, err) require.Equal(t, err.Error(), "package name collision: multiple contracts would generate the same package name 'test_contract' (contracts are converted to snake_case for package names). Please rename one of your contract files to avoid this conflict") } @@ -890,7 +791,6 @@ func TestProcessAbiDirectory_DuplicateContractNameAcrossExtensions(t *testing.T) inputs := Inputs{ ProjectRoot: tempDir, - ChainFamily: "evm", GoLang: true, AbiPath: abiDir, PkgName: "bindings", @@ -912,7 +812,6 @@ func TestProcessAbiDirectory_NonExistentDirectory(t *testing.T) { inputs := Inputs{ ProjectRoot: "/tmp", - ChainFamily: "evm", GoLang: true, AbiPath: "/non/existent/abi", PkgName: "bindings", @@ -930,7 +829,7 @@ func TestProcessAbiDirectory_NonExistentDirectory(t *testing.T) { func TestGenerateBindings_UnconventionalNaming(t *testing.T) { tests := []struct { name string - contractABI string // raw ABI JSON array + contractABI string pkgName string typeName string shouldFail bool @@ -1044,7 +943,7 @@ func TestGenerateBindings_UnconventionalNaming(t *testing.T) { require.NoError(t, err) outFile := filepath.Join(tempDir, "bindings.go") - err = bindings.GenerateBindings("", abiFile, tc.pkgName, tc.typeName, outFile) + err = GenerateBindings("", abiFile, tc.pkgName, tc.typeName, outFile) if tc.shouldFail { require.Error(t, err, "Expected binding generation to fail for %s", tc.name) @@ -1106,20 +1005,18 @@ func TestGenerateBindings_StructNamePrefixStripping(t *testing.T) { require.NoError(t, err) outFile := filepath.Join(tempDir, "bindings.go") - err = bindings.GenerateBindings("", abiFile, "mycontract", "MyContract", outFile) + err = GenerateBindings("", abiFile, "mycontract", "MyContract", outFile) require.NoError(t, err) content, err := os.ReadFile(outFile) require.NoError(t, err) src := string(content) - // Struct declarations should have the prefix stripped. assert.Contains(t, src, "type DONInfo struct") assert.Contains(t, src, "type CapabilityConfiguration struct") assert.NotContains(t, src, "type MyContractDONInfo struct") assert.NotContains(t, src, "type MyContractCapabilityConfiguration struct") - // Field type references inside structs should also be stripped. assert.Contains(t, src, "[]CapabilityConfiguration") assert.NotContains(t, src, "[]MyContractCapabilityConfiguration") } diff --git a/cmd/generate-bindings/bindings/gen.go b/cmd/generate-bindings/evm/gen.go similarity index 67% rename from cmd/generate-bindings/bindings/gen.go rename to cmd/generate-bindings/evm/gen.go index febfa9dc..34410db7 100644 --- a/cmd/generate-bindings/bindings/gen.go +++ b/cmd/generate-bindings/evm/gen.go @@ -1,2 +1,2 @@ //go:generate go run ./testdata/gen -package bindings +package evm diff --git a/cmd/generate-bindings/evm/gen_test.go b/cmd/generate-bindings/evm/gen_test.go new file mode 100644 index 00000000..a2fa4688 --- /dev/null +++ b/cmd/generate-bindings/evm/gen_test.go @@ -0,0 +1,19 @@ +package evm_test + +import ( + "testing" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" +) + +func TestGenerateBindings(t *testing.T) { + if err := evm.GenerateBindings( + "./testdata/DataStorage_combined.json", + "", + "bindings", + "", + "./testdata/bindings.go", + ); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/generate-bindings/bindings/mockcontract.go.tpl b/cmd/generate-bindings/evm/mockcontract.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/mockcontract.go.tpl rename to cmd/generate-bindings/evm/mockcontract.go.tpl diff --git a/cmd/generate-bindings/bindings/mockcontract.ts.tpl b/cmd/generate-bindings/evm/mockcontract.ts.tpl similarity index 100% rename from cmd/generate-bindings/bindings/mockcontract.ts.tpl rename to cmd/generate-bindings/evm/mockcontract.ts.tpl diff --git a/cmd/generate-bindings/bindings/sourcecre.go.tpl b/cmd/generate-bindings/evm/sourcecre.go.tpl similarity index 100% rename from cmd/generate-bindings/bindings/sourcecre.go.tpl rename to cmd/generate-bindings/evm/sourcecre.go.tpl diff --git a/cmd/generate-bindings/bindings/sourcecre.ts.tpl b/cmd/generate-bindings/evm/sourcecre.ts.tpl similarity index 100% rename from cmd/generate-bindings/bindings/sourcecre.ts.tpl rename to cmd/generate-bindings/evm/sourcecre.ts.tpl diff --git a/cmd/generate-bindings/bindings/testdata/DataStorage.sol b/cmd/generate-bindings/evm/testdata/DataStorage.sol similarity index 100% rename from cmd/generate-bindings/bindings/testdata/DataStorage.sol rename to cmd/generate-bindings/evm/testdata/DataStorage.sol diff --git a/cmd/generate-bindings/bindings/testdata/DataStorage_combined.json b/cmd/generate-bindings/evm/testdata/DataStorage_combined.json similarity index 100% rename from cmd/generate-bindings/bindings/testdata/DataStorage_combined.json rename to cmd/generate-bindings/evm/testdata/DataStorage_combined.json diff --git a/cmd/generate-bindings/bindings/testdata/bindings.go b/cmd/generate-bindings/evm/testdata/bindings.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/bindings.go rename to cmd/generate-bindings/evm/testdata/bindings.go diff --git a/cmd/generate-bindings/bindings/testdata/bindings_mock.go b/cmd/generate-bindings/evm/testdata/bindings_mock.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/bindings_mock.go rename to cmd/generate-bindings/evm/testdata/bindings_mock.go diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract.sol b/cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract.sol similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract.sol rename to cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract.sol diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract_combined.json b/cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract_combined.json similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/EmptyContract_combined.json rename to cmd/generate-bindings/evm/testdata/emptybindings/EmptyContract_combined.json diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings.go b/cmd/generate-bindings/evm/testdata/emptybindings/emptybindings.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings.go rename to cmd/generate-bindings/evm/testdata/emptybindings/emptybindings.go diff --git a/cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings_mock.go b/cmd/generate-bindings/evm/testdata/emptybindings/emptybindings_mock.go similarity index 100% rename from cmd/generate-bindings/bindings/testdata/emptybindings/emptybindings_mock.go rename to cmd/generate-bindings/evm/testdata/emptybindings/emptybindings_mock.go diff --git a/cmd/generate-bindings/bindings/testdata/gen/main.go b/cmd/generate-bindings/evm/testdata/gen/main.go similarity index 70% rename from cmd/generate-bindings/bindings/testdata/gen/main.go rename to cmd/generate-bindings/evm/testdata/gen/main.go index 2eda5a71..44836dd4 100644 --- a/cmd/generate-bindings/bindings/testdata/gen/main.go +++ b/cmd/generate-bindings/evm/testdata/gen/main.go @@ -1,11 +1,11 @@ package main import ( - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" ) func main() { - if err := bindings.GenerateBindings( + if err := evm.GenerateBindings( "./testdata/DataStorage_combined.json", "", "bindings", @@ -15,7 +15,7 @@ func main() { panic(err) } - if err := bindings.GenerateBindings( + if err := evm.GenerateBindings( "./testdata/emptybindings/EmptyContract_combined.json", "", "emptybindings", diff --git a/cmd/generate-bindings/generate-bindings.go b/cmd/generate-bindings/generate-bindings.go index 63691e80..0411122f 100644 --- a/cmd/generate-bindings/generate-bindings.go +++ b/cmd/generate-bindings/generate-bindings.go @@ -1,498 +1,22 @@ package generatebindings import ( - "fmt" - "os" - "os/exec" - "path/filepath" - "sort" - "strings" - - "github.com/rs/zerolog" "github.com/spf13/cobra" - "github.com/spf13/viper" - "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/bindings" - "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/evm" + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana" "github.com/smartcontractkit/cre-cli/internal/runtime" - "github.com/smartcontractkit/cre-cli/internal/ui" - "github.com/smartcontractkit/cre-cli/internal/validation" ) -type Inputs struct { - ProjectRoot string `validate:"required,dir" cli:"--project-root"` - ChainFamily string `validate:"required,oneof=evm" cli:"--chain-family"` - GoLang bool - TypeScript bool - AbiPath string `validate:"required,path_read" cli:"--abi"` - PkgName string `validate:"required" cli:"--pkg"` - GoOutPath string // contracts/{chain}/src/generated — set when GoLang is true - TSOutPath string // contracts/{chain}/ts/generated — set when TypeScript is true -} - func New(runtimeContext *runtime.Context) *cobra.Command { generateBindingsCmd := &cobra.Command{ - Use: "generate-bindings ", - Short: "Generate bindings from contract ABI", - Long: `This command generates bindings from contract ABI files. -Supports EVM chain family with Go and TypeScript languages. -The target language is auto-detected from project files, or can be -specified explicitly with --language. -Each contract gets its own package subdirectory to avoid naming conflicts. -For example, IERC20.abi generates bindings in generated/ierc20/ package. - -Both raw ABI files (*.abi) and JSON artifact files (*.json) are supported. -For JSON files the ABI is read from the top-level "abi" field.`, - Example: " cre generate-bindings evm", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - handler := newHandler(runtimeContext) - - inputs, err := handler.ResolveInputs(args, runtimeContext.Viper) - if err != nil { - return err - } - err = handler.ValidateInputs(inputs) - if err != nil { - return err - } - return handler.Execute(inputs) - }, + Use: "generate-bindings", + Short: "Generate bindings for contracts", + Long: `The generate-bindings command allows you to generate bindings for contracts.`, } - generateBindingsCmd.Flags().StringP("project-root", "p", "", "Path to project root directory (defaults to current directory)") - generateBindingsCmd.Flags().StringP("language", "l", "", "Target language: go, typescript (auto-detected from project files when omitted)") - generateBindingsCmd.Flags().StringP("abi", "a", "", "Path to ABI directory (defaults to contracts/{chain-family}/src/abi/). Supports *.abi and *.json files") - generateBindingsCmd.Flags().StringP("pkg", "k", "bindings", "Base package name (each contract gets its own subdirectory)") + generateBindingsCmd.AddCommand(evm.New(runtimeContext)) + generateBindingsCmd.AddCommand(solana.New(runtimeContext)) return generateBindingsCmd } - -type handler struct { - log *zerolog.Logger - validated bool -} - -func newHandler(ctx *runtime.Context) *handler { - return &handler{ - log: ctx.Logger, - validated: false, - } -} - -func detectLanguages(projectRoot string) (goLang, typescript bool) { - _ = filepath.WalkDir(projectRoot, func(path string, d os.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - // Skip node_modules and other dependency directories - if d.Name() == "node_modules" || d.Name() == ".git" { - return filepath.SkipDir - } - return nil - } - base := filepath.Base(path) - if strings.HasSuffix(base, ".go") { - goLang = true - } - if strings.HasSuffix(base, ".ts") && !strings.HasSuffix(base, ".d.ts") { - typescript = true - } - return nil - }) - return goLang, typescript -} - -func (h *handler) ResolveInputs(args []string, v *viper.Viper) (Inputs, error) { - // Get current working directory as default project root - currentDir, err := os.Getwd() - if err != nil { - return Inputs{}, fmt.Errorf("failed to get current working directory: %w", err) - } - - // Resolve project root with fallback to current directory - projectRoot := v.GetString("project-root") - if projectRoot == "" { - projectRoot = currentDir - } - - contractsPath := filepath.Join(projectRoot, "contracts") - if _, err := os.Stat(contractsPath); err != nil { - return Inputs{}, fmt.Errorf("contracts folder not found in project root: %s", contractsPath) - } - - // Chain family is now a positional argument - chainFamily := args[0] - - // Resolve languages: --language flag takes precedence, else auto-detect - var goLang, typescript bool - langFlag := strings.ToLower(strings.TrimSpace(v.GetString("language"))) - switch langFlag { - case "": - goLang, typescript = detectLanguages(projectRoot) - if !goLang && !typescript { - return Inputs{}, fmt.Errorf("no target language detected (use --language go or --language typescript, or ensure project contains .go or .ts files)") - } - case constants.WorkflowLanguageGolang: - goLang = true - case constants.WorkflowLanguageTypeScript: - typescript = true - default: - return Inputs{}, fmt.Errorf("unsupported language %q (supported: go, typescript)", langFlag) - } - - // Unified ABI path for both languages: contracts/{chain}/src/abi - abiPath := v.GetString("abi") - if abiPath == "" { - abiPath = filepath.Join(projectRoot, "contracts", chainFamily, "src", "abi") - } - - // Package name defaults are handled by StringP - pkgName := v.GetString("pkg") - - // Separate output paths: Go uses src/, TS uses ts/ (typescript convention) - var goOutPath, tsOutPath string - if goLang { - goOutPath = filepath.Join(projectRoot, "contracts", chainFamily, "src", "generated") - } - if typescript { - tsOutPath = filepath.Join(projectRoot, "contracts", chainFamily, "ts", "generated") - } - - return Inputs{ - ProjectRoot: projectRoot, - ChainFamily: chainFamily, - GoLang: goLang, - TypeScript: typescript, - AbiPath: abiPath, - PkgName: pkgName, - GoOutPath: goOutPath, - TSOutPath: tsOutPath, - }, nil -} - -// findAbiFiles returns all supported ABI files (*.abi and *.json) found in dir. -func findAbiFiles(dir string) ([]string, error) { - abiFiles, err := filepath.Glob(filepath.Join(dir, "*.abi")) - if err != nil { - return nil, err - } - jsonFiles, err := filepath.Glob(filepath.Join(dir, "*.json")) - if err != nil { - return nil, err - } - all := append(abiFiles, jsonFiles...) - sort.Strings(all) - return all, nil -} - -// contractNameFromFile returns the contract name by stripping the .abi or .json -// extension from the base filename. -func contractNameFromFile(path string) string { - name := filepath.Base(path) - ext := filepath.Ext(name) - if ext != "" { - name = name[:len(name)-len(ext)] - } - return name -} - -func (h *handler) ValidateInputs(inputs Inputs) error { - validate, err := validation.NewValidator() - if err != nil { - return fmt.Errorf("failed to initialize validator: %w", err) - } - - if err = validate.Struct(inputs); err != nil { - return validate.ParseValidationErrors(err) - } - - // Additional validation for ABI path - if _, err := os.Stat(inputs.AbiPath); err != nil { - if os.IsNotExist(err) { - return fmt.Errorf("ABI path does not exist: %s", inputs.AbiPath) - } - return fmt.Errorf("failed to access ABI path: %w", err) - } - - // Validate that if AbiPath is a directory, it contains ABI files (*.abi or *.json) - if info, err := os.Stat(inputs.AbiPath); err == nil && info.IsDir() { - files, err := findAbiFiles(inputs.AbiPath) - if err != nil { - return fmt.Errorf("failed to check for ABI files in directory: %w", err) - } - if len(files) == 0 { - return fmt.Errorf("no *.abi or *.json files found in directory: %s", inputs.AbiPath) - } - } - - // Ensure at least one output path is set for the active language(s) - if inputs.GoLang && inputs.GoOutPath == "" { - return fmt.Errorf("go output path is required when language is go") - } - if inputs.TypeScript && inputs.TSOutPath == "" { - return fmt.Errorf("typescript output path is required when language is typescript") - } - - h.validated = true - return nil -} - -// contractNameToPackage converts contract names to valid Go package names -// Examples: IERC20 -> ierc20, ReserveManager -> reserve_manager, IReserveManager -> ireserve_manager -func contractNameToPackage(contractName string) string { - if contractName == "" { - return "" - } - - var result []rune - runes := []rune(contractName) - - for i, r := range runes { - // Convert to lowercase - if r >= 'A' && r <= 'Z' { - lower := r - 'A' + 'a' - - // Add underscore before uppercase letters, but not: - // - At the beginning (i == 0) - // - If the previous character was also uppercase and this is followed by lowercase (e.g., "ERC" in "ERC20") - // - If this is part of a sequence of uppercase letters at the beginning (e.g., "IERC20" -> "ierc20") - if i > 0 { - prevIsUpper := runes[i-1] >= 'A' && runes[i-1] <= 'Z' - nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z' - - // Add underscore if: - // - Previous char was lowercase (CamelCase boundary) - // - Previous char was uppercase but this char is followed by lowercase (end of acronym) - if !prevIsUpper || (prevIsUpper && nextIsLower && i > 1) { - result = append(result, '_') - } - } - - result = append(result, lower) - } else { - result = append(result, r) - } - } - - return string(result) -} - -func (h *handler) processAbiDirectory(inputs Inputs) error { - files, err := findAbiFiles(inputs.AbiPath) - if err != nil { - return fmt.Errorf("failed to find ABI files: %w", err) - } - - if len(files) == 0 { - return fmt.Errorf("no *.abi or *.json files found in directory: %s", inputs.AbiPath) - } - - // Detect duplicate contract names across extensions (e.g. Foo.abi and Foo.json) - contractNames := make(map[string]string) // contract name -> originating file - for _, f := range files { - name := contractNameFromFile(f) - if prev, exists := contractNames[name]; exists { - return fmt.Errorf("duplicate contract name %q: found in both %s and %s", name, filepath.Base(prev), filepath.Base(f)) - } - contractNames[name] = f - } - - if inputs.GoLang { - packageNames := make(map[string]bool) - for _, abiFile := range files { - contractName := contractNameFromFile(abiFile) - packageName := contractNameToPackage(contractName) - if _, exists := packageNames[packageName]; exists { - return fmt.Errorf("package name collision: multiple contracts would generate the same package name '%s' (contracts are converted to snake_case for package names). Please rename one of your contract files to avoid this conflict", packageName) - } - packageNames[packageName] = true - } - } - - // Track generated files for TypeScript barrel export - var generatedContracts []string - - // Process each ABI file - for _, abiFile := range files { - contractName := contractNameFromFile(abiFile) - - if inputs.TypeScript { - outputFile := filepath.Join(inputs.TSOutPath, contractName+".ts") - ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) - - err = bindings.GenerateBindingsTS( - abiFile, - contractName, - outputFile, - ) - if err != nil { - return fmt.Errorf("failed to generate TypeScript bindings for %s: %w", contractName, err) - } - generatedContracts = append(generatedContracts, contractName) - } - - if inputs.GoLang { - packageName := contractNameToPackage(contractName) - - contractOutDir := filepath.Join(inputs.GoOutPath, packageName) - if err := os.MkdirAll(contractOutDir, 0o755); err != nil { - return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) - } - - outputFile := filepath.Join(contractOutDir, contractName+".go") - ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) - - err = bindings.GenerateBindings( - "", - abiFile, - packageName, - contractName, - outputFile, - ) - if err != nil { - return fmt.Errorf("failed to generate bindings for %s: %w", contractName, err) - } - } - } - - // Generate barrel index.ts for TypeScript - if inputs.TypeScript && len(generatedContracts) > 0 { - indexPath := filepath.Join(inputs.TSOutPath, "index.ts") - var indexContent string - indexContent += "// Code generated — DO NOT EDIT.\n" - for _, name := range generatedContracts { - indexContent += fmt.Sprintf("export * from './%s'\n", name) - indexContent += fmt.Sprintf("export * from './%s_mock'\n", name) - } - if err := os.WriteFile(indexPath, []byte(indexContent), 0o600); err != nil { - return fmt.Errorf("failed to write index.ts: %w", err) - } - } - - return nil -} - -func (h *handler) processSingleAbi(inputs Inputs) error { - contractName := contractNameFromFile(inputs.AbiPath) - - if inputs.TypeScript { - outputFile := filepath.Join(inputs.TSOutPath, contractName+".ts") - ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) - - if err := bindings.GenerateBindingsTS( - inputs.AbiPath, - contractName, - outputFile, - ); err != nil { - return err - } - } - - if inputs.GoLang { - packageName := contractNameToPackage(contractName) - - contractOutDir := filepath.Join(inputs.GoOutPath, packageName) - if err := os.MkdirAll(contractOutDir, 0o755); err != nil { - return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) - } - - outputFile := filepath.Join(contractOutDir, contractName+".go") - ui.Dim(fmt.Sprintf("Processing: %s -> %s", contractName, outputFile)) - - if err := bindings.GenerateBindings( - "", - inputs.AbiPath, - packageName, - contractName, - outputFile, - ); err != nil { - return err - } - } - - return nil -} - -func (h *handler) Execute(inputs Inputs) error { - langs := []string{} - if inputs.GoLang { - langs = append(langs, "go") - } - if inputs.TypeScript { - langs = append(langs, "typescript") - } - ui.Dim(fmt.Sprintf("Project: %s, Chain: %s, Languages: %v", inputs.ProjectRoot, inputs.ChainFamily, langs)) - - // Validate chain family and handle accordingly - switch inputs.ChainFamily { - case "evm": - // Create output directories for active language(s) - if inputs.GoLang { - if err := os.MkdirAll(inputs.GoOutPath, 0o755); err != nil { - return fmt.Errorf("failed to create Go output directory: %w", err) - } - } - if inputs.TypeScript { - if err := os.MkdirAll(inputs.TSOutPath, 0o755); err != nil { - return fmt.Errorf("failed to create TypeScript output directory: %w", err) - } - } - - // Check if ABI path is a directory or file - info, err := os.Stat(inputs.AbiPath) - if err != nil { - return fmt.Errorf("failed to access ABI path: %w", err) - } - - if info.IsDir() { - if err := h.processAbiDirectory(inputs); err != nil { - return err - } - } else { - if err := h.processSingleAbi(inputs); err != nil { - return err - } - } - - if inputs.GoLang { - spinner := ui.NewSpinner() - spinner.Start("Installing dependencies...") - - err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go@"+constants.SdkVersion) - if err != nil { - spinner.Stop() - return err - } - err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm@"+constants.EVMCapabilitiesVersion) - if err != nil { - spinner.Stop() - return err - } - if err = runCommand(inputs.ProjectRoot, "go", "mod", "tidy"); err != nil { - spinner.Stop() - return err - } - - spinner.Stop() - } - - ui.Success("Bindings generated successfully") - return nil - default: - return fmt.Errorf("unsupported chain family: %s", inputs.ChainFamily) - } -} - -func runCommand(dir string, command string, args ...string) error { - cmd := exec.Command(command, args...) - cmd.Dir = dir - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run %s: %w", command, err) - } - - return nil -} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/accounts.go b/cmd/generate-bindings/solana/anchor-go/generator/accounts.go new file mode 100644 index 00000000..ea7d16fc --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/accounts.go @@ -0,0 +1,409 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + "strconv" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/anchor-go/tools" +) + +func (g *Generator) genfile_accounts() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains parsers for the accounts defined in the IDL.") + file.HeaderComment("Code generated by https://github.com/smartcontractkit/cre-cli. DO NOT EDIT.") + + names := []string{} + { + for _, acc := range g.idl.Accounts { + names = append(names, tools.ToCamelUpper(acc.Name)) + } + } + { + code, err := g.gen_accountParser(names) + if err != nil { + return nil, fmt.Errorf("error generating account parser: %w", err) + } + file.Add(code) + } + + return &OutputFile{ + Name: "accounts.go", + File: file, + }, nil +} + +func (g *Generator) gen_accountParser(accountNames []string) (Code, error) { + code := Empty() + { + code.Func().Id("ParseAnyAccount"). + Params(Id("accountData").Index().Byte()). + Params(Any(), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("accountData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek account discriminator: %w"), Err()), + ), + ) + + block.Switch(Id("discriminator")).BlockFunc(func(switchBlock *Group) { + for _, name := range accountNames { + switchBlock.Case(Id(FormatAccountDiscriminatorName(name))).Block( + Id("value").Op(":=").New(Id(name)), + Err().Op(":=").Id("value").Dot("UnmarshalWithDecoder").Call(Id("decoder")), + If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal account as "+name+": %w"), Err()), + ), + ), + Return(Id("value"), Nil()), + ) + } + switchBlock.Default().Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("unknown discriminator: %s"), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + }) + }) + } + { + code.Line().Line() + // for each account, generate a function to parse it: + for _, name := range accountNames { + discriminatorName := FormatAccountDiscriminatorName(name) + + code.Func().Id("ParseAccount_"+name). + Params(Id("accountData").Index().Byte()). + Params(Op("*").Id(name), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("accountData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek discriminator: %w"), Err()), + ), + ) + + block.If(Id("discriminator").Op("!=").Id(discriminatorName)).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("expected discriminator %v, got %s"), Id(discriminatorName), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + + block.Id("acc").Op(":=").New(Id(name)) + block.Err().Op("=").Id("acc").Dot("UnmarshalWithDecoder").Call(Id("decoder")) + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal account of type "+name+": %w"), Err()), + ), + ) + + block.Return(Id("acc"), Nil()) + }) + code.Line().Line() + + // DecodeAccount method for the codec + code.Add(creDecodeAccountFn(name)) + code.Line().Line() + } + } + return code, nil +} + +func (g *Generator) gen_IDLTypeDefTyStruct( + name string, + docs []string, + typ idl.IdlTypeDefTyStruct, + withDiscriminator bool, +) (Code, error) { + st := newStatement() + + exportedAccountName := tools.ToCamelUpper(name) + { + // Declare the struct: + code := Empty() + addComments(code, docs) + code.Type().Id(exportedAccountName).StructFunc(func(fieldsGroup *Group) { + switch fields := typ.Fields.(type) { + case idl.IdlDefinedFieldsNamed: + // Generate unique field names to handle duplicates + uniqueFieldNames := generateUniqueFieldNames(fields) + + for fieldIndex, field := range fields { + + // Add docs for the field: + for docIndex, doc := range field.Docs { + if docIndex == 0 && fieldIndex > 0 { + fieldsGroup.Line() + } + fieldsGroup.Comment(doc) + } + // fieldsGroup.Line() + optionality := IsOption(field.Ty) || IsCOption(field.Ty) + + // TODO: optionality for complex enums is a nil interface. + uniqueFieldName := uniqueFieldNames[field.Name] + fieldsGroup.Add(g.genFieldWithName(field, uniqueFieldName, optionality)). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(field.Ty) { + tagMap["bin"] = "optional" + } + if IsCOption(field.Ty) { + tagMap["bin"] = "coption" + } + // add json tag: use original field name to avoid duplicates + tagMap["json"] = field.Name + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + case idl.IdlDefinedFieldsTuple: + // panic(fmt.Errorf("tuple fields not supported: %s", spew.Sdump(fields))) + for fieldIndex, field := range fields { + + fieldsGroup.Line() + optionality := IsOption(field) || IsCOption(field) + + fieldsGroup.Add(g.genFieldNamed( + FormatTupleItemName(fieldIndex), + field, + optionality, + )). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(field) { + tagMap["bin"] = "optional" + } + if IsCOption(field) { + tagMap["bin"] = "coption" + } + // add json tag: + tagMap["json"] = tools.ToCamelLower(FormatTupleItemName(fieldIndex)) + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + + case nil: + // No fields, just an empty struct. + // TODO: should we panic here? + default: + panic(fmt.Errorf("unknown fields type: %T", typ.Fields)) + } + }) + st.Add(code.Line()) + } + { + // Declare the decoder/encoder methods: + code := Empty() + + { + discriminatorName := FormatAccountDiscriminatorName(exportedAccountName) + + // Declare MarshalWithEncoder: + // TODO: + code.Line().Line().Add( + g.gen_MarshalWithEncoder_struct( + g.idl, + withDiscriminator, + exportedAccountName, + discriminatorName, + typ.Fields, + true, + )) + + // Declare UnmarshalWithDecoder + code.Line().Line().Add( + g.gen_UnmarshalWithDecoder_struct( + g.idl, + withDiscriminator, + exportedAccountName, + discriminatorName, + typ.Fields, + )) + } + st.Add(code.Line().Line()) + } + { + code := Empty() + code.Add(creGenerateCodecEncoderForTypes(exportedAccountName)) + st.Add(code.Line().Line()) + } + { + // Declare the WriteReportFrom methods: + // TODO: should i exclude events here ? currently it does accounts/structs/events + code := Empty() + code.Add(creWriteReportFromStructs(exportedAccountName, g)) + code.Line().Line() + code.Add(creWriteReportFromStructsSlice(exportedAccountName, g)) + st.Add(code.Line().Line()) + } + return st, nil +} + +// generateUniqueFieldNames creates unique Go field names from IDL field names, +// handling cases where multiple IDL fields would map to the same Go field name +func generateUniqueFieldNames(fields []idl.IdlField) map[string]string { + fieldNameMap := make(map[string]string) + usedNames := make(map[string]int) + + for _, field := range fields { + baseName := tools.ToCamelUpper(field.Name) + finalName := baseName + + // Check if this name has been used before + if count, exists := usedNames[baseName]; exists { + // Add a numeric suffix to make it unique + finalName = baseName + fmt.Sprintf("%d", count+1) + usedNames[baseName] = count + 1 + } else { + usedNames[baseName] = 0 + } + + fieldNameMap[field.Name] = finalName + } + + return fieldNameMap +} + +func (g *Generator) genField(field idl.IdlField, pointer bool) Code { + return g.genFieldNamed(field.Name, field.Ty, pointer) +} + +func (g *Generator) genFieldWithName(field idl.IdlField, fieldName string, pointer bool) Code { + return g.genFieldNamed(fieldName, field.Ty, pointer) +} + +func (g *Generator) genFieldNamed(name string, typ idltype.IdlType, pointer bool) Code { + st := newStatement() + st.Id(tools.ToCamelUpper(name)). + Add(func() Code { + if g.isComplexEnum(typ) { + return nil + } + if pointer { + return Op("*") + } + return nil + }()). + Add(genTypeName(typ)) + return st +} + +func genTypeName(idlTypeEnv idltype.IdlType) Code { + st := newStatement() + switch { + case IsIDLTypeKind(idlTypeEnv): + { + str := idlTypeEnv + st.Add(IDLTypeKind_ToTypeDeclCode(str)) + } + case IsOption(idlTypeEnv): + { + opt := idlTypeEnv.(*idltype.Option) + // TODO: optional = pointer? or that's determined upstream? + st.Add(genTypeName(opt.Option)) + } + case IsCOption(idlTypeEnv): + { + copt := idlTypeEnv.(*idltype.COption) + st.Add(genTypeName(copt.COption)) + } + case IsVec(idlTypeEnv): + { + vec := idlTypeEnv.(*idltype.Vec) + st.Index().Add(genTypeName(vec.Vec)) + } + case IsDefined(idlTypeEnv): + { + def := idlTypeEnv.(*idltype.Defined) + st.Add(Id(tools.ToCamelUpper(def.Name))) + } + case IsArray(idlTypeEnv): + { + arr := idlTypeEnv.(*idltype.Array) + { + switch size := arr.Size.(type) { + case *idltype.IdlArrayLenGeneric: + panic(fmt.Sprintf("generic array length not supported: %s", spew.Sdump(size))) + case *idltype.IdlArrayLenValue: + if size.Value < 0 { + panic(fmt.Sprintf("expected positive integer, got %d", size.Value)) + } + st.Index(Id(strconv.Itoa(int(size.Value)))).Add(genTypeName(arr.Type)) + } + } + } + default: + panic("unhandled type: " + spew.Sdump(idlTypeEnv)) + } + return st +} + +func IDLTypeKind_ToTypeDeclCode(ts idltype.IdlType) *Statement { + stat := newStatement() + switch ts.(type) { + case *idltype.Bool: + stat.Bool() + case *idltype.U8: + stat.Uint8() + case *idltype.I8: + stat.Int8() + case *idltype.U16: + // TODO: some types have their implementation in github.com/gagliardetto/binary + stat.Uint16() + case *idltype.I16: + stat.Int16() + case *idltype.U32: + stat.Uint32() + case *idltype.I32: + stat.Int32() + case *idltype.F32: + stat.Float32() + case *idltype.U64: + stat.Uint64() + case *idltype.I64: + stat.Int64() + case *idltype.F64: + stat.Float64() + case *idltype.U128: + stat.Qual(PkgBinary, "Uint128") + case *idltype.I128: + stat.Qual(PkgBinary, "Int128") + case *idltype.U256: + stat.Index(Lit(32)).Byte() + case *idltype.I256: + stat.Index(Lit(32)).Byte() + case *idltype.Bytes: + stat.Index().Byte() + case *idltype.String: + stat.String() + case *idltype.Pubkey: + stat.Qual(PkgSolanaGo, "PublicKey") + + default: + panic(fmt.Sprintf("unhandled type: %s", spew.Sdump(ts))) + } + + return stat +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go b/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go new file mode 100644 index 00000000..477cefa8 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go @@ -0,0 +1,140 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// collidingNamedFields is an IDL shape where two distinct field names normalize to the +// same Go identifier via tools.ToCamelUpper (foo_bar and fooBar -> FooBar). Struct +// generation deconflicts these as FooBar and FooBar1; marshal/unmarshal must use the same names. +func collidingNamedFields() idl.IdlDefinedFieldsNamed { + return idl.IdlDefinedFieldsNamed{ + {Name: "foo_bar", Ty: &idltype.U8{}}, + {Name: "fooBar", Ty: &idltype.U8{}}, + } +} + +func TestGenerateUniqueFieldNames_collidingIDLNames(t *testing.T) { + fields := collidingNamedFields() + m := generateUniqueFieldNames(fields) + require.Len(t, m, 2) + assert.Equal(t, "FooBar", m["foo_bar"]) + assert.Equal(t, "FooBar1", m["fooBar"]) +} + +// TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames documents the regression where +// gen_MarshalWithEncoder_struct / gen_UnmarshalWithDecoder_struct used tools.ToCamelUpper(field.Name) +// for accessors instead of generateUniqueFieldNames: both fields targeted obj.FooBar, so one +// value was serialized twice and the FooBar1 sibling was never written or read. +// +// The expected assertions describe the correct fixed behavior; they fail until uniquified names +// are threaded through marshal/unmarshal generation. +func TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames(t *testing.T) { + idlMinimal := &idl.Idl{} + fields := collidingNamedFields() + receiver := "CollideAccount" + + g := &Generator{ + idl: idlMinimal, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } + + marshalCode := g.gen_MarshalWithEncoder_struct( + idlMinimal, + false, + receiver, + "", + fields, + true, + ) + unmarshalCode := g.gen_UnmarshalWithDecoder_struct( + idlMinimal, + false, + receiver, + "", + fields, + ) + + f := jen.NewFile("fixture") + f.Add(marshalCode) + f.Add(unmarshalCode) + src := f.GoString() + + // Correct codegen must reference both uniquified struct fields. + assert.Contains(t, src, "obj.FooBar1", "marshal/unmarshal must access the deconflicted FooBar1 field") + + // Buggy codegen encodes/decodes the same field twice; reject duplicate bare obj.FooBar + // Encode/Decode when a second distinct IDL field exists. + encodeFooBar := strings.Count(src, "Encode(obj.FooBar)") + decodeFooBar := strings.Count(src, "Decode(&obj.FooBar)") + assert.Equal(t, 1, encodeFooBar, "each IDL field must map to a single Encode(obj.); duplicate Encode(obj.FooBar) indicates silent corruption") + assert.Equal(t, 1, decodeFooBar, "each IDL field must map to a single Decode(&obj.); duplicate Decode(&obj.FooBar) indicates silent corruption") + + assert.Contains(t, src, "Encode(obj.FooBar1)") + assert.Contains(t, src, "Decode(&obj.FooBar1)") +} + +func TestGenerateUniqueParamNames_collidingIDLNames(t *testing.T) { + fields := collidingNamedFields() + m := generateUniqueParamNames(fields) + require.Len(t, m, 2) + assert.NotEqual(t, m[fields[0].Name], m[fields[1].Name]) + b0 := formatParamName(fields[0].Name) + b1 := formatParamName(fields[1].Name) + if b0 == b1 { + assert.Equal(t, b0+"1", m[fields[1].Name]) + } +} + +func TestGenInstructionType_uniquifiesArgFieldsAndDecode(t *testing.T) { + ins := idl.IdlInstruction{ + Name: "do_test", + Args: []idl.IdlField(collidingNamedFields()), + Accounts: []idl.IdlInstructionAccountItem{}, + } + g := &Generator{idl: &idl.Idl{}, options: &GeneratorOptions{Package: "test"}} + code, err := g.gen_instructionType(ins) + require.NoError(t, err) + + f := jen.NewFile("test") + f.Add(code) + src := f.GoString() + + assert.Contains(t, src, "FooBar1") + assert.Equal(t, 1, strings.Count(src, "Decode(&obj.FooBar)")) + assert.Contains(t, src, "Decode(&obj.FooBar1)") +} + +func TestGenInstructions_builderEncodesEachArgOnce(t *testing.T) { + fields := collidingNamedFields() + paramNames := generateUniqueParamNames(fields) + + idlData := &idl.Idl{ + Instructions: []idl.IdlInstruction{ + { + Name: "do_test", + Args: []idl.IdlField(fields), + Accounts: []idl.IdlInstructionAccountItem{}, + }, + }, + } + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + out, err := gen.gen_instructions() + require.NoError(t, err) + s := out.File.GoString() + + p0 := paramNames[fields[0].Name] + p1 := paramNames[fields[1].Name] + assert.Equal(t, 1, strings.Count(s, "Encode("+p0+")"), "each arg must be encoded exactly once") + assert.Equal(t, 1, strings.Count(s, "Encode("+p1+")")) + assert.NotEqual(t, p0, p1) +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go new file mode 100644 index 00000000..62e66642 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go @@ -0,0 +1,48 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" +) + +func (g *Generator) isComplexEnum(envel idltype.IdlType) bool { + switch vv := envel.(type) { + case *idltype.Defined: + _, ok := g.complexEnumRegistry[vv.Name] + return ok + } + return false +} + +func (g *Generator) registerComplexEnumType(name string) { + if g.complexEnumRegistry == nil { + g.complexEnumRegistry = make(map[string]struct{}) + } + g.complexEnumRegistry[name] = struct{}{} +} + +func (g *Generator) isOptionalComplexEnum(ty idltype.IdlType) bool { + switch v := ty.(type) { + case *idltype.Option: + return g.isComplexEnum(v.Option) + case *idltype.COption: + return g.isComplexEnum(v.COption) + } + return false +} + +func (g *Generator) registerComplexEnums(def idl.IdlTypeDef) { + switch vv := def.Ty.(type) { + case *idl.IdlTypeDefTyEnum: + enumTypeName := def.Name + if !vv.IsAllSimple() { + g.registerComplexEnumType(enumTypeName) + } + case idl.IdlTypeDefTyEnum: + enumTypeName := def.Name + if !vv.IsAllSimple() { + g.registerComplexEnumType(enumTypeName) + } + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex_enum_encode_test.go b/cmd/generate-bindings/solana/anchor-go/generator/complex_enum_encode_test.go new file mode 100644 index 00000000..b92501e1 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex_enum_encode_test.go @@ -0,0 +1,86 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// complexEnumIDL returns a minimal IDL containing a two-variant complex enum +// ("MyAction") suitable for exercising gen_complexEnum codegen. +func complexEnumIDL() *idl.Idl { + enumType := &idl.IdlTypeDefTyEnum{ + Variants: idl.VariantSlice{ + { + Name: "Transfer", + Fields: idl.Some[idl.IdlDefinedFields](idl.IdlDefinedFieldsNamed{ + {Name: "amount", Ty: &idltype.U64{}}, + }), + }, + { + Name: "Burn", + Fields: idl.Some[idl.IdlDefinedFields](idl.IdlDefinedFieldsNamed{ + {Name: "quantity", Ty: &idltype.U32{}}, + }), + }, + }, + } + return &idl.Idl{ + Types: idl.IdTypeDef_slice{ + { + Name: "MyAction", + Ty: enumType, + }, + }, + } +} + +func genComplexEnumSource(t *testing.T) string { + t.Helper() + idlData := complexEnumIDL() + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + enumType := idlData.Types[0].Ty.(*idl.IdlTypeDefTyEnum) + code, err := gen.gen_complexEnum("MyAction", nil, *enumType) + require.NoError(t, err) + + f := jen.NewFile("test") + f.Add(code) + return f.GoString() +} + +func TestComplexEnumEncode_nilInterfaceReturnsError(t *testing.T) { + src := genComplexEnumSource(t) + + assert.Contains(t, src, "case nil:", "encoder must reject nil interface values") + assert.Contains(t, src, `cannot encode nil value`, "nil case must return a descriptive error") +} + +func TestComplexEnumEncode_defaultArmReturnsError(t *testing.T) { + src := genComplexEnumSource(t) + + assert.Contains(t, src, "default:", "encoder must reject unknown variant types") + assert.Contains(t, src, `unknown variant type`, "default case must return a descriptive error") +} + +func TestComplexEnumEncode_nilPointerGuardPerVariant(t *testing.T) { + src := genComplexEnumSource(t) + + assert.Contains(t, src, "realvalue == nil", "each variant case must guard against typed nil pointers") + assert.Contains(t, src, `cannot encode nil *MyAction_Transfer`, + "Transfer variant must have a nil-pointer error message") + assert.Contains(t, src, `cannot encode nil *MyAction_Burn`, + "Burn variant must have a nil-pointer error message") + + nilGuards := strings.Count(src, "realvalue == nil") + assert.Equal(t, 2, nilGuards, "must have exactly one nil-pointer guard per variant") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go b/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go new file mode 100644 index 00000000..c8049cf1 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go @@ -0,0 +1,114 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" +) + +// complexEnumGuard mirrors the condition used in gen_marshal_DefinedFieldsNamed +// and gen_unmarshal_DefinedFieldsNamed to decide whether a field is routed to +// the specialized enum encoder/parser or falls through to the generic +// Encode/Decode path. +func complexEnumGuard(g *Generator, ty idltype.IdlType) bool { + return g.isComplexEnum(ty) || + (IsArray(ty) && g.isComplexEnum(ty.(*idltype.Array).Type)) || + (IsVec(ty) && g.isComplexEnum(ty.(*idltype.Vec).Vec)) || + g.isOptionalComplexEnum(ty) +} + +func newTestGenerator() *Generator { + return &Generator{ + idl: &idl.Idl{}, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } +} + +func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) { + const name = "Outcome" + g := newTestGenerator() + g.registerComplexEnumType(name) + + defined := &idltype.Defined{Name: name} + + assert.True(t, complexEnumGuard(g, defined), "bare Defined") + assert.True(t, complexEnumGuard(g, &idltype.Option{Option: defined}), "Option") + assert.True(t, complexEnumGuard(g, &idltype.COption{COption: defined}), "COption") +} + +// TestComplexEnumGuard_rejectsNonComplexOptionals ensures the guard does NOT +// fire for Option/COption wrapping a non-complex Defined or a primitive. +// A false positive here would cause the switch to enter the Option/COption case +// where .Option.(*idltype.Defined) would panic on a non-Defined inner type. +func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) { + const complexName = "Outcome" + g := newTestGenerator() + g.registerComplexEnumType(complexName) + + nonComplex := &idltype.Defined{Name: "PlainStruct"} + + assert.False(t, complexEnumGuard(g, &idltype.Option{Option: nonComplex}), + "Option must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(g, &idltype.COption{COption: nonComplex}), + "COption must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(g, &idltype.Option{Option: &idltype.U64{}}), + "Option must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(g, &idltype.COption{COption: &idltype.U8{}}), + "COption must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(g, &idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}), + "Option> — nested containers not supported, must not match") +} + +// TestComplexEnumCodegen_optionalComplexEnum runs the actual marshal/unmarshal +// generator with Option and COption fields and +// verifies the generated Go source uses the specialized enum encoder/parser +// instead of the generic Encode/Decode. +func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) { + const enumName = "Outcome" + g := newTestGenerator() + g.registerComplexEnumType(enumName) + + fields := idl.IdlDefinedFieldsNamed{ + {Name: "id", Ty: &idltype.U64{}}, + {Name: "verdict", Ty: &idltype.Option{Option: &idltype.Defined{Name: enumName}}}, + {Name: "alt_verdict", Ty: &idltype.COption{COption: &idltype.Defined{Name: enumName}}}, + {Name: "checksum", Ty: &idltype.U64{}}, + } + + marshalCode := g.gen_MarshalWithEncoder_struct( + &idl.Idl{}, false, "Report", "", fields, true, + ) + unmarshalCode := g.gen_UnmarshalWithDecoder_struct( + &idl.Idl{}, false, "Report", "", fields, + ) + + f := jen.NewFile("fixture") + f.Add(marshalCode) + f.Add(unmarshalCode) + src := f.GoString() + + // Specialized enum encoder/parser must appear. + assert.Contains(t, src, "EncodeOutcome", + "Option/COption fields must call the specialized enum encoder") + assert.Contains(t, src, "DecodeOutcome", + "Option/COption fields must call the specialized enum parser") + + // Option flags must still be written/read. + assert.Contains(t, src, "WriteOption") + assert.Contains(t, src, "WriteCOption") + assert.Contains(t, src, "ReadOption") + assert.Contains(t, src, "ReadCOption") + + // Only the two plain U64 fields (Id, Checksum) should use the generic + // encoder/decoder. If the enum fields also fall through, the count is 4. + assert.Equal(t, 2, strings.Count(src, ".Encode("), + "generic Encode must only be used for non-enum fields") + assert.Equal(t, 2, strings.Count(src, ".Decode("), + "generic Decode must only be used for non-enum fields") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/constants.go b/cmd/generate-bindings/solana/anchor-go/generator/constants.go new file mode 100644 index 00000000..ad2d6ad3 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/constants.go @@ -0,0 +1,408 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "encoding/json" + "fmt" + "math/big" + "strconv" + "strings" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/solana-go" +) + +func (g *Generator) gen_constants() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains constants.") + { + if len(g.idl.Constants) > 0 { + file.Comment("Constants defined in the IDL:") + file.Line() + } + code := Empty() + for coi, co := range g.idl.Constants { + if co.Name == "" { + continue // Skip constants without a name. + } + if len(co.Value) == 0 { + continue // Skip constants without a value. + } + + addComments(code, co.Docs) + + switch ty := co.Ty.(type) { + case *idltype.String: + _ = ty + // "value":"\"organism\"" + v, err := strconv.Unquote(co.Value) + if err != nil { + return nil, fmt.Errorf("failed to unquote string constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(v) + code.Line() + case *idltype.Bytes: + _ = ty + // "value":"[102, 101, 101, 95, 118, 97, 117, 108, 116]" + var b []byte + err := json.Unmarshal([]byte(co.Value), &b) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal bytes constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Var().Id(co.Name).Op("=").Index().Byte().Op("{").ListFunc(func(byteGroup *Group) { + for _, byteVal := range b[:] { + byteGroup.Lit(int(byteVal)) + } + }).Op("}") + code.Line() + case *idltype.Pubkey: + _ = ty + // "value":"MiNTdCbWwAu3boEeTL6HzS5VgLb89mhf8VhMLtMrmWL" + pk, err := solana.PublicKeyFromBase58(co.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse pubkey constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Var().Id(co.Name).Op("=").Qual(PkgSolanaGo, "MustPublicKeyFromBase58").Call(Lit(pk.String())) + code.Line() + case *idltype.Bool: + _ = ty + // "value":"true" + v, err := strconv.ParseBool(co.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse bool constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Var().Id(co.Name).Op("=").Lit(v) + code.Line() + case *idltype.U8: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse u8 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint8(v)) + code.Line() + case *idltype.I8: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse i8 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int8(v)) + code.Line() + case *idltype.U16: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse u16 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint16(v)) + code.Line() + case *idltype.I16: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse i16 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int16(v)) + code.Line() + case *idltype.U32: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse u32 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint32(v)) + code.Line() + case *idltype.I32: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse i32 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int32(v)) + code.Line() + case *idltype.U64: + _ = ty + // "value":"42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse u64 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint64(v)) + code.Line() + case *idltype.I64: + _ = ty + // "value":"-42" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse i64 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int64(v)) + code.Line() + case *idltype.U128: + _ = ty + // "value":"100_000_000" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + bigInt := new(big.Int) + _, ok := bigInt.SetString(cleanValue, 10) + if !ok { + return nil, fmt.Errorf("failed to parse u128 constants[%d] %s: invalid format", coi, spew.Sdump(co)) + } + // Generate code that creates a big.Int from string + code.Var().Id(co.Name).Op("=").Func().Params().Op("*").Qual("math/big", "Int").Block( + Id("val").Op(",").Id("ok").Op(":=").New(Qual("math/big", "Int")).Dot("SetString").Call(Lit(cleanValue), Lit(10)), + If(Op("!").Id("ok")).Block( + Panic(Lit(fmt.Sprintf("invalid u128 constant %s", co.Name))), + ), + Return(Id("val")), + ).Call() + code.Line() + case *idltype.I128: + _ = ty + // "value":"-100_000_000" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + bigInt := new(big.Int) + _, ok := bigInt.SetString(cleanValue, 10) + if !ok { + return nil, fmt.Errorf("failed to parse i128 constants[%d] %s: invalid format", coi, spew.Sdump(co)) + } + // Generate code that creates a big.Int from string + code.Var().Id(co.Name).Op("=").Func().Params().Op("*").Qual("math/big", "Int").Block( + Id("val").Op(",").Id("ok").Op(":=").New(Qual("math/big", "Int")).Dot("SetString").Call(Lit(cleanValue), Lit(10)), + If(Op("!").Id("ok")).Block( + Panic(Lit(fmt.Sprintf("invalid i128 constant %s", co.Name))), + ), + Return(Id("val")), + ).Call() + code.Line() + case *idltype.U256: + _ = ty + cleanValue := strings.ReplaceAll(co.Value, "_", "") + bigInt := new(big.Int) + _, ok := bigInt.SetString(cleanValue, 10) + if !ok { + return nil, fmt.Errorf("failed to parse u256 constants[%d] %s: invalid format", coi, spew.Sdump(co)) + } + code.Var().Id(co.Name).Op("=").Func().Params().Op("*").Qual("math/big", "Int").Block( + Id("val").Op(",").Id("ok").Op(":=").New(Qual("math/big", "Int")).Dot("SetString").Call(Lit(cleanValue), Lit(10)), + If(Op("!").Id("ok")).Block( + Panic(Lit(fmt.Sprintf("invalid u256 constant %s", co.Name))), + ), + Return(Id("val")), + ).Call() + code.Line() + case *idltype.I256: + _ = ty + cleanValue := strings.ReplaceAll(co.Value, "_", "") + bigInt := new(big.Int) + _, ok := bigInt.SetString(cleanValue, 10) + if !ok { + return nil, fmt.Errorf("failed to parse i256 constants[%d] %s: invalid format", coi, spew.Sdump(co)) + } + code.Var().Id(co.Name).Op("=").Func().Params().Op("*").Qual("math/big", "Int").Block( + Id("val").Op(",").Id("ok").Op(":=").New(Qual("math/big", "Int")).Dot("SetString").Call(Lit(cleanValue), Lit(10)), + If(Op("!").Id("ok")).Block( + Panic(Lit(fmt.Sprintf("invalid i256 constant %s", co.Name))), + ), + Return(Id("val")), + ).Call() + code.Line() + case *idltype.F32: + _ = ty + // "value":"3.14" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseFloat(cleanValue, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse f32 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(float32(v)) + code.Line() + case *idltype.F64: + _ = ty + // "value":"3.14" + // "value":"4e-6" + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseFloat(cleanValue, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse f64 constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(v) + code.Line() + case *idltype.Array: + _ = ty + // "type":{"array":["u8",23]},"value":"[115, 101, 110, 100, 95, 119, 105, 116, 104, 95, 115, 119, 97, 112, 95, 100, 101, 108, 101, 103, 97, 116, 101]" + var b []any + dec := json.NewDecoder(strings.NewReader(co.Value)) + dec.UseNumber() + err := dec.Decode(&b) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal array constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + size, ok := ty.Size.(*idltype.IdlArrayLenValue) + if !ok { + return nil, fmt.Errorf("expected IdlArrayLenValue for constants[%d] %s, got %T", coi, spew.Sdump(co), ty.Size) + } + if len(b) != size.Value { + return nil, fmt.Errorf("expected %d elements in array constants[%d] %s, got %d", ty.Size, coi, spew.Sdump(co), len(b)) + } + code.Var().Id(co.Name).Op("=").Index(Lit(size.Value)).Do(func(index *Statement) { + switch ty.Type.(type) { + case *idltype.U8: + index.Byte() + case *idltype.I8: + index.Int8() + case *idltype.U16: + index.Uint16() + case *idltype.I16: + index.Int16() + case *idltype.U32: + index.Uint32() + case *idltype.I32: + index.Int32() + case *idltype.U64: + index.Uint64() + case *idltype.I64: + index.Int64() + case *idltype.F32: + index.Float32() + case *idltype.F64: + index.Float64() + case *idltype.String: + index.String() + case *idltype.Bool: + index.Bool() + default: + panic(fmt.Errorf("unsupported array type for constants[%d] %s: %T", coi, spew.Sdump(co), ty.Type)) + } + }).Op("{").ListFunc(func(byteGroup *Group) { + for _, val := range b[:] { + switch ty.Type.(type) { + case *idltype.U8: + v, err := strconv.ParseUint(val.(json.Number).String(), 10, 8) + if err != nil { + panic(fmt.Errorf("failed to parse u8 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(byte(v)) + case *idltype.I8: + v, err := strconv.ParseInt(val.(json.Number).String(), 10, 8) + if err != nil { + panic(fmt.Errorf("failed to parse i8 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(int8(v)) + case *idltype.U16: + v, err := strconv.ParseUint(val.(json.Number).String(), 10, 16) + if err != nil { + panic(fmt.Errorf("failed to parse u16 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(uint16(v)) + case *idltype.I16: + v, err := strconv.ParseInt(val.(json.Number).String(), 10, 16) + if err != nil { + panic(fmt.Errorf("failed to parse i16 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(int16(v)) + case *idltype.U32: + v, err := strconv.ParseUint(val.(json.Number).String(), 10, 32) + if err != nil { + panic(fmt.Errorf("failed to parse u32 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(uint32(v)) + case *idltype.I32: + v, err := strconv.ParseInt(val.(json.Number).String(), 10, 32) + if err != nil { + panic(fmt.Errorf("failed to parse i32 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(int32(v)) + case *idltype.U64: + v, err := strconv.ParseUint(val.(json.Number).String(), 10, 64) + if err != nil { + panic(fmt.Errorf("failed to parse u64 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(uint64(v)) + case *idltype.I64: + v, err := strconv.ParseInt(val.(json.Number).String(), 10, 64) + if err != nil { + panic(fmt.Errorf("failed to parse i64 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(int64(v)) + case *idltype.F32: + v, err := strconv.ParseFloat(val.(json.Number).String(), 32) + if err != nil { + panic(fmt.Errorf("failed to parse f32 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(float32(v)) + case *idltype.F64: + v, err := strconv.ParseFloat(val.(json.Number).String(), 64) + if err != nil { + panic(fmt.Errorf("failed to parse f64 in constants[%d] %s: %w", coi, spew.Sdump(co), err)) + } + byteGroup.Lit(v) + case *idltype.String: + byteGroup.Lit(val.(string)) + case *idltype.Bool: + byteGroup.Lit(val.(bool)) + default: + panic(fmt.Errorf("unsupported array type for constants[%d] %s: %T", coi, spew.Sdump(co), ty.Type)) + } + } + }).Op("}") + code.Line() + + case *idltype.Defined: + _ = ty + // Handle user-defined types like usize, isize, etc. + switch ty.Name { + case "usize": + // usize is typically a pointer-sized unsigned integer + // In most cases, this is equivalent to u64 on 64-bit systems + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseUint(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse usize constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(uint64(v)) + code.Line() + case "isize": + // isize is typically a pointer-sized signed integer + // In most cases, this is equivalent to i64 on 64-bit systems + cleanValue := strings.ReplaceAll(co.Value, "_", "") + v, err := strconv.ParseInt(cleanValue, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse isize constants[%d] %s: %w", coi, spew.Sdump(co), err) + } + code.Const().Id(co.Name).Op("=").Lit(int64(v)) + code.Line() + default: + // For other defined types, we could try to resolve them, + // but for now, we'll return an error with more specific information + return nil, fmt.Errorf("unsupported defined type '%s' for constants[%d] %s: %T", ty.Name, coi, spew.Sdump(co), ty) + } + + default: + return nil, fmt.Errorf("unsupported constant type for constants[%d] %s: %T", coi, spew.Sdump(co), ty) + } + } + file.Add(code) + } + return &OutputFile{ + Name: "constants.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/constants_test.go b/cmd/generate-bindings/solana/anchor-go/generator/constants_test.go new file mode 100644 index 00000000..83f298cb --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/constants_test.go @@ -0,0 +1,1183 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + "strings" + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenConstants(t *testing.T) { + tests := []struct { + name string + constants []idl.IdlConst + expectError bool + expectCode []string // 期望在生成的代码中找到的字符串 + }{ + { + name: "String constant", + constants: []idl.IdlConst{ + { + Name: "TEST_STRING", + Ty: &idltype.String{}, + Value: `"hello world"`, + }, + }, + expectCode: []string{ + "const TEST_STRING = \"hello world\"", + }, + }, + { + name: "Boolean constants", + constants: []idl.IdlConst{ + { + Name: "IS_ENABLED", + Ty: &idltype.Bool{}, + Value: "true", + }, + { + Name: "IS_DISABLED", + Ty: &idltype.Bool{}, + Value: "false", + }, + }, + expectCode: []string{ + "var IS_ENABLED = true", + "var IS_DISABLED = false", + }, + }, + { + name: "Unsigned integer constants", + constants: []idl.IdlConst{ + { + Name: "MAX_U8", + Ty: &idltype.U8{}, + Value: "255", + }, + { + Name: "MAX_U16", + Ty: &idltype.U16{}, + Value: "65535", + }, + { + Name: "MAX_U32", + Ty: &idltype.U32{}, + Value: "4294967295", + }, + { + Name: "MAX_U64", + Ty: &idltype.U64{}, + Value: "18446744073709551615", + }, + }, + expectCode: []string{ + "const MAX_U8 = uint8(0xff)", + "const MAX_U16 = uint16(0xffff)", + "const MAX_U32 = uint32(0xffffffff)", + "const MAX_U64 = uint64(0xffffffffffffffff)", + }, + }, + { + name: "Signed integer constants", + constants: []idl.IdlConst{ + { + Name: "MIN_I8", + Ty: &idltype.I8{}, + Value: "-128", + }, + { + Name: "MIN_I16", + Ty: &idltype.I16{}, + Value: "-32768", + }, + { + Name: "MIN_I32", + Ty: &idltype.I32{}, + Value: "-2147483648", + }, + { + Name: "MIN_I64", + Ty: &idltype.I64{}, + Value: "-9223372036854775808", + }, + }, + expectCode: []string{ + "const MIN_I8 = int8(-128)", + "const MIN_I16 = int16(-32768)", + "const MIN_I32 = int32(-2147483648)", + "const MIN_I64 = int64(-9223372036854775808)", + }, + }, + { + name: "Float constants", + constants: []idl.IdlConst{ + { + Name: "PI_F32", + Ty: &idltype.F32{}, + Value: "3.14159", + }, + { + Name: "E_F64", + Ty: &idltype.F64{}, + Value: "2.718281828459045", + }, + }, + expectCode: []string{ + "const PI_F32 = float32(3.14159)", + "const E_F64 = 2.718281828459045", + }, + }, + { + name: "Numbers with underscores", + constants: []idl.IdlConst{ + { + Name: "LARGE_NUMBER", + Ty: &idltype.U64{}, + Value: "100_000_000", + }, + { + Name: "NEGATIVE_NUMBER", + Ty: &idltype.I32{}, + Value: "-1_000_000", + }, + }, + expectCode: []string{ + "const LARGE_NUMBER = uint64(0x5f5e100)", + "const NEGATIVE_NUMBER = int32(-1000000)", + }, + }, + { + name: "usize constant", + constants: []idl.IdlConst{ + { + Name: "MAX_BIN_PER_ARRAY", + Ty: &idltype.Defined{ + Name: "usize", + }, + Value: "70", + }, + }, + expectCode: []string{ + "const MAX_BIN_PER_ARRAY = uint64(0x46)", + }, + }, + { + name: "isize constant", + constants: []idl.IdlConst{ + { + Name: "MIN_BIN_ID", + Ty: &idltype.Defined{ + Name: "isize", + }, + Value: "-443636", + }, + }, + expectCode: []string{ + "const MIN_BIN_ID = int64(-443636)", + }, + }, + { + name: "u128 constant", + constants: []idl.IdlConst{ + { + Name: "MAX_BASE_FEE", + Ty: &idltype.U128{}, + Value: "100_000_000", + }, + }, + expectCode: []string{ + "var MAX_BASE_FEE = func() *big.Int", + ".SetString(\"100000000\", 10)", + }, + }, + { + name: "i128 constant", + constants: []idl.IdlConst{ + { + Name: "MIN_BALANCE", + Ty: &idltype.I128{}, + Value: "-1_000_000_000_000", + }, + }, + expectCode: []string{ + "var MIN_BALANCE = func() *big.Int", + ".SetString(\"-1000000000000\", 10)", + }, + }, + { + name: "Bytes constant", + constants: []idl.IdlConst{ + { + Name: "SEED_BYTES", + Ty: &idltype.Bytes{}, + Value: "[102, 101, 101, 95, 118, 97, 117, 108, 116]", + }, + }, + expectCode: []string{ + "var SEED_BYTES = []byte{102, 101, 101, 95, 118, 97, 117, 108, 116}", + }, + }, + { + name: "Pubkey constant", + constants: []idl.IdlConst{ + { + Name: "PROGRAM_ID", + Ty: &idltype.Pubkey{}, + Value: "11111111111111111111111111111112", // System Program ID + }, + }, + expectCode: []string{ + "var PROGRAM_ID = solanago.MustPublicKeyFromBase58(\"11111111111111111111111111111112\")", + }, + }, + { + name: "Empty name - should be skipped", + constants: []idl.IdlConst{ + { + Name: "", + Ty: &idltype.U8{}, + Value: "42", + }, + { + Name: "VALID_CONST", + Ty: &idltype.U8{}, + Value: "42", + }, + }, + expectCode: []string{ + "const VALID_CONST = uint8(0x2a)", + }, + }, + { + name: "Empty value - should be skipped", + constants: []idl.IdlConst{ + { + Name: "EMPTY_VALUE", + Ty: &idltype.U8{}, + Value: "", + }, + { + Name: "VALID_CONST", + Ty: &idltype.U8{}, + Value: "42", + }, + }, + expectCode: []string{ + "const VALID_CONST = uint8(0x2a)", + }, + }, + { + name: "Unsupported defined type", + constants: []idl.IdlConst{ + { + Name: "CUSTOM_TYPE", + Ty: &idltype.Defined{ + Name: "CustomType", + }, + Value: "42", + }, + }, + expectError: true, + }, + { + name: "Invalid string format", + constants: []idl.IdlConst{ + { + Name: "INVALID_STRING", + Ty: &idltype.String{}, + Value: "invalid string format", // 应该有引号 + }, + }, + expectError: true, + }, + { + name: "Invalid number format", + constants: []idl.IdlConst{ + { + Name: "INVALID_NUMBER", + Ty: &idltype.U8{}, + Value: "not_a_number", + }, + }, + expectError: true, + }, + { + name: "Invalid u128 format", + constants: []idl.IdlConst{ + { + Name: "INVALID_U128", + Ty: &idltype.U128{}, + Value: "not_a_number", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 创建一个最小的 IDL 结构 + idlData := &idl.Idl{ + Constants: tt.constants, + } + + // 创建生成器 + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + // 生成常量 + outputFile, err := gen.gen_constants() + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, outputFile) + + // 获取生成的代码 + generatedCode := outputFile.File.GoString() + + // 检查期望的代码片段是否存在 + for _, expectedCode := range tt.expectCode { + assert.Contains(t, generatedCode, expectedCode, + "Expected code snippet not found: %s\nGenerated code:\n%s", + expectedCode, generatedCode) + } + + // 基本的结构检查 + assert.Contains(t, generatedCode, "package test") + assert.Contains(t, generatedCode, "Code generated by https://github.com/gagliardetto/anchor-go") + assert.Contains(t, generatedCode, "This file contains constants") + }) + } +} + +func TestGenConstantsWithArrays(t *testing.T) { + // 测试数组常量 + constants := []idl.IdlConst{ + { + Name: "BYTE_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.U8{}, + Size: &idltype.IdlArrayLenValue{Value: 3}, + }, + Value: "[1, 2, 3]", + }, + } + + idlData := &idl.Idl{ + Constants: constants, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var BYTE_ARRAY = [3]byte{uint8(0x1), uint8(0x2), uint8(0x3)}") +} + +func TestGenConstantsWithF32Array(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "FLOAT_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.F32{}, + Size: &idltype.IdlArrayLenValue{Value: 3}, + }, + Value: "[1.5, 2.5, 3.5]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var FLOAT_ARRAY = [3]float32{float32(1.5), float32(2.5), float32(3.5)}") +} + +func TestGenConstantsWithF64Array(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "DOUBLE_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.F64{}, + Size: &idltype.IdlArrayLenValue{Value: 2}, + }, + Value: "[3.14159, 2.71828]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var DOUBLE_ARRAY = [2]float64{3.14159, 2.71828}") +} + +func TestGenConstantsWithBoolArray(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "BOOL_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.Bool{}, + Size: &idltype.IdlArrayLenValue{Value: 3}, + }, + Value: "[true, false, true]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var BOOL_ARRAY = [3]bool{true, false, true}") +} + +func TestGenConstantsWithStringArray(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "STRING_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.String{}, + Size: &idltype.IdlArrayLenValue{Value: 2}, + }, + Value: `["hello", "world"]`, + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, `var STRING_ARRAY = [2]string{"hello", "world"}`) +} + +func TestGenConstantsEdgeCases(t *testing.T) { + t.Run("No constants", func(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{}, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "package test") + // 不应该包含 "Constants defined in the IDL:" 注释 + assert.NotContains(t, generatedCode, "Constants defined in the IDL:") + }) + + t.Run("Underscore cleaning", func(t *testing.T) { + // 测试下划线清理功能 + testCases := []struct { + value string + expected string + }{ + {"1_000", "1000"}, + {"1_000_000", "1000000"}, + {"1_2_3_4", "1234"}, + {"100", "100"}, // 没有下划线 + } + + for _, tc := range testCases { + constants := []idl.IdlConst{ + { + Name: "TEST_VALUE", + Ty: &idltype.U64{}, + Value: tc.value, + }, + } + + idlData := &idl.Idl{ + Constants: constants, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err, "Failed for value: %s", tc.value) + + generatedCode := outputFile.File.GoString() + + // 验证生成的代码不包含原始的下划线值 + if strings.Contains(tc.value, "_") { + assert.NotContains(t, generatedCode, tc.value) + } + } + }) +} + +func TestGenConstantsPerformance(t *testing.T) { + // 测试大量常量的性能 + constants := make([]idl.IdlConst, 1000) + for i := 0; i < 1000; i++ { + constants[i] = idl.IdlConst{ + Name: fmt.Sprintf("CONST_%d", i), + Ty: &idltype.U32{}, + Value: fmt.Sprintf("%d", i), + } + } + + idlData := &idl.Idl{ + Constants: constants, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{ + Package: "test", + }, + } + + // 测试性能(应该在合理时间内完成) + outputFile, err := gen.gen_constants() + require.NoError(t, err) + require.NotNil(t, outputFile) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "CONST_0") + assert.Contains(t, generatedCode, "CONST_999") +} + +// TestGenConstantsSpecialCases 测试特殊情况 +func TestGenConstantsSpecialCases(t *testing.T) { + t.Run("Zero values", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "ZERO_U8", + Ty: &idltype.U8{}, + Value: "0", + }, + { + Name: "ZERO_I32", + Ty: &idltype.I32{}, + Value: "0", + }, + { + Name: "ZERO_F64", + Ty: &idltype.F64{}, + Value: "0.0", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "const ZERO_U8 = uint8(0x0)") + assert.Contains(t, generatedCode, "const ZERO_I32 = int32(0)") + assert.Contains(t, generatedCode, "const ZERO_F64 = 0") + }) + + t.Run("Maximum values", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "MAX_U8_VALUE", + Ty: &idltype.U8{}, + Value: "255", + }, + { + Name: "MAX_I8_VALUE", + Ty: &idltype.I8{}, + Value: "127", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "const MAX_U8_VALUE = uint8(0xff)") + assert.Contains(t, generatedCode, "const MAX_I8_VALUE = int8(127)") + }) + + t.Run("Complex underscores", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "COMPLEX_NUMBER", + Ty: &idltype.U64{}, + Value: "1_000_000_000_000_000_000", + }, + { + Name: "HEX_LIKE_NUMBER", + Ty: &idltype.U32{}, + Value: "0_x_F_F_F_F", // 这不是真正的十六进制,只是包含下划线的数字 + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + // 第二个应该失败,因为它不是有效的数字 + outputFile, err := gen.gen_constants() + assert.Error(t, err) // 应该失败 + _ = outputFile + }) + + t.Run("Scientific notation", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "SCIENTIFIC_F32", + Ty: &idltype.F32{}, + Value: "1.23e-4", + }, + { + Name: "SCIENTIFIC_F64", + Ty: &idltype.F64{}, + Value: "1.23456789e10", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "const SCIENTIFIC_F32 = float32(0.000123)") + assert.Contains(t, generatedCode, "const SCIENTIFIC_F64 = 1.23456789e+10") + }) + + t.Run("Empty bytes array", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "EMPTY_BYTES", + Ty: &idltype.Bytes{}, + Value: "[]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var EMPTY_BYTES = []byte{}") + }) + + t.Run("With docs", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "DOCUMENTED_CONST", + Docs: []string{"This is a test constant", "With multiple lines of documentation"}, + Ty: &idltype.U32{}, + Value: "42", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "// This is a test constant") + assert.Contains(t, generatedCode, "// With multiple lines of documentation") + assert.Contains(t, generatedCode, "const DOCUMENTED_CONST = uint32(0x2a)") + }) +} + +// TestGenConstantsErrorCases 测试各种错误情况 +func TestGenConstantsErrorCases(t *testing.T) { + t.Run("Invalid pubkey", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "INVALID_PUBKEY", + Ty: &idltype.Pubkey{}, + Value: "invalid_pubkey_format", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse pubkey") + }) + + t.Run("Invalid bytes format", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "INVALID_BYTES", + Ty: &idltype.Bytes{}, + Value: "[1, 2, invalid]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal bytes") + }) + + t.Run("Invalid array format", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "INVALID_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.U8{}, + Size: &idltype.IdlArrayLenValue{Value: 3}, + }, + Value: "[1, 2]", // 只有2个元素,但期望3个 + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "got 2") + }) + + t.Run("Number overflow", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "OVERFLOW_U8", + Ty: &idltype.U8{}, + Value: "256", // 超出 u8 范围 + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + _, err := gen.gen_constants() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse u8") + }) +} + +// TestGenConstantsLargeU64I64ArrayPrecision verifies that u64 and i64 array +// elements above 2^53 are emitted with full 64-bit precision. json.Unmarshal +// into []any decodes numbers as float64, which silently rounds integers larger +// than 2^53. This test catches that: if the generator still uses float64 casts, +// the expected exact values will not appear in the generated code. +func TestGenConstantsLargeU64I64ArrayPrecision(t *testing.T) { + t.Run("u64 array with values above 2^53", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "LARGE_U64_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.U64{}, + Size: &idltype.IdlArrayLenValue{Value: 4}, + }, + // 2^53 = 9007199254740992 is the last integer float64 represents exactly. + // 2^53+1 and 2^53+3 are NOT representable in float64 and will be rounded + // to 2^53 and 2^53+4 respectively if parsed through float64. + Value: "[9007199254740993, 9007199254740995, 18446744073709551615, 9007199254740992]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + + // 2^53+1 = 0x20000000000001 — NOT exactly representable in float64 + assert.Contains(t, generatedCode, "uint64(0x20000000000001)", + "9007199254740993 (2^53+1) was rounded; float64 precision loss in u64 array element") + // 2^53+3 = 0x20000000000003 — NOT exactly representable in float64 + assert.Contains(t, generatedCode, "uint64(0x20000000000003)", + "9007199254740995 (2^53+3) was rounded; float64 precision loss in u64 array element") + // max u64 = 0xffffffffffffffff + assert.Contains(t, generatedCode, "uint64(0xffffffffffffffff)", + "18446744073709551615 (max u64) was rounded; float64 precision loss in u64 array element") + // 2^53 exactly representable — should always work + assert.Contains(t, generatedCode, "uint64(0x20000000000000)", + "9007199254740992 (2^53) should be emitted correctly") + }) + + t.Run("i64 array with values above 2^53", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "LARGE_I64_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.I64{}, + Size: &idltype.IdlArrayLenValue{Value: 4}, + }, + Value: "[9007199254740993, -9007199254740993, 9223372036854775807, -9223372036854775808]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + + // 2^53+1 positive + assert.Contains(t, generatedCode, "int64(9007199254740993)", + "9007199254740993 (2^53+1) was rounded; float64 precision loss in i64 array element") + // 2^53+1 negative + assert.Contains(t, generatedCode, "int64(-9007199254740993)", + "-9007199254740993 was rounded; float64 precision loss in i64 array element") + // max i64 + assert.Contains(t, generatedCode, "int64(9223372036854775807)", + "max i64 was rounded; float64 precision loss in i64 array element") + // min i64 + assert.Contains(t, generatedCode, "int64(-9223372036854775808)", + "min i64 was rounded; float64 precision loss in i64 array element") + }) + + t.Run("u32 array is not affected", func(t *testing.T) { + // u32 max = 4294967295 < 2^53, so float64 is fine + constants := []idl.IdlConst{ + { + Name: "U32_ARRAY", + Ty: &idltype.Array{ + Type: &idltype.U32{}, + Size: &idltype.IdlArrayLenValue{Value: 2}, + }, + Value: "[4294967295, 0]", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "uint32(0xffffffff)") + assert.Contains(t, generatedCode, "uint32(0x0)") + }) +} + +// TestGenConstantsRealWorldExamples 测试真实世界的例子 +func TestGenConstantsRealWorldExamples(t *testing.T) { + t.Run("Solana program constants", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "LAMPORTS_PER_SOL", + Ty: &idltype.U64{}, + Value: "1_000_000_000", + }, + { + Name: "SEED_PREFIX", + Ty: &idltype.String{}, + Value: `"anchor"`, + }, + { + Name: "MAX_SEED_LEN", + Ty: &idltype.U32{}, + Value: "32", + }, + { + Name: "SYSTEM_PROGRAM_ID", + Ty: &idltype.Pubkey{}, + Value: "11111111111111111111111111111112", + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "myprogram"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "package myprogram") + assert.Contains(t, generatedCode, "const LAMPORTS_PER_SOL = uint64(0x3b9aca00)") + assert.Contains(t, generatedCode, "const SEED_PREFIX = \"anchor\"") + assert.Contains(t, generatedCode, "const MAX_SEED_LEN = uint32(0x20)") + assert.Contains(t, generatedCode, "var SYSTEM_PROGRAM_ID = solanago.MustPublicKeyFromBase58") + }) + + t.Run("Mixed types with all supported features", func(t *testing.T) { + constants := []idl.IdlConst{ + { + Name: "FEATURE_ENABLED", + Docs: []string{"Feature flag for new functionality"}, + Ty: &idltype.Bool{}, + Value: "true", + }, + { + Name: "MAX_BIN_COUNT", + Docs: []string{"Maximum number of bins per array"}, + Ty: &idltype.Defined{ + Name: "usize", + }, + Value: "70", + }, + { + Name: "PROTOCOL_FEE", + Docs: []string{"Protocol fee in basis points"}, + Ty: &idltype.U128{}, + Value: "10_000_000_000_000_000_000", + }, + { + Name: "SIGNATURE_SEED", + Ty: &idltype.Array{ + Type: &idltype.U8{}, + Size: &idltype.IdlArrayLenValue{Value: 8}, + }, + Value: "[115, 105, 103, 110, 97, 116, 117, 114]", // "signatur" in ASCII + }, + } + + idlData := &idl.Idl{Constants: constants} + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + + // 检查注释 + assert.Contains(t, generatedCode, "// Feature flag for new functionality") + assert.Contains(t, generatedCode, "// Maximum number of bins per array") + assert.Contains(t, generatedCode, "// Protocol fee in basis points") + + // 检查生成的常量 + assert.Contains(t, generatedCode, "var FEATURE_ENABLED = true") + assert.Contains(t, generatedCode, "const MAX_BIN_COUNT = uint64(0x46)") + assert.Contains(t, generatedCode, "var PROTOCOL_FEE = func() *big.Int") + assert.Contains(t, generatedCode, "var SIGNATURE_SEED = [8]byte{uint8(0x73), uint8(0x69), uint8(0x67), uint8(0x6e), uint8(0x61), uint8(0x74), uint8(0x75), uint8(0x72)}") + }) +} + +func TestGenerateCodecStructMethod_SkipsEnums(t *testing.T) { + idlData := &idl.Idl{ + Types: idl.IdTypeDef_slice{ + { + Name: "MyStruct", + Ty: &idl.IdlTypeDefTyStruct{ + Fields: idl.IdlDefinedFieldsNamed{ + {Name: "value", Ty: &idltype.U64{}}, + }, + }, + }, + { + Name: "MySimpleEnum", + Ty: &idl.IdlTypeDefTyEnum{ + Variants: idl.VariantSlice{ + {Name: "VariantA"}, + {Name: "VariantB"}, + }, + }, + }, + { + Name: "AnotherStruct", + Ty: &idl.IdlTypeDefTyStruct{ + Fields: idl.IdlDefinedFieldsNamed{ + {Name: "name", Ty: &idltype.String{}}, + }, + }, + }, + }, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + methods, err := gen.generateCodecStructMethod() + require.NoError(t, err) + + require.Len(t, methods, 2, "only the two struct types should produce codec methods") + + rendered := make([]string, len(methods)) + for i, m := range methods { + rendered[i] = fmt.Sprintf("%#v", m) + } + + assert.Contains(t, rendered[0], "EncodeMyStructStruct") + assert.Contains(t, rendered[1], "EncodeAnotherStructStruct") + + for _, r := range rendered { + assert.NotContains(t, r, "MySimpleEnum", + "enum type should not appear in codec struct methods") + } +} + +func TestGenerateCodecMethods_EnumOnlyIDL(t *testing.T) { + idlData := &idl.Idl{ + Types: idl.IdTypeDef_slice{ + { + Name: "Status", + Ty: &idl.IdlTypeDefTyEnum{ + Variants: idl.VariantSlice{ + {Name: "Active"}, + {Name: "Inactive"}, + }, + }, + }, + }, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + methods, err := gen.generateCodecMethods() + require.NoError(t, err) + assert.Empty(t, methods, "codec should have no methods for an enum-only IDL") +} + +func TestGenerateCodecStructMethod_CodecInterfaceMatchesImpl(t *testing.T) { + idlData := &idl.Idl{ + Types: idl.IdTypeDef_slice{ + { + Name: "MyStruct", + Ty: &idl.IdlTypeDefTyStruct{ + Fields: idl.IdlDefinedFieldsNamed{ + {Name: "value", Ty: &idltype.U64{}}, + }, + }, + }, + { + Name: "MyEnum", + Ty: &idl.IdlTypeDefTyEnum{ + Variants: idl.VariantSlice{ + {Name: "On"}, + {Name: "Off"}, + }, + }, + }, + }, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + interfaceMethods, err := gen.generateCodecStructMethod() + require.NoError(t, err) + + interfaceMethodNames := make(map[string]bool) + for _, m := range interfaceMethods { + s := fmt.Sprintf("%#v", m) + for _, typ := range idlData.Types { + name := "Encode" + typ.Name + "Struct" + if strings.Contains(s, name) { + interfaceMethodNames[name] = true + } + } + } + + for _, typ := range idlData.Types { + name := "Encode" + typ.Name + "Struct" + if _, isStruct := typ.Ty.(*idl.IdlTypeDefTyStruct); isStruct { + assert.True(t, interfaceMethodNames[name], + "struct %q should have codec interface method %s", typ.Name, name) + + implCode := creGenerateCodecEncoderForTypes(typ.Name) + implStr := fmt.Sprintf("%#v", implCode) + assert.Contains(t, implStr, name, + "struct %q should have matching implementation", typ.Name) + } else { + assert.False(t, interfaceMethodNames[name], + "enum %q must not have codec interface method %s", typ.Name, name) + } + } +} + +func TestGenfileConstructor_WithEnumTypes(t *testing.T) { + idlData := &idl.Idl{ + Metadata: idl.IdlMetadata{ + Name: "test_program", + Version: "0.1.0", + Spec: "0.1.0", + }, + Types: idl.IdTypeDef_slice{ + { + Name: "Config", + Ty: &idl.IdlTypeDefTyStruct{ + Fields: idl.IdlDefinedFieldsNamed{ + {Name: "value", Ty: &idltype.U64{}}, + }, + }, + }, + { + Name: "Mode", + Ty: &idl.IdlTypeDefTyEnum{ + Variants: idl.VariantSlice{ + {Name: "Fast"}, + {Name: "Slow"}, + }, + }, + }, + }, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "testpkg"}, + } + + outputFile, err := gen.genfile_constructor() + require.NoError(t, err) + require.NotNil(t, outputFile) + + code := fmt.Sprintf("%#v", outputFile.File) + + assert.Contains(t, code, "EncodeConfigStruct", + "struct type Config should have a codec interface method") + assert.NotContains(t, code, "EncodeModeStruct", + "enum type Mode should not produce a codec interface method") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/cre.go b/cmd/generate-bindings/solana/anchor-go/generator/cre.go new file mode 100644 index 00000000..630722c2 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/cre.go @@ -0,0 +1,413 @@ +// This file contains all the cre specific code for the generator. +// The other files are copied from https://github.com/gagliardetto/anchor-go/blob/main/generator/ +// They simply call functions in this file. +// +//nolint:all +package generator + +import ( + "encoding/json" + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +const ( + PkgCRE = "github.com/smartcontractkit/cre-sdk-go/cre" + PkgPbSdk = "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + PkgSolanaCre = "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + PkgBindings = "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/bindings" +) + +// func (c *Codec) Decode(data []byte) (*, error) { +func creDecodeAccountFn(name string) Code { + return Func(). + Params(Id("c").Op("*").Id("Codec")). + Id("Decode"+name). + Params(Id("data").Index().Byte()). + Params(Op("*").Id(name), Error()). + Block(Return(Id("ParseAccount_" + name).Call(Id("data")))) +} + +// func (c *Codec) EncodeStruct(in ) ([]byte, error) { +// return in.Marshal() +// } +func creGenerateCodecEncoderForTypes(exportedAccountName string) Code { + return Func(). + Params(Id("c").Op("*").Id("Codec")). + Id("Encode"+exportedAccountName+"Struct"). + Params(Id("in").Id(exportedAccountName)). + Params(Index().Byte(), Error()). + Block(Return(Id("in").Dot("Marshal").Call())) +} + +// if err block +// +// return cre.PromiseFromResult[*](nil, err) +// } +func creWriteReportErrorBlock() Code { + code := Empty() + code.If(Id("err").Op("!=").Nil()).Block( + Return( + Qual(PkgCRE, "PromiseFromResult").Types(Op("*").Qual(PkgSolanaCre, "WriteReportReply")).Call( + Nil(), Id("err"), + ))) + code.Line().Line() + return code +} + +func creWriteReportFromStructs(exportedAccountName string, g *Generator) Code { + code := Empty() + declarerName := newWriteReportFromInstructionFuncName(exportedAccountName) + code.Commentf("%s encodes the input struct, hashes the provided accounts,", declarerName) + code.Comment("generates a signed report, and submits it via WriteReport.") + code.Comment("") + code.Comment("remainingAccounts must follow the keystone-forwarder account layout:") + code.Comment(" - Index 0: forwarderState – the forwarder program's state account.") + code.Comment(" - Index 1: forwarderAuthority – PDA derived from seeds") + code.Comment(" [\"forwarder\", forwarderState, receiverProgram] under the forwarder program ID.") + code.Comment(" - Index 2+: receiver-specific accounts required by the target program.") + code.Comment("") + code.Comment("The full slice is hashed (via CalculateAccountsHash) into the report and forwarded") + code.Comment("as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1") + code.Comment("before CPI-ing into the receiver, so they must be present and correctly ordered.") + code.Line() + code.Func(). + Params(Id("c").Op("*").Id(tools.ToCamelUpper(g.options.Package))). // method receiver + Id(declarerName). + // params + Params( + ListMultiline(func(p *Group) { + p.Id("runtime").Qual(PkgCRE, "Runtime") + p.Id("input").Id(exportedAccountName) + p.Id("remainingAccounts").Index().Op("*").Qual(PkgSolanaCre, "AccountMeta") + p.Id("computeConfig").Op("*").Qual(PkgSolanaCre, "ComputeConfig") + }), + ). + // return type + Params(Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgSolanaCre, "WriteReportReply"))). + BlockFunc(func(block *Group) { + // encoded, err := c.Codec.EncodeStruct(input) + block.List(Id("encodedInput"), Id("err")).Op(":="). + Id("c").Dot("Codec").Dot("Encode" + exportedAccountName + "Struct").Call(Id("input")) + + // if err block + block.Add(creWriteReportErrorBlock()) + + // encodedAccountList, err := EncodeAccountList(accountList) + block.Id("encodedAccountList").Op(":="). + Qual(PkgBindings, "CalculateAccountsHash").Call(Id("remainingAccounts")).Line() + + // fwdReport := ForwarderReport{Payload: encodedInput, AccountHash: encodedAccountList} + block.Id("fwdReport").Op(":=").Qual(PkgBindings, "ForwarderReport").Values(Dict{ + Id("Payload"): Id("encodedInput"), + Id("AccountHash"): Id("encodedAccountList"), + }) + + // encodedFwdReport, err := fwdReport.Marshal() + block.List(Id("encodedFwdReport"), Id("err")).Op(":=").Id("fwdReport").Dot("Marshal").Call() + + // if err block + block.Add(creWriteReportErrorBlock()) + + // promise := runtime.GenerateReport(&pb2.ReportRequest{ ... }) + block.Id("promise").Op(":=").Id("runtime").Dot("GenerateReport").Call( + Op("&").Qual(PkgPbSdk, "ReportRequest").Values(Dict{ + Id("EncodedPayload"): Id("encodedFwdReport"), + Id("EncoderName"): Lit("solana"), + Id("SigningAlgo"): Lit("ecdsa"), + Id("HashingAlgo"): Lit("keccak256"), + }), + ).Line() + + //return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + // return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + // AccountList: typedAccountList, + // Receiver: ProgramID.Bytes(), + // Report: report, + // }) + // }) + block.Return( + Qual(PkgCRE, "ThenPromise").Call( + Id("promise"), + creWriteReportFromStructsLambda(), + ), + ) + }) + return code +} + +func creEncodeBorshVecU32() Code { + st := Empty() + st.Comment(`EncodeBorshVecU32 returns Anchor/Borsh encoding of a Vec whose elements are opaque byte payloads.`) + st.Comment(`Each [][]byte element must already be fully serialized for one Vec item on the wire.`) + st.Comment(`Layout: little-endian u32 length followed by concatenated element payloads.`) + st.Line() + st.Func(). + Id("EncodeBorshVecU32"). + Params(Id("elements").Index().Index().Byte()). + Params(Index().Byte(), Error()). + BlockFunc(func(b *Group) { + b.Id("buf").Op(":=").Qual("bytes", "NewBuffer").Call(Nil()) + b.If( + Err().Op(":=").Qual("encoding/binary", "Write").Call( + Id("buf"), + Qual("encoding/binary", "LittleEndian"), + Id("uint32").Call(Len(Id("elements"))), + ), + Err().Op("!=").Nil(), + ).Block(Return(Nil(), Err())) + b.For(Id("_").Op(",").Id("elem").Op(":=").Range().Id("elements")).Block( + List(Id("_"), Err()).Op(":=").Id("buf").Dot("Write").Call(Id("elem")), + If(Err().Op("!=").Nil()).Block( + Return(Nil(), Err()), + ), + ) + b.Return(Id("buf").Dot("Bytes").Call(), Nil()) + }) + return st +} + +// WriteReportFromBorshEncodedVec forwards a CRE report whose inner payload is EncodeBorshVecU32(elementPayloads). +func creWriteReportFromBorshEncodedVec(g *Generator) Code { + pkg := tools.ToCamelUpper(g.options.Package) + code := Empty() + code.Comment(`WriteReportFromBorshEncodedVec publishes through the CRE signer using a forwarder payload built from`) + code.Comment(`Borsh Vec semantics (EncodeBorshVecU32). Compose each elementPayload for your program (e.g. one encoded struct per row).`) + code.Comment(`Pass computeConfig = nil to use the host default Solana compute budget.`) + code.Line() + code.Func(). + Params(Id("c").Op("*").Id(pkg)). + Id("WriteReportFromBorshEncodedVec"). + Params(ListFunc(func(pl *Group) { + pl.Id("runtime").Qual(PkgCRE, "Runtime") + pl.Id("elementPayloads").Index().Index().Byte() + pl.Id("remainingAccounts").Index().Op("*").Qual(PkgSolanaCre, "AccountMeta") + pl.Id("computeConfig").Op("*").Qual(PkgSolanaCre, "ComputeConfig") + })). + Params(Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgSolanaCre, "WriteReportReply"))). + BlockFunc(func(block *Group) { + block.List(Id("payload"), Id("err")).Op(":=").Id("EncodeBorshVecU32").Call(Id("elementPayloads")) + block.Add(creWriteReportErrorBlock()) + block.Id("encodedAccountList").Op(":=").Qual(PkgBindings, "CalculateAccountsHash").Call(Id("remainingAccounts")) + block.Id("fwdReport").Op(":=").Qual(PkgBindings, "ForwarderReport").Values(Dict{ + Id("AccountHash"): Id("encodedAccountList"), + Id("Payload"): Id("payload"), + }) + block.List(Id("encodedFwdReport"), Id("err")).Op(":=").Id("fwdReport").Dot("Marshal").Call() + block.Add(creWriteReportErrorBlock()) + block.Id("promise").Op(":=").Id("runtime").Dot("GenerateReport").Call( + Op("&").Qual(PkgPbSdk, "ReportRequest").Values(Dict{ + Id("EncodedPayload"): Id("encodedFwdReport"), + Id("EncoderName"): Lit("solana"), + Id("SigningAlgo"): Lit("ecdsa"), + Id("HashingAlgo"): Lit("keccak256"), + }), + ).Line() + block.Return(Qual(PkgCRE, "ThenPromise").Call(Id("promise"), creWriteReportFromStructsLambda())) + }) + return code +} + +// creWriteReportFromStructsSlice emits: +// +// func (c *) WriteReportFroms(runtime cre.Runtime, inputs [], remainingAccounts []*solana.AccountMeta, computeConfig *solana.ComputeConfig) cre.Promise[*solana.WriteReportReply] { +// elements := make([][]byte, len(inputs)) +// for i, input := range inputs { +// encoded, err := c.Codec.EncodeStruct(input) +// if err != nil { return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) } +// elements[i] = encoded +// } +// return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +// } +func creWriteReportFromStructsSlice(exportedStructName string, g *Generator) Code { + pkg := tools.ToCamelUpper(g.options.Package) + declarerName := newWriteReportFromInstructionFuncName(exportedStructName) + "s" + return Func(). + Params(Id("c").Op("*").Id(pkg)). + Id(declarerName). + Params(ListMultiline(func(p *Group) { + p.Id("runtime").Qual(PkgCRE, "Runtime") + p.Id("inputs").Index().Id(exportedStructName) + p.Id("remainingAccounts").Index().Op("*").Qual(PkgSolanaCre, "AccountMeta") + p.Id("computeConfig").Op("*").Qual(PkgSolanaCre, "ComputeConfig") + })). + Params(Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgSolanaCre, "WriteReportReply"))). + BlockFunc(func(block *Group) { + block.Id("elements").Op(":=").Make(Index().Index().Byte(), Len(Id("inputs"))) + block.For(Id("i").Op(",").Id("input").Op(":=").Range().Id("inputs")).Block( + List(Id("encoded"), Err()).Op(":="). + Id("c").Dot("Codec").Dot("Encode"+exportedStructName+"Struct").Call(Id("input")), + If(Err().Op("!=").Nil()).Block( + Return(Qual(PkgCRE, "PromiseFromResult"). + Types(Op("*").Qual(PkgSolanaCre, "WriteReportReply")). + Call(Nil(), Err())), + ), + Id("elements").Index(Id("i")).Op("=").Id("encoded"), + ) + block.Return(Id("c").Dot("WriteReportFromBorshEncodedVec").Call( + Id("runtime"), + Id("elements"), + Id("remainingAccounts"), + Id("computeConfig"), + )) + }) +} + +func creWriteReportFromStructsLambda() *Statement { + // func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + // return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + // AccountList: typedAccountList, + // Receiver: ProgramID.Bytes(), + // Report: report, + // }) + // } + return Func(). + Params(Id("report").Op("*").Qual(PkgCRE, "Report")). + Qual(PkgCRE, "Promise").Types(Op("*").Qual(PkgSolanaCre, "WriteReportReply")). + Block( + Return( + Id("c").Dot("client").Dot("WriteReport").Call( + Id("runtime"), + Op("&").Qual(PkgSolanaCre, "WriteCreReportRequest").Values(Dict{ + Id("Receiver"): Id("ProgramID").Dot("Bytes").Call(), + Id("Report"): Id("report"), + Id("RemainingAccounts"): Id("remainingAccounts"), + Id("ComputeConfig"): Id("computeConfig"), + }), + ), + ), + ) +} + +// genfile_constructor generates the file `constructor.go`. +func (g *Generator) genfile_constructor() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains the constructor for the program.") + + { + // idl string + code := newStatement() + idlData, err := json.Marshal(g.idl) + if err != nil { + return nil, fmt.Errorf("error reading IDL file: %w", err) + } + code.Var().Id("IDL").Op("=").Lit(string(idlData)) + file.Add(code) + code.Line() + + // contract type + code = newStatement() + code.Type().Id(tools.ToCamelUpper(g.options.Package)).Struct( + Id("client").Op("*").Qual(PkgSolanaCre, "Client"), + Id("Codec").Id(tools.ToCamelUpper(g.options.Package)+"Codec"), + ) + code.Line() + file.Add(code) + code.Line() + + // codec type + code = newStatement() + code.Type().Id("Codec").Struct() + code.Line() + file.Add(code) + + // new constructor + code = newStatement() + code.Func(). + Id("New"+tools.ToCamelUpper(g.options.Package)). + Params( + Id("client").Op("*").Qual(PkgSolanaCre, "Client"), + ). + Params( + Op("*").Id(tools.ToCamelUpper(g.options.Package)), Error(), + ). + Block( + Return( + Op("&").Id(tools.ToCamelUpper(g.options.Package)).Values(Dict{ + Id("Codec"): Op("&").Id("Codec").Values(), + Id("client"): Id("client"), + }), + Nil(), + ), + ) + file.Add(code) + code.Line() + + file.Add(creEncodeBorshVecU32()) + code.Line() + file.Add(creWriteReportFromBorshEncodedVec(g)) + code.Line() + + methods, err := g.generateCodecMethods() + if err != nil { + return nil, err + } + + // Codec interface + code = newStatement() + code.Type().Id(tools.ToCamelUpper(g.options.Package) + "Codec").Interface(methods...) + file.Add(code) + code.Line() + } + + return &OutputFile{ + Name: "constructor.go", + File: file, + }, nil +} + +func (g *Generator) generateCodecAccountMethods() ([]Code, error) { + accountMethods := make([]Code, 0, len(g.idl.Accounts)) + for _, acc := range g.idl.Accounts { + exportedName := tools.ToCamelUpper(acc.Name) + methodName := "Decode" + exportedName + m := Id(methodName). + Params(Id("data").Index().Byte()). // ([]byte) + Params( + Op("*").Id(exportedName), // (*DataAccount) + Error(), // error + ) + + accountMethods = append(accountMethods, m) + } + + return accountMethods, nil +} + +func (g *Generator) generateCodecStructMethod() ([]Code, error) { + structMethods := make([]Code, 0, len(g.idl.Types)) + for _, typ := range g.idl.Types { + exportedName := tools.ToCamelUpper(typ.Name) + methodName := "Encode" + exportedName + "Struct" + if _, isEnum := typ.Ty.(*idl.IdlTypeDefTyEnum); isEnum { + continue + } + m := Id(methodName). + Params( + Id("in").Id(exportedName), // e.g., AccessLogged / DataAccount / ... + ). + Params( + Index().Byte(), // []byte + Error(), // error + ) + structMethods = append(structMethods, m) + } + return structMethods, nil +} + +func (g *Generator) generateCodecMethods() ([]Code, error) { + accountMethods, err := g.generateCodecAccountMethods() + if err != nil { + return nil, err + } + + structMethods, err := g.generateCodecStructMethod() + if err != nil { + return nil, err + } + return append(accountMethods, structMethods...), nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/discriminator.go b/cmd/generate-bindings/solana/anchor-go/generator/discriminator.go new file mode 100644 index 00000000..749d0079 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/discriminator.go @@ -0,0 +1,118 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) gen_discriminators() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains the discriminators for accounts and events defined in the IDL.") + + { + accountDiscriminatorsCodes := Empty() + accountDiscriminatorsCodes.Comment("Account discriminators") + accountDiscriminatorsCodes.Line() + accountDiscriminatorsCodes.Var().Parens( + DoGroup(func(code *Group) { + for _, account := range g.idl.Accounts { + if account.Discriminator == nil { + continue + } + + discriminator := account.Discriminator + if len(discriminator) != 8 { + panic(fmt.Errorf("discriminator for account %s must be exactly 8 bytes long, got %d bytes", account.Name, len(discriminator))) + } + + discriminatorName := FormatAccountDiscriminatorName(account.Name) + { + // binary.TypeID (not [8]byte) matches ReadDiscriminator's return type; mixing + // types breaks equality checks on wasm/wasip1 (runtime.memequal pointer types). + code.Id(discriminatorName).Op("=").Qual(PkgBinary, "TypeID").ValuesFunc(func(byteGroup *Group) { + for _, byteVal := range discriminator[:] { + byteGroup.Lit(int(byteVal)) + } + }) + } + code.Line() + } + }), + ) + file.Add(accountDiscriminatorsCodes) + file.Line() + } + { + // Generate the discriminators for events. + eventDiscriminatorsCodes := Empty() + eventDiscriminatorsCodes.Comment("Event discriminators") + eventDiscriminatorsCodes.Line() + eventDiscriminatorsCodes.Var().Parens( + DoGroup(func(code *Group) { + for _, event := range g.idl.Events { + if event.Discriminator == nil { + continue + } + + discriminator := event.Discriminator + if len(discriminator) != 8 { + panic(fmt.Errorf("discriminator for event %s must be exactly 8 bytes long", event.Name)) + } + + discriminatorName := FormatEventDiscriminatorName(event.Name) + { + code.Id(discriminatorName).Op("=").Qual(PkgBinary, "TypeID").ValuesFunc(func(byteGroup *Group) { + for _, byteVal := range discriminator[:] { + byteGroup.Lit(int(byteVal)) + } + }) + } + code.Line() + } + }), + ) + file.Add(eventDiscriminatorsCodes) + file.Line() + } + { + // Generate the discriminators for instructions. + instructionDiscriminatorsCodes := Empty() + instructionDiscriminatorsCodes.Comment("Instruction discriminators") + instructionDiscriminatorsCodes.Line() + instructionDiscriminatorsCodes.Var().Parens( + DoGroup( + func(code *Group) { + for _, instruction := range g.idl.Instructions { + if instruction.Discriminator == nil { + continue + } + + discriminator := instruction.Discriminator + if len(discriminator) != 8 { + panic(fmt.Errorf("discriminator for instruction %s must be exactly 8 bytes long", instruction.Name)) + } + + discriminatorName := FormatInstructionDiscriminatorName(instruction.Name) + { + code.Id(discriminatorName).Op("=").Qual(PkgBinary, "TypeID").ValuesFunc(func(byteGroup *Group) { + for _, byteVal := range discriminator[:] { + byteGroup.Lit(int(byteVal)) + } + }) + } + code.Line() + } + }, + ), + ) + file.Add(instructionDiscriminatorsCodes) + file.Line() + } + return &OutputFile{ + Name: "discriminators.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/doc.go b/cmd/generate-bindings/solana/anchor-go/generator/doc.go new file mode 100644 index 00000000..152d5a80 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/doc.go @@ -0,0 +1,34 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + . "github.com/dave/jennifer/jen" //nolint:staticcheck // ST1019: dot import used for code generation convenience +) + +func (g *Generator) genfile_doc() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains documentation and example usage for the generated code.") + + // TODO: + // - example usage + // - documentation + + file.Line().Line() + + if len(g.idl.Docs) == 0 { + file.Comment("No documentation available from the IDL.") + file.Comment("Please refer to the IDL source or the program documentation for more information.") + file.Line() + } else { + file.Comment("Documentation from the IDL:") + for _, comment := range g.idl.Docs { + file.Comment(comment) + } + } + + return &OutputFile{ + Name: "doc.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/errors.go b/cmd/generate-bindings/solana/anchor-go/generator/errors.go new file mode 100644 index 00000000..f8efd9a4 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/errors.go @@ -0,0 +1,101 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "encoding/json" + "errors" + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/solana-go/rpc/jsonrpc" +) + +func (g *Generator) gen_errors() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains errors.") + { + code := Empty() + for _, e := range g.idl.Errors { + _ = e + // spew.Dump(e) + + // type IdlErrorCode struct { + // Code uint32 `json:"code"` + // Name string `json:"name"` + // // #[serde(skip_serializing_if = "is_default")] + // // pub msg: Option, + // Msg Option[string] `json:"msg,omitzero"` + // } + } + file.Add(code) + } + return &OutputFile{ + Name: "errors.go", + File: file, + }, nil +} + +type CustomError interface { + Code() int + Name() string + Error() string +} +type customErrorDef struct { + code int + name string + msg string +} + +func (e *customErrorDef) Code() int { + return e.code +} + +func (e *customErrorDef) Name() string { + return e.name +} + +func (e *customErrorDef) Error() string { + return fmt.Sprintf("%s(%d): %s", e.name, e.code, e.msg) +} + +var Errors = map[int]CustomError{} + +func DecodeCustomError(rpcErr error) (err error, ok bool) { + if errCode, o := decodeErrorCode(rpcErr); o { + if customErr, o := Errors[errCode]; o { + err = customErr + ok = true + return + } + } + return +} + +func decodeErrorCode(rpcErr error) (errorCode int, ok bool) { + var jErr *jsonrpc.RPCError + if errors.As(rpcErr, &jErr) && jErr.Data != nil { + if root, o := jErr.Data.(map[string]any); o { + if rootErr, o := root["err"].(map[string]any); o { + if rootErrInstructionError, o := rootErr["InstructionError"]; o { + if rootErrInstructionErrorItems, o := rootErrInstructionError.([]any); o { + if len(rootErrInstructionErrorItems) == 2 { + if v, o := rootErrInstructionErrorItems[1].(map[string]any); o { + if v2, o := v["Custom"].(json.Number); o { + if code, err := v2.Int64(); err == nil { + ok = true + errorCode = int(code) + } + } else if v2, o := v["Custom"].(float64); o { + ok = true + errorCode = int(v2) + } + } + } + } + } + } + } + } + return +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/events.go b/cmd/generate-bindings/solana/anchor-go/generator/events.go new file mode 100644 index 00000000..61969234 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/events.go @@ -0,0 +1,113 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/tools" +) + +func (g *Generator) genfile_events() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains parsers for the events defined in the IDL.") + + names := []string{} + { + for _, event := range g.idl.Events { + names = append(names, tools.ToCamelUpper(event.Name)) + } + } + { + code, err := g.gen_eventParser(names) + if err != nil { + return nil, fmt.Errorf("error generating event parser: %w", err) + } + file.Add(code) + } + + return &OutputFile{ + Name: "events.go", + File: file, + }, nil +} + +func (g *Generator) gen_eventParser(eventNames []string) (Code, error) { + code := Empty() + { + code.Func().Id("ParseAnyEvent"). + Params(Id("eventData").Index().Byte()). + Params(Any(), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("eventData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek event discriminator: %w"), Err()), + ), + ) + + block.Switch(Id("discriminator")).BlockFunc(func(switchBlock *Group) { + for _, name := range eventNames { + switchBlock.Case(Id(FormatEventDiscriminatorName(name))).Block( + Id("value").Op(":=").New(Id(name)), + Err().Op(":=").Id("value").Dot("UnmarshalWithDecoder").Call(Id("decoder")), + If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal event as "+name+": %w"), Err()), + ), + ), + Return(Id("value"), Nil()), + ) + } + switchBlock.Default().Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("unknown discriminator: %s"), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + }) + }) + } + { + code.Line().Line() + // for each event, generate a function to parse it: + for _, name := range eventNames { + discriminatorName := FormatEventDiscriminatorName(name) + + code.Func().Id("ParseEvent_"+name). + Params(Id("eventData").Index().Byte()). + Params(Op("*").Id(name), Error()). + BlockFunc(func(block *Group) { + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("eventData")) + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to peek discriminator: %w"), Err()), + ), + ) + + block.If(Id("discriminator").Op("!=").Id(discriminatorName)).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("expected discriminator %v, got %s"), Id(discriminatorName), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + + block.Id("event").Op(":=").New(Id(name)) + block.Err().Op("=").Id("event").Dot("UnmarshalWithDecoder").Call(Id("decoder")) + + block.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to unmarshal event of type "+name+": %w"), Err()), + ), + ) + + block.Return(Id("event"), Nil()) + }) + code.Line().Line() + } + } + return code, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/fetchers.go b/cmd/generate-bindings/solana/anchor-go/generator/fetchers.go new file mode 100644 index 00000000..b5007851 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/fetchers.go @@ -0,0 +1,18 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) gen_fetchers() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains fetcher functions.") + { + } + return &OutputFile{ + Name: "fetchers.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/generator.go b/cmd/generate-bindings/solana/anchor-go/generator/generator.go new file mode 100644 index 00000000..b95e1bc5 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/generator.go @@ -0,0 +1,170 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/solana-go" +) + +var Debug = false // Set to true to enable debug logging. + +type Generator struct { + options *GeneratorOptions + idl *idl.Idl + complexEnumRegistry map[string]struct{} +} + +type GeneratorOptions struct { + OutputDir string // Directory to write the generated code to. + Package string // Package name for the generated code. + ModPath string // Module path for the generated code. E.g. "github.com/gagliardetto/mysolana-program-go" + ProgramId *solana.PublicKey // Program ID to use in the generated code. + ProgramName string // Name of the program for the generated code. + SkipGoMod bool // If true, skip generating the go.mod file. +} + +func NewGenerator(idl *idl.Idl, options *GeneratorOptions) *Generator { + return &Generator{ + idl: idl, + options: options, + } +} + +type OutputFile struct { + Name string // Name of the output file. + File *File +} + +type Output struct { + Files []*OutputFile // List of output files to be generated. + GoMod []byte // Go module file content. +} + +func (g *Generator) Generate() (*Output, error) { + if g.idl == nil { + return nil, fmt.Errorf("IDL is nil, cannot generate code") + } + if g.options == nil { + g.options = &GeneratorOptions{ + OutputDir: "generated", + Package: "idlclient", + ModPath: "github.com/gagliardetto/anchor-go/idlclient", + ProgramId: nil, + ProgramName: "myprogram", + } + } + if err := g.idl.Validate(); err != nil { + return nil, fmt.Errorf("invalid IDL: %w", err) + } + output := &Output{ + Files: make([]*OutputFile, 0), + } + + g.complexEnumRegistry = make(map[string]struct{}) + + { + // Register complex enums. + { + // TODO: .types is the only place where we can find complex enums? (or enums in general?) + for _, typ := range g.idl.Types { + g.registerComplexEnums(typ) + } + } + if len(g.idl.Docs) > 0 { + file, err := g.genfile_doc() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if len(g.idl.Accounts) > 0 { + file, err := g.genfile_accounts() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if len(g.idl.Events) > 0 { + file, err := g.genfile_events() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.genfile_constructor() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.genfile_types() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_discriminators() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_fetchers() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_errors() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_constants() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_tests() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + { + file, err := g.gen_instructions() + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if g.options.ProgramId != nil { + file, err := g.genfile_programID(*g.options.ProgramId) + if err != nil { + return nil, err + } + output.Files = append(output.Files, file) + } + if !g.options.SkipGoMod { + goMod, err := g.gen_gomod() + if err != nil { + return nil, err + } + output.GoMod = goMod + } + } + + return output, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/gomod.go b/cmd/generate-bindings/solana/anchor-go/generator/gomod.go new file mode 100644 index 00000000..6ac113f6 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/gomod.go @@ -0,0 +1,27 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "golang.org/x/mod/modfile" +) + +// gen_gomod generates a `go.mod` file for the generated code, and writes +// it to the destination directory. +func (g *Generator) gen_gomod() ([]byte, error) { + mdf := &modfile.File{} + mdf.AddModuleStmt(g.options.ModPath) + + mdf.AddNewRequire("github.com/gagliardetto/solana-go", "v1.12.0", false) + mdf.AddNewRequire("github.com/gagliardetto/anchor-go", "v0.3.2", false) + mdf.AddNewRequire("github.com/gagliardetto/binary", "v0.8.0", false) + mdf.AddNewRequire("github.com/gagliardetto/treeout", "v0.1.4", false) + mdf.AddNewRequire("github.com/gagliardetto/gofuzz", "v1.2.2", false) + mdf.AddNewRequire("github.com/stretchr/testify", "v1.10.0", false) + mdf.AddNewRequire("github.com/davecgh/go-spew", "v1.1.1", false) + + // add replacement for "github.com/gagliardetto/anchor-go/errors" to ../../demo-anchor-go/errors + // mdf.AddReplace("github.com/gagliardetto/anchor-go", "", "../../demo-anchor-go", "") + mdf.Cleanup() + + return mdf.Format() +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/id.go b/cmd/generate-bindings/solana/anchor-go/generator/id.go new file mode 100644 index 00000000..9476a497 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/id.go @@ -0,0 +1,26 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/solana-go" +) + +// TODO: +// - generate program IDs for mainnet, devnet, testnet, and localnet. + +func (g *Generator) genfile_programID(id solana.PublicKey) (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains the program ID.") + + { + idDecl := Var().Id("ProgramID").Op("=").Qual(PkgSolanaGo, "MustPublicKeyFromBase58").Call(Lit(id.String())) + file.Add(idDecl) + } + + return &OutputFile{ + Name: "program_id.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/idl_validate.go b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate.go new file mode 100644 index 00000000..35f41e69 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate.go @@ -0,0 +1,215 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +// ValidateIDLDerivedIdentifiers checks that names from the IDL produce valid Go identifiers +// after the same transforms used by the Jennifer-based generator. Call this before Generate(). +func ValidateIDLDerivedIdentifiers(i *idl.Idl) error { + if i == nil { + return fmt.Errorf("idl is nil") + } + for ai, acc := range i.Accounts { + ctx := fmt.Sprintf("accounts[%d](name=%q)", ai, acc.Name) + if err := validatePascalIdent(ctx, acc.Name); err != nil { + return err + } + disc := FormatAccountDiscriminatorName(acc.Name) + if err := validateRawIdent(ctx+".discriminatorVar", acc.Name, disc); err != nil { + return err + } + } + for ei, ev := range i.Events { + ctx := fmt.Sprintf("events[%d](name=%q)", ei, ev.Name) + if err := validatePascalIdent(ctx, ev.Name); err != nil { + return err + } + disc := FormatEventDiscriminatorName(ev.Name) + if err := validateRawIdent(ctx+".discriminatorVar", ev.Name, disc); err != nil { + return err + } + } + for ci, co := range i.Constants { + if co.Name == "" { + continue + } + ctx := fmt.Sprintf("constants[%d]", ci) + if err := validateRawIdent(ctx, co.Name, co.Name); err != nil { + return err + } + } + for ixIdx, ix := range i.Instructions { + ctx := fmt.Sprintf("instructions[%d](name=%q)", ixIdx, ix.Name) + if err := validatePascalIdent(ctx, ix.Name); err != nil { + return err + } + disc := FormatInstructionDiscriminatorName(ix.Name) + if err := validateRawIdent(ctx+".discriminatorVar", ix.Name, disc); err != nil { + return err + } + fn := newInstructionFuncName(ix.Name) + if err := validateRawIdent(ctx+".constructor", ix.Name, fn); err != nil { + return err + } + typeName := instructionStructTypeName(ix.Name) + if err := validateRawIdent(ctx+".instructionStructType", ix.Name, typeName); err != nil { + return err + } + for _, arg := range ix.Args { + argCtx := ctx + ".args(name=" + quoteIDL(arg.Name) + ")" + if err := validatePascalIdent(argCtx, arg.Name); err != nil { + return err + } + param := formatParamName(arg.Name) + if err := validateRawIdent(argCtx+".builderParam", arg.Name, param); err != nil { + return err + } + } + for ai, accItem := range ix.Accounts { + switch acc := accItem.(type) { + case *idl.IdlInstructionAccount: + acCtx := fmt.Sprintf("%s.accounts[%d](name=%q)", ctx, ai, acc.Name) + if err := validatePascalIdent(acCtx, acc.Name); err != nil { + return err + } + fieldBase := tools.ToCamelUpper(acc.Name) + if err := validateRawIdent(acCtx+".accountField", acc.Name, fieldBase); err != nil { + return err + } + if acc.Writable { + if err := validateRawIdent(acCtx+".writableFlag", acc.Name, fieldBase+"Writable"); err != nil { + return err + } + } + if acc.Signer { + if err := validateRawIdent(acCtx+".signerFlag", acc.Name, fieldBase+"Signer"); err != nil { + return err + } + } + if acc.Optional { + if err := validateRawIdent(acCtx+".optionalFlag", acc.Name, fieldBase+"Optional"); err != nil { + return err + } + } + param := formatAccountNameParam(acc.Name) + if err := validateRawIdent(acCtx+".builderParam", acc.Name, param); err != nil { + return err + } + case *idl.IdlInstructionAccounts: + return fmt.Errorf("%s.accounts[%d]: composite account groups are not supported", ctx, ai) + default: + return fmt.Errorf("%s.accounts[%d]: unknown account item type %T", ctx, ai, accItem) + } + } + } + for ti, def := range i.Types { + ctx := fmt.Sprintf("types[%d](name=%q)", ti, def.Name) + if err := validatePascalIdent(ctx, def.Name); err != nil { + return err + } + if err := validateTypeDefTy(ctx, def.Name, def.Ty); err != nil { + return err + } + } + return nil +} + +func instructionStructTypeName(instructionName string) string { + lower := strings.ToLower(instructionName) + if strings.HasSuffix(lower, "instruction") { + return tools.ToCamelUpper(instructionName) + } + return tools.ToCamelUpper(instructionName) + "Instruction" +} + +func quoteIDL(s string) string { + return fmt.Sprintf("%q", s) +} + +func validateTypeDefTy(ctx, typeName string, ty idl.IdlTypeDefTy) error { + if ty == nil { + return fmt.Errorf("%s: type definition has nil type body", ctx) + } + switch vv := ty.(type) { + case *idl.IdlTypeDefTyStruct: + fields := vv.Fields + if fields == nil { + return nil + } + switch f := fields.(type) { + case idl.IdlDefinedFieldsNamed: + for fi, field := range f { + fctx := fmt.Sprintf("%s.fields[%d](name=%q)", ctx, fi, field.Name) + if err := validatePascalIdent(fctx, field.Name); err != nil { + return err + } + } + case idl.IdlDefinedFieldsTuple: + _ = f + } + case *idl.IdlTypeDefTyEnum: + enumExported := tools.ToCamelUpper(typeName) + if vv.Variants.IsAllSimple() { + for vi, variant := range vv.Variants { + vctx := fmt.Sprintf("%s.variants[%d](name=%q)", ctx, vi, variant.Name) + if err := validatePascalIdent(vctx, variant.Name); err != nil { + return err + } + combo := formatSimpleEnumVariantName(variant.Name, enumExported) + if err := validateRawIdent(vctx+".simpleEnumConst", variant.Name, combo); err != nil { + return err + } + } + } else { + for vi, variant := range vv.Variants { + vctx := fmt.Sprintf("%s.variants[%d](name=%q)", ctx, vi, variant.Name) + if err := validatePascalIdent(vctx, variant.Name); err != nil { + return err + } + vt := formatComplexEnumVariantTypeName(enumExported, variant.Name) + if err := validateRawIdent(vctx+".complexVariantType", variant.Name, vt); err != nil { + return err + } + if !variant.Fields.IsSome() { + continue + } + switch df := variant.Fields.Unwrap().(type) { + case idl.IdlDefinedFieldsNamed: + for fi, field := range df { + fctx := fmt.Sprintf("%s.fields[%d](name=%q)", vctx, fi, field.Name) + if err := validatePascalIdent(fctx, field.Name); err != nil { + return err + } + } + case idl.IdlDefinedFieldsTuple: + } + } + } + default: + return fmt.Errorf("%s: unsupported IDL type definition shape %T", ctx, ty) + } + return nil +} + +func validatePascalIdent(context, raw string) error { + ident := tools.ToCamelUpper(raw) + return validateRawIdent(context, raw, ident) +} + +func validateRawIdent(context, idlSource, goIdent string) error { + if goIdent == "" { + return fmt.Errorf("%s: empty Go identifier derived from IDL name %q", context, idlSource) + } + if !tools.IsValidIdent(goIdent) { + return fmt.Errorf("%s: IDL name %q yields invalid Go identifier %q (must be a valid Go identifier for generated bindings)", context, idlSource, goIdent) + } + if tools.IsReservedKeyword(goIdent) { + return fmt.Errorf("%s: IDL name %q yields Go reserved keyword %q", context, idlSource, goIdent) + } + return nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/idl_validate_test.go b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate_test.go new file mode 100644 index 00000000..dda3a799 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate_test.go @@ -0,0 +1,49 @@ +package generator + +import ( + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/solana-go" + "github.com/stretchr/testify/require" +) + +func testProgramID(t *testing.T) *solana.PublicKey { + t.Helper() + pk, err := solana.PublicKeyFromBase58("ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL") + require.NoError(t, err) + return &pk +} + +func minimalInstruction(name string) idl.IdlInstruction { + return idl.IdlInstruction{ + Name: name, + Discriminator: idl.IdlDiscriminator{175, 175, 109, 31, 13, 152, 155, 237}, + Accounts: []idl.IdlInstructionAccountItem{}, + Args: []idl.IdlField{}, + } +} + +func TestValidateIDLDerivedIdentifiers_valid(t *testing.T) { + i := &idl.Idl{ + Address: testProgramID(t), + Instructions: []idl.IdlInstruction{minimalInstruction("initialize")}, + } + require.NoError(t, ValidateIDLDerivedIdentifiers(i)) +} + +func TestValidateIDLDerivedIdentifiers_invalidConstantName(t *testing.T) { + i := &idl.Idl{ + Address: testProgramID(t), + Instructions: []idl.IdlInstruction{minimalInstruction("initialize")}, + Constants: []idl.IdlConst{{ + Name: "123bad", + Ty: &idltype.U8{}, + Value: "1", + }}, + } + err := ValidateIDLDerivedIdentifiers(i) + require.Error(t, err) + require.Contains(t, err.Error(), "123bad") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/instructions.go b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go new file mode 100644 index 00000000..679949a0 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go @@ -0,0 +1,797 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + "strings" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +func (g *Generator) gen_instructions() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains instructions and instruction parsers.") + { + for _, instruction := range g.idl.Instructions { + uniqueParamNames := generateUniqueParamNames(instruction.Args) + ixCode := Empty() + { + declarerName := newInstructionFuncName(instruction.Name) + ixCode.Commentf("Builds a %q instruction.", instruction.Name) + { + if len(instruction.Docs) > 0 { + ixCode.Line() + // Add documentation comments for the instruction. + for _, doc := range instruction.Docs { + ixCode.Comment(doc) + } + } + } + ixCode.Line() + ixCode.Func().Id(declarerName). + Params( + DoGroup( + func(g *Group) { + addCommentSections := len(instruction.Args) > 0 && len(instruction.Accounts) > 0 + if addCommentSections { + g.Line().Comment("Params:") + } + g.Add( + ListMultiline( + func(paramsCode *Group) { + for _, param := range instruction.Args { + paramType := genTypeName(param.Ty) + if IsOption(param.Ty) || IsCOption(param.Ty) { + paramType = Op("*").Add(paramType) + } + paramsCode.Id(uniqueParamNames[param.Name]).Add(paramType) + } + }, + ), + ) + if addCommentSections { + g.Line().Comment("Accounts:") + } + g.Add( + ListMultiline( + func(accountsCode *Group) { + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + accountsCode.Id(formatAccountNameParam(acc.Name)).Qual(PkgSolanaGo, "PublicKey") + } + // TODO: for accounts: + // - Optional? + // - PDA? + // - Address? + // - Relations? + case *idl.IdlInstructionAccounts: + { + panic(fmt.Errorf("Accounts groups are not supported yet: %s", acc.Name)) + // accs := acc.Accounts + // // add comment for the accounts + // if len(accs) > 0 { + // accountsCode.Commentf("Accounts group %q:", acc.Name) + // } + // for _, acc := range accs { + // // If the account has a name, use it as the parameter name. + // // Otherwise, use a generic name. + // acc := acc.(*idl.IdlInstructionAccount) + // accountName := formatAccountNameParam(acc.Name) + // accountsCode.Id(accountName).Qual(PkgSolanaGo, "PublicKey") + // } + } + default: + panic("unknown account type: " + spew.Sdump(account)) + } + } + }, + ), + ) + }, + ), + ). + ParamsFunc(func(returnsCode *Group) { + returnsCode.Qual(PkgSolanaGo, "Instruction") + returnsCode.Error() + }).BlockFunc(func(body *Group) { + body.Id("buf__").Op(":=").New(Qual("bytes", "Buffer")) + body.Id("enc__").Op(":=").Qual(PkgBinary, "NewBorshEncoder").Call(Id("buf__")) + + { + body.Line().Comment("Encode the instruction discriminator.") + discriminatorName := FormatInstructionDiscriminatorName(instruction.Name) + body.Err().Op(":=").Id("enc__").Dot("WriteBytes").Call(Id(discriminatorName).Index(Op(":")), False()) + body.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed to write instruction discriminator: %w"), Err()), + ), + ) + } + if len(instruction.Args) > 0 { + checkNil := true + body.BlockFunc(func(grp *Group) { + g.gen_marshal_DefinedFieldsNamed( + grp, + instruction.Args, + checkNil, + func(param idl.IdlField) *Statement { + return Id(uniqueParamNames[param.Name]) + }, + "enc__", + true, // returnNilErr + func(param idl.IdlField) string { + return uniqueParamNames[param.Name] + }, + ) + }) + } + body.Id("accounts__").Op(":=").Qual(PkgSolanaGo, "AccountMetaSlice").Block() + if len(instruction.Accounts) > 0 { + body.Line().Comment("Add the accounts to the instruction.") + + body.Block( + DoGroup(func(body *Group) { + for ai, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + if ai > 0 { + body.Line() + } + body.Comment(formatAccountCommentDocs(ai, acc)) + body.Line() + { + // add comment for the account + if len(acc.Docs) > 0 { + for _, doc := range acc.Docs { + body.Comment(doc).Line() + } + } + } + accountName := formatAccountNameParam(acc.Name) + body.Id("accounts__").Dot("Append").Call( + Qual(PkgSolanaGo, "NewAccountMeta").Call( + Id(accountName), + Lit(acc.Writable), + Lit(acc.Signer), + ), + ) + } + + case *idl.IdlInstructionAccounts: + { + panic(fmt.Errorf("Accounts groups are not supported yet: %s", acc.Name)) + // if ai > 0 { + // body.Line() + // } + // body.Commentf("Accounts group: %s", acc.Name) + // body.Line() + // accs := acc.Accounts + // for acci, acc := range accs { + // acc := acc.(*idl.IdlInstructionAccount) + // body.Comment(formatAccountCommentDocs(acci, acc)) + // body.Line() + // accountName := formatAccountNameParam(acc.Name) + // body.Id("accounts__").Dot("Append").Call( + // Qual(PkgSolanaGo, "NewAccountMeta").Call( + // Id(accountName), + // Lit(acc.Writable), + // Lit(acc.Signer), + // ), + // ) + // } + } + default: + panic("unknown account type: " + spew.Sdump(account)) + } + } + }), + ) + } + + // create the return instruction + body.Line().Comment("Create the instruction.") + body.Return( + Qual(PkgSolanaGo, "NewInstruction").CallFunc( + func(g *Group) { + g.Add( + ListMultiline(func(gg *Group) { + gg.Id("ProgramID") + gg.Id("accounts__") + gg.Id("buf__").Dot("Bytes").Call() + }), + ) + }, + ), + Nil(), // No error + ) + }) + } + file.Add(ixCode) + } + } + + // Add instruction types and parsers + { + typeNames := []string{} + discriminatorNames := []string{} + for _, instruction := range g.idl.Instructions { + // Check if the instruction name already ends with "instruction" (case-insensitive) + instructionNameLower := strings.ToLower(instruction.Name) + if strings.HasSuffix(instructionNameLower, "instruction") { + // Already has "instruction" suffix, don't add it again + typeNames = append(typeNames, tools.ToCamelUpper(instruction.Name)) + } else { + // Add "Instruction" suffix + typeNames = append(typeNames, tools.ToCamelUpper(instruction.Name)+"Instruction") + } + discriminatorNames = append(discriminatorNames, tools.ToCamelUpper(instruction.Name)) + } + + // Generate instruction struct types + { + for _, instruction := range g.idl.Instructions { + typeCode, err := g.gen_instructionType(instruction) + if err != nil { + return nil, fmt.Errorf("error generating instruction type for %s: %w", instruction.Name, err) + } + file.Add(typeCode) + } + } + + // Generate instruction parsers + { + code, err := g.gen_instructionParser(typeNames, discriminatorNames) + if err != nil { + return nil, fmt.Errorf("error generating instruction parser: %w", err) + } + file.Add(code) + } + } + + return &OutputFile{ + Name: "instructions.go", + File: file, + }, nil +} + +func formatAccountNameParam(accountName string) string { + accountName = accountName + "Account" + if tools.IsReservedKeyword(accountName) { + return accountName + "_" + } + if !tools.IsValidIdent(accountName) { + return "a_" + tools.ToCamelUpper(accountName) + } + return tools.ToCamelLower(accountName) +} + +func formatParamName(paramName string) string { + paramName = paramName + "Param" + if tools.IsReservedKeyword(paramName) { + return paramName + "_" + } + if !tools.IsValidIdent(paramName) { + return "p_" + tools.ToCamelUpper(paramName) + } + return tools.ToCamelLower(paramName) +} + +// generateUniqueParamNames creates unique Go parameter names for instruction arguments, +// mirroring generateUniqueFieldNames but using formatParamName as the base identifier +// (builder params use a different convention than struct field names). +func generateUniqueParamNames(fields []idl.IdlField) map[string]string { + fieldNameMap := make(map[string]string) + usedNames := make(map[string]int) + + for _, field := range fields { + baseName := formatParamName(field.Name) + finalName := baseName + + if count, exists := usedNames[baseName]; exists { + finalName = baseName + fmt.Sprintf("%d", count+1) + usedNames[baseName] = count + 1 + } else { + usedNames[baseName] = 0 + } + + fieldNameMap[field.Name] = finalName + } + + return fieldNameMap +} + +func newInstructionFuncName(instructionName string) string { + // Check if the instruction name already ends with "instruction" (case-insensitive) + instructionNameLower := strings.ToLower(instructionName) + if strings.HasSuffix(instructionNameLower, "instruction") { + // Already has "instruction" suffix, don't add it again + return "New" + tools.ToCamelUpper(instructionName) + } else { + // Add "Instruction" suffix + return "New" + tools.ToCamelUpper(instructionName) + "Instruction" + } +} + +func newWriteReportFromInstructionFuncName(instructionName string) string { + return "WriteReportFrom" + tools.ToCamelUpper(instructionName) +} + +func formatAccountCommentDocs(index int, account *idl.IdlInstructionAccount) string { + buf := new(strings.Builder) + buf.WriteString(fmt.Sprintf("Account %d %q", index, account.Name)) + buf.WriteString(": ") + if account.Writable { + buf.WriteString("Writable") + } else { + buf.WriteString("Read-only") + } + if account.Signer { + buf.WriteString(", Signer") + } else { + buf.WriteString(", Non-signer") + } + if account.Optional { + buf.WriteString(", Optional") + } else { + buf.WriteString(", Required") + } + if account.Address.IsSome() && !account.Address.Unwrap().IsZero() { + buf.WriteString(fmt.Sprintf(", Address: %s", account.Address.Unwrap().String())) + } + // TODO: Handle PDA and Relations + return buf.String() +} + +func (g *Generator) gen_instructionParser(typeNames []string, discriminatorNames []string) (Code, error) { + code := Empty() + + // Generate Instruction interface + code.Line().Line() + code.Comment("Instruction interface defines common methods for all instruction types") + code.Line() + code.Type().Id("Instruction").Interface( + Id("GetDiscriminator").Params().Params(Index().Byte()), + Line(), + Id("UnmarshalWithDecoder").Params(Id("decoder").Op("*").Qual(PkgBinary, "Decoder")).Params(Error()), + Line(), + Id("UnmarshalAccountIndices").Params(Id("buf").Index().Byte()).Params(Index().Uint8(), Error()), + Line(), + Id("PopulateFromAccountIndices").Params(Id("indices").Index().Uint8(), Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey")).Params(Error()), + Line(), + Id("GetAccountKeys").Params().Params(Index().Qual(PkgSolanaGo, "PublicKey")), + ) + + // Single unified ParseInstruction function with optional accounts + code.Line().Line() + code.Comment("ParseInstruction parses instruction data and optionally populates accounts").Line() + code.Comment("If accountIndicesData is nil or empty, accounts will not be populated") + code.Line() + code.Func().Id("ParseInstruction"). + Params( + Id("instructionData").Index().Byte(), + Id("accountIndicesData").Index().Byte(), + Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey"), + ). + Params(Id("Instruction"), Error()). + BlockFunc(func(block *Group) { + block.Comment("Validate inputs") + block.If(Len(Id("instructionData")).Op("<").Lit(8)).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("instruction data too short: expected at least 8 bytes, got %d"), Len(Id("instructionData")))), + ) + + block.Comment("Extract discriminator (TypeID for consistent equality with generated constants)") + block.Id("discriminator").Op(":=").Qual(PkgBinary, "TypeIDFromBytes").Call(Id("instructionData").Index(Lit(0), Lit(8))) + + block.Comment("Parse based on discriminator") + block.Switch(Id("discriminator")).BlockFunc(func(switchBlock *Group) { + // This for loop runs during code generation, not at runtime + for i, typeName := range typeNames { + discriminatorName := discriminatorNames[i] + switchBlock.Case(Id(FormatInstructionDiscriminatorName(discriminatorName))).Block( + Id("instruction").Op(":=").New(Id(typeName)), + Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("instructionData")), + Id("err").Op(":=").Id("instruction").Dot("UnmarshalWithDecoder").Call(Id("decoder")), + If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to unmarshal instruction as "+typeName+": %w"), Id("err"))), + ), + If(Id("accountIndicesData").Op("!=").Nil().Op("&&").Len(Id("accountIndicesData")).Op(">").Lit(0)).Block( + Id("indices").Op(",").Id("err").Op(":=").Id("instruction").Dot("UnmarshalAccountIndices").Call(Id("accountIndicesData")), + If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to unmarshal account indices: %w"), Id("err"))), + ), + Id("err").Op("=").Id("instruction").Dot("PopulateFromAccountIndices").Call(Id("indices"), Id("accountKeys")), + If(Id("err").Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to populate accounts: %w"), Id("err"))), + ), + ), + Return(Id("instruction"), Nil()), + ) + } + switchBlock.Default().Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("unknown instruction discriminator: %s"), Qual(PkgBinary, "FormatDiscriminator").Call(Id("discriminator")))), + ) + }) + }) + + // Generic ParseInstructionTyped function for type-safe parsing + code.Line().Line() + code.Comment("ParseInstructionTyped parses instruction data and returns a specific instruction type") + code.Comment("T must implement the Instruction interface") + code.Line() + code.Func().Id("ParseInstructionTyped"). + Types(Id("T").Id("Instruction")). + Params( + Id("instructionData").Index().Byte(), + Id("accountIndicesData").Index().Byte(), + Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey"), + ). + Params(Id("T"), Error()). + BlockFunc(func(block *Group) { + block.Id("instruction").Op(",").Id("err").Op(":=").Id("ParseInstruction").Call(Id("instructionData"), Id("accountIndicesData"), Id("accountKeys")) + block.If(Id("err").Op("!=").Nil()).Block( + Return(Op("*").New(Id("T")), Id("err")), + ) + block.Id("typed").Op(",").Id("ok").Op(":=").Id("instruction").Assert(Id("T")) + block.If(Op("!").Id("ok")).Block( + Return(Op("*").New(Id("T")), Qual("fmt", "Errorf").Call(Lit("instruction is not of expected type"))), + ) + block.Return(Id("typed"), Nil()) + }) + + // Convenience function for parsing without accounts + code.Line().Line() + code.Comment("ParseInstructionWithoutAccounts parses instruction data without account information") + code.Line() + code.Func().Id("ParseInstructionWithoutAccounts"). + Params(Id("instructionData").Index().Byte()). + Params(Id("Instruction"), Error()). + Block( + Return(Id("ParseInstruction").Call(Id("instructionData"), Nil(), Index().Qual(PkgSolanaGo, "PublicKey").Op("{}"))), + ) + + // Convenience function for parsing with accounts + code.Line().Line() + code.Comment("ParseInstructionWithAccounts parses instruction data with account information") + code.Line() + code.Func().Id("ParseInstructionWithAccounts"). + Params( + Id("instructionData").Index().Byte(), + Id("accountIndicesData").Index().Byte(), + Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey"), + ). + Params(Id("Instruction"), Error()). + Block( + Return(Id("ParseInstruction").Call(Id("instructionData"), Id("accountIndicesData"), Id("accountKeys"))), + ) + + return code, nil +} + +func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, error) { + code := Empty() + + uniqueArgFieldNames := generateUniqueFieldNames(instruction.Args) + + // Check if the instruction name already ends with "instruction" (case-insensitive) + instructionNameLower := strings.ToLower(instruction.Name) + var typeName string + if strings.HasSuffix(instructionNameLower, "instruction") { + // Already has "instruction" suffix, don't add it again + typeName = tools.ToCamelUpper(instruction.Name) + } else { + // Add "Instruction" suffix + typeName = tools.ToCamelUpper(instruction.Name) + "Instruction" + } + + // Generate the instruction struct type + code.Type().Id(typeName).StructFunc(func(structGroup *Group) { + // Add fields for each instruction argument + for _, arg := range instruction.Args { + fieldType := genTypeName(arg.Ty) + if IsOption(arg.Ty) || IsCOption(arg.Ty) { + fieldType = Op("*").Add(fieldType) + } + structGroup.Id(uniqueArgFieldNames[arg.Name]).Add(fieldType).Tag(map[string]string{ + "json": arg.Name, + }) + } + + // Add fields for each instruction account + if len(instruction.Accounts) > 0 { + structGroup.Line().Comment("Accounts:") + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + // Add account field with metadata + fieldName := tools.ToCamelUpper(acc.Name) + structGroup.Id(fieldName).Qual(PkgSolanaGo, "PublicKey").Tag(map[string]string{ + "json": acc.Name, + }) + + // Add account metadata fields + if acc.Writable { + structGroup.Id(fieldName + "Writable").Bool().Tag(map[string]string{ + "json": acc.Name + "_writable", + }) + } + if acc.Signer { + structGroup.Id(fieldName + "Signer").Bool().Tag(map[string]string{ + "json": acc.Name + "_signer", + }) + } + if acc.Optional { + structGroup.Id(fieldName + "Optional").Bool().Tag(map[string]string{ + "json": acc.Name + "_optional", + }) + } + } + case *idl.IdlInstructionAccounts: + { + // Handle account groups (not fully implemented yet) + structGroup.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + } + }) + + // Generate GetDiscriminator method (required by Instruction interface) + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("GetDiscriminator"). + Params(). + Params(Index().Byte()). + Block( + Return(Id(FormatInstructionDiscriminatorName(tools.ToCamelUpper(instruction.Name))).Index(Op(":"))), + ) + + // Generate UnmarshalWithDecoder method + code.Line().Line() + code.Commentf("UnmarshalWithDecoder unmarshals the %s from Borsh-encoded bytes prefixed with its discriminator.", typeName).Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("UnmarshalWithDecoder"). + Params(Id("decoder").Op("*").Qual(PkgBinary, "Decoder")). + Params(Error()). + BlockFunc(func(block *Group) { + // Note: discriminator has already been read and validated by the parser + // Read instruction arguments + if len(instruction.Args) > 0 { + block.Var().Id("err").Error() + } + { + // Read the discriminator and check it against the expected value + block.Comment("Read the discriminator and check it against the expected value:") + block.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + block.If(Err().Op("!=").Nil()).Block( + Return(Qual("fmt", "Errorf").Call(Lit("failed to read instruction discriminator for %s: %w"), Lit(typeName), Err())), + ) + block.If(Id("discriminator").Op("!=").Id(FormatInstructionDiscriminatorName(tools.ToCamelUpper(instruction.Name)))).Block( + Return( + Qual("fmt", "Errorf").Call( + Lit("instruction discriminator mismatch for %s: expected %s, got %s"), + Lit(typeName), + Id(FormatInstructionDiscriminatorName(tools.ToCamelUpper(instruction.Name))), + Id("discriminator"), + ), + ), + ) + } + for _, arg := range instruction.Args { + fieldName := uniqueArgFieldNames[arg.Name] + block.Commentf("Deserialize `%s`:", fieldName) + + if IsOption(arg.Ty) || IsCOption(arg.Ty) { + var optionalityReaderName string + switch { + case IsOption(arg.Ty): + optionalityReaderName = "ReadOption" + case IsCOption(arg.Ty): + optionalityReaderName = "ReadCOption" + } + + block.BlockFunc(func(optGroup *Group) { + optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call() + optGroup.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + optGroup.If(Id("ok")).Block( + List(Err()).Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(fieldName)), + If(Err().Op("!=").Nil()).Block( + Return(Err()), + ), + ) + }) + } else { + block.List(Err()).Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(fieldName)) + block.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + } + } + + // Note: Accounts are not typically serialized in instruction data + // They are passed as part of the transaction's account metas + // This method only deserializes the instruction arguments + + block.Return(Nil()) + }) + + // Generate account-related methods if instruction has accounts + if len(instruction.Accounts) > 0 { + // Generate UnmarshalAccountIndices method + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("UnmarshalAccountIndices"). + Params(Id("buf").Index().Byte()). + Params(Index().Uint8(), Error()). + BlockFunc(func(block *Group) { + block.Comment("UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes") + block.Id("decoder").Op(":=").Qual(PkgBinary, "NewBorshDecoder").Call(Id("buf")) + block.Id("indices").Op(":=").Make(Index().Uint8(), Lit(0)) + block.Id("index").Op(":=").Uint8().Call(Lit(0)) + block.Var().Id("err").Error() + + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + block.Commentf("Decode from %s account index", acc.Name) + block.Id("index").Op("=").Uint8().Call(Lit(0)) + block.List(Err()).Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("index")) + block.If(Err().Op("!=").Nil()).Block( + Return(Nil(), Qual("fmt", "Errorf").Call(Lit("failed to decode %s account index: %w"), Lit(acc.Name), Err())), + ) + block.Id("indices").Op("=").Append(Id("indices"), Id("index")) + } + case *idl.IdlInstructionAccounts: + { + block.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + + block.Return(Id("indices"), Nil()) + }) + + // Generate PopulateFromAccountIndices method + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("PopulateFromAccountIndices"). + Params(Id("indices").Index().Uint8(), Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey")). + Params(Error()). + BlockFunc(func(block *Group) { + block.Comment("PopulateFromAccountIndices sets account public keys from indices and account keys array") + + // Count expected accounts + expectedAccountCount := 0 + for _, account := range instruction.Accounts { + switch account.(type) { + case *idl.IdlInstructionAccount: + expectedAccountCount++ + } + } + + block.If(Len(Id("indices")).Op("!=").Lit(expectedAccountCount)).Block( + Return(Qual("fmt", "Errorf").Call(Lit("mismatch between expected accounts (%d) and provided indices (%d)"), Lit(expectedAccountCount), Len(Id("indices")))), + ) + + block.Id("indexOffset").Op(":=").Lit(0) + + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + fieldName := tools.ToCamelUpper(acc.Name) + block.Commentf("Set %s account from index", acc.Name) + block.If(Id("indices").Index(Id("indexOffset")).Op(">=").Uint8().Call(Len(Id("accountKeys")))).Block( + Return(Qual("fmt", "Errorf").Call(Lit("account index %d for %s is out of bounds (max: %d)"), Id("indices").Index(Id("indexOffset")), Lit(acc.Name), Len(Id("accountKeys")).Op("-").Lit(1))), + ) + block.Id("obj").Dot(fieldName).Op("=").Id("accountKeys").Index(Id("indices").Index(Id("indexOffset"))) + block.Id("indexOffset").Op("++") + } + case *idl.IdlInstructionAccounts: + { + block.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + + block.Return(Nil()) + }) + + // Generate GetAccountKeys method + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("GetAccountKeys"). + Params(). + Params(Index().Qual(PkgSolanaGo, "PublicKey")). + BlockFunc(func(block *Group) { + block.Id("keys").Op(":=").Make(Index().Qual(PkgSolanaGo, "PublicKey"), Lit(0)) + + for _, account := range instruction.Accounts { + switch acc := account.(type) { + case *idl.IdlInstructionAccount: + { + fieldName := tools.ToCamelUpper(acc.Name) + block.Id("keys").Op("=").Append(Id("keys"), Id("obj").Dot(fieldName)) + } + case *idl.IdlInstructionAccounts: + { + block.Commentf("Account group: %s (not fully supported)", acc.Name) + } + } + } + + block.Return(Id("keys")) + }) + } else { + // Generate empty implementations for instructions without accounts + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("UnmarshalAccountIndices"). + Params(Id("buf").Index().Byte()). + Params(Index().Uint8(), Error()). + Block( + Return(Index().Uint8().Op("{}"), Nil()), + ) + + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("PopulateFromAccountIndices"). + Params(Id("indices").Index().Uint8(), Id("accountKeys").Index().Qual(PkgSolanaGo, "PublicKey")). + Params(Error()). + Block( + Return(Nil()), + ) + + code.Line().Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("GetAccountKeys"). + Params(). + Params(Index().Qual(PkgSolanaGo, "PublicKey")). + Block( + Return(Index().Qual(PkgSolanaGo, "PublicKey").Op("{}")), + ) + } + + // Generate Unmarshal method + code.Line().Line() + code.Commentf("Unmarshal unmarshals the %s from Borsh-encoded bytes prefixed with the discriminator.", typeName).Line() + code.Func().Params(Id("obj").Op("*").Id(typeName)).Id("Unmarshal"). + Params(Id("buf").Index().Byte()). + Params(Error()). + BlockFunc(func(block *Group) { + block.Var().Id("err").Error() + block.List(Err()).Op("=").Id("obj").Dot("UnmarshalWithDecoder").Call( + Qual(PkgBinary, "NewBorshDecoder").Call(Id("buf")), + ) + block.If(Err().Op("!=").Nil()).Block( + Return( + Qual("fmt", "Errorf").Call( + Lit("error while unmarshaling "+typeName+": %w"), + Err(), + ), + ), + ) + block.Return(Nil()) + }) + + // Generate Unmarshal function + code.Line().Line() + code.Commentf("Unmarshal%s unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator.", typeName).Line() + code.Func().Id("Unmarshal"+typeName). + Params(Id("buf").Index().Byte()). + Params(Op("*").Id(typeName), Error()). + BlockFunc(func(block *Group) { + block.Id("obj").Op(":=").New(Id(typeName)) + block.Var().Id("err").Error() + block.List(Err()).Op("=").Id("obj").Dot("Unmarshal").Call(Id("buf")) + block.If(Err().Op("!=").Nil()).Block( + Return(Nil(), Err()), + ) + block.Return(Id("obj"), Nil()) + }) + + return code, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/instructions_test.go b/cmd/generate-bindings/solana/anchor-go/generator/instructions_test.go new file mode 100644 index 00000000..05e5fc74 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/instructions_test.go @@ -0,0 +1,48 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenInstructionsZeroArgsAlwaysWritesDiscriminator(t *testing.T) { + idlData := &idl.Idl{ + Instructions: []idl.IdlInstruction{ + { + Name: "ping", + Discriminator: idl.IdlDiscriminator{1, 2, 3, 4, 5, 6, 7, 8}, + Args: []idl.IdlField{}, + }, + }, + } + + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + outputFile, err := gen.gen_instructions() + require.NoError(t, err) + require.NotNil(t, outputFile) + + code := outputFile.File.GoString() + + assert.Contains(t, code, "NewBorshEncoder", + "zero-arg instruction builder must allocate an encoder for the discriminator") + assert.Contains(t, code, "WriteBytes", + "zero-arg instruction builder must write the discriminator bytes") + assert.NotContains(t, code, "nil, // No arguments to encode", + "zero-arg instruction must not pass nil as instruction data") + + // The generated code must reference buf__.Bytes() so the discriminator is sent. + builderIdx := strings.Index(code, "func NewPingInstruction") + require.Greater(t, builderIdx, 0, "expected to find NewPingInstruction in generated code") + builderSnippet := code[builderIdx:] + assert.Contains(t, builderSnippet, "buf__.Bytes()", + "zero-arg instruction must pass buf__.Bytes() (containing the discriminator) as instruction data") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/is.go b/cmd/generate-bindings/solana/anchor-go/generator/is.go new file mode 100644 index 00000000..46050b74 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/is.go @@ -0,0 +1,101 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import "github.com/gagliardetto/anchor-go/idl/idltype" + +func IsOption(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Option: + return true + default: + return false + } +} + +func IsCOption(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.COption: + return true + default: + return false + } +} + +func IsDefined(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Defined: + return true + default: + return false + } +} + +func IsVec(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Vec: + return true + default: + return false + } +} + +func IsArray(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Array: + return true + default: + return false + } +} + +func IsIDLTypeKind(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Bool: + return true + case *idltype.U8: + return true + case *idltype.I8: + return true + case *idltype.U16: + return true + case *idltype.I16: + return true + case *idltype.U32: + return true + case *idltype.I32: + return true + case *idltype.F32: + return true + case *idltype.U64: + return true + case *idltype.I64: + return true + case *idltype.F64: + return true + case *idltype.U128: + return true + case *idltype.I128: + return true + case *idltype.U256: + return true + case *idltype.I256: + return true + case *idltype.Bytes: + return true + case *idltype.String: + return true + case *idltype.Pubkey: + return true + default: + return false + } +} + +func IsBool(v idltype.IdlType) bool { + switch v.(type) { + case *idltype.Bool: + return true + default: + return false + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go new file mode 100644 index 00000000..cc024ac6 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go @@ -0,0 +1,444 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" +) + +func (g *Generator) gen_MarshalWithEncoder_struct( + idl_ *idl.Idl, + withDiscriminator bool, + receiverTypeName string, + discriminatorName string, + fields idl.IdlDefinedFields, + checkNil bool, +) Code { + code := Empty() + { + // Declare MarshalWithEncoder + code.Func().Params(Id("obj").Id(receiverTypeName)).Id("MarshalWithEncoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("encoder").Op("*").Qual(PkgBinary, "Encoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + if withDiscriminator && discriminatorName != "" { + body.Comment("Write account discriminator:") + body.Err().Op("=").Id("encoder").Dot("WriteBytes").Call(Id(discriminatorName).Index(Op(":")), False()) + body.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + } + switch fields := fields.(type) { + case idl.IdlDefinedFieldsNamed: + uniqueFieldNames := generateUniqueFieldNames(fields) + g.gen_marshal_DefinedFieldsNamed( + body, + fields, + checkNil, + func(field idl.IdlField) *Statement { + return Id("obj").Dot(uniqueFieldNames[field.Name]) + }, + "encoder", + false, // returnNilErr + func(field idl.IdlField) string { + return uniqueFieldNames[field.Name] + }, + ) + case idl.IdlDefinedFieldsTuple: + convertedFields := tupleToFieldsNamed(fields) + uniqueFieldNames := generateUniqueFieldNames(convertedFields) + g.gen_marshal_DefinedFieldsNamed( + body, + convertedFields, + checkNil, + func(field idl.IdlField) *Statement { + return Id("obj").Dot(uniqueFieldNames[field.Name]) + }, + "encoder", + false, // returnNilErr + func(field idl.IdlField) string { + return uniqueFieldNames[field.Name] + }, + ) + case nil: + // No fields, just an empty struct. + // TODO: should we panic here? + default: + panic(fmt.Sprintf("unexpected fields type: %T", fields)) + } + + body.Return(Nil()) + }) + } + { + code.Line().Line() + // also generate a + // func (obj ) Marshal() ([]byte, error) { + // return obj.MarshalWithEncoder(bin.NewBorshEncoder(buf)) + // } + // func (obj ) Marshal() ([]byte, error) { + // buf := new(bytes.Buffer) + // enc := bin.NewBorshEncoder(buf) + // err := enc.Encode(meta) + // if err != nil { + // return nil, err + // } + // return buf.Bytes(), nil + // } + code.Func().Params(Id("obj").Id(receiverTypeName)).Id("Marshal"). + Params( + ListFunc(func(results *Group) { + // no parameters + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Index().Byte() + results.Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + body.Id("buf").Op(":=").Qual("bytes", "NewBuffer").Call(Nil()) + body.Id("encoder").Op(":=").Qual(PkgBinary, "NewBorshEncoder").Call(Id("buf")) + body.Err().Op(":=").Id("obj").Dot("MarshalWithEncoder").Call(Id("encoder")) + body.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call( + Lit("error while encoding "+receiverTypeName+": %w"), + Err(), + ), + ), + ) + body.Return( + Id("buf").Dot("Bytes").Call(), + Nil(), + ) + }) + } + + return code +} + +func (g *Generator) gen_marshal_DefinedFieldsNamed( + body *Group, + fields idl.IdlDefinedFieldsNamed, + checkNil bool, + nameFormatter func(field idl.IdlField) *Statement, + encoderVariableName string, + returnNilErr bool, + traceNameFormatter func(field idl.IdlField) string, +) { + for _, field := range fields { + exportedArgName := traceNameFormatter(field) + if IsOption(field.Ty) || IsCOption(field.Ty) { + body.Commentf("Serialize `%s` (optional):", exportedArgName) + } else { + body.Commentf("Serialize `%s`:", exportedArgName) + } + + if g.isComplexEnum(field.Ty) || (IsArray(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || g.isOptionalComplexEnum(field.Ty) { + switch field.Ty.(type) { + case *idltype.Defined: + enumTypeName := field.Ty.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + argBody.Err().Op(":=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)) + argBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ) + }) + case *idltype.Array: + enumTypeName := field.Ty.(*idltype.Array).Type.(*idltype.Defined).Name + // TODO: handle array length, which is defined in the type. + body.BlockFunc(func(argBody *Group) { + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Len(nameFormatter(field)), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.Err().Op(":=").Id(formatEnumEncoderName(enumTypeName)).Call( + Id(encoderVariableName), + nameFormatter(field).Index(Id("i")), + ) + forBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ) + }, + ), + ) + }) + }) + case *idltype.Vec: + enumTypeName := field.Ty.(*idltype.Vec).Vec.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + argBody.Err().Op(":=").Id(encoderVariableName).Dot("WriteLength").Call( + Len(nameFormatter(field)), + ) + argBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while writing vector length: %w"), + Err(), + ), + ) + }, + ), + ) + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Len(nameFormatter(field)), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.Err().Op(":=").Id(formatEnumEncoderName(enumTypeName)).Call( + Id(encoderVariableName), + nameFormatter(field).Index(Id("i")), + ) + forBody.If( + Err().Op("!=").Nil(), + ).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ) + }, + ), + ) + }) + }) + case *idltype.Option: + enumTypeName := field.Ty.(*idltype.Option).Option.(*idltype.Defined).Name + gen_marshal_optionalComplexEnum(body, "WriteOption", enumTypeName, field, checkNil, nameFormatter, encoderVariableName, returnNilErr, exportedArgName) + case *idltype.COption: + enumTypeName := field.Ty.(*idltype.COption).COption.(*idltype.Defined).Name + gen_marshal_optionalComplexEnum(body, "WriteCOption", enumTypeName, field, checkNil, nameFormatter, encoderVariableName, returnNilErr, exportedArgName) + } + } else { + if IsOption(field.Ty) || IsCOption(field.Ty) { + var optionalityWriterName string + if IsOption(field.Ty) { + optionalityWriterName = "WriteOption" + } else { + optionalityWriterName = "WriteCOption" + } + if checkNil { + body.BlockFunc(func(optGroup *Group) { + // if nil: + optGroup.If(nameFormatter(field).Op("==").Nil()).Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(False()), + If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while encoding optionality: %w"), + Err(), + ), + ) + }, + ), + ), + ).Else().Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()), + If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while encoding optionality: %w"), + Err(), + ), + ) + }, + ), + ), + Err().Op("=").Id(encoderVariableName).Dot("Encode").Call(nameFormatter(field)), + If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ), + ) + }) + } else { + body.BlockFunc(func(optGroup *Group) { + // TODO: make optional fields of accounts a pointer. + // Write as if not nil: + optGroup.Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()) + optGroup.If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while encoding optionality: %w"), + Err(), + ), + ) + }, + ), + ) + optGroup.Err().Op("=").Id(encoderVariableName).Dot("Encode").Call(nameFormatter(field)) + optGroup.If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ) + }) + } + } else { + body.Err().Op("=").Id(encoderVariableName).Dot("Encode").Call(nameFormatter(field)) + body.If(Err().Op("!=").Nil()).Block( + ReturnFunc( + func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ) + }, + ), + ) + } + } + } +} + +func gen_marshal_optionalComplexEnum( + body *Group, + optionalityWriterName string, + enumTypeName string, + field idl.IdlField, + checkNil bool, + nameFormatter func(field idl.IdlField) *Statement, + encoderVariableName string, + returnNilErr bool, + exportedArgName string, +) { + errReturn := func(wrapped Code) *Statement { + return ReturnFunc(func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Add(wrapped) + }) + } + optionalityErr := func() *Statement { + return errReturn( + Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call(Lit("error while encoding optionality: %w"), Err()), + ), + ) + } + fieldErr := func() *Statement { + return errReturn( + Qual(PkgAnchorGoErrors, "NewField").Call(Lit(exportedArgName), Err()), + ) + } + + if checkNil { + body.BlockFunc(func(optGroup *Group) { + optGroup.If(nameFormatter(field).Op("==").Nil()).Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(False()), + If(Err().Op("!=").Nil()).Block(optionalityErr()), + ).Else().Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()), + If(Err().Op("!=").Nil()).Block(optionalityErr()), + Err().Op("=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)), + If(Err().Op("!=").Nil()).Block(fieldErr()), + ) + }) + } else { + body.BlockFunc(func(optGroup *Group) { + optGroup.Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()) + optGroup.If(Err().Op("!=").Nil()).Block(optionalityErr()) + optGroup.Err().Op("=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)) + optGroup.If(Err().Op("!=").Nil()).Block(fieldErr()) + }) + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/tests.go b/cmd/generate-bindings/solana/anchor-go/generator/tests.go new file mode 100644 index 00000000..64389fed --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/tests.go @@ -0,0 +1,18 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + . "github.com/dave/jennifer/jen" +) + +func (g *Generator) gen_tests() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains tests.") + { + } + return &OutputFile{ + Name: "tests_test.go", + File: file, + }, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/tools.go b/cmd/generate-bindings/solana/anchor-go/generator/tools.go new file mode 100644 index 00000000..91dd956a --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/tools.go @@ -0,0 +1,69 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "bytes" + "os" + "path" + + . "github.com/dave/jennifer/jen" +) + +const ( + PkgBinary = "github.com/gagliardetto/binary" + PkgSolanaGo = "github.com/gagliardetto/solana-go" + PkgSolanaGoText = "github.com/gagliardetto/solana-go/text" + PkgAnchorGoErrors = "github.com/gagliardetto/anchor-go/errors" + + // TODO: use or remove this: + PkgTreeout = "github.com/gagliardetto/treeout" + PkgFormat = "github.com/gagliardetto/solana-go/text/format" + PkgGoFuzz = "github.com/gagliardetto/gofuzz" + PkgTestifyRequire = "github.com/stretchr/testify/require" +) + +func WriteFile(outDir string, assetFileName string, file *File) error { + assetFilepath := path.Join(outDir, assetFileName) + var buf bytes.Buffer + if err := file.Render(&buf); err != nil { + return err + } + return os.WriteFile(assetFilepath, buf.Bytes(), 0o644) +} + +func DoGroup(f func(*Group)) *Statement { + g := &Group{} + g.CustomFunc(Options{ + Multi: false, + }, f) + s := newStatement() + *s = append(*s, g) + return s +} + +func DoGroupMultiline(f func(*Group)) *Statement { + g := &Group{} + g.CustomFunc(Options{ + Multi: true, + }, f) + s := newStatement() + *s = append(*s, g) + return s +} + +func ListMultiline(f func(*Group)) *Statement { + g := &Group{} + g.CustomFunc(Options{ + Multi: true, + Separator: ",", + Open: "", + Close: " ", + }, f) + s := newStatement() + *s = append(*s, g) + return s +} + +func newStatement() *Statement { + return &Statement{} +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/types.go b/cmd/generate-bindings/solana/anchor-go/generator/types.go new file mode 100644 index 00000000..ea39a20d --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/types.go @@ -0,0 +1,418 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/davecgh/go-spew/spew" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +// genfile_types generates the file `types.go`. +func (g *Generator) genfile_types() (*OutputFile, error) { + file := NewFile(g.options.Package) + file.HeaderComment("Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT.") + file.HeaderComment("This file contains parsers for the types defined in the IDL.") + + { + for index, typ := range g.idl.Types { + code, err := g.gen_IDLTypeDef(typ) + if err != nil { + return nil, fmt.Errorf("error generating type %d: %w", index, err) + } + file.Add(code) + } + } + + return &OutputFile{ + Name: "types.go", + File: file, + }, nil +} + +// `def.Type` is `IDLTypeDefTy` (which is an interface): +// either `IDLTypeDefTyEnum` or `IDLTypeDefTyStruct`. +func (g *Generator) gen_IDLTypeDef(def idl.IdlTypeDef) (Code, error) { + switch vv := def.Ty.(type) { + case *idl.IdlTypeDefTyStruct: + return g.gen_IDLTypeDefTyStruct(def.Name, def.Docs, *vv, false) + case *idl.IdlTypeDefTyEnum: + return g.gen_IDLTypeDefTyEnum(def.Name, def.Docs, *vv) + default: + panic(fmt.Errorf("unhandled type: %T", vv)) + } +} + +func (g *Generator) gen_IDLTypeDefTyEnum(name string, docs []string, typ idl.IdlTypeDefTyEnum) (Code, error) { + if typ.Variants.IsAllSimple() { + return g.gen_simpleEnum(name, docs, typ) + } + return g.gen_complexEnum(name, docs, typ) +} + +func (g *Generator) gen_simpleEnum(name string, docs []string, typ idl.IdlTypeDefTyEnum) (Code, error) { + st := newStatement() + + code := newStatement() + enumTypeName := tools.ToCamelUpper(name) + + addComments(code, docs) + { + code.Type().Id(enumTypeName).Qual(PkgBinary, "BorshEnum") + code.Line().Const().Parens(DoGroup(func(gr *Group) { + for variantIndex, variant := range typ.Variants { + // TODO: enum variants should have docs too. + // for docIndex, doc := range variant.Docs { + // if docIndex == 0 { + // gr.Line() + // } + // gr.Comment(doc).Line() + // } + + gr.Id(formatSimpleEnumVariantName(variant.Name, enumTypeName)).Add(func() Code { + if variantIndex == 0 { + return Id(enumTypeName).Op("=").Iota() + } + return nil + }()).Line() + } + // TODO: check for fields, etc. + })) + + // Generate stringer for the uint8 enum values: + code.Line().Line().Func().Params(Id("value").Id(enumTypeName)).Id("String"). + Params(). + Params(String()). + BlockFunc(func(body *Group) { + body.Switch(Id("value")).BlockFunc(func(switchBlock *Group) { + for _, variant := range typ.Variants { + switchBlock.Case(Id(formatSimpleEnumVariantName(variant.Name, enumTypeName))).Line().Return(Lit(variant.Name)) + } + switchBlock.Default().Line().Return(Lit("")) + }) + }) + st.Add(code.Line()) + } + return st, nil +} + +func addComments(code *Statement, docs []string) { + for _, doc := range docs { + code.Line() + code.Comment(doc) + } + if len(docs) > 0 { + code.Line() + } +} + +func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeDefTyEnum) (Code, error) { + st := newStatement() + + code := newStatement() + enumTypeName := tools.ToCamelUpper(name) + + // Add comments for the enum type: + addComments(code, docs) + { + g.registerComplexEnumType(name) + containerName := formatEnumContainerName(enumTypeName) + interfaceMethodName := formatInterfaceMethodName(enumTypeName) + + // Declare the interface of the enum type: + code.Commentf("The %q interface for the %q complex enum.", interfaceMethodName, enumTypeName).Line() + code.Type().Id(enumTypeName).Interface( + Id(interfaceMethodName).Call(), + ).Line().Line() + + // Declare the enum variants container (non-exported, used internally) + code.Type().Id(containerName).StructFunc( + func(structGroup *Group) { + structGroup.Id("Enum").Qual(PkgBinary, "BorshEnum").Tag(map[string]string{ + "bin": "enum", + }) + + for _, variant := range typ.Variants { + structGroup.Id(tools.ToCamelUpper(variant.Name)).Id(formatComplexEnumVariantTypeName(enumTypeName, variant.Name)) + } + }, + ).Line().Line() + + // Declare parser function for the enum type: + code.Func().Id(formatEnumParserName(enumTypeName)).Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("decoder").Op("*").Qual(PkgBinary, "Decoder") + }), + ).Params( + ListFunc(func(results *Group) { + // Results: + results.Id(enumTypeName) + results.Error() + }), + ). + BlockFunc(func(body *Group) { + enumName := enumTypeName + body.BlockFunc(func(argBody *Group) { + argBody.List(Id("tmp")).Op(":=").New(Id(formatEnumContainerName(enumName))) + + argBody.Err().Op(":=").Id("decoder").Dot("Decode").Call(Id("tmp")) + + argBody.If( + Err().Op("!=").Nil(), + ).Block( + Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit("failed parsing "+enumTypeName+": %w"), Err()), + ), + ) + + argBody.Switch(Id("tmp").Dot("Enum")). + BlockFunc(func(switchGroup *Group) { + interfaceType := g.idl.Types.ByName(name) + if interfaceType == nil { + panic(fmt.Errorf("complex enum type %q not found in IDL types", name)) + } + + for variantIndex, variant := range interfaceType.Ty.(*idl.IdlTypeDefTyEnum).Variants { + switchGroup.Case(Lit(variantIndex)). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Op("&").Id("tmp").Dot(tools.ToCamelUpper(variant.Name)), + Nil(), + ) + }) + } + switchGroup.Default(). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Nil(), + Qual("fmt", "Errorf").Call(Lit(enumTypeName+": unknown enum index: %v"), Id("tmp").Dot("Enum")), + ) + }) + }) + }) + }).Line().Line() + + // Declare the marshaler for the enum type: + code.Func().Id(formatEnumEncoderName(enumTypeName)).Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("encoder").Op("*").Qual(PkgBinary, "Encoder") + params.Id("value").Id(enumTypeName) + }), + ).Params( + ListFunc(func(results *Group) { + // Results: + results.Error() + }), + ). + BlockFunc(func(body *Group) { + body.BlockFunc(func(argBody *Group) { + argBody.List(Id("tmp")).Op(":=").Id(formatEnumContainerName(enumTypeName)).Block() + argBody.Switch(Id("realvalue").Op(":=").Id("value").Op(".").Parens(Type())). + BlockFunc(func(switchGroup *Group) { + switchGroup.Case(Nil()). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Qual("fmt", "Errorf").Call(Lit(enumTypeName + ": cannot encode nil value")), + ) + }) + + interfaceType := g.idl.Types.ByName(name) + if interfaceType == nil { + panic(fmt.Errorf("complex enum type %q not found in IDL types", name)) + } + for variantIndex, variant := range interfaceType.Ty.(*idl.IdlTypeDefTyEnum).Variants { + variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) + + switchGroup.Case(Op("*").Id(variantTypeNameStruct)). + BlockFunc(func(caseGroup *Group) { + caseGroup.If(Id("realvalue").Op("==").Nil()).Block( + Return(Qual("fmt", "Errorf").Call(Lit(enumTypeName+": cannot encode nil *"+variantTypeNameStruct))), + ) + caseGroup.Id("tmp").Dot("Enum").Op("=").Lit(variantIndex) + caseGroup.Id("tmp").Dot(tools.ToCamelUpper(variant.Name)).Op("=").Op("*").Id("realvalue") + }) + } + + switchGroup.Default(). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Qual("fmt", "Errorf").Call(Lit(enumTypeName+": unknown variant type %T"), Id("value")), + ) + }) + }) + + argBody.Return(Id("encoder").Dot("Encode").Call(Id("tmp"))) + }) + }).Line().Line() + + for _, variant := range typ.Variants { + // Name of the variant type if the enum is a complex enum (i.e. enum variants are inline structs): + variantTypeNameComplex := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) + + // Declare the enum variant types: + if variant.IsSimple() { + code.Type().Id(variantTypeNameComplex).Qual(PkgBinary, "EmptyVariant").Line().Line() + } else if variant.Fields.IsSome() { + code.Commentf("Variant %q of enum %q", variant.Name, enumTypeName).Line() + code.Type().Id(variantTypeNameComplex).StructFunc( + func(structGroup *Group) { + switch fields := variant.Fields.Unwrap().(type) { + case idl.IdlDefinedFieldsNamed: + for _, variantField := range fields { + optionality := IsOption(variantField.Ty) || IsCOption(variantField.Ty) + structGroup.Add(g.genField(variantField, optionality)). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(variantField.Ty) { + tagMap["bin"] = "optional" + } + if IsCOption(variantField.Ty) { + tagMap["bin"] = "coption" + } + // add json tag: + tagMap["json"] = tools.ToCamelLower(variantField.Name) + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + case idl.IdlDefinedFieldsTuple: + for itemIndex, tupleItem := range fields { + optionality := IsOption(tupleItem) || IsCOption(tupleItem) + tupleItemName := FormatTupleItemName(itemIndex) + structGroup.Add(g.genFieldNamed(tupleItemName, tupleItem, optionality)). + Add(func() Code { + tagMap := map[string]string{} + if IsOption(tupleItem) { + tagMap["bin"] = "optional" + } + if IsCOption(tupleItem) { + tagMap["bin"] = "coption" + } + // add json tag: + tagMap["json"] = tools.ToCamelLower(tupleItemName) + func() string { + if optionality { + return ",omitempty" + } + return "" + }() + return Tag(tagMap) + }()) + } + default: + panic("not handled: " + spew.Sdump(variant.Fields)) + } + }, + ).Line().Line() + } + + if variant.IsSimple() { + // Declare MarshalWithEncoder + code.Line().Line().Func().Params(Id("obj").Id(variantTypeNameComplex)).Id("MarshalWithEncoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("encoder").Op("*").Qual(PkgBinary, "Encoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + body.Return(Nil()) + }) + code.Line().Line() + + // Declare UnmarshalWithDecoder + code.Func().Params(Id("obj").Op("*").Id(variantTypeNameComplex)).Id("UnmarshalWithDecoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("decoder").Op("*").Qual(PkgBinary, "Decoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + body.Return(Nil()) + }) + code.Line().Line() + } else if variant.Fields.IsSome() { + switch fields := variant.Fields.Unwrap().(type) { + case idl.IdlDefinedFieldsNamed: + // Declare MarshalWithEncoder: + code.Line().Line().Add( + g.gen_MarshalWithEncoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + true, + )) + + // Declare UnmarshalWithDecoder + code.Line().Line().Add( + g.gen_UnmarshalWithDecoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + )) + code.Line().Line() + case idl.IdlDefinedFieldsTuple: + // TODO: handle tuples + // Declare MarshalWithEncoder: + code.Line().Line().Add( + g.gen_MarshalWithEncoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + true, + )) + + // Declare UnmarshalWithDecoder + code.Line().Line().Add( + g.gen_UnmarshalWithDecoder_struct( + g.idl, + false, + variantTypeNameComplex, + "", + fields, + )) + code.Line().Line() + default: + panic("not handled: " + spew.Sdump(variant.Fields)) + } + } + + // Declare the method to implement the parent enum interface: + if variant.IsSimple() { + code.Func().Params(Id("_").Op("*").Id(variantTypeNameComplex)).Id(interfaceMethodName).Params().Block().Line().Line() + } else { + code.Func().Params(Id("_").Op("*").Id(variantTypeNameComplex)).Id(interfaceMethodName).Params().Block().Line().Line() + } + } + + st.Add(code.Line().Line()) + } + return st, nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/types_test.go b/cmd/generate-bindings/solana/anchor-go/generator/types_test.go new file mode 100644 index 00000000..72712eb8 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/types_test.go @@ -0,0 +1,103 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makeComplexEnumIDL(enumName string) *idl.Idl { + enumType := &idl.IdlTypeDefTyEnum{ + Kind: "enum", + Variants: idl.VariantSlice{ + {Name: "Simple"}, + { + Name: "WithFields", + Fields: idl.Some[idl.IdlDefinedFields](idl.IdlDefinedFieldsNamed{ + {Name: "value", Ty: &idltype.U64{}}, + }), + }, + }, + } + + return &idl.Idl{ + Types: idl.IdTypeDef_slice{ + { + Name: enumName, + Ty: enumType, + }, + }, + } +} + +func TestGenComplexEnum_ConsecutiveUppercase(t *testing.T) { + // "HTTPStatus" is stored in the IDL as-is. ToCamelUpper converts it to + // "HttpStatus" (via snake_case intermediary), so ByName("HttpStatus") + // won't find the original "HTTPStatus" entry and returns nil. + idlData := makeComplexEnumIDL("HTTPStatus") + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } + + // Register the complex enum as the generator normally would. + for _, typ := range gen.idl.Types { + gen.registerComplexEnums(typ) + } + + outputFile, err := gen.genfile_types() + require.NoError(t, err, "genfile_types should not panic or error for enum named HTTPStatus") + require.NotNil(t, outputFile) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "HttpStatus") +} + +func TestGenComplexEnum_SnakeCaseName(t *testing.T) { + // "my_status" is stored in the IDL. ToCamelUpper converts it to + // "MyStatus", so ByName("MyStatus") won't find "my_status". + idlData := makeComplexEnumIDL("my_status") + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } + + for _, typ := range gen.idl.Types { + gen.registerComplexEnums(typ) + } + + outputFile, err := gen.genfile_types() + require.NoError(t, err, "genfile_types should not panic or error for enum named my_status") + require.NotNil(t, outputFile) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "MyStatus") +} + +func TestGenComplexEnum_AlreadyCamelCase(t *testing.T) { + // "MyStatus" is already CamelCase. ToCamelUpper("MyStatus") == "MyStatus", + // so ByName should find it. This should always work. + idlData := makeComplexEnumIDL("MyStatus") + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } + + for _, typ := range gen.idl.Types { + gen.registerComplexEnums(typ) + } + + outputFile, err := gen.genfile_types() + require.NoError(t, err, "genfile_types should not panic or error for enum named MyStatus") + require.NotNil(t, outputFile) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "MyStatus") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/u256_test.go b/cmd/generate-bindings/solana/anchor-go/generator/u256_test.go new file mode 100644 index 00000000..13fb77ae --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/u256_test.go @@ -0,0 +1,148 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIDLTypeKind_ToTypeDeclCode_U256(t *testing.T) { + assert.NotPanics(t, func() { + result := IDLTypeKind_ToTypeDeclCode(&idltype.U256{}) + assert.NotNil(t, result) + }) +} + +func TestIDLTypeKind_ToTypeDeclCode_I256(t *testing.T) { + assert.NotPanics(t, func() { + result := IDLTypeKind_ToTypeDeclCode(&idltype.I256{}) + assert.NotNil(t, result) + }) +} + +func TestGenTypeName_U256(t *testing.T) { + assert.NotPanics(t, func() { + result := genTypeName(&idltype.U256{}) + assert.NotNil(t, result) + }) +} + +func TestGenTypeName_I256(t *testing.T) { + assert.NotPanics(t, func() { + result := genTypeName(&idltype.I256{}) + assert.NotNil(t, result) + }) +} + +func TestGenConstants_U256(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{ + { + Name: "MAX_SUPPLY", + Ty: &idltype.U256{}, + Value: "115792089237316195423570985008687907853269984665640564039457584007913129639935", + }, + }, + } + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var MAX_SUPPLY = func() *big.Int") + assert.Contains(t, generatedCode, ".SetString(\"115792089237316195423570985008687907853269984665640564039457584007913129639935\", 10)") +} + +func TestGenConstants_I256(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{ + { + Name: "MIN_VALUE", + Ty: &idltype.I256{}, + Value: "-57896044618658097711785492504343953926634992332820282019728792003956564819968", + }, + }, + } + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var MIN_VALUE = func() *big.Int") + assert.Contains(t, generatedCode, ".SetString(\"-57896044618658097711785492504343953926634992332820282019728792003956564819968\", 10)") +} + +func TestGenConstants_U256_Invalid(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{ + { + Name: "INVALID_U256", + Ty: &idltype.U256{}, + Value: "not_a_number", + }, + }, + } + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + _, err := gen.gen_constants() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse u256") +} + +func TestGenConstants_I256_Invalid(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{ + { + Name: "INVALID_I256", + Ty: &idltype.I256{}, + Value: "not_a_number", + }, + }, + } + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + _, err := gen.gen_constants() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse i256") +} + +func TestGenConstants_U256_WithUnderscores(t *testing.T) { + idlData := &idl.Idl{ + Constants: []idl.IdlConst{ + { + Name: "LARGE_U256", + Ty: &idltype.U256{}, + Value: "1_000_000_000_000_000_000_000_000_000", + }, + }, + } + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + outputFile, err := gen.gen_constants() + require.NoError(t, err) + + generatedCode := outputFile.File.GoString() + assert.Contains(t, generatedCode, "var LARGE_U256 = func() *big.Int") + assert.Contains(t, generatedCode, ".SetString(\"1000000000000000000000000000\", 10)") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go new file mode 100644 index 00000000..a0f45fb1 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go @@ -0,0 +1,432 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "fmt" + + . "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/anchor-go/tools" +) + +func formatComplexEnumVariantTypeName(enumTypeName string, variantName string) string { + return fmt.Sprintf("%s_%s", tools.ToCamelUpper(enumTypeName), tools.ToCamelUpper(variantName)) +} + +func formatSimpleEnumVariantName(variantName string, enumTypeName string) string { + return fmt.Sprintf("%s_%s", tools.ToCamelUpper(enumTypeName), tools.ToCamelUpper(variantName)) +} + +func FormatTupleItemName(index int) string { + return tools.ToCamelUpper(fmt.Sprintf("V%d", index)) +} + +func formatEnumContainerName(enumTypeName string) string { + return tools.ToCamelLower(enumTypeName) + "EnumContainer" +} + +func formatInterfaceMethodName(enumTypeName string) string { + return "is" + tools.ToCamelUpper(enumTypeName) +} + +func formatDiscriminatorName(kind string, exportedAccountName string) string { + // trim prefix or suffix "Account" or "Event" from exportedAccountName + exportedAccountName = tools.ToCamelUpper(exportedAccountName) + + // // TODO: sometimes there's accounts/events like this: + // // - "Fund" + // // - "FundAccount" + // // This will create a name collision and fail to compile because + // // we remove the "Account" or "Event" suffix from the second one, + // // so there's a duplicate name "Fund". + // exportedAccountName = strings.TrimSuffix(exportedAccountName, "Account") + // exportedAccountName = strings.TrimSuffix(exportedAccountName, "Event") + // exportedAccountName = strings.TrimPrefix(exportedAccountName, "Account") + // exportedAccountName = strings.TrimPrefix(exportedAccountName, "Event") + + return kind + "_" + tools.ToCamelUpper(exportedAccountName) +} + +func FormatAccountDiscriminatorName(exportedAccountName string) string { + return formatDiscriminatorName("Account", exportedAccountName) +} + +func FormatEventDiscriminatorName(exportedEventName string) string { + return formatDiscriminatorName("Event", exportedEventName) +} + +func FormatInstructionDiscriminatorName(exportedInstructionName string) string { + return formatDiscriminatorName("Instruction", exportedInstructionName) +} + +func formatBuilderFuncName(insExportedName string) string { + return "New" + insExportedName + "InstructionBuilder" +} + +func formatEnumParserName(enumTypeName string) string { + return "Decode" + enumTypeName +} + +func formatEnumEncoderName(enumTypeName string) string { + return "Encode" + enumTypeName +} + +func (g *Generator) gen_UnmarshalWithDecoder_struct( + idl_ *idl.Idl, + withDiscriminator bool, + receiverTypeName string, + discriminatorName string, + fields idl.IdlDefinedFields, +) Code { + code := Empty() + { + // Declare UnmarshalWithDecoder + code.Func().Params(Id("obj").Op("*").Id(receiverTypeName)).Id("UnmarshalWithDecoder"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("decoder").Op("*").Qual(PkgBinary, "Decoder") + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Err().Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + if withDiscriminator && discriminatorName != "" { + body.Comment("Read and check account discriminator:") + body.BlockFunc(func(discReadBody *Group) { + discReadBody.List(Id("discriminator"), Err()).Op(":=").Id("decoder").Dot("ReadDiscriminator").Call() + discReadBody.If(Err().Op("!=").Nil()).Block( + Return(Err()), + ) + discReadBody.If(Op("!").Id("discriminator").Dot("Equal").Call(Id(discriminatorName).Index(Op(":")))).Block( + Return( + Qual("fmt", "Errorf").Call( + Line().Lit("wrong discriminator: wanted %s, got %s"), + Line().Id(discriminatorName).Index(Op(":")), + Line().Qual("fmt", "Sprint").Call(Id("discriminator").Index(Op(":"))), + ), + ), + ) + }) + } + + switch fields := fields.(type) { + case idl.IdlDefinedFieldsNamed: + g.gen_unmarshal_DefinedFieldsNamed(body, fields, generateUniqueFieldNames(fields)) + case idl.IdlDefinedFieldsTuple: + convertedFields := tupleToFieldsNamed(fields) + g.gen_unmarshal_DefinedFieldsNamed(body, convertedFields, generateUniqueFieldNames(convertedFields)) + case nil: + // No fields, just an empty struct. + // TODO: should we panic here? + default: + panic(fmt.Sprintf("unexpected fields type: %T", fields)) + } + + body.Return(Nil()) + }) + } + { + code.Line().Line() + // func (obj *) Unmarshal(buf []byte) (err error) { + // return obj.UnmarshalWithDecoder(bin.NewBorshDecoder(buf)) + // } + code.Func().Params(Id("obj").Op("*").Id(receiverTypeName)).Id("Unmarshal"). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("buf").Index().Byte() + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + body.Err().Op(":=").Id("obj").Dot("UnmarshalWithDecoder").Call( + Qual(PkgBinary, "NewBorshDecoder").Call(Id("buf")), + ) + body.If(Err().Op("!=").Nil()).Block( + // If there was an error, return it. + Return( + Qual("fmt", "Errorf").Call( + Lit("error while unmarshaling "+receiverTypeName+": %w"), + Err(), + ), + ), + ) + body.Return( + Nil(), // No error. + ) + }) + } + { + code.Line().Line() + // func Unmarshal(buf []byte) (, error) { + // obj := new() + // err := obj.Unmarshal(buf) + // if err != nil { + // return nil, err + // } + // return obj, nil + // } + code.Func().Id("Unmarshal" + receiverTypeName). + Params( + ListFunc(func(params *Group) { + // Parameters: + params.Id("buf").Index().Byte() + }), + ). + Params( + ListFunc(func(results *Group) { + // Results: + results.Op("*").Id(receiverTypeName) + results.Error() + }), + ). + BlockFunc(func(body *Group) { + // Body: + body.Id("obj").Op(":=").New(Id(receiverTypeName)) + body.Err().Op(":=").Id("obj").Dot("Unmarshal").Call(Id("buf")) + body.If(Err().Op("!=").Nil()).Block( + Return( + Nil(), + Err(), + ), + ) + body.Return( + Id("obj"), + Nil(), // No error. + ) + }) + } + return code +} + +func tupleToFieldsNamed( + tuple idl.IdlDefinedFieldsTuple, +) idl.IdlDefinedFieldsNamed { + fields := make(idl.IdlDefinedFieldsNamed, len(tuple)) + for i, item := range tuple { + tupleItemName := FormatTupleItemName(i) + fields[i] = idl.IdlField{ + Name: tupleItemName, + Ty: item, + } + } + return fields +} + +func (g *Generator) gen_unmarshal_DefinedFieldsNamed( + body *Group, + fields idl.IdlDefinedFieldsNamed, + uniqueFieldNames map[string]string, +) { + for _, field := range fields { + goFieldName := uniqueFieldNames[field.Name] + exportedArgName := goFieldName + if IsOption(field.Ty) || IsCOption(field.Ty) { + body.Commentf("Deserialize `%s` (optional):", exportedArgName) + } else { + body.Commentf("Deserialize `%s`:", exportedArgName) + } + + if g.isComplexEnum(field.Ty) || (IsArray(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || g.isOptionalComplexEnum(field.Ty) { + switch field.Ty.(type) { + case *idltype.Defined: + enumName := field.Ty.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + { + argBody.Var().Err().Error() + argBody.List( + Id("obj").Dot(goFieldName), + Err(), + ).Op("=").Id(formatEnumParserName(enumName)).Call(Id("decoder")) + } + argBody.If( + Err().Op("!=").Nil(), + ).Block( + Return(Err()), + ) + }) + case *idltype.Array: + enumTypeName := field.Ty.(*idltype.Array).Type.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + // Read the array items: + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Len(Id("obj").Dot(goFieldName)), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.List( + Id("obj").Dot(goFieldName).Index(Id("i")), + Err(), + ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")) + forBody.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ), + ), + ) + }) + }) + case *idltype.Vec: + enumTypeName := field.Ty.(*idltype.Vec).Vec.(*idltype.Defined).Name + body.BlockFunc(func(argBody *Group) { + // Read the vector length: + argBody.List(Id("vecLen"), Err()).Op(":=").Id("decoder").Dot("ReadLength").Call() + argBody.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while reading vector length: %w"), + Err(), + ), + ), + ), + ) + argBody.If(Id("vecLen").Op(">").Id("decoder").Dot("Remaining").Call()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("vector length %d exceeds remaining decoder bytes %d"), + Id("vecLen"), + Id("decoder").Dot("Remaining").Call(), + ), + ), + ), + ) + // Create the vector: + argBody.Id("obj").Dot(goFieldName).Op("=").Make(Index().Id(enumTypeName), Id("vecLen")) + // Read the vector items: + argBody.For( + Id("i").Op(":=").Lit(0), + Id("i").Op("<").Id("vecLen"), + Id("i").Op("++"), + ).BlockFunc(func(forBody *Group) { + forBody.List( + Id("obj").Dot(goFieldName).Index(Id("i")), + Err(), + ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")) + forBody.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Qual(PkgAnchorGoErrors, "NewIndex").Call( + Id("i"), + Err(), + ), + ), + ), + ) + }) + }) + case *idltype.Option: + enumTypeName := field.Ty.(*idltype.Option).Option.(*idltype.Defined).Name + gen_unmarshal_optionalComplexEnum(body, "ReadOption", enumTypeName, exportedArgName) + case *idltype.COption: + enumTypeName := field.Ty.(*idltype.COption).COption.(*idltype.Defined).Name + gen_unmarshal_optionalComplexEnum(body, "ReadCOption", enumTypeName, exportedArgName) + } + } else { + if IsOption(field.Ty) || IsCOption(field.Ty) { + var optionalityReaderName string + switch { + case IsOption(field.Ty): + optionalityReaderName = "ReadOption" + case IsCOption(field.Ty): + optionalityReaderName = "ReadCOption" + } + + body.BlockFunc(func(optGroup *Group) { + // if nil: + optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call() + optGroup.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while reading optionality: %w"), + Err(), + ), + ), + ), + ) + optGroup.If(Id("ok")).Block( + Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(goFieldName)), + If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ), + ), + ), + ) + }) + } else { + body.Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(goFieldName)) + body.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ), + ), + ) + } + } + } +} + +func gen_unmarshal_optionalComplexEnum( + body *Group, + optionalityReaderName string, + enumTypeName string, + exportedArgName string, +) { + body.BlockFunc(func(optGroup *Group) { + optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call() + optGroup.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while reading optionality: %w"), + Err(), + ), + ), + ), + ) + optGroup.If(Id("ok")).Block( + List( + Id("obj").Dot(exportedArgName), + Err(), + ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")), + If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ), + ), + ), + ) + }) +} diff --git a/cmd/generate-bindings/solana/bindgen.go b/cmd/generate-bindings/solana/bindgen.go new file mode 100644 index 00000000..f9f5b98e --- /dev/null +++ b/cmd/generate-bindings/solana/bindgen.go @@ -0,0 +1,155 @@ +package solana + +import ( + "bytes" + "fmt" + "go/token" + "log/slog" + "os" + "path" + "strings" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" + bin "github.com/gagliardetto/binary" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/anchor-go/generator" +) + +func GenerateBindings( + pathToIdl string, + programName string, + outputDir string, +) error { + if pathToIdl == "" { + return fmt.Errorf("pathToIdl is empty") + } + if programName == "" { + return fmt.Errorf("programName is empty") + } + if outputDir == "" { + return fmt.Errorf("outputDir is empty") + } + if err := os.MkdirAll(outputDir, 0o777); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + slog.Info("Starting code generation", + "outputDir", outputDir, + "pathToIdl", pathToIdl, + ) + parsedIdl, err := idl.ParseFromFilepath(pathToIdl) + if err != nil { + return fmt.Errorf("failed to parse IDL: %w", err) + } + if parsedIdl == nil { + return fmt.Errorf("parsedIdl is nil") + } + if err := parsedIdl.Validate(); err != nil { + return fmt.Errorf("invalid IDL: %w", err) + } + if parsedIdl.Address == nil || parsedIdl.Address.IsZero() { + return fmt.Errorf("address is empty in idl file: %s", pathToIdl) + } + slog.Info("Using IDL address as program ID", "address", parsedIdl.Address.String()) + + parsedIdl.Metadata.Name = bin.ToSnakeForSighash(parsedIdl.Metadata.Name) + // check that the name is not a reserved keyword: + if parsedIdl.Metadata.Name != "" { + if tools.IsReservedKeyword(parsedIdl.Metadata.Name) { + slog.Warn("The IDL metadata.name is a reserved Go keyword: adding a suffix to avoid conflicts.", + "name", parsedIdl.Metadata.Name, + "reservedKeyword", token.Lookup(parsedIdl.Metadata.Name).String(), + ) + // Add a suffix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name += "_program" + } + if !tools.IsValidIdent(parsedIdl.Metadata.Name) { + // add a prefix to the name to avoid conflicts with Go reserved keywords: + parsedIdl.Metadata.Name = "my_" + parsedIdl.Metadata.Name + } + } + + packageName, err := normalizeGoPackageName(programName) + if err != nil { + return err + } + if err := generator.ValidateIDLDerivedIdentifiers(parsedIdl); err != nil { + return fmt.Errorf("IDL contains names that cannot be mapped to valid Go identifiers: %w", err) + } + + options := generator.GeneratorOptions{ + OutputDir: outputDir, + Package: packageName, + ProgramName: programName, + ProgramId: parsedIdl.Address, + } + + slog.Info("Parsed IDL successfully", + "version", parsedIdl.Metadata.Version, + "name", parsedIdl.Metadata.Name, + "address", parsedIdl.Address, + "programId", parsedIdl.Address.String(), + "instructionsCount", len(parsedIdl.Instructions), + "accountsCount", len(parsedIdl.Accounts), + "eventsCount", len(parsedIdl.Events), + "typesCount", len(parsedIdl.Types), + "constantsCount", len(parsedIdl.Constants), + "errorsCount", len(parsedIdl.Errors), + ) + + gen := generator.NewGenerator(parsedIdl, &options) + generatedFiles, err := gen.Generate() + if err != nil { + return fmt.Errorf("failed to generate: %w", err) + } + + for _, file := range generatedFiles.Files { + assetFilename := file.Name + assetFilepath := path.Join(options.OutputDir, assetFilename) + + var buf bytes.Buffer + if err := file.File.Render(&buf); err != nil { + return fmt.Errorf("failed to render generated file %q: %w", assetFilename, err) + } + + slog.Info("Writing file", + "filepath", assetFilepath, + "name", file.Name, + "modPath", options.ModPath, + ) + if err := os.WriteFile(assetFilepath, buf.Bytes(), 0o600); err != nil { + return fmt.Errorf("failed to write file %q: %w", assetFilepath, err) + } + } + slog.Info("Generation completed successfully", + "outputDir", options.OutputDir, + "modPath", options.ModPath, + "package", options.Package, + "programName", options.ProgramName, + ) + return nil +} + +// normalizeGoPackageName maps a contract filename stem or program label to a valid Go package name. +func normalizeGoPackageName(name string) (string, error) { + if strings.TrimSpace(name) == "" { + return "", fmt.Errorf("contract/program name for Go package is empty") + } + var b strings.Builder + for _, r := range strings.ToLower(name) { + if r == '-' { + b.WriteByte('_') + } else { + b.WriteRune(r) + } + } + out := b.String() + if !tools.IsValidIdent(out) { + return "", fmt.Errorf("invalid Go package name after normalization (from contract/program name %q): %q is not a valid Go identifier; use only letters, digits, and underscores, and do not start with a digit", name, out) + } + if tools.IsReservedKeyword(out) { + return "", fmt.Errorf("invalid Go package name: normalized name %q is a Go reserved keyword (from contract/program name %q)", out, name) + } + return out, nil +} diff --git a/cmd/generate-bindings/solana/bindings_test.go b/cmd/generate-bindings/solana/bindings_test.go new file mode 100644 index 00000000..d7349319 --- /dev/null +++ b/cmd/generate-bindings/solana/bindings_test.go @@ -0,0 +1,179 @@ +package solana_test + +import ( + "context" + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/test-go/testify/require" + "google.golang.org/protobuf/proto" + + ocr3types "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + solanasdk "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + solanasdkmock "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/mock" + "github.com/smartcontractkit/cre-sdk-go/cre/testutils" + consensusmock "github.com/smartcontractkit/cre-sdk-go/internal_testing/capabilities/consensus/mock" + + datastorage "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana/testdata/data_storage" +) + +const anyChainSelector = uint64(1337) + +func TestGeneratedBindingsCodec(t *testing.T) { + codec := datastorage.Codec{} + + t.Run("encode functions", func(t *testing.T) { + // structs + userData := datastorage.UserData{ + Key: "testKey", + Value: "testValue", + } + _, err := codec.EncodeUserDataStruct(userData) + require.NoError(t, err) + + testPrivKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + testPubKey := testPrivKey.PublicKey() + + logAccess := datastorage.AccessLogged{ + Caller: testPubKey, + Message: "testMessage", + } + _, err = codec.EncodeAccessLoggedStruct(logAccess) + require.NoError(t, err) + + readData := datastorage.DataAccount{ + Sender: testPubKey.String(), + Key: "testKey", + Value: "testValue", + } + _, err = codec.EncodeDataAccountStruct(readData) + require.NoError(t, err) + + storeData := datastorage.DynamicEvent{ + Key: "testKey", + UserData: userData, + Sender: testPubKey.String(), + Metadata: []byte("testMetadata"), + MetadataArray: [][]byte{}, + } + _, err = codec.EncodeDynamicEventStruct(storeData) + require.NoError(t, err) + + storeUserData := datastorage.UpdateReserves{ + TotalMinted: 100, + TotalReserve: uint64(200), + } + _, err = codec.EncodeUpdateReservesStruct(storeUserData) + require.NoError(t, err) + }) +} + +func TestWriteReportMethods(t *testing.T) { + client := &solanasdk.Client{ChainSelector: anyChainSelector} + ds, err := datastorage.NewDataStorage(client) + require.NoError(t, err, "Failed to create DataStorage instance") + + report := ocr3types.Metadata{ + Version: 1, + ExecutionID: "1234567890123456789012345678901234567890123456789012345678901234", + Timestamp: 1620000000, + DONID: 1, + DONConfigVersion: 1, + WorkflowID: "1234567890123456789012345678901234567890123456789012345678901234", + WorkflowName: "12", + WorkflowOwner: "1234567890123456789012345678901234567890", + ReportID: "1234", + } + + rawReport, err := report.Encode() + require.NoError(t, err) + + consensusCap, err := consensusmock.NewConsensusCapability(t) + require.NoError(t, err, "Failed to create Consensus capability") + consensusCap.Report = func(_ context.Context, input *sdk.ReportRequest) (*sdk.ReportResponse, error) { + return &sdk.ReportResponse{ + RawReport: rawReport, + }, nil + } + + solanaCap, err := solanasdkmock.NewClientCapability(anyChainSelector, t) + require.NoError(t, err, "Failed to create Solana client capability") + solanaCap.WriteReport = func(_ context.Context, req *solanasdk.WriteReportRequest) (*solanasdk.WriteReportReply, error) { + return &solanasdk.WriteReportReply{ + TxStatus: solanasdk.TxStatus_TX_STATUS_SUCCESS, + TxSignature: []byte{0x01, 0x02, 0x03, 0x04}, + }, nil + } + + runtime := testutils.NewRuntime(t, testutils.Secrets{}) + + reply := ds.WriteReportFromUserData(runtime, datastorage.UserData{ + Key: "testKey", + Value: "testValue", + }, nil, nil) + require.NoError(t, err, "WriteReportDataStorageUserData should not return an error") + response, err := reply.Await() + require.NoError(t, err, "Awaiting WriteReportDataStorageUserData reply should not return an error") + require.NotNil(t, response, "Response from WriteReportDataStorageUserData should not be nil") + require.True(t, proto.Equal(&solanasdk.WriteReportReply{ + TxStatus: solanasdk.TxStatus_TX_STATUS_SUCCESS, + TxSignature: []byte{0x01, 0x02, 0x03, 0x04}, + }, response), "Response should match expected WriteReportReply") +} + +func TestZeroArgInstructionRoundTrip(t *testing.T) { + zeroArgInstructions := []struct { + name string + buildFn func() (solana.Instruction, error) + expectedTyp datastorage.Instruction + }{ + { + name: "get_multiple_reserves", + buildFn: datastorage.NewGetMultipleReservesInstruction, + expectedTyp: &datastorage.GetMultipleReservesInstruction{}, + }, + { + name: "get_reserves", + buildFn: datastorage.NewGetReservesInstruction, + expectedTyp: &datastorage.GetReservesInstruction{}, + }, + { + name: "get_tuple_reserves", + buildFn: datastorage.NewGetTupleReservesInstruction, + expectedTyp: &datastorage.GetTupleReservesInstruction{}, + }, + } + + for _, tc := range zeroArgInstructions { + t.Run(tc.name, func(t *testing.T) { + ix, err := tc.buildFn() + require.NoError(t, err, "building instruction should succeed") + + data, err := ix.Data() + require.NoError(t, err) + require.Len(t, data, 8, "zero-arg instruction data must be exactly the 8-byte discriminator") + + parsed, err := datastorage.ParseInstructionWithoutAccounts(data) + require.NoError(t, err, "ParseInstruction must accept the discriminator-only data produced by the builder") + require.IsType(t, tc.expectedTyp, parsed) + }) + } +} + +func TestEncodeStruct(t *testing.T) { + client := &solanasdk.Client{ChainSelector: anyChainSelector} + ds, err := datastorage.NewDataStorage(client) + require.NoError(t, err, "Failed to create DataStorage instance") + + str := datastorage.DataAccount{ + Key: "testKey", + Value: "testValue", + Sender: "testSender", + } + + encoded, err := ds.Codec.EncodeDataAccountStruct(str) + require.NoError(t, err, "Encoding DataStorageDataAccount should not return an error") + require.NotNil(t, encoded, "Encoded data should not be nil") +} diff --git a/cmd/generate-bindings/solana/gen.go b/cmd/generate-bindings/solana/gen.go new file mode 100644 index 00000000..b74413f3 --- /dev/null +++ b/cmd/generate-bindings/solana/gen.go @@ -0,0 +1,2 @@ +//go:generate go run ./testdata/gen +package solana diff --git a/cmd/generate-bindings/solana/gen_test.go b/cmd/generate-bindings/solana/gen_test.go new file mode 100644 index 00000000..1a97f327 --- /dev/null +++ b/cmd/generate-bindings/solana/gen_test.go @@ -0,0 +1,17 @@ +package solana_test + +import ( + "testing" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana" +) + +func TestGenerateBindings(t *testing.T) { + if err := solana.GenerateBindings( + "./testdata/contracts/idl/data_storage.json", + "data_storage", + "./testdata/data_storage", + ); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/generate-bindings/solana/solana.go b/cmd/generate-bindings/solana/solana.go new file mode 100644 index 00000000..1ca8ec58 --- /dev/null +++ b/cmd/generate-bindings/solana/solana.go @@ -0,0 +1,258 @@ +package solana + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/smartcontractkit/cre-cli/internal/constants" + "github.com/smartcontractkit/cre-cli/internal/runtime" + "github.com/smartcontractkit/cre-cli/internal/validation" +) + +type Inputs struct { + ProjectRoot string `validate:"required,dir" cli:"--project-root"` + Language string `validate:"required,oneof=go" cli:"--language"` + IdlPath string `validate:"required,path_read" cli:"--idl"` + OutPath string `validate:"required" cli:"--out"` +} + +func New(runtimeContext *runtime.Context) *cobra.Command { + var generateBindingsCmd = &cobra.Command{ + Use: "solana", + Short: "Generate bindings from contract IDL", + Long: `This command generates bindings from contract IDL files. +Supports Solana chain family and Go language. +Each contract gets its own package subdirectory to avoid naming conflicts. +For example, data_storage.json generates bindings in generated/data_storage/ package.`, + Example: " cre generate-bindings-solana", + RunE: func(cmd *cobra.Command, args []string) error { + handler := newHandler(runtimeContext) + inputs, err := handler.ResolveInputs(runtimeContext.Viper) + if err != nil { + return err + } + if err := handler.ValidateInputs(inputs); err != nil { + return err + } + return handler.Execute(inputs) + }, + } + + generateBindingsCmd.Flags().StringP("project-root", "p", "", "Path to project root directory (defaults to current directory)") + generateBindingsCmd.Flags().StringP("language", "l", "go", "Target language (go)") + generateBindingsCmd.Flags().StringP("idl", "i", "", "Path to IDL directory (defaults to contracts/solana/src/idl/)") + generateBindingsCmd.Flags().StringP("out", "o", "", "Path to output directory (defaults to contracts/solana/src/generated/)") + + return generateBindingsCmd +} + +type handler struct { + log *zerolog.Logger +} + +func newHandler(ctx *runtime.Context) *handler { + return &handler{ + log: ctx.Logger, + } +} + +func (h *handler) ResolveInputs(v *viper.Viper) (Inputs, error) { + // Get current working directory as default project root + currentDir, err := os.Getwd() + if err != nil { + return Inputs{}, fmt.Errorf("failed to get current working directory: %w", err) + } + + // Resolve project root with fallback to current directory + projectRoot := v.GetString("project-root") + if projectRoot == "" { + projectRoot = currentDir + } + + contractsPath := filepath.Join(projectRoot, "contracts") + if _, err := os.Stat(contractsPath); err != nil { + return Inputs{}, fmt.Errorf("contracts folder not found in project root: %s", contractsPath) + } + + // Language defaults are handled by StringP + language := v.GetString("language") + + // Resolve IDL path with fallback to contracts/solana/src/idl/ + idlPath := v.GetString("idl") + if idlPath == "" { + idlPath = filepath.Join(projectRoot, "contracts", "solana", "src", "idl") + } + + // Resolve output path with fallback to contracts/solana/src/generated/ + outPath := v.GetString("out") + if outPath == "" { + outPath = filepath.Join(projectRoot, "contracts", "solana", "src", "generated") + } + + return Inputs{ + ProjectRoot: projectRoot, + Language: language, + IdlPath: idlPath, + OutPath: outPath, + }, nil +} + +func (h *handler) ValidateInputs(inputs Inputs) error { + validate, err := validation.NewValidator() + if err != nil { + return fmt.Errorf("failed to initialize validator: %w", err) + } + + if err = validate.Struct(inputs); err != nil { + return validate.ParseValidationErrors(err) + } + + // Additional validation for Idl path + if _, err := os.Stat(inputs.IdlPath); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("IDL path does not exist: %s", inputs.IdlPath) + } + return fmt.Errorf("failed to access IDL path: %w", err) + } + + // Validate that if IdlPath is a directory, it contains .json files + if info, err := os.Stat(inputs.IdlPath); err == nil && info.IsDir() { + files, err := filepath.Glob(filepath.Join(inputs.IdlPath, "*.json")) + if err != nil { + return fmt.Errorf("failed to check for IDL files in directory: %w", err) + } + if len(files) == 0 { + return fmt.Errorf("no .json files found in directory: %s", inputs.IdlPath) + } + } + + return nil +} + +func (h *handler) processIdlDirectory(inputs Inputs) error { + // Read all .json files in the directory + files, err := filepath.Glob(filepath.Join(inputs.IdlPath, "*.json")) + if err != nil { + return fmt.Errorf("failed to find IDL files: %w", err) + } + + if len(files) == 0 { + return fmt.Errorf("no .json files found in directory: %s", inputs.IdlPath) + } + + // Process each IDL file + for _, idlFile := range files { + // Extract contract name from filename (remove .json extension) + contractName := filepath.Base(idlFile) + contractName = contractName[:len(contractName)-5] // Remove .json extension + + // Create per-contract output directory + contractOutDir := filepath.Join(inputs.OutPath, contractName) + if err := os.MkdirAll(contractOutDir, 0755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + // Create output file path in contract-specific directory + outputFile := filepath.Join(contractOutDir, contractName+".go") + + fmt.Printf("Processing IDL file: %s, contract: %s, output: %s\n", idlFile, contractName, outputFile) + + err = GenerateBindings( + idlFile, + contractName, + contractOutDir, + ) + if err != nil { + return fmt.Errorf("failed to generate bindings for %s: %w", idlFile, err) + } + } + + return nil +} + +func (h *handler) processSingleIdl(inputs Inputs) error { + // Extract contract name from IDL file path + contractName := filepath.Base(inputs.IdlPath) + if filepath.Ext(contractName) == ".json" { + contractName = contractName[:len(contractName)-5] // Remove .json extension + } + + // Create per-contract output directory + contractOutDir := filepath.Join(inputs.OutPath, contractName) + if err := os.MkdirAll(contractOutDir, 0755); err != nil { + return fmt.Errorf("failed to create contract output directory %s: %w", contractOutDir, err) + } + + fmt.Printf("Processing single IDL file: %s, contract: %s, output: %s\n", inputs.IdlPath, contractName, contractOutDir) + + return GenerateBindings( + inputs.IdlPath, + contractName, + contractOutDir, + ) +} + +func (h *handler) Execute(inputs Inputs) error { + // Validate language + switch inputs.Language { + case "go": + // Language supported, continue + default: + return fmt.Errorf("unsupported language: %s", inputs.Language) + } + + // Create output directory if it doesn't exist + if err := os.MkdirAll(inputs.OutPath, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + // Check if IDL path is a directory or file + info, err := os.Stat(inputs.IdlPath) + if err != nil { + return fmt.Errorf("failed to access IDL path: %w", err) + } + + if info.IsDir() { + if err := h.processIdlDirectory(inputs); err != nil { + return err + } + } else { + if err := h.processSingleIdl(inputs); err != nil { + return err + } + } + + err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go@"+constants.SdkVersion) + if err != nil { + return err + } + err = runCommand(inputs.ProjectRoot, "go", "get", "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana@"+constants.SolanaCapabilitiesVersion) + if err != nil { + return err + } + err = runCommand(inputs.ProjectRoot, "go", "mod", "tidy") + if err != nil { + return err + } + + return nil +} + +// runCommand executes a command in a specified directory +func runCommand(dir string, command string, args ...string) error { + cmd := exec.Command(command, args...) + cmd.Dir = dir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run %s: %w", command, err) + } + + return nil +} diff --git a/cmd/generate-bindings/solana/solana_test.go b/cmd/generate-bindings/solana/solana_test.go new file mode 100644 index 00000000..e3eebfa9 --- /dev/null +++ b/cmd/generate-bindings/solana/solana_test.go @@ -0,0 +1,285 @@ +package solana + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/cre-cli/internal/runtime" +) + +func TestResolveSolanaInputs_DefaultFallbacks(t *testing.T) { + // Create a temporary directory for testing + tempDir, err := os.MkdirTemp("", "generate-bindings-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create required contracts directory and go.mod + contractsDir := filepath.Join(tempDir, "contracts") + err = os.MkdirAll(contractsDir, 0755) + require.NoError(t, err) + + goModPath := filepath.Join(contractsDir, "go.mod") + err = os.WriteFile(goModPath, []byte("module test/contracts\n\ngo 1.20\n"), 0600) + require.NoError(t, err) + + // Change to temp directory + originalDir, err := os.Getwd() + require.NoError(t, err) + defer func() { + if err := os.Chdir(originalDir); err != nil { + t.Errorf("Failed to restore original directory: %v", err) + } + }() + + err = os.Chdir(tempDir) + require.NoError(t, err) + + // Test with minimal input + v := viper.New() + v.Set("language", "go") // Default from StringP + + runtimeCtx := &runtime.Context{} + handler := newHandler(runtimeCtx) + + inputs, err := handler.ResolveInputs(v) + require.NoError(t, err) + + // Use filepath.EvalSymlinks to handle macOS /var vs /private/var symlink issues + expectedRoot, _ := filepath.EvalSymlinks(tempDir) + actualRoot, _ := filepath.EvalSymlinks(inputs.ProjectRoot) + assert.Equal(t, expectedRoot, actualRoot) + assert.Equal(t, "go", inputs.Language) + expectedIdl, _ := filepath.EvalSymlinks(filepath.Join(tempDir, "contracts", "solana", "src", "idl")) + actualIdl, _ := filepath.EvalSymlinks(inputs.IdlPath) + assert.Equal(t, expectedIdl, actualIdl) + expectedOut, _ := filepath.EvalSymlinks(filepath.Join(tempDir, "contracts", "solana", "src", "generated")) + actualOut, _ := filepath.EvalSymlinks(inputs.OutPath) + assert.Equal(t, expectedOut, actualOut) +} + +func TestResolveSolanaInputs_CustomOutPath(t *testing.T) { + tempDir, err := os.MkdirTemp("", "generate-bindings-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + contractsDir := filepath.Join(tempDir, "contracts") + err = os.MkdirAll(contractsDir, 0755) + require.NoError(t, err) + + originalDir, err := os.Getwd() + require.NoError(t, err) + defer func() { + if err := os.Chdir(originalDir); err != nil { + t.Errorf("Failed to restore original directory: %v", err) + } + }() + + err = os.Chdir(tempDir) + require.NoError(t, err) + + customOut := filepath.Join(tempDir, "my-custom-output") + + v := viper.New() + v.Set("language", "go") + v.Set("out", customOut) + + runtimeCtx := &runtime.Context{} + handler := newHandler(runtimeCtx) + + inputs, err := handler.ResolveInputs(v) + require.NoError(t, err) + + assert.Equal(t, customOut, inputs.OutPath) +} + +func TestProcessSolanaSingleIdl(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "generate-bindings-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + idlDir := filepath.Join(tempDir, "idl") + outDir := filepath.Join(tempDir, "generated") + + err = os.MkdirAll(idlDir, 0755) + require.NoError(t, err) + + // Create a simple IDL file + simpleIdl := `{ + "address": "ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL", + "metadata": { + "name": "simple_contract", + "version": "0.1.0", + "spec": "0.1.0" + }, + "instructions": [ + { + "name": "initialize", + "discriminator": [175, 175, 109, 31, 13, 152, 155, 237], + "accounts": [], + "args": [] + } + ], + "accounts": [], + "types": [] +}` + + idlFile := filepath.Join(idlDir, "simple_contract.json") + err = os.WriteFile(idlFile, []byte(simpleIdl), 0600) + require.NoError(t, err) + + // Create contracts directory with go.mod for module path detection + contractsDir := filepath.Join(tempDir, "contracts") + err = os.MkdirAll(contractsDir, 0755) + require.NoError(t, err) + + goModPath := filepath.Join(contractsDir, "go.mod") + err = os.WriteFile(goModPath, []byte("module test/contracts\n\ngo 1.20\n"), 0600) + require.NoError(t, err) + + inputs := Inputs{ + ProjectRoot: tempDir, + Language: "go", + IdlPath: idlFile, + OutPath: outDir, + } + + runtimeCtx := &runtime.Context{} + handler := newHandler(runtimeCtx) + + // Process the single IDL file + err = handler.processSingleIdl(inputs) + + // We expect this might fail due to missing dependencies or generator issues, + // but we can verify that the contract directory was created + if err != nil { + t.Logf("Expected error occurred: %v", err) + } + + // Verify that the contract directory was created + contractDir := filepath.Join(outDir, "simple_contract") + assert.DirExists(t, contractDir, "Expected contract directory to be created at %s", contractDir) +} + +func TestProcessSolanaIdlDirectory(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "generate-bindings-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + idlDir := filepath.Join(tempDir, "idl") + outDir := filepath.Join(tempDir, "generated") + + err = os.MkdirAll(idlDir, 0755) + require.NoError(t, err) + + // Create multiple simple IDL files + simpleIdl1 := `{ + "address": "ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL", + "metadata": { + "name": "contract_one", + "version": "0.1.0", + "spec": "0.1.0" + }, + "instructions": [ + { + "name": "initialize", + "discriminator": [175, 175, 109, 31, 13, 152, 155, 237], + "accounts": [], + "args": [] + } + ], + "accounts": [], + "types": [] +}` + + simpleIdl2 := `{ + "address": "FDL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfM", + "metadata": { + "name": "contract_two", + "version": "0.1.0", + "spec": "0.1.0" + }, + "instructions": [ + { + "name": "execute", + "discriminator": [100, 100, 100, 31, 13, 152, 155, 237], + "accounts": [], + "args": [] + } + ], + "accounts": [], + "types": [] +}` + + err = os.WriteFile(filepath.Join(idlDir, "contract_one.json"), []byte(simpleIdl1), 0600) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(idlDir, "contract_two.json"), []byte(simpleIdl2), 0600) + require.NoError(t, err) + + // Create contracts directory with go.mod for module path detection + contractsDir := filepath.Join(tempDir, "contracts") + err = os.MkdirAll(contractsDir, 0755) + require.NoError(t, err) + + goModPath := filepath.Join(contractsDir, "go.mod") + err = os.WriteFile(goModPath, []byte("module test/contracts\n\ngo 1.20\n"), 0600) + require.NoError(t, err) + + inputs := Inputs{ + ProjectRoot: tempDir, + Language: "go", + IdlPath: idlDir, + OutPath: outDir, + } + + runtimeCtx := &runtime.Context{} + handler := newHandler(runtimeCtx) + + // Process the IDL directory + err = handler.processIdlDirectory(inputs) + + // We expect this might fail due to missing dependencies or generator issues, + // but we can verify that the contract directories were created + if err != nil { + t.Logf("Expected error occurred: %v", err) + } + + // Verify that per-contract directories were created + contract1Dir := filepath.Join(outDir, "contract_one") + contract2Dir := filepath.Join(outDir, "contract_two") + assert.DirExists(t, contract1Dir, "Expected contract directory to be created at %s", contract1Dir) + assert.DirExists(t, contract2Dir, "Expected contract directory to be created at %s", contract2Dir) +} + +func TestProcessSolanaIdlDirectory_NoIdlFiles(t *testing.T) { + // Create a temporary directory structure + tempDir, err := os.MkdirTemp("", "generate-bindings-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + idlDir := filepath.Join(tempDir, "idl") + outDir := filepath.Join(tempDir, "generated") + + err = os.MkdirAll(idlDir, 0755) + require.NoError(t, err) + + inputs := Inputs{ + ProjectRoot: tempDir, + Language: "go", + IdlPath: idlDir, + OutPath: outDir, + } + + runtimeCtx := &runtime.Context{} + handler := newHandler(runtimeCtx) + + err = handler.processIdlDirectory(inputs) + require.Error(t, err) + assert.Contains(t, err.Error(), "no .json files found") +} diff --git a/cmd/generate-bindings/solana/testdata/contracts/idl/data_storage.json b/cmd/generate-bindings/solana/testdata/contracts/idl/data_storage.json new file mode 100644 index 00000000..2ff93a7a --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/contracts/idl/data_storage.json @@ -0,0 +1,511 @@ +{ + "address": "ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL", + "metadata": { + "name": "data_storage", + "version": "0.1.0", + "spec": "0.1.0", + "description": "Created with Anchor" + }, + "instructions": [ + { + "name": "get_multiple_reserves", + "discriminator": [ + 104, + 122, + 140, + 104, + 175, + 151, + 70, + 42 + ], + "accounts": [], + "args": [], + "returns": { + "vec": { + "defined": { + "name": "UpdateReserves" + } + } + } + }, + { + "name": "get_reserves", + "discriminator": [ + 121, + 140, + 237, + 84, + 218, + 105, + 48, + 17 + ], + "accounts": [], + "args": [], + "returns": { + "defined": { + "name": "UpdateReserves" + } + } + }, + { + "name": "get_tuple_reserves", + "discriminator": [ + 189, + 83, + 186, + 20, + 127, + 80, + 109, + 49 + ], + "accounts": [], + "args": [] + }, + { + "name": "initialize_data_account", + "discriminator": [ + 9, + 64, + 78, + 49, + 71, + 193, + 15, + 250 + ], + "accounts": [ + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + }, + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "system_program", + "address": "11111111111111111111111111111111" + } + ], + "args": [ + { + "name": "input", + "type": { + "defined": { + "name": "UserData" + } + } + } + ] + }, + { + "name": "log_access", + "discriminator": [ + 196, + 55, + 194, + 24, + 5, + 224, + 161, + 204 + ], + "accounts": [ + { + "name": "user", + "signer": true + } + ], + "args": [ + { + "name": "message", + "type": "string" + } + ] + }, + { + "name": "on_report", + "discriminator": [ + 214, + 173, + 18, + 221, + 173, + 148, + 151, + 208 + ], + "accounts": [ + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + }, + { + "name": "system_program", + "address": "11111111111111111111111111111111" + } + ], + "args": [ + { + "name": "_metadata", + "type": "bytes" + }, + { + "name": "payload", + "type": "bytes" + } + ] + }, + { + "name": "update_key_value_data", + "discriminator": [ + 67, + 137, + 144, + 35, + 210, + 126, + 254, + 79 + ], + "accounts": [ + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + } + ], + "args": [ + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string" + } + ] + }, + { + "name": "update_user_data", + "discriminator": [ + 11, + 13, + 114, + 150, + 194, + 224, + 192, + 78 + ], + "accounts": [ + { + "name": "user", + "writable": true, + "signer": true + }, + { + "name": "data_account", + "writable": true, + "pda": { + "seeds": [ + { + "kind": "const", + "value": [ + 100, + 97, + 116, + 97, + 95, + 97, + 99, + 99, + 111, + 117, + 110, + 116 + ] + }, + { + "kind": "account", + "path": "user" + } + ] + } + } + ], + "args": [ + { + "name": "input", + "type": { + "defined": { + "name": "UserData" + } + } + } + ] + } + ], + "accounts": [ + { + "name": "DataAccount", + "discriminator": [ + 85, + 240, + 182, + 158, + 76, + 7, + 18, + 233 + ] + } + ], + "events": [ + { + "name": "AccessLogged", + "discriminator": [ + 243, + 53, + 225, + 71, + 64, + 120, + 109, + 25 + ] + }, + { + "name": "DynamicEvent", + "discriminator": [ + 236, + 145, + 224, + 161, + 9, + 222, + 218, + 237 + ] + }, + { + "name": "NoFields", + "discriminator": [ + 160, + 156, + 94, + 85, + 77, + 122, + 98, + 240 + ] + } + ], + "errors": [ + { + "code": 6000, + "name": "DataNotFound", + "msg": "data not found" + } + ], + "types": [ + { + "name": "AccessLogged", + "type": { + "kind": "struct", + "fields": [ + { + "name": "caller", + "type": "pubkey" + }, + { + "name": "message", + "type": "string" + } + ] + } + }, + { + "name": "DataAccount", + "type": { + "kind": "struct", + "fields": [ + { + "name": "sender", + "type": "string" + }, + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string" + } + ] + } + }, + { + "name": "DynamicEvent", + "type": { + "kind": "struct", + "fields": [ + { + "name": "key", + "type": "string" + }, + { + "name": "user_data", + "type": { + "defined": { + "name": "UserData" + } + } + }, + { + "name": "sender", + "type": "string" + }, + { + "name": "metadata", + "type": "bytes" + }, + { + "name": "metadata_array", + "type": { + "vec": "bytes" + } + } + ] + } + }, + { + "name": "NoFields", + "type": { + "kind": "struct", + "fields": [] + } + }, + { + "name": "UpdateReserves", + "type": { + "kind": "struct", + "fields": [ + { + "name": "total_minted", + "type": "u64" + }, + { + "name": "total_reserve", + "type": "u64" + } + ] + } + }, + { + "name": "UserData", + "type": { + "kind": "struct", + "fields": [ + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string" + } + ] + } + } + ] +} \ No newline at end of file diff --git a/cmd/generate-bindings/solana/testdata/contracts/source/data_storage.rs b/cmd/generate-bindings/solana/testdata/contracts/source/data_storage.rs new file mode 100644 index 00000000..dc5b02b3 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/contracts/source/data_storage.rs @@ -0,0 +1,251 @@ +use anchor_lang::prelude::*; + +declare_id!("ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL"); + +#[program] +pub mod data_storage { + use super::*; + + // simulate + pub fn get_reserves(_ctx: Context) -> Result { + Ok(UpdateReserves { + total_minted: 100, + total_reserve: 200, + }) + } + + // simulate + pub fn get_multiple_reserves( + _ctx: Context, + ) -> Result> { + let reserves = vec![ + UpdateReserves { + total_minted: 100, + total_reserve: 200, + }, + UpdateReserves { + total_minted: 300, + total_reserve: 400, + }, + ]; + + Ok(reserves) + } + + // simulate + pub fn get_tuple_reserves(_ctx: Context) -> Result<(u64, u64)> { + Ok((100, 200)) + } + + pub fn initialize_data_account(ctx: Context, input: UserData) -> Result<()> { + ctx.accounts.data_account.sender = ctx.accounts.user.key().to_string(); + ctx.accounts.data_account.key = input.key; + ctx.accounts.data_account.value = input.value; + Ok(()) + } + + // no event + pub fn update_key_value_data( + ctx: Context, + key: String, + value: String, + ) -> Result<()> { + let acc = &mut ctx.accounts.data_account; + + acc.sender = ctx.accounts.user.key().to_string(); + acc.key = key.clone(); + acc.value = value.clone(); + + let user_data = UserData { + key: key.clone(), + value: value.clone(), + }; + + emit!(DynamicEvent { + key: key, + user_data: user_data, + sender: ctx.accounts.user.key().to_string(), + metadata: vec![1, 2, 3], + metadata_array: vec![], + }); + + Ok(()) + } + + pub fn update_user_data(ctx: Context, input: UserData) -> Result<()> { + let acc = &mut ctx.accounts.data_account; + + acc.sender = ctx.accounts.user.key().to_string(); + acc.key = input.key.clone(); + acc.value = input.value.clone(); + + let user_data_cloned = input.clone(); + + emit!(DynamicEvent { + key: input.key, + user_data: user_data_cloned, + sender: ctx.accounts.user.key().to_string(), + metadata: vec![1, 2, 3], + metadata_array: vec![], + }); + + Ok(()) + } + + pub fn log_access(ctx: Context, message: String) -> Result<()> { + emit!(AccessLogged { + caller: ctx.accounts.user.key(), + message, + }); + Ok(()) + } + + pub fn on_report(ctx: Context, _metadata: Vec, payload: Vec) -> Result<()> { + // decode payload into UserData + let mut bytes: &[u8] = &payload; + let user = UserData::deserialize(&mut bytes)?; // requires AnchorDeserialize on UserData + + // update mapping-equivalent: this user's PDA + let acc = &mut ctx.accounts.data_account; + acc.sender = ctx.accounts.user.key().to_string(); + acc.key = user.key.clone(); + acc.value = user.value.clone(); + + let user_cloned = user.clone(); + + // emit event + emit!(DynamicEvent { + sender: ctx.accounts.user.key().to_string(), + key: user.key, + user_data: user_cloned, + metadata: vec![1, 2, 3], + metadata_array: vec![], + }); + + Ok(()) + } + + pub fn handle_forwarder_report( + _ctx: Context, + _report: ForwarderReport, + ) -> Result<()> { + // TODO: implement forwarding logic here + Ok(()) + } +} + +// read data from here +#[account] +pub struct DataAccount { + pub sender: String, + pub key: String, + pub value: String, +} + +#[derive(Accounts)] +pub struct Initialize<'info> { + #[account( + init, + payer = user, + space = 8 + + (4 + 64) // sender max 64 + + (4 + 64) // key max 64 + + (4 + 256) // value max 256 + + 1, // bump + seeds = [b"data_account", user.key().as_ref()], // seed for deterministic PDA + bump + )] + pub data_account: Account<'info, DataAccount>, + + #[account(mut)] + pub user: Signer<'info>, + pub system_program: Program<'info, System>, +} + +#[derive(Accounts)] +pub struct UpdateData<'info> { + #[account(mut)] + pub user: Signer<'info>, + + // PDA: one account per user, same seeds as Initialize + #[account( + mut, + seeds = [b"data_account", user.key().as_ref()], + bump, + )] + pub data_account: Account<'info, DataAccount>, +} + +// just use to have a complex event type ? +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug, PartialEq)] +pub struct UserData { + pub key: String, + pub value: String, +} + +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug, PartialEq)] +pub struct ForwarderReport { + pub account_hash: Vec, + pub payload: Vec, +} + +#[event] +pub struct DynamicEvent { + pub key: String, + pub user_data: UserData, + pub sender: String, + pub metadata: Vec, + pub metadata_array: Vec>, +} + +#[event] +pub struct AccessLogged { + pub caller: Pubkey, + pub message: String, +} + +#[event] +pub struct NoFields {} + +#[error_code] +pub enum DataError { + #[msg("data not found")] + DataNotFound = 0, +} + +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)] +pub struct UpdateReserves { + pub total_minted: u64, + pub total_reserve: u64, +} + +// empty contexts +#[derive(Accounts)] +pub struct GetReserves {} +#[derive(Accounts)] +pub struct GetMultipleReserves {} +#[derive(Accounts)] +pub struct GetTupleReserves {} + +#[derive(Accounts)] +pub struct HandleForwarderReport {} + +#[derive(Accounts)] +pub struct LogAccess<'info> { + pub user: Signer<'info>, +} + +#[derive(Accounts)] +pub struct OnReport<'info> { + #[account(mut)] + pub user: Signer<'info>, + + #[account( + mut, + seeds = [b"data_account", user.key().as_ref()], + bump, + )] + pub data_account: Account<'info, DataAccount>, + + pub system_program: Program<'info, System>, +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/accounts.go b/cmd/generate-bindings/solana/testdata/data_storage/accounts.go new file mode 100644 index 00000000..f5711951 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/accounts.go @@ -0,0 +1,50 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains parsers for the accounts defined in the IDL. +// Code generated by https://github.com/smartcontractkit/cre-cli. DO NOT EDIT. + +package data_storage + +import ( + "fmt" + binary "github.com/gagliardetto/binary" +) + +func ParseAnyAccount(accountData []byte) (any, error) { + decoder := binary.NewBorshDecoder(accountData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek account discriminator: %w", err) + } + switch discriminator { + case Account_DataAccount: + value := new(DataAccount) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account as DataAccount: %w", err) + } + return value, nil + default: + return nil, fmt.Errorf("unknown discriminator: %s", binary.FormatDiscriminator(discriminator)) + } +} + +func ParseAccount_DataAccount(accountData []byte) (*DataAccount, error) { + decoder := binary.NewBorshDecoder(accountData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Account_DataAccount { + return nil, fmt.Errorf("expected discriminator %v, got %s", Account_DataAccount, binary.FormatDiscriminator(discriminator)) + } + acc := new(DataAccount) + err = acc.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account of type DataAccount: %w", err) + } + return acc, nil +} + +func (c *Codec) DecodeDataAccount(data []byte) (*DataAccount, error) { + return ParseAccount_DataAccount(data) +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/constants.go b/cmd/generate-bindings/solana/testdata/data_storage/constants.go new file mode 100644 index 00000000..0c192cb2 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/constants.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains constants. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/constructor.go b/cmd/generate-bindings/solana/testdata/data_storage/constructor.go new file mode 100644 index 00000000..65dd788f --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/constructor.go @@ -0,0 +1,88 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains the constructor for the program. + +package data_storage + +import ( + "bytes" + "encoding/binary" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + solana "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + bindings "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/bindings" + cre "github.com/smartcontractkit/cre-sdk-go/cre" +) + +var IDL = "{\"address\":\"ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL\",\"metadata\":{\"name\":\"data_storage\",\"version\":\"0.1.0\",\"spec\":\"0.1.0\",\"description\":\"Created with Anchor\"},\"instructions\":[{\"name\":\"get_multiple_reserves\",\"discriminator\":[104,122,140,104,175,151,70,42],\"accounts\":[],\"args\":[],\"returns\":{\"vec\":{\"defined\":{\"name\":\"UpdateReserves\"}}}},{\"name\":\"get_reserves\",\"discriminator\":[121,140,237,84,218,105,48,17],\"accounts\":[],\"args\":[],\"returns\":{\"defined\":{\"name\":\"UpdateReserves\"}}},{\"name\":\"get_tuple_reserves\",\"discriminator\":[189,83,186,20,127,80,109,49],\"accounts\":[],\"args\":[]},{\"name\":\"initialize_data_account\",\"discriminator\":[9,64,78,49,71,193,15,250],\"accounts\":[{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}},{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"system_program\",\"address\":\"11111111111111111111111111111111\"}],\"args\":[{\"name\":\"input\",\"type\":{\"defined\":{\"name\":\"UserData\"}}}]},{\"name\":\"log_access\",\"discriminator\":[196,55,194,24,5,224,161,204],\"accounts\":[{\"name\":\"user\",\"signer\":true}],\"args\":[{\"name\":\"message\",\"type\":\"string\"}]},{\"name\":\"on_report\",\"discriminator\":[214,173,18,221,173,148,151,208],\"accounts\":[{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}},{\"name\":\"system_program\",\"address\":\"11111111111111111111111111111111\"}],\"args\":[{\"name\":\"_metadata\",\"type\":\"bytes\"},{\"name\":\"payload\",\"type\":\"bytes\"}]},{\"name\":\"update_key_value_data\",\"discriminator\":[67,137,144,35,210,126,254,79],\"accounts\":[{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}}],\"args\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":\"string\"}]},{\"name\":\"update_user_data\",\"discriminator\":[11,13,114,150,194,224,192,78],\"accounts\":[{\"name\":\"user\",\"writable\":true,\"signer\":true},{\"name\":\"data_account\",\"writable\":true,\"pda\":{\"seeds\":[{\"kind\":\"const\",\"value\":[100,97,116,97,95,97,99,99,111,117,110,116]},{\"kind\":\"account\",\"path\":\"user\"}]}}],\"args\":[{\"name\":\"input\",\"type\":{\"defined\":{\"name\":\"UserData\"}}}]}],\"accounts\":[{\"name\":\"DataAccount\",\"discriminator\":[85,240,182,158,76,7,18,233]}],\"events\":[{\"name\":\"AccessLogged\",\"discriminator\":[243,53,225,71,64,120,109,25]},{\"name\":\"DynamicEvent\",\"discriminator\":[236,145,224,161,9,222,218,237]},{\"name\":\"NoFields\",\"discriminator\":[160,156,94,85,77,122,98,240]}],\"errors\":[{\"code\":6000,\"name\":\"DataNotFound\",\"msg\":\"data not found\"}],\"types\":[{\"name\":\"AccessLogged\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"caller\",\"type\":\"pubkey\"},{\"name\":\"message\",\"type\":\"string\"}]}},{\"name\":\"DataAccount\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"sender\",\"type\":\"string\"},{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":\"string\"}]}},{\"name\":\"DynamicEvent\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"user_data\",\"type\":{\"defined\":{\"name\":\"UserData\"}}},{\"name\":\"sender\",\"type\":\"string\"},{\"name\":\"metadata\",\"type\":\"bytes\"},{\"name\":\"metadata_array\",\"type\":{\"vec\":\"bytes\"}}]}},{\"name\":\"NoFields\",\"type\":{\"kind\":\"struct\",\"fields\":[]}},{\"name\":\"UpdateReserves\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"total_minted\",\"type\":\"u64\"},{\"name\":\"total_reserve\",\"type\":\"u64\"}]}},{\"name\":\"UserData\",\"type\":{\"kind\":\"struct\",\"fields\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":\"string\"}]}}]}" + +type DataStorage struct { + client *solana.Client + Codec DataStorageCodec +} + +type Codec struct{} + +func NewDataStorage(client *solana.Client) (*DataStorage, error) { + return &DataStorage{ + Codec: &Codec{}, + client: client, + }, nil +} + +// EncodeBorshVecU32 returns Anchor/Borsh encoding of a Vec whose elements are opaque byte payloads. // Each [][]byte element must already be fully serialized for one Vec item on the wire. // Layout: little-endian u32 length followed by concatenated element payloads. +func EncodeBorshVecU32(elements [][]byte) ([]byte, error) { + buf := bytes.NewBuffer(nil) + if err := binary.Write(buf, binary.LittleEndian, uint32(len(elements))); err != nil { + return nil, err + } + for _, elem := range elements { + _, err := buf.Write(elem) + if err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} + +// WriteReportFromBorshEncodedVec publishes through the CRE signer using a forwarder payload built from // Borsh Vec semantics (EncodeBorshVecU32). Compose each elementPayload for your program (e.g. one encoded struct per row). // Pass computeConfig = nil to use the host default Solana compute budget. +func (c *DataStorage) WriteReportFromBorshEncodedVec(runtime cre.Runtime, elementPayloads [][]byte, remainingAccounts []*solana.AccountMeta, computeConfig *solana.ComputeConfig) cre.Promise[*solana.WriteReportReply] { + payload, err := EncodeBorshVecU32(elementPayloads) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: payload, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +type DataStorageCodec interface { + DecodeDataAccount(data []byte) (*DataAccount, error) + EncodeAccessLoggedStruct(in AccessLogged) ([]byte, error) + EncodeDataAccountStruct(in DataAccount) ([]byte, error) + EncodeDynamicEventStruct(in DynamicEvent) ([]byte, error) + EncodeNoFieldsStruct(in NoFields) ([]byte, error) + EncodeUpdateReservesStruct(in UpdateReserves) ([]byte, error) + EncodeUserDataStruct(in UserData) ([]byte, error) +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/discriminators.go b/cmd/generate-bindings/solana/testdata/data_storage/discriminators.go new file mode 100644 index 00000000..251e3d0b --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/discriminators.go @@ -0,0 +1,30 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains the discriminators for accounts and events defined in the IDL. + +package data_storage + +import binary "github.com/gagliardetto/binary" + +// Account discriminators +var ( + Account_DataAccount = binary.TypeID{85, 240, 182, 158, 76, 7, 18, 233} +) + +// Event discriminators +var ( + Event_AccessLogged = binary.TypeID{243, 53, 225, 71, 64, 120, 109, 25} + Event_DynamicEvent = binary.TypeID{236, 145, 224, 161, 9, 222, 218, 237} + Event_NoFields = binary.TypeID{160, 156, 94, 85, 77, 122, 98, 240} +) + +// Instruction discriminators +var ( + Instruction_GetMultipleReserves = binary.TypeID{104, 122, 140, 104, 175, 151, 70, 42} + Instruction_GetReserves = binary.TypeID{121, 140, 237, 84, 218, 105, 48, 17} + Instruction_GetTupleReserves = binary.TypeID{189, 83, 186, 20, 127, 80, 109, 49} + Instruction_InitializeDataAccount = binary.TypeID{9, 64, 78, 49, 71, 193, 15, 250} + Instruction_LogAccess = binary.TypeID{196, 55, 194, 24, 5, 224, 161, 204} + Instruction_OnReport = binary.TypeID{214, 173, 18, 221, 173, 148, 151, 208} + Instruction_UpdateKeyValueData = binary.TypeID{67, 137, 144, 35, 210, 126, 254, 79} + Instruction_UpdateUserData = binary.TypeID{11, 13, 114, 150, 194, 224, 192, 78} +) diff --git a/cmd/generate-bindings/solana/testdata/data_storage/errors.go b/cmd/generate-bindings/solana/testdata/data_storage/errors.go new file mode 100644 index 00000000..576b057b --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/errors.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains errors. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/events.go b/cmd/generate-bindings/solana/testdata/data_storage/events.go new file mode 100644 index 00000000..804e0344 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/events.go @@ -0,0 +1,93 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains parsers for the events defined in the IDL. + +package data_storage + +import ( + "fmt" + binary "github.com/gagliardetto/binary" +) + +func ParseAnyEvent(eventData []byte) (any, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek event discriminator: %w", err) + } + switch discriminator { + case Event_AccessLogged: + value := new(AccessLogged) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event as AccessLogged: %w", err) + } + return value, nil + case Event_DynamicEvent: + value := new(DynamicEvent) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event as DynamicEvent: %w", err) + } + return value, nil + case Event_NoFields: + value := new(NoFields) + err := value.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event as NoFields: %w", err) + } + return value, nil + default: + return nil, fmt.Errorf("unknown discriminator: %s", binary.FormatDiscriminator(discriminator)) + } +} + +func ParseEvent_AccessLogged(eventData []byte) (*AccessLogged, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Event_AccessLogged { + return nil, fmt.Errorf("expected discriminator %v, got %s", Event_AccessLogged, binary.FormatDiscriminator(discriminator)) + } + event := new(AccessLogged) + err = event.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event of type AccessLogged: %w", err) + } + return event, nil +} + +func ParseEvent_DynamicEvent(eventData []byte) (*DynamicEvent, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Event_DynamicEvent { + return nil, fmt.Errorf("expected discriminator %v, got %s", Event_DynamicEvent, binary.FormatDiscriminator(discriminator)) + } + event := new(DynamicEvent) + err = event.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event of type DynamicEvent: %w", err) + } + return event, nil +} + +func ParseEvent_NoFields(eventData []byte) (*NoFields, error) { + decoder := binary.NewBorshDecoder(eventData) + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return nil, fmt.Errorf("failed to peek discriminator: %w", err) + } + if discriminator != Event_NoFields { + return nil, fmt.Errorf("expected discriminator %v, got %s", Event_NoFields, binary.FormatDiscriminator(discriminator)) + } + event := new(NoFields) + err = event.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal event of type NoFields: %w", err) + } + return event, nil +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/fetchers.go b/cmd/generate-bindings/solana/testdata/data_storage/fetchers.go new file mode 100644 index 00000000..606a2030 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/fetchers.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains fetcher functions. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/instructions.go b/cmd/generate-bindings/solana/testdata/data_storage/instructions.go new file mode 100644 index 00000000..1a88b7ab --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/instructions.go @@ -0,0 +1,1204 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains instructions and instruction parsers. + +package data_storage + +import ( + "bytes" + "fmt" + errors "github.com/gagliardetto/anchor-go/errors" + binary "github.com/gagliardetto/binary" + solanago "github.com/gagliardetto/solana-go" +) + +// Builds a "get_multiple_reserves" instruction. +func NewGetMultipleReservesInstruction() (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_GetMultipleReserves[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + accounts__ := solanago.AccountMetaSlice{} + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "get_reserves" instruction. +func NewGetReservesInstruction() (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_GetReserves[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + accounts__ := solanago.AccountMetaSlice{} + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "get_tuple_reserves" instruction. +func NewGetTupleReservesInstruction() (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_GetTupleReserves[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + accounts__ := solanago.AccountMetaSlice{} + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "initialize_data_account" instruction. +func NewInitializeDataAccountInstruction( + // Params: + inputParam UserData, + + // Accounts: + dataAccountAccount solanago.PublicKey, + userAccount solanago.PublicKey, + systemProgramAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_InitializeDataAccount[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `inputParam`: + err = enc__.Encode(inputParam) + if err != nil { + return nil, errors.NewField("inputParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + // Account 1 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 2 "system_program": Read-only, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(systemProgramAccount, false, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "log_access" instruction. +func NewLogAccessInstruction( + // Params: + messageParam string, + + // Accounts: + userAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_LogAccess[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `messageParam`: + err = enc__.Encode(messageParam) + if err != nil { + return nil, errors.NewField("messageParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Read-only, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, false, true)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "on_report" instruction. +func NewOnReportInstruction( + // Params: + metadataParam []byte, + payloadParam []byte, + + // Accounts: + userAccount solanago.PublicKey, + dataAccountAccount solanago.PublicKey, + systemProgramAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_OnReport[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `metadataParam`: + err = enc__.Encode(metadataParam) + if err != nil { + return nil, errors.NewField("metadataParam", err) + } + // Serialize `payloadParam`: + err = enc__.Encode(payloadParam) + if err != nil { + return nil, errors.NewField("payloadParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 1 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + // Account 2 "system_program": Read-only, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(systemProgramAccount, false, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "update_key_value_data" instruction. +func NewUpdateKeyValueDataInstruction( + // Params: + keyParam string, + valueParam string, + + // Accounts: + userAccount solanago.PublicKey, + dataAccountAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_UpdateKeyValueData[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `keyParam`: + err = enc__.Encode(keyParam) + if err != nil { + return nil, errors.NewField("keyParam", err) + } + // Serialize `valueParam`: + err = enc__.Encode(valueParam) + if err != nil { + return nil, errors.NewField("valueParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 1 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +// Builds a "update_user_data" instruction. +func NewUpdateUserDataInstruction( + // Params: + inputParam UserData, + + // Accounts: + userAccount solanago.PublicKey, + dataAccountAccount solanago.PublicKey, +) (solanago.Instruction, error) { + buf__ := new(bytes.Buffer) + enc__ := binary.NewBorshEncoder(buf__) + + // Encode the instruction discriminator. + err := enc__.WriteBytes(Instruction_UpdateUserData[:], false) + if err != nil { + return nil, fmt.Errorf("failed to write instruction discriminator: %w", err) + } + { + // Serialize `inputParam`: + err = enc__.Encode(inputParam) + if err != nil { + return nil, errors.NewField("inputParam", err) + } + } + accounts__ := solanago.AccountMetaSlice{} + + // Add the accounts to the instruction. + { + // Account 0 "user": Writable, Signer, Required + accounts__.Append(solanago.NewAccountMeta(userAccount, true, true)) + // Account 1 "data_account": Writable, Non-signer, Required + accounts__.Append(solanago.NewAccountMeta(dataAccountAccount, true, false)) + } + + // Create the instruction. + return solanago.NewInstruction( + ProgramID, + accounts__, + buf__.Bytes(), + ), nil +} + +type GetMultipleReservesInstruction struct{} + +func (obj *GetMultipleReservesInstruction) GetDiscriminator() []byte { + return Instruction_GetMultipleReserves[:] +} + +// UnmarshalWithDecoder unmarshals the GetMultipleReservesInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *GetMultipleReservesInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "GetMultipleReservesInstruction", err) + } + if discriminator != Instruction_GetMultipleReserves { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "GetMultipleReservesInstruction", Instruction_GetMultipleReserves, discriminator) + } + return nil +} + +func (obj *GetMultipleReservesInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + return []uint8{}, nil +} + +func (obj *GetMultipleReservesInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + return nil +} + +func (obj *GetMultipleReservesInstruction) GetAccountKeys() []solanago.PublicKey { + return []solanago.PublicKey{} +} + +// Unmarshal unmarshals the GetMultipleReservesInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *GetMultipleReservesInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling GetMultipleReservesInstruction: %w", err) + } + return nil +} + +// UnmarshalGetMultipleReservesInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalGetMultipleReservesInstruction(buf []byte) (*GetMultipleReservesInstruction, error) { + obj := new(GetMultipleReservesInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type GetReservesInstruction struct{} + +func (obj *GetReservesInstruction) GetDiscriminator() []byte { + return Instruction_GetReserves[:] +} + +// UnmarshalWithDecoder unmarshals the GetReservesInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *GetReservesInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "GetReservesInstruction", err) + } + if discriminator != Instruction_GetReserves { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "GetReservesInstruction", Instruction_GetReserves, discriminator) + } + return nil +} + +func (obj *GetReservesInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + return []uint8{}, nil +} + +func (obj *GetReservesInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + return nil +} + +func (obj *GetReservesInstruction) GetAccountKeys() []solanago.PublicKey { + return []solanago.PublicKey{} +} + +// Unmarshal unmarshals the GetReservesInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *GetReservesInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling GetReservesInstruction: %w", err) + } + return nil +} + +// UnmarshalGetReservesInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalGetReservesInstruction(buf []byte) (*GetReservesInstruction, error) { + obj := new(GetReservesInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type GetTupleReservesInstruction struct{} + +func (obj *GetTupleReservesInstruction) GetDiscriminator() []byte { + return Instruction_GetTupleReserves[:] +} + +// UnmarshalWithDecoder unmarshals the GetTupleReservesInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *GetTupleReservesInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "GetTupleReservesInstruction", err) + } + if discriminator != Instruction_GetTupleReserves { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "GetTupleReservesInstruction", Instruction_GetTupleReserves, discriminator) + } + return nil +} + +func (obj *GetTupleReservesInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + return []uint8{}, nil +} + +func (obj *GetTupleReservesInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + return nil +} + +func (obj *GetTupleReservesInstruction) GetAccountKeys() []solanago.PublicKey { + return []solanago.PublicKey{} +} + +// Unmarshal unmarshals the GetTupleReservesInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *GetTupleReservesInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling GetTupleReservesInstruction: %w", err) + } + return nil +} + +// UnmarshalGetTupleReservesInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalGetTupleReservesInstruction(buf []byte) (*GetTupleReservesInstruction, error) { + obj := new(GetTupleReservesInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type InitializeDataAccountInstruction struct { + Input UserData `json:"input"` + + // Accounts: + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + SystemProgram solanago.PublicKey `json:"system_program"` +} + +func (obj *InitializeDataAccountInstruction) GetDiscriminator() []byte { + return Instruction_InitializeDataAccount[:] +} + +// UnmarshalWithDecoder unmarshals the InitializeDataAccountInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *InitializeDataAccountInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "InitializeDataAccountInstruction", err) + } + if discriminator != Instruction_InitializeDataAccount { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "InitializeDataAccountInstruction", Instruction_InitializeDataAccount, discriminator) + } + // Deserialize `Input`: + err = decoder.Decode(&obj.Input) + if err != nil { + return err + } + return nil +} + +func (obj *InitializeDataAccountInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from system_program account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "system_program", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *InitializeDataAccountInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 3 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 3, len(indices)) + } + indexOffset := 0 + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set system_program account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "system_program", len(accountKeys)-1) + } + obj.SystemProgram = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *InitializeDataAccountInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.DataAccount) + keys = append(keys, obj.User) + keys = append(keys, obj.SystemProgram) + return keys +} + +// Unmarshal unmarshals the InitializeDataAccountInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *InitializeDataAccountInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling InitializeDataAccountInstruction: %w", err) + } + return nil +} + +// UnmarshalInitializeDataAccountInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalInitializeDataAccountInstruction(buf []byte) (*InitializeDataAccountInstruction, error) { + obj := new(InitializeDataAccountInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type LogAccessInstruction struct { + Message string `json:"message"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserSigner bool `json:"user_signer"` +} + +func (obj *LogAccessInstruction) GetDiscriminator() []byte { + return Instruction_LogAccess[:] +} + +// UnmarshalWithDecoder unmarshals the LogAccessInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *LogAccessInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "LogAccessInstruction", err) + } + if discriminator != Instruction_LogAccess { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "LogAccessInstruction", Instruction_LogAccess, discriminator) + } + // Deserialize `Message`: + err = decoder.Decode(&obj.Message) + if err != nil { + return err + } + return nil +} + +func (obj *LogAccessInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *LogAccessInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 1 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 1, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *LogAccessInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + return keys +} + +// Unmarshal unmarshals the LogAccessInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *LogAccessInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling LogAccessInstruction: %w", err) + } + return nil +} + +// UnmarshalLogAccessInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalLogAccessInstruction(buf []byte) (*LogAccessInstruction, error) { + obj := new(LogAccessInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type OnReportInstruction struct { + Metadata []byte `json:"_metadata"` + Payload []byte `json:"payload"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` + SystemProgram solanago.PublicKey `json:"system_program"` +} + +func (obj *OnReportInstruction) GetDiscriminator() []byte { + return Instruction_OnReport[:] +} + +// UnmarshalWithDecoder unmarshals the OnReportInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *OnReportInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "OnReportInstruction", err) + } + if discriminator != Instruction_OnReport { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "OnReportInstruction", Instruction_OnReport, discriminator) + } + // Deserialize `Metadata`: + err = decoder.Decode(&obj.Metadata) + if err != nil { + return err + } + // Deserialize `Payload`: + err = decoder.Decode(&obj.Payload) + if err != nil { + return err + } + return nil +} + +func (obj *OnReportInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + // Decode from system_program account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "system_program", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *OnReportInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 3 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 3, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + // Set system_program account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "system_program", len(accountKeys)-1) + } + obj.SystemProgram = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *OnReportInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + keys = append(keys, obj.DataAccount) + keys = append(keys, obj.SystemProgram) + return keys +} + +// Unmarshal unmarshals the OnReportInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *OnReportInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling OnReportInstruction: %w", err) + } + return nil +} + +// UnmarshalOnReportInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalOnReportInstruction(buf []byte) (*OnReportInstruction, error) { + obj := new(OnReportInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type UpdateKeyValueDataInstruction struct { + Key string `json:"key"` + Value string `json:"value"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` +} + +func (obj *UpdateKeyValueDataInstruction) GetDiscriminator() []byte { + return Instruction_UpdateKeyValueData[:] +} + +// UnmarshalWithDecoder unmarshals the UpdateKeyValueDataInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *UpdateKeyValueDataInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "UpdateKeyValueDataInstruction", err) + } + if discriminator != Instruction_UpdateKeyValueData { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "UpdateKeyValueDataInstruction", Instruction_UpdateKeyValueData, discriminator) + } + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return err + } + // Deserialize `Value`: + err = decoder.Decode(&obj.Value) + if err != nil { + return err + } + return nil +} + +func (obj *UpdateKeyValueDataInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *UpdateKeyValueDataInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 2 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 2, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *UpdateKeyValueDataInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + keys = append(keys, obj.DataAccount) + return keys +} + +// Unmarshal unmarshals the UpdateKeyValueDataInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *UpdateKeyValueDataInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UpdateKeyValueDataInstruction: %w", err) + } + return nil +} + +// UnmarshalUpdateKeyValueDataInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalUpdateKeyValueDataInstruction(buf []byte) (*UpdateKeyValueDataInstruction, error) { + obj := new(UpdateKeyValueDataInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +type UpdateUserDataInstruction struct { + Input UserData `json:"input"` + + // Accounts: + User solanago.PublicKey `json:"user"` + UserWritable bool `json:"user_writable"` + UserSigner bool `json:"user_signer"` + DataAccount solanago.PublicKey `json:"data_account"` + DataAccountWritable bool `json:"data_account_writable"` +} + +func (obj *UpdateUserDataInstruction) GetDiscriminator() []byte { + return Instruction_UpdateUserData[:] +} + +// UnmarshalWithDecoder unmarshals the UpdateUserDataInstruction from Borsh-encoded bytes prefixed with its discriminator. +func (obj *UpdateUserDataInstruction) UnmarshalWithDecoder(decoder *binary.Decoder) error { + var err error + // Read the discriminator and check it against the expected value: + discriminator, err := decoder.ReadDiscriminator() + if err != nil { + return fmt.Errorf("failed to read instruction discriminator for %s: %w", "UpdateUserDataInstruction", err) + } + if discriminator != Instruction_UpdateUserData { + return fmt.Errorf("instruction discriminator mismatch for %s: expected %s, got %s", "UpdateUserDataInstruction", Instruction_UpdateUserData, discriminator) + } + // Deserialize `Input`: + err = decoder.Decode(&obj.Input) + if err != nil { + return err + } + return nil +} + +func (obj *UpdateUserDataInstruction) UnmarshalAccountIndices(buf []byte) ([]uint8, error) { + // UnmarshalAccountIndices decodes account indices from Borsh-encoded bytes + decoder := binary.NewBorshDecoder(buf) + indices := make([]uint8, 0) + index := uint8(0) + var err error + // Decode from user account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "user", err) + } + indices = append(indices, index) + // Decode from data_account account index + index = uint8(0) + err = decoder.Decode(&index) + if err != nil { + return nil, fmt.Errorf("failed to decode %s account index: %w", "data_account", err) + } + indices = append(indices, index) + return indices, nil +} + +func (obj *UpdateUserDataInstruction) PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error { + // PopulateFromAccountIndices sets account public keys from indices and account keys array + if len(indices) != 2 { + return fmt.Errorf("mismatch between expected accounts (%d) and provided indices (%d)", 2, len(indices)) + } + indexOffset := 0 + // Set user account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "user", len(accountKeys)-1) + } + obj.User = accountKeys[indices[indexOffset]] + indexOffset++ + // Set data_account account from index + if indices[indexOffset] >= uint8(len(accountKeys)) { + return fmt.Errorf("account index %d for %s is out of bounds (max: %d)", indices[indexOffset], "data_account", len(accountKeys)-1) + } + obj.DataAccount = accountKeys[indices[indexOffset]] + indexOffset++ + return nil +} + +func (obj *UpdateUserDataInstruction) GetAccountKeys() []solanago.PublicKey { + keys := make([]solanago.PublicKey, 0) + keys = append(keys, obj.User) + keys = append(keys, obj.DataAccount) + return keys +} + +// Unmarshal unmarshals the UpdateUserDataInstruction from Borsh-encoded bytes prefixed with the discriminator. +func (obj *UpdateUserDataInstruction) Unmarshal(buf []byte) error { + var err error + err = obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UpdateUserDataInstruction: %w", err) + } + return nil +} + +// UnmarshalUpdateUserDataInstruction unmarshals the instruction from Borsh-encoded bytes prefixed with the discriminator. +func UnmarshalUpdateUserDataInstruction(buf []byte) (*UpdateUserDataInstruction, error) { + obj := new(UpdateUserDataInstruction) + var err error + err = obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +// Instruction interface defines common methods for all instruction types +type Instruction interface { + GetDiscriminator() []byte + + UnmarshalWithDecoder(decoder *binary.Decoder) error + + UnmarshalAccountIndices(buf []byte) ([]uint8, error) + + PopulateFromAccountIndices(indices []uint8, accountKeys []solanago.PublicKey) error + + GetAccountKeys() []solanago.PublicKey +} + +// ParseInstruction parses instruction data and optionally populates accounts +// If accountIndicesData is nil or empty, accounts will not be populated +func ParseInstruction(instructionData []byte, accountIndicesData []byte, accountKeys []solanago.PublicKey) (Instruction, error) { + // Validate inputs + if len(instructionData) < 8 { + return nil, fmt.Errorf("instruction data too short: expected at least 8 bytes, got %d", len(instructionData)) + } + // Extract discriminator (TypeID for consistent equality with generated constants) + discriminator := binary.TypeIDFromBytes(instructionData[0:8]) + // Parse based on discriminator + switch discriminator { + case Instruction_GetMultipleReserves: + instruction := new(GetMultipleReservesInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as GetMultipleReservesInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_GetReserves: + instruction := new(GetReservesInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as GetReservesInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_GetTupleReserves: + instruction := new(GetTupleReservesInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as GetTupleReservesInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_InitializeDataAccount: + instruction := new(InitializeDataAccountInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as InitializeDataAccountInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_LogAccess: + instruction := new(LogAccessInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as LogAccessInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_OnReport: + instruction := new(OnReportInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as OnReportInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_UpdateKeyValueData: + instruction := new(UpdateKeyValueDataInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as UpdateKeyValueDataInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + case Instruction_UpdateUserData: + instruction := new(UpdateUserDataInstruction) + decoder := binary.NewBorshDecoder(instructionData) + err := instruction.UnmarshalWithDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal instruction as UpdateUserDataInstruction: %w", err) + } + if accountIndicesData != nil && len(accountIndicesData) > 0 { + indices, err := instruction.UnmarshalAccountIndices(accountIndicesData) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal account indices: %w", err) + } + err = instruction.PopulateFromAccountIndices(indices, accountKeys) + if err != nil { + return nil, fmt.Errorf("failed to populate accounts: %w", err) + } + } + return instruction, nil + default: + return nil, fmt.Errorf("unknown instruction discriminator: %s", binary.FormatDiscriminator(discriminator)) + } +} + +// ParseInstructionTyped parses instruction data and returns a specific instruction type // T must implement the Instruction interface +func ParseInstructionTyped[T Instruction](instructionData []byte, accountIndicesData []byte, accountKeys []solanago.PublicKey) (T, error) { + instruction, err := ParseInstruction(instructionData, accountIndicesData, accountKeys) + if err != nil { + return *new(T), err + } + typed, ok := instruction.(T) + if !ok { + return *new(T), fmt.Errorf("instruction is not of expected type") + } + return typed, nil +} + +// ParseInstructionWithoutAccounts parses instruction data without account information +func ParseInstructionWithoutAccounts(instructionData []byte) (Instruction, error) { + return ParseInstruction(instructionData, nil, []solanago.PublicKey{}) +} + +// ParseInstructionWithAccounts parses instruction data with account information +func ParseInstructionWithAccounts(instructionData []byte, accountIndicesData []byte, accountKeys []solanago.PublicKey) (Instruction, error) { + return ParseInstruction(instructionData, accountIndicesData, accountKeys) +} diff --git a/cmd/generate-bindings/solana/testdata/data_storage/program_id.go b/cmd/generate-bindings/solana/testdata/data_storage/program_id.go new file mode 100644 index 00000000..1e0b7950 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/program_id.go @@ -0,0 +1,8 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains the program ID. + +package data_storage + +import solanago "github.com/gagliardetto/solana-go" + +var ProgramID = solanago.MustPublicKeyFromBase58("ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL") diff --git a/cmd/generate-bindings/solana/testdata/data_storage/tests_test.go b/cmd/generate-bindings/solana/testdata/data_storage/tests_test.go new file mode 100644 index 00000000..704cda06 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/tests_test.go @@ -0,0 +1,4 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains tests. + +package data_storage diff --git a/cmd/generate-bindings/solana/testdata/data_storage/types.go b/cmd/generate-bindings/solana/testdata/data_storage/types.go new file mode 100644 index 00000000..157f5199 --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/data_storage/types.go @@ -0,0 +1,763 @@ +// Code generated by https://github.com/gagliardetto/anchor-go. DO NOT EDIT. +// This file contains parsers for the types defined in the IDL. + +package data_storage + +import ( + "bytes" + "fmt" + errors "github.com/gagliardetto/anchor-go/errors" + binary "github.com/gagliardetto/binary" + solanago "github.com/gagliardetto/solana-go" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + solana "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana" + bindings "github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana/bindings" + cre "github.com/smartcontractkit/cre-sdk-go/cre" +) + +type AccessLogged struct { + Caller solanago.PublicKey `json:"caller"` + Message string `json:"message"` +} + +func (obj AccessLogged) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Caller`: + err = encoder.Encode(obj.Caller) + if err != nil { + return errors.NewField("Caller", err) + } + // Serialize `Message`: + err = encoder.Encode(obj.Message) + if err != nil { + return errors.NewField("Message", err) + } + return nil +} + +func (obj AccessLogged) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding AccessLogged: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *AccessLogged) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Caller`: + err = decoder.Decode(&obj.Caller) + if err != nil { + return errors.NewField("Caller", err) + } + // Deserialize `Message`: + err = decoder.Decode(&obj.Message) + if err != nil { + return errors.NewField("Message", err) + } + return nil +} + +func (obj *AccessLogged) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling AccessLogged: %w", err) + } + return nil +} + +func UnmarshalAccessLogged(buf []byte) (*AccessLogged, error) { + obj := new(AccessLogged) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeAccessLoggedStruct(in AccessLogged) ([]byte, error) { + return in.Marshal() +} + +// WriteReportFromAccessLogged encodes the input struct, hashes the provided accounts, // generates a signed report, and submits it via WriteReport. // // remainingAccounts must follow the keystone-forwarder account layout: // - Index 0: forwarderState – the forwarder program's state account. // - Index 1: forwarderAuthority – PDA derived from seeds // ["forwarder", forwarderState, receiverProgram] under the forwarder program ID. // - Index 2+: receiver-specific accounts required by the target program. // // The full slice is hashed (via CalculateAccountsHash) into the report and forwarded // as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1 // before CPI-ing into the receiver, so they must be present and correctly ordered. +func (c *DataStorage) WriteReportFromAccessLogged( + runtime cre.Runtime, + input AccessLogged, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeAccessLoggedStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +func (c *DataStorage) WriteReportFromAccessLoggeds( + runtime cre.Runtime, + inputs []AccessLogged, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + elements := make([][]byte, len(inputs)) + for i, input := range inputs { + encoded, err := c.Codec.EncodeAccessLoggedStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + elements[i] = encoded + } + return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +} + +type DataAccount struct { + Sender string `json:"sender"` + Key string `json:"key"` + Value string `json:"value"` +} + +func (obj DataAccount) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Sender`: + err = encoder.Encode(obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Serialize `Key`: + err = encoder.Encode(obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Serialize `Value`: + err = encoder.Encode(obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj DataAccount) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding DataAccount: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *DataAccount) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Sender`: + err = decoder.Decode(&obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Deserialize `Value`: + err = decoder.Decode(&obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj *DataAccount) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling DataAccount: %w", err) + } + return nil +} + +func UnmarshalDataAccount(buf []byte) (*DataAccount, error) { + obj := new(DataAccount) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeDataAccountStruct(in DataAccount) ([]byte, error) { + return in.Marshal() +} + +// WriteReportFromDataAccount encodes the input struct, hashes the provided accounts, // generates a signed report, and submits it via WriteReport. // // remainingAccounts must follow the keystone-forwarder account layout: // - Index 0: forwarderState – the forwarder program's state account. // - Index 1: forwarderAuthority – PDA derived from seeds // ["forwarder", forwarderState, receiverProgram] under the forwarder program ID. // - Index 2+: receiver-specific accounts required by the target program. // // The full slice is hashed (via CalculateAccountsHash) into the report and forwarded // as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1 // before CPI-ing into the receiver, so they must be present and correctly ordered. +func (c *DataStorage) WriteReportFromDataAccount( + runtime cre.Runtime, + input DataAccount, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeDataAccountStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +func (c *DataStorage) WriteReportFromDataAccounts( + runtime cre.Runtime, + inputs []DataAccount, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + elements := make([][]byte, len(inputs)) + for i, input := range inputs { + encoded, err := c.Codec.EncodeDataAccountStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + elements[i] = encoded + } + return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +} + +type DynamicEvent struct { + Key string `json:"key"` + UserData UserData `json:"user_data"` + Sender string `json:"sender"` + Metadata []byte `json:"metadata"` + MetadataArray [][]byte `json:"metadata_array"` +} + +func (obj DynamicEvent) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Key`: + err = encoder.Encode(obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Serialize `UserData`: + err = encoder.Encode(obj.UserData) + if err != nil { + return errors.NewField("UserData", err) + } + // Serialize `Sender`: + err = encoder.Encode(obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Serialize `Metadata`: + err = encoder.Encode(obj.Metadata) + if err != nil { + return errors.NewField("Metadata", err) + } + // Serialize `MetadataArray`: + err = encoder.Encode(obj.MetadataArray) + if err != nil { + return errors.NewField("MetadataArray", err) + } + return nil +} + +func (obj DynamicEvent) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding DynamicEvent: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *DynamicEvent) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Deserialize `UserData`: + err = decoder.Decode(&obj.UserData) + if err != nil { + return errors.NewField("UserData", err) + } + // Deserialize `Sender`: + err = decoder.Decode(&obj.Sender) + if err != nil { + return errors.NewField("Sender", err) + } + // Deserialize `Metadata`: + err = decoder.Decode(&obj.Metadata) + if err != nil { + return errors.NewField("Metadata", err) + } + // Deserialize `MetadataArray`: + err = decoder.Decode(&obj.MetadataArray) + if err != nil { + return errors.NewField("MetadataArray", err) + } + return nil +} + +func (obj *DynamicEvent) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling DynamicEvent: %w", err) + } + return nil +} + +func UnmarshalDynamicEvent(buf []byte) (*DynamicEvent, error) { + obj := new(DynamicEvent) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeDynamicEventStruct(in DynamicEvent) ([]byte, error) { + return in.Marshal() +} + +// WriteReportFromDynamicEvent encodes the input struct, hashes the provided accounts, // generates a signed report, and submits it via WriteReport. // // remainingAccounts must follow the keystone-forwarder account layout: // - Index 0: forwarderState – the forwarder program's state account. // - Index 1: forwarderAuthority – PDA derived from seeds // ["forwarder", forwarderState, receiverProgram] under the forwarder program ID. // - Index 2+: receiver-specific accounts required by the target program. // // The full slice is hashed (via CalculateAccountsHash) into the report and forwarded // as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1 // before CPI-ing into the receiver, so they must be present and correctly ordered. +func (c *DataStorage) WriteReportFromDynamicEvent( + runtime cre.Runtime, + input DynamicEvent, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeDynamicEventStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +func (c *DataStorage) WriteReportFromDynamicEvents( + runtime cre.Runtime, + inputs []DynamicEvent, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + elements := make([][]byte, len(inputs)) + for i, input := range inputs { + encoded, err := c.Codec.EncodeDynamicEventStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + elements[i] = encoded + } + return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +} + +type NoFields struct{} + +func (obj NoFields) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + return nil +} + +func (obj NoFields) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding NoFields: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *NoFields) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + return nil +} + +func (obj *NoFields) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling NoFields: %w", err) + } + return nil +} + +func UnmarshalNoFields(buf []byte) (*NoFields, error) { + obj := new(NoFields) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeNoFieldsStruct(in NoFields) ([]byte, error) { + return in.Marshal() +} + +// WriteReportFromNoFields encodes the input struct, hashes the provided accounts, // generates a signed report, and submits it via WriteReport. // // remainingAccounts must follow the keystone-forwarder account layout: // - Index 0: forwarderState – the forwarder program's state account. // - Index 1: forwarderAuthority – PDA derived from seeds // ["forwarder", forwarderState, receiverProgram] under the forwarder program ID. // - Index 2+: receiver-specific accounts required by the target program. // // The full slice is hashed (via CalculateAccountsHash) into the report and forwarded // as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1 // before CPI-ing into the receiver, so they must be present and correctly ordered. +func (c *DataStorage) WriteReportFromNoFields( + runtime cre.Runtime, + input NoFields, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeNoFieldsStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +func (c *DataStorage) WriteReportFromNoFieldss( + runtime cre.Runtime, + inputs []NoFields, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + elements := make([][]byte, len(inputs)) + for i, input := range inputs { + encoded, err := c.Codec.EncodeNoFieldsStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + elements[i] = encoded + } + return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +} + +type UpdateReserves struct { + TotalMinted uint64 `json:"total_minted"` + TotalReserve uint64 `json:"total_reserve"` +} + +func (obj UpdateReserves) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `TotalMinted`: + err = encoder.Encode(obj.TotalMinted) + if err != nil { + return errors.NewField("TotalMinted", err) + } + // Serialize `TotalReserve`: + err = encoder.Encode(obj.TotalReserve) + if err != nil { + return errors.NewField("TotalReserve", err) + } + return nil +} + +func (obj UpdateReserves) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding UpdateReserves: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *UpdateReserves) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `TotalMinted`: + err = decoder.Decode(&obj.TotalMinted) + if err != nil { + return errors.NewField("TotalMinted", err) + } + // Deserialize `TotalReserve`: + err = decoder.Decode(&obj.TotalReserve) + if err != nil { + return errors.NewField("TotalReserve", err) + } + return nil +} + +func (obj *UpdateReserves) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UpdateReserves: %w", err) + } + return nil +} + +func UnmarshalUpdateReserves(buf []byte) (*UpdateReserves, error) { + obj := new(UpdateReserves) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeUpdateReservesStruct(in UpdateReserves) ([]byte, error) { + return in.Marshal() +} + +// WriteReportFromUpdateReserves encodes the input struct, hashes the provided accounts, // generates a signed report, and submits it via WriteReport. // // remainingAccounts must follow the keystone-forwarder account layout: // - Index 0: forwarderState – the forwarder program's state account. // - Index 1: forwarderAuthority – PDA derived from seeds // ["forwarder", forwarderState, receiverProgram] under the forwarder program ID. // - Index 2+: receiver-specific accounts required by the target program. // // The full slice is hashed (via CalculateAccountsHash) into the report and forwarded // as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1 // before CPI-ing into the receiver, so they must be present and correctly ordered. +func (c *DataStorage) WriteReportFromUpdateReserves( + runtime cre.Runtime, + input UpdateReserves, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeUpdateReservesStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +func (c *DataStorage) WriteReportFromUpdateReservess( + runtime cre.Runtime, + inputs []UpdateReserves, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + elements := make([][]byte, len(inputs)) + for i, input := range inputs { + encoded, err := c.Codec.EncodeUpdateReservesStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + elements[i] = encoded + } + return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +} + +type UserData struct { + Key string `json:"key"` + Value string `json:"value"` +} + +func (obj UserData) MarshalWithEncoder(encoder *binary.Encoder) (err error) { + // Serialize `Key`: + err = encoder.Encode(obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Serialize `Value`: + err = encoder.Encode(obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj UserData) Marshal() ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := binary.NewBorshEncoder(buf) + err := obj.MarshalWithEncoder(encoder) + if err != nil { + return nil, fmt.Errorf("error while encoding UserData: %w", err) + } + return buf.Bytes(), nil +} + +func (obj *UserData) UnmarshalWithDecoder(decoder *binary.Decoder) (err error) { + // Deserialize `Key`: + err = decoder.Decode(&obj.Key) + if err != nil { + return errors.NewField("Key", err) + } + // Deserialize `Value`: + err = decoder.Decode(&obj.Value) + if err != nil { + return errors.NewField("Value", err) + } + return nil +} + +func (obj *UserData) Unmarshal(buf []byte) error { + err := obj.UnmarshalWithDecoder(binary.NewBorshDecoder(buf)) + if err != nil { + return fmt.Errorf("error while unmarshaling UserData: %w", err) + } + return nil +} + +func UnmarshalUserData(buf []byte) (*UserData, error) { + obj := new(UserData) + err := obj.Unmarshal(buf) + if err != nil { + return nil, err + } + return obj, nil +} + +func (c *Codec) EncodeUserDataStruct(in UserData) ([]byte, error) { + return in.Marshal() +} + +// WriteReportFromUserData encodes the input struct, hashes the provided accounts, // generates a signed report, and submits it via WriteReport. // // remainingAccounts must follow the keystone-forwarder account layout: // - Index 0: forwarderState – the forwarder program's state account. // - Index 1: forwarderAuthority – PDA derived from seeds // ["forwarder", forwarderState, receiverProgram] under the forwarder program ID. // - Index 2+: receiver-specific accounts required by the target program. // // The full slice is hashed (via CalculateAccountsHash) into the report and forwarded // as WriteCreReportRequest.RemainingAccounts. The on-chain forwarder strips indices 0 and 1 // before CPI-ing into the receiver, so they must be present and correctly ordered. +func (c *DataStorage) WriteReportFromUserData( + runtime cre.Runtime, + input UserData, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + encodedInput, err := c.Codec.EncodeUserDataStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + encodedAccountList := bindings.CalculateAccountsHash(remainingAccounts) + + fwdReport := bindings.ForwarderReport{ + AccountHash: encodedAccountList, + Payload: encodedInput, + } + encodedFwdReport, err := fwdReport.Marshal() + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + + promise := runtime.GenerateReport(&sdk.ReportRequest{ + EncodedPayload: encodedFwdReport, + EncoderName: "solana", + HashingAlgo: "keccak256", + SigningAlgo: "ecdsa", + }) + + return cre.ThenPromise(promise, func(report *cre.Report) cre.Promise[*solana.WriteReportReply] { + return c.client.WriteReport(runtime, &solana.WriteCreReportRequest{ + ComputeConfig: computeConfig, + Receiver: ProgramID.Bytes(), + RemainingAccounts: remainingAccounts, + Report: report, + }) + }) +} + +func (c *DataStorage) WriteReportFromUserDatas( + runtime cre.Runtime, + inputs []UserData, + remainingAccounts []*solana.AccountMeta, + computeConfig *solana.ComputeConfig, +) cre.Promise[*solana.WriteReportReply] { + elements := make([][]byte, len(inputs)) + for i, input := range inputs { + encoded, err := c.Codec.EncodeUserDataStruct(input) + if err != nil { + return cre.PromiseFromResult[*solana.WriteReportReply](nil, err) + } + elements[i] = encoded + } + return c.WriteReportFromBorshEncodedVec(runtime, elements, remainingAccounts, computeConfig) +} diff --git a/cmd/generate-bindings/solana/testdata/gen/main.go b/cmd/generate-bindings/solana/testdata/gen/main.go new file mode 100644 index 00000000..30dc025e --- /dev/null +++ b/cmd/generate-bindings/solana/testdata/gen/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "log" + + "github.com/smartcontractkit/cre-cli/cmd/generate-bindings/solana" +) + +func main() { + if err := solana.GenerateBindings( + "./testdata/contracts/idl/data_storage.json", + "data_storage", + "./testdata/data_storage", + ); err != nil { + log.Fatal(err) + } +} diff --git a/cmd/root.go b/cmd/root.go index aa149602..4f16e99c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -524,6 +524,8 @@ func isLoadSettings(cmd *cobra.Command) bool { "cre account list-key": {}, "cre init": {}, "cre generate-bindings": {}, + "cre generate-bindings evm": {}, + "cre generate-bindings solana": {}, "cre completion bash": {}, "cre completion fish": {}, "cre completion powershell": {}, @@ -555,28 +557,30 @@ func isLoadSettings(cmd *cobra.Command) bool { func isLoadCredentials(cmd *cobra.Command) bool { // It is not expected to have the credentials loaded when running the following commands var excludedCommands = map[string]struct{}{ - "cre version": {}, - "cre login": {}, - "cre logout": {}, - "cre completion bash": {}, - "cre completion fish": {}, - "cre completion powershell": {}, - "cre completion zsh": {}, - "cre help": {}, - "cre generate-bindings": {}, - "cre update": {}, - "cre workflow": {}, - "cre workflow limits": {}, - "cre workflow limits export": {}, - "cre account": {}, - "cre secrets": {}, - "cre workflow build": {}, - "cre workflow hash": {}, - "cre templates": {}, - "cre templates list": {}, - "cre templates add": {}, - "cre templates remove": {}, - "cre": {}, + "cre version": {}, + "cre login": {}, + "cre logout": {}, + "cre completion bash": {}, + "cre completion fish": {}, + "cre completion powershell": {}, + "cre completion zsh": {}, + "cre help": {}, + "cre generate-bindings": {}, + "cre generate-bindings evm": {}, + "cre generate-bindings solana": {}, + "cre update": {}, + "cre workflow": {}, + "cre workflow limits": {}, + "cre workflow limits export": {}, + "cre account": {}, + "cre secrets": {}, + "cre workflow build": {}, + "cre workflow hash": {}, + "cre templates": {}, + "cre templates list": {}, + "cre templates add": {}, + "cre templates remove": {}, + "cre": {}, } _, exists := excludedCommands[cmd.CommandPath()] diff --git a/docs/cre.md b/docs/cre.md index 3f77ed01..81092ba9 100644 --- a/docs/cre.md +++ b/docs/cre.md @@ -26,7 +26,7 @@ cre [optional flags] ### SEE ALSO * [cre account](cre_account.md) - Manage account and request deploy access -* [cre generate-bindings](cre_generate-bindings.md) - Generate bindings from contract ABI +* [cre generate-bindings](cre_generate-bindings.md) - Generate bindings for contracts * [cre init](cre_init.md) - Initialize a new cre project (recommended starting point) * [cre login](cre_login.md) - Start authentication flow * [cre logout](cre_logout.md) - Revoke authentication tokens and remove local credentials diff --git a/docs/cre_generate-bindings.md b/docs/cre_generate-bindings.md index 9b6bacf9..d1f5befc 100644 --- a/docs/cre_generate-bindings.md +++ b/docs/cre_generate-bindings.md @@ -1,37 +1,15 @@ ## cre generate-bindings -Generate bindings from contract ABI +Generate bindings for contracts ### Synopsis -This command generates bindings from contract ABI files. -Supports EVM chain family with Go and TypeScript languages. -The target language is auto-detected from project files, or can be -specified explicitly with --language. -Each contract gets its own package subdirectory to avoid naming conflicts. -For example, IERC20.abi generates bindings in generated/ierc20/ package. - -Both raw ABI files (*.abi) and JSON artifact files (*.json) are supported. -For JSON files the ABI is read from the top-level "abi" field. - -``` -cre generate-bindings [optional flags] -``` - -### Examples - -``` - cre generate-bindings evm -``` +The generate-bindings command allows you to generate bindings for contracts. ### Options ``` - -a, --abi string Path to ABI directory (defaults to contracts/{chain-family}/src/abi/). Supports *.abi and *.json files - -h, --help help for generate-bindings - -l, --language string Target language: go, typescript (auto-detected from project files when omitted) - -k, --pkg string Base package name (each contract gets its own subdirectory) (default "bindings") - -p, --project-root string Path to project root directory (defaults to current directory) + -h, --help help for generate-bindings ``` ### Options inherited from parent commands @@ -40,6 +18,7 @@ cre generate-bindings [optional flags] --allow-unknown-chains Skip chain-name validation against the chain-selectors registry (for experimental chains) -e, --env string Path to .env file which contains sensitive info --non-interactive Fail instead of prompting; requires all inputs via flags + -R, --project-root string Path to the project root -E, --public-env string Path to .env.public file which contains shared, non-sensitive build config -T, --target string Use target settings from YAML config -v, --verbose Run command in VERBOSE mode @@ -48,4 +27,6 @@ cre generate-bindings [optional flags] ### SEE ALSO * [cre](cre.md) - CRE CLI tool +* [cre generate-bindings evm](cre_generate-bindings_evm.md) - Generate bindings from contract ABI +* [cre generate-bindings solana](cre_generate-bindings_solana.md) - Generate bindings from contract IDL diff --git a/docs/cre_generate-bindings_evm.md b/docs/cre_generate-bindings_evm.md new file mode 100644 index 00000000..bbe4558f --- /dev/null +++ b/docs/cre_generate-bindings_evm.md @@ -0,0 +1,51 @@ +## cre generate-bindings evm + +Generate bindings from contract ABI + +### Synopsis + +This command generates bindings from contract ABI files. +Supports EVM chain family with Go and TypeScript languages. +The target language is auto-detected from project files, or can be +specified explicitly with --language. +Each contract gets its own package subdirectory to avoid naming conflicts. +For example, IERC20.abi generates bindings in generated/ierc20/ package. + +Both raw ABI files (*.abi) and JSON artifact files (*.json) are supported. +For JSON files the ABI is read from the top-level "abi" field. + +``` +cre generate-bindings evm [optional flags] +``` + +### Examples + +``` + cre generate-bindings evm +``` + +### Options + +``` + -a, --abi string Path to ABI directory (defaults to contracts/evm/src/abi/). Supports *.abi and *.json files + -h, --help help for evm + -l, --language string Target language: go, typescript (auto-detected from project files when omitted) + -k, --pkg string Base package name (each contract gets its own subdirectory) (default "bindings") + -p, --project-root string Path to project root directory (defaults to current directory) +``` + +### Options inherited from parent commands + +``` + --allow-unknown-chains Skip chain-name validation against the chain-selectors registry (for experimental chains) + -e, --env string Path to .env file which contains sensitive info + --non-interactive Fail instead of prompting; requires all inputs via flags + -E, --public-env string Path to .env.public file which contains shared, non-sensitive build config + -T, --target string Use target settings from YAML config + -v, --verbose Run command in VERBOSE mode +``` + +### SEE ALSO + +* [cre generate-bindings](cre_generate-bindings.md) - Generate bindings for contracts + diff --git a/docs/cre_generate-bindings_solana.md b/docs/cre_generate-bindings_solana.md new file mode 100644 index 00000000..b95b9d83 --- /dev/null +++ b/docs/cre_generate-bindings_solana.md @@ -0,0 +1,46 @@ +## cre generate-bindings solana + +Generate bindings from contract IDL + +### Synopsis + +This command generates bindings from contract IDL files. +Supports Solana chain family and Go language. +Each contract gets its own package subdirectory to avoid naming conflicts. +For example, data_storage.json generates bindings in generated/data_storage/ package. + +``` +cre generate-bindings solana [optional flags] +``` + +### Examples + +``` + cre generate-bindings-solana +``` + +### Options + +``` + -h, --help help for solana + -i, --idl string Path to IDL directory (defaults to contracts/solana/src/idl/) + -l, --language string Target language (go) (default "go") + -o, --out string Path to output directory (defaults to contracts/solana/src/generated/) + -p, --project-root string Path to project root directory (defaults to current directory) +``` + +### Options inherited from parent commands + +``` + --allow-unknown-chains Skip chain-name validation against the chain-selectors registry (for experimental chains) + -e, --env string Path to .env file which contains sensitive info + --non-interactive Fail instead of prompting; requires all inputs via flags + -E, --public-env string Path to .env.public file which contains shared, non-sensitive build config + -T, --target string Use target settings from YAML config + -v, --verbose Run command in VERBOSE mode +``` + +### SEE ALSO + +* [cre generate-bindings](cre_generate-bindings.md) - Generate bindings for contracts + diff --git a/go.mod b/go.mod index f2b862ff..a1b4b70d 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,13 @@ require ( github.com/charmbracelet/bubbletea v1.3.6 github.com/charmbracelet/huh v0.8.0 github.com/charmbracelet/lipgloss v1.1.0 + github.com/dave/jennifer v1.7.1 + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc github.com/denisbrodbeck/machineid v1.0.1 github.com/ethereum/go-ethereum v1.17.3 + github.com/gagliardetto/anchor-go v1.0.0 + github.com/gagliardetto/binary v0.8.0 + github.com/gagliardetto/solana-go v1.13.0 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.30.2 @@ -27,13 +32,14 @@ require ( github.com/smartcontractkit/chainlink-common v0.11.2-0.20260520194751-11a4f360f4e2 github.com/smartcontractkit/chainlink-common/keystore v1.1.0 github.com/smartcontractkit/chainlink-evm/gethwrappers v0.0.0-20260512150409-b4068bf735e6 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260521152427-d3f6dc93de42 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260522145417-85c85baa73cf github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20260323124644-faea187e6997 github.com/smartcontractkit/chainlink-testing-framework/seth v1.51.5 github.com/smartcontractkit/chainlink/deployment v0.0.0-20260521170940-67f9a4b233f8 github.com/smartcontractkit/chainlink/v2 v2.29.1-cre-beta.0.0.20260521170940-67f9a4b233f8 github.com/smartcontractkit/cre-sdk-go v1.11.0 github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm v1.0.0-beta.12 + github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.0-beta.1 github.com/smartcontractkit/mcms v0.45.0 github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20251120172354-e8ec0386b06c github.com/spf13/cobra v1.10.2 @@ -42,6 +48,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/test-go/testify v1.1.4 go.uber.org/zap v1.28.0 + golang.org/x/mod v0.36.0 golang.org/x/term v0.43.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v2 v2.4.0 @@ -135,7 +142,6 @@ require ( github.com/creachadair/jrpc2 v1.2.0 // indirect github.com/creachadair/mds v0.13.4 // indirect github.com/danieljoos/wincred v1.2.1 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dchest/siphash v1.2.3 // indirect github.com/deckarep/golang-set/v2 v2.9.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.1 // indirect @@ -159,8 +165,6 @@ require ( github.com/fsnotify/fsnotify v1.10.1 // indirect github.com/fxamacker/cbor/v2 v2.9.2 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect - github.com/gagliardetto/binary v0.8.0 // indirect - github.com/gagliardetto/solana-go v1.13.0 // indirect github.com/gagliardetto/treeout v0.1.4 // indirect github.com/getsentry/sentry-go v0.27.0 // indirect github.com/gin-contrib/sessions v0.0.5 // indirect @@ -347,6 +351,7 @@ require ( github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/tklauser/go-sysconf v0.4.0 // indirect github.com/tklauser/numcpus v0.12.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect @@ -397,7 +402,6 @@ require ( golang.org/x/arch v0.11.0 // indirect golang.org/x/crypto v0.51.0 // indirect golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect - golang.org/x/mod v0.36.0 // indirect golang.org/x/net v0.54.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sync v0.20.0 // indirect diff --git a/go.sum b/go.sum index 85e83cf5..798e43aa 100644 --- a/go.sum +++ b/go.sum @@ -339,6 +339,8 @@ github.com/danieljoos/wincred v1.2.1 h1:dl9cBrupW8+r5250DYkYxocLeZ1Y4vB1kxgtjxw8 github.com/danieljoos/wincred v1.2.1/go.mod h1:uGaFL9fDn3OLTvzCGulzE+SzjEe5NGlh5FdCcyfPwps= github.com/danielkov/gin-helmet v0.0.0-20171108135313-1387e224435e h1:5jVSh2l/ho6ajWhSPNN84eHEdq3dp0T7+f6r3Tc6hsk= github.com/danielkov/gin-helmet v0.0.0-20171108135313-1387e224435e/go.mod h1:IJgIiGUARc4aOr4bOQ85klmjsShkEEfiRc6q/yBSfo8= +github.com/dave/jennifer v1.7.1 h1:B4jJJDHelWcDhlRQxWeo0Npa/pYKBLrirAQoTN45txo= +github.com/dave/jennifer v1.7.1/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -429,6 +431,8 @@ github.com/fxamacker/cbor/v2 v2.9.2 h1:X4Ksno9+x3cz0TZv69ec1hxP/+tymuR8PXQJyDwfh github.com/fxamacker/cbor/v2 v2.9.2/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/gagliardetto/anchor-go v1.0.0 h1:YNt9I/9NOrNzz5uuzfzByAcbp39Ft07w63iPqC/wi34= +github.com/gagliardetto/anchor-go v1.0.0/go.mod h1:X6c9bx9JnmwNiyy8hmV5pAsq1c/zzPvkdzeq9/qmlCg= github.com/gagliardetto/binary v0.8.0 h1:U9ahc45v9HW0d15LoN++vIXSJyqR/pWw8DDlhd7zvxg= github.com/gagliardetto/binary v0.8.0/go.mod h1:2tfj51g5o9dnvsc+fL3Jxr22MuWzYXwx9wEoN0XQ7/c= github.com/gagliardetto/gofuzz v1.2.2 h1:XL/8qDMzcgvR4+CyRQW9UGdwPRPMHVJfqQ/uMvSUuQw= @@ -1192,8 +1196,8 @@ github.com/smartcontractkit/chainlink-protos/chainlink-ccv/message-discovery v0. github.com/smartcontractkit/chainlink-protos/chainlink-ccv/message-discovery v0.0.0-20251211142334-5c3421fe2c8d/go.mod h1:ATjAPIVJibHRcIfiG47rEQkUIOoYa6KDvWj3zwCAw6g= github.com/smartcontractkit/chainlink-protos/chainlink-ccv/verifier v0.0.0-20251211142334-5c3421fe2c8d h1:AJy55QJ/pBhXkZjc7N+ATnWfxrcjq9BI9DmdtdjwDUQ= github.com/smartcontractkit/chainlink-protos/chainlink-ccv/verifier v0.0.0-20251211142334-5c3421fe2c8d/go.mod h1:5JdppgngCOUS76p61zCinSCgOhPeYQ+OcDUuome5THQ= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260521152427-d3f6dc93de42 h1:tyEgGOaYa7PBsmDIIvctOQPUc5c4Jtd/p4X1i96K+yo= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260521152427-d3f6dc93de42/go.mod h1:vTFHTCbLui4Vn8fTmAadfE3rdnvfrDwOmMujmW857D0= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260522145417-85c85baa73cf h1:9nKluBQ0GBgnOokB8FCU1dmgZXDh22u9UPPMWFdKaYE= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260522145417-85c85baa73cf/go.mod h1:vTFHTCbLui4Vn8fTmAadfE3rdnvfrDwOmMujmW857D0= github.com/smartcontractkit/chainlink-protos/data-feeds v0.1.1-0.20260501174546-2e8846986b36 h1:SG+wAsNyAcA6Kk19ljuxi3HK9Ll2lpHik8OKoY4x7A0= github.com/smartcontractkit/chainlink-protos/data-feeds v0.1.1-0.20260501174546-2e8846986b36/go.mod h1:vL1bDgPSJjV0EqHYs4dDlR+EEE0cJchgvGLYXhwIjXY= github.com/smartcontractkit/chainlink-protos/job-distributor v0.18.0 h1:q+VDPcxWrj5k9QizSYfUOSMnDH3Sd5HvbPguZOgfXTY= @@ -1238,6 +1242,8 @@ github.com/smartcontractkit/cre-sdk-go v1.11.0 h1:E3MG0j8O9qDv6lDz71HPD3/WRKh/PX github.com/smartcontractkit/cre-sdk-go v1.11.0/go.mod h1:8SDE/e+eDAFpbRjRyKnIalUkQk9BcNbo2aLnda9BM48= github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm v1.0.0-beta.12 h1:cNUnDjtg88OXcmNQRbGgpRZ5/ZU5SqLp0wrHCxHn4M8= github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/evm v1.0.0-beta.12/go.mod h1:UNizO1xcv154THYM2hufcrWAp2rfuErC7Yeq8oJUrbk= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.0-beta.1 h1:VTpFjEhGpFJKnl+B5vH/obnWUjRKeo1dSISP+FIL/eY= +github.com/smartcontractkit/cre-sdk-go/capabilities/blockchain/solana v0.1.0-beta.1/go.mod h1:IHCd8wnWeL9yLVHkOTB2MxIuFE+Zil8Gw2jrrfPPU9Y= github.com/smartcontractkit/freeport v0.1.3-0.20250828155247-add56fa28aad h1:lgHxTHuzJIF3Vj6LSMOnjhqKgRqYW+0MV2SExtCYL1Q= github.com/smartcontractkit/freeport v0.1.3-0.20250828155247-add56fa28aad/go.mod h1:T4zH9R8R8lVWKfU7tUvYz2o2jMv1OpGCdpY2j2QZXzU= github.com/smartcontractkit/grpc-proxy v0.0.0-20240830132753-a7e17fec5ab7 h1:12ijqMM9tvYVEm+nR826WsrNi6zCKpwBhuApq127wHs= @@ -1322,6 +1328,7 @@ github.com/theodesp/go-heaps v0.0.0-20190520121037-88e35354fe0a h1:YuO+afVc3eqrj github.com/theodesp/go-heaps v0.0.0-20190520121037-88e35354fe0a/go.mod h1:/sfW47zCZp9FrtGcWyo1VjbgDaodxX9ovZvgLb/MxaA= github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -1329,6 +1336,8 @@ github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.4.0 h1:7H0uAN+7RkwWRaxhYXDLqa5V3LPrJeV8wmD9dRUgPQU= github.com/tklauser/go-sysconf v0.4.0/go.mod h1:8mTNWyog7H+MpKijp4VmKJAd2bbYQ2zuUwkYRbUArPI= diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 065446e4..35cc5946 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -58,10 +58,11 @@ const ( WorkflowLanguageWasm = "wasm" // SDK dependency versions (used by generate-bindings and go module init) - SdkVersion = "v1.11.0" - EVMCapabilitiesVersion = "v1.0.0-beta.12" - HTTPCapabilitiesVersion = "v1.3.0" - CronCapabilitiesVersion = "v1.3.0" + SdkVersion = "v1.11.0" + EVMCapabilitiesVersion = "v1.0.0-beta.12" + HTTPCapabilitiesVersion = "v1.3.0" + CronCapabilitiesVersion = "v1.3.0" + SolanaCapabilitiesVersion = "v0.1.1-0.20260210120110-1f2d5201a23f" TestAddress = "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266" TestAddress2 = "0x70997970C51812dc3A010C7d01b50e0d17dc79C8"