Skip to content

Commit 9610f6a

Browse files
committed
feat(golang): emit models to a different package
Closes #835. Adds four Go codegen options that let users place table model structs and enums in a Go package distinct from the queries package, optionally shared across multiple SQL blocks: - output_models_path: directory for the models file - output_models_package: Go package name for the models file - output_models_import: import path for the models package - output_models_emit: when false, skip emitting models.go and only reference an externally-emitted models package When the models package differs from the queries package, query files import the models package and reference types as `model.User` instead of `User`. Embedded model structs in synthetic Row structs and enum types in synthetic Params structs are qualified consistently. End-to-end fixtures cover both shapes: a single block with models in a sibling directory, and two query blocks that share one models package. https://claude.ai/code/session_01Nj4LbGbqFBTU1EPKKNHtNn
1 parent 977ac6d commit 9610f6a

28 files changed

Lines changed: 871 additions & 43 deletions

File tree

internal/cmd/generate.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,35 @@ func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSett
204204
// out is specified by the user, not a plugin
205205
absout := filepath.Join(g.dir, out)
206206

207+
// When the Go codegen is configured to emit the models file into a
208+
// separate package directory, route that file to its own absolute path.
209+
// This is the only file allowed to live outside of `out`.
210+
var (
211+
modelsFileName string
212+
modelsAbsout string
213+
modelsAbsfile string
214+
)
215+
if sql.Gen.Go != nil && sql.Gen.Go.OutputModelsPath != "" && sql.Gen.Go.ModelsEmitEnabled() {
216+
modelsFileName = sql.Gen.Go.OutputModelsFileName
217+
if modelsFileName == "" {
218+
modelsFileName = "models.go"
219+
}
220+
modelsAbsout = filepath.Join(g.dir, sql.Gen.Go.OutputModelsPath)
221+
modelsAbsfile = filepath.Join(modelsAbsout, modelsFileName)
222+
}
223+
207224
for n, source := range files {
225+
if modelsFileName != "" && n == modelsFileName {
226+
// Models file routed to a separate package directory.
227+
if strings.Contains(modelsAbsfile, "..") {
228+
return fmt.Errorf("invalid file output path: %s", modelsAbsfile)
229+
}
230+
if !strings.HasPrefix(modelsAbsfile, modelsAbsout) {
231+
return fmt.Errorf("invalid file output path: %s", modelsAbsfile)
232+
}
233+
g.output[modelsAbsfile] = source
234+
continue
235+
}
208236
filename := filepath.Join(g.dir, out, n)
209237
// filepath.Join calls filepath.Clean which should remove all "..", but
210238
// double check to make sure

internal/codegen/golang/gen.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ import (
1717
)
1818

1919
type tmplCtx struct {
20-
Q string
21-
Package string
22-
SQLDriver opts.SQLDriver
23-
Enums []Enum
24-
Structs []Struct
25-
GoQueries []Query
26-
SqlcVersion string
20+
Q string
21+
Package string
22+
ModelsPackage string
23+
SQLDriver opts.SQLDriver
24+
Enums []Enum
25+
Structs []Struct
26+
GoQueries []Query
27+
SqlcVersion string
2728

2829
// TODO: Race conditions
2930
SourceName string
@@ -120,13 +121,13 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
120121

121122
enums := buildEnums(req, options)
122123
structs := buildStructs(req, options)
123-
queries, err := buildQueries(req, options, structs)
124+
queries, err := buildQueries(req, options, enums, structs)
124125
if err != nil {
125126
return nil, err
126127
}
127128

128129
if options.OmitUnusedStructs {
129-
enums, structs = filterUnusedStructs(enums, structs, queries)
130+
enums, structs = filterUnusedStructs(enums, structs, queries, options.ModelsTypeQualifier())
130131
}
131132

132133
if err := validate(options, enums, structs, queries); err != nil {
@@ -186,6 +187,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
186187
SQLDriver: parseDriver(options.SqlPackage),
187188
Q: "`",
188189
Package: options.Package,
190+
ModelsPackage: options.ModelsPackage(),
189191
Enums: enums,
190192
Structs: structs,
191193
SqlcVersion: req.SqlcVersion,
@@ -292,8 +294,10 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
292294
if err := execute(dbFileName, "dbFile"); err != nil {
293295
return nil, err
294296
}
295-
if err := execute(modelsFileName, "modelsFile"); err != nil {
296-
return nil, err
297+
if options.ModelsEmitEnabled() {
298+
if err := execute(modelsFileName, "modelsFile"); err != nil {
299+
return nil, err
300+
}
297301
}
298302
if options.EmitInterface {
299303
if err := execute(querierFileName, "interfaceFile"); err != nil {
@@ -367,25 +371,35 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
367371
return nil
368372
}
369373

370-
func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
374+
func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query, qualifier string) ([]Enum, []Struct) {
371375
keepTypes := make(map[string]struct{})
372376

377+
keep := func(t string) {
378+
keepTypes[t] = struct{}{}
379+
// Also store the bare type name so that lookups against
380+
// bare struct/enum names match even when types have been
381+
// qualified with the models package prefix (e.g. "model.User").
382+
if bare := stripQualifier(t, qualifier); bare != t {
383+
keepTypes[bare] = struct{}{}
384+
}
385+
}
386+
373387
for _, query := range queries {
374388
if !query.Arg.isEmpty() {
375-
keepTypes[query.Arg.Type()] = struct{}{}
389+
keep(query.Arg.Type())
376390
if query.Arg.IsStruct() {
377391
for _, field := range query.Arg.Struct.Fields {
378-
keepTypes[field.Type] = struct{}{}
392+
keep(field.Type)
379393
}
380394
}
381395
}
382396
if query.hasRetType() {
383-
keepTypes[query.Ret.Type()] = struct{}{}
397+
keep(query.Ret.Type())
384398
if query.Ret.IsStruct() {
385399
for _, field := range query.Ret.Struct.Fields {
386-
keepTypes[strings.TrimPrefix(field.Type, "[]")] = struct{}{}
400+
keep(strings.TrimPrefix(field.Type, "[]"))
387401
for _, embedField := range field.EmbedFields {
388-
keepTypes[embedField.Type] = struct{}{}
402+
keep(embedField.Type)
389403
}
390404
}
391405
}

internal/codegen/golang/imports.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,17 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
243243
}
244244
}
245245

246+
// Models package import. When models live in a separate Go package and
247+
// any type in this file references a qualified model type, we need to
248+
// import the models package. The qualifier is derived from the
249+
// configured models package name, so the alias matches references like
250+
// `model.User`.
251+
if options.ModelsAreExternal() {
252+
if uses(options.ModelsTypeQualifier()) {
253+
pkg[ImportSpec{Path: options.OutputModelsImport, ID: options.OutputModelsPackage}] = struct{}{}
254+
}
255+
}
256+
246257
return std, pkg
247258
}
248259

internal/codegen/golang/opts/options.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ type Options struct {
3737
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
3838
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
3939
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
40+
OutputModelsPath string `json:"output_models_path,omitempty" yaml:"output_models_path"`
41+
OutputModelsPackage string `json:"output_models_package,omitempty" yaml:"output_models_package"`
42+
OutputModelsImport string `json:"output_models_import,omitempty" yaml:"output_models_import"`
43+
OutputModelsEmit *bool `json:"output_models_emit,omitempty" yaml:"output_models_emit"`
4044
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
4145
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
4246
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
@@ -154,5 +158,74 @@ func ValidateOpts(opts *Options) error {
154158
return fmt.Errorf("invalid options: query parameter limit must not be negative")
155159
}
156160

161+
if err := validateModelsOptions(opts); err != nil {
162+
return err
163+
}
164+
165+
return nil
166+
}
167+
168+
// ModelsEmitEnabled reports whether this codegen block should write the
169+
// models file. Defaults to true when the option is unset.
170+
func (o *Options) ModelsEmitEnabled() bool {
171+
if o.OutputModelsEmit == nil {
172+
return true
173+
}
174+
return *o.OutputModelsEmit
175+
}
176+
177+
// ModelsPackage returns the Go package name to use for model types. When the
178+
// caller has not configured a separate models package, this is the same as
179+
// Package.
180+
func (o *Options) ModelsPackage() string {
181+
if o.OutputModelsPackage != "" {
182+
return o.OutputModelsPackage
183+
}
184+
return o.Package
185+
}
186+
187+
// ModelsAreExternal reports whether model types live in a different Go
188+
// package than the queries package. When true, query files must import the
189+
// models package and reference types as `<pkg>.Type`.
190+
func (o *Options) ModelsAreExternal() bool {
191+
return o.OutputModelsPackage != "" && o.OutputModelsPackage != o.Package
192+
}
193+
194+
// ModelsTypeQualifier returns the prefix to use when referencing a model
195+
// type from a query file (e.g. "model."). Empty string when no qualifier is
196+
// needed.
197+
func (o *Options) ModelsTypeQualifier() string {
198+
if o.ModelsAreExternal() {
199+
return o.OutputModelsPackage + "."
200+
}
201+
return ""
202+
}
203+
204+
func validateModelsOptions(opts *Options) error {
205+
hasAnyModelsOpt := opts.OutputModelsPath != "" ||
206+
opts.OutputModelsPackage != "" ||
207+
opts.OutputModelsImport != "" ||
208+
opts.OutputModelsEmit != nil
209+
210+
if !hasAnyModelsOpt {
211+
return nil
212+
}
213+
214+
if opts.OutputModelsImport == "" {
215+
return fmt.Errorf("invalid options: output_models_import is required when any output_models_* option is set")
216+
}
217+
218+
if opts.ModelsEmitEnabled() && opts.OutputModelsPath == "" {
219+
return fmt.Errorf("invalid options: output_models_path is required when emitting models to a separate package")
220+
}
221+
222+
if opts.OutputModelsPackage == "" {
223+
return fmt.Errorf("invalid options: output_models_package is required when any output_models_* option is set")
224+
}
225+
226+
if !opts.ModelsEmitEnabled() && opts.OutputModelsPackage == opts.Package {
227+
return fmt.Errorf("invalid options: output_models_emit is false but output_models_package matches package; nothing to import")
228+
}
229+
157230
return nil
158231
}

internal/codegen/golang/qualify.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package golang
2+
3+
import "strings"
4+
5+
// stripQualifier removes a leading slice/pointer prefix and the given
6+
// `pkg.` qualifier from a Go type expression. When the qualifier is empty
7+
// or absent from the type, the input is returned unchanged.
8+
func stripQualifier(t, qualifier string) string {
9+
if qualifier == "" {
10+
return t
11+
}
12+
prefix := ""
13+
rest := t
14+
for {
15+
if strings.HasPrefix(rest, "[]") {
16+
prefix += "[]"
17+
rest = rest[2:]
18+
continue
19+
}
20+
if strings.HasPrefix(rest, "*") {
21+
prefix += "*"
22+
rest = rest[1:]
23+
continue
24+
}
25+
break
26+
}
27+
if strings.HasPrefix(rest, qualifier) {
28+
return prefix + rest[len(qualifier):]
29+
}
30+
return t
31+
}
32+
33+
// modelTypeSet is the set of Go type names that live in the models file.
34+
type modelTypeSet map[string]struct{}
35+
36+
// buildModelTypeSet returns the set of type names that are declared in
37+
// models.go for the current codegen invocation.
38+
func buildModelTypeSet(enums []Enum, structs []Struct) modelTypeSet {
39+
set := make(modelTypeSet, len(enums)*4+len(structs))
40+
for _, e := range enums {
41+
set[e.Name] = struct{}{}
42+
set["Null"+e.Name] = struct{}{}
43+
}
44+
for _, s := range structs {
45+
if s.IsModel {
46+
set[s.Name] = struct{}{}
47+
}
48+
}
49+
return set
50+
}
51+
52+
// qualifyType prefixes a Go type expression with `qualifier` when the bare
53+
// type name belongs to `models`. Slice and pointer prefixes are preserved.
54+
// When qualifier is empty (i.e. models live in the queries package), the
55+
// input is returned unchanged.
56+
func qualifyType(t string, models modelTypeSet, qualifier string) string {
57+
if qualifier == "" || t == "" || len(models) == 0 {
58+
return t
59+
}
60+
prefix := ""
61+
rest := t
62+
for {
63+
if strings.HasPrefix(rest, "[]") {
64+
prefix += "[]"
65+
rest = rest[2:]
66+
continue
67+
}
68+
if strings.HasPrefix(rest, "*") {
69+
prefix += "*"
70+
rest = rest[1:]
71+
continue
72+
}
73+
break
74+
}
75+
if _, ok := models[rest]; ok {
76+
return prefix + qualifier + rest
77+
}
78+
return t
79+
}

internal/codegen/golang/query.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ type QueryValue struct {
1818
Typ string
1919
SQLDriver opts.SQLDriver
2020

21+
// ModelQualifier prefixes references to model types when the models file
22+
// lives in a different Go package (e.g. "model."). Empty otherwise.
23+
ModelQualifier string
24+
2125
// Column is kept so late in the generation process around to differentiate
2226
// between mysql slices and pg arrays
2327
Column *plugin.Column
@@ -88,6 +92,14 @@ func (v QueryValue) Type() string {
8892
return v.Typ
8993
}
9094
if v.Struct != nil {
95+
// Model structs (table-derived) live in the models file. When that
96+
// file is generated into a different Go package, references from
97+
// query files must be qualified. Synthetic structs (Params/Row)
98+
// are defined in the same query file as their use, so they stay
99+
// bare.
100+
if v.Struct.IsModel && v.ModelQualifier != "" {
101+
return v.ModelQualifier + v.Struct.Name
102+
}
91103
return v.Struct.Name
92104
}
93105
panic("no type for QueryValue: " + v.Name)

0 commit comments

Comments
 (0)