diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex_enum_encode_test.go b/cmd/generate-bindings/solana/anchor-go/generator/complex_enum_encode_test.go new file mode 100644 index 00000000..b92501e1 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex_enum_encode_test.go @@ -0,0 +1,86 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// complexEnumIDL returns a minimal IDL containing a two-variant complex enum +// ("MyAction") suitable for exercising gen_complexEnum codegen. +func complexEnumIDL() *idl.Idl { + enumType := &idl.IdlTypeDefTyEnum{ + Variants: idl.VariantSlice{ + { + Name: "Transfer", + Fields: idl.Some[idl.IdlDefinedFields](idl.IdlDefinedFieldsNamed{ + {Name: "amount", Ty: &idltype.U64{}}, + }), + }, + { + Name: "Burn", + Fields: idl.Some[idl.IdlDefinedFields](idl.IdlDefinedFieldsNamed{ + {Name: "quantity", Ty: &idltype.U32{}}, + }), + }, + }, + } + return &idl.Idl{ + Types: idl.IdTypeDef_slice{ + { + Name: "MyAction", + Ty: enumType, + }, + }, + } +} + +func genComplexEnumSource(t *testing.T) string { + t.Helper() + idlData := complexEnumIDL() + gen := &Generator{ + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + } + + enumType := idlData.Types[0].Ty.(*idl.IdlTypeDefTyEnum) + code, err := gen.gen_complexEnum("MyAction", nil, *enumType) + require.NoError(t, err) + + f := jen.NewFile("test") + f.Add(code) + return f.GoString() +} + +func TestComplexEnumEncode_nilInterfaceReturnsError(t *testing.T) { + src := genComplexEnumSource(t) + + assert.Contains(t, src, "case nil:", "encoder must reject nil interface values") + assert.Contains(t, src, `cannot encode nil value`, "nil case must return a descriptive error") +} + +func TestComplexEnumEncode_defaultArmReturnsError(t *testing.T) { + src := genComplexEnumSource(t) + + assert.Contains(t, src, "default:", "encoder must reject unknown variant types") + assert.Contains(t, src, `unknown variant type`, "default case must return a descriptive error") +} + +func TestComplexEnumEncode_nilPointerGuardPerVariant(t *testing.T) { + src := genComplexEnumSource(t) + + assert.Contains(t, src, "realvalue == nil", "each variant case must guard against typed nil pointers") + assert.Contains(t, src, `cannot encode nil *MyAction_Transfer`, + "Transfer variant must have a nil-pointer error message") + assert.Contains(t, src, `cannot encode nil *MyAction_Burn`, + "Burn variant must have a nil-pointer error message") + + nilGuards := strings.Count(src, "realvalue == nil") + assert.Equal(t, 2, nilGuards, "must have exactly one nil-pointer guard per variant") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/types.go b/cmd/generate-bindings/solana/anchor-go/generator/types.go index d3822fb1..04a267f2 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/types.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/types.go @@ -226,17 +226,33 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD argBody.List(Id("tmp")).Op(":=").Id(formatEnumContainerName(enumTypeName)).Block() argBody.Switch(Id("realvalue").Op(":=").Id("value").Op(".").Parens(Type())). BlockFunc(func(switchGroup *Group) { - // TODO: maybe it's from idl.Accounts ??? + switchGroup.Case(Nil()). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Qual("fmt", "Errorf").Call(Lit(enumTypeName + ": cannot encode nil value")), + ) + }) + interfaceType := g.idl.Types.ByName(enumTypeName) for variantIndex, variant := range interfaceType.Ty.(*idl.IdlTypeDefTyEnum).Variants { variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) switchGroup.Case(Op("*").Id(variantTypeNameStruct)). BlockFunc(func(caseGroup *Group) { + caseGroup.If(Id("realvalue").Op("==").Nil()).Block( + Return(Qual("fmt", "Errorf").Call(Lit(enumTypeName+": cannot encode nil *"+variantTypeNameStruct))), + ) caseGroup.Id("tmp").Dot("Enum").Op("=").Lit(variantIndex) caseGroup.Id("tmp").Dot(tools.ToCamelUpper(variant.Name)).Op("=").Op("*").Id("realvalue") }) } + + switchGroup.Default(). + BlockFunc(func(caseGroup *Group) { + caseGroup.Return( + Qual("fmt", "Errorf").Call(Lit(enumTypeName+": unknown variant type %T"), Id("value")), + ) + }) }) argBody.Return(Id("encoder").Dot("Encode").Call(Id("tmp")))