From 8dcf3bb8fefc042c4b3edd2b0b06b152758207df Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 16:00:08 +0000 Subject: [PATCH] Extract code generation into internal/api package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a small programmatic interface to sqlc's code generation, inspired by esbuild's Build API. The api package exports three names: func Generate(ctx, GenerateOptions) GenerateResult type GenerateOptions struct { Config io.Reader Stderr io.Writer Write bool Diff bool BaseDir string EnableProcessPlugins bool } type GenerateResult struct { Files map[string]string Errors []error } Everything else in the package — parse, codegen, the plugin shim, the result processor — is unexported. Notable design choices: * Config as io.Reader. No Dir/File fields. Callers hand sqlc whatever bytes they want; the CLI reads from disk and library callers can construct configs in memory. * BaseDir does double duty: it's the directory relative paths in the config resolve against and the prefix stripped from file paths in parse errors and diff labels. When empty it defaults to the current working directory. * Process plugins are off by default. EnableProcessPlugins is a single bool whose zero value refuses any process plugin in the config. The CLI sets it from the `processplugins` SQLCDEBUG setting (default 1). * Write/Diff replace separate functions. `sqlc generate` is api.Generate{Write: true}, `sqlc compile` is neither, `sqlc diff` is {Diff: true}. cmd.Diff is gone. CLI: each of genCmd, checkCmd, and diffCmd reads the config bytes, chdirs into the config's directory so relative paths resolve, and calls api.Generate with BaseDir set to that directory. cmd.Vet and cmd.Push keep their own helpers (parse, processQuerySets, codeGenRequest) since they have surface beyond what api covers; both packages skip joining their dir parameter when the path is already absolute, so configs with absolute paths flow through both. Endtoend tests: TestExamples and TestReplay call api.Generate directly. A mutatedConfigBytes helper parses the test's config, optionally applies a mutation (the managed-db context adds servers + sets database.managed), forces version "2" so v1 configs round-trip cleanly, and re-encodes as YAML. When mutated, the bytes are also dropped to a temp file alongside the original so cmd.Vet (which still takes a path) can use it. Per-test environment variables from exec.json are applied via t.Setenv, and cmd.Env is then populated via opts.ExperimentFromEnv — same path the CLI takes. config.AnalyzerDatabase gained MarshalYAML/MarshalJSON so the parsed Config round-trips through yaml.Marshal cleanly — needed for the test helper. --- internal/api/api.go | 11 ++ internal/api/codegen.go | 108 +++++++++++++ internal/api/diff.go | 109 +++++++++++++ internal/api/generate.go | 249 +++++++++++++++++++++++++++++ internal/api/parse.go | 74 +++++++++ internal/api/process.go | 104 ++++++++++++ internal/api/shim.go | 233 +++++++++++++++++++++++++++ internal/cmd/cmd.go | 104 ++++++------ internal/cmd/diff.go | 59 ------- internal/cmd/generate.go | 222 ------------------------- internal/cmd/options.go | 12 +- internal/cmd/process.go | 10 +- internal/cmd/vet.go | 13 +- internal/config/config.go | 24 +++ internal/endtoend/endtoend_test.go | 229 +++++++++++++++++--------- internal/endtoend/vet_test.go | 4 +- 16 files changed, 1140 insertions(+), 425 deletions(-) create mode 100644 internal/api/api.go create mode 100644 internal/api/codegen.go create mode 100644 internal/api/diff.go create mode 100644 internal/api/generate.go create mode 100644 internal/api/parse.go create mode 100644 internal/api/process.go create mode 100644 internal/api/shim.go delete mode 100644 internal/cmd/diff.go diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000000..c1011de572 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,11 @@ +// Package api is intended to be the future public API for sqlc. +// +// The shape of this package is inspired by esbuild's Build API +// (https://pkg.go.dev/github.com/evanw/esbuild/pkg/api#hdr-Build_API): a small +// surface area of options structs and result structs that lets callers drive +// sqlc programmatically without going through the CLI. +// +// Today the package lives under internal/ while the API stabilises. Once the +// surface settles it is expected to graduate to pkg/api so it can be imported +// by external Go programs. +package api diff --git a/internal/api/codegen.go b/internal/api/codegen.go new file mode 100644 index 0000000000..d733c3b2d4 --- /dev/null +++ b/internal/api/codegen.go @@ -0,0 +1,108 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "runtime/trace" + + "google.golang.org/grpc" + + "github.com/sqlc-dev/sqlc/internal/codegen/golang" + genjson "github.com/sqlc-dev/sqlc/internal/codegen/json" + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/config/convert" + "github.com/sqlc-dev/sqlc/internal/ext" + "github.com/sqlc-dev/sqlc/internal/ext/process" + "github.com/sqlc-dev/sqlc/internal/ext/wasm" + "github.com/sqlc-dev/sqlc/internal/plugin" +) + +func findPlugin(conf config.Config, name string) (*config.Plugin, error) { + for _, plug := range conf.Plugins { + if plug.Name == name { + return &plug, nil + } + } + return nil, fmt.Errorf("plugin not found") +} + +func codegen(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) { + defer trace.StartRegion(ctx, "codegen").End() + req := codeGenRequest(result, combo) + var handler grpc.ClientConnInterface + var out string + switch { + case sql.Plugin != nil: + out = sql.Plugin.Out + plug, err := findPlugin(combo.Global, sql.Plugin.Plugin) + if err != nil { + return "", nil, fmt.Errorf("plugin not found: %s", err) + } + + switch { + case plug.Process != nil: + handler = &process.Runner{ + Cmd: plug.Process.Cmd, + Env: plug.Env, + Format: plug.Process.Format, + } + case plug.WASM != nil: + handler = &wasm.Runner{ + URL: plug.WASM.URL, + SHA256: plug.WASM.SHA256, + Env: plug.Env, + } + default: + return "", nil, fmt.Errorf("unsupported plugin type") + } + + opts, err := convert.YAMLtoJSON(sql.Plugin.Options) + if err != nil { + return "", nil, fmt.Errorf("invalid plugin options: %w", err) + } + req.PluginOptions = opts + + global, found := combo.Global.Options[plug.Name] + if found { + opts, err := convert.YAMLtoJSON(global) + if err != nil { + return "", nil, fmt.Errorf("invalid global options: %w", err) + } + req.GlobalOptions = opts + } + + case sql.Gen.Go != nil: + out = combo.Go.Out + handler = ext.HandleFunc(golang.Generate) + opts, err := json.Marshal(sql.Gen.Go) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.PluginOptions = opts + + if combo.Global.Overrides.Go != nil { + opts, err := json.Marshal(combo.Global.Overrides.Go) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.GlobalOptions = opts + } + + case sql.Gen.JSON != nil: + out = combo.JSON.Out + handler = ext.HandleFunc(genjson.Generate) + opts, err := json.Marshal(sql.Gen.JSON) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.PluginOptions = opts + + default: + return "", nil, fmt.Errorf("missing language backend") + } + client := plugin.NewCodegenServiceClient(handler) + resp, err := client.Generate(ctx, req) + return out, resp, err +} diff --git a/internal/api/diff.go b/internal/api/diff.go new file mode 100644 index 0000000000..1fc49511da --- /dev/null +++ b/internal/api/diff.go @@ -0,0 +1,109 @@ +package api + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime/trace" + "sort" + + "github.com/cubicdaiya/gonp" +) + +func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer) error { + defer trace.StartRegion(ctx, "writefiles").End() + for filename, source := range files { + if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil { + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + return err + } + if err := os.WriteFile(filename, []byte(source), 0644); err != nil { + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + return err + } + } + return nil +} + +func diffFiles(ctx context.Context, baseDir string, files map[string]string, stderr io.Writer) error { + defer trace.StartRegion(ctx, "checkfiles").End() + var errored bool + + if baseDir == "" { + baseDir, _ = os.Getwd() + } + + keys := make([]string, 0, len(files)) + for k := range files { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, filename := range keys { + source := files[filename] + if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { + errored = true + continue + } + existing, err := os.ReadFile(filename) + if err != nil { + errored = true + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + continue + } + d := gonp.New(getLines(existing), getLines([]byte(source))) + d.Compose() + uniHunks := filterHunks(d.UnifiedHunks()) + + if len(uniHunks) > 0 { + errored = true + label := filename + if baseDir != "" { + if rel, err := filepath.Rel(baseDir, filename); err == nil { + label = "/" + rel + } + } + fmt.Fprintf(stderr, "--- a%s\n", label) + fmt.Fprintf(stderr, "+++ b%s\n", label) + d.FprintUniHunks(stderr, uniHunks) + } + } + if errored { + return errors.New("diff found") + } + return nil +} + +func getLines(f []byte) []string { + fp := bytes.NewReader(f) + scanner := bufio.NewScanner(fp) + lines := make([]string, 0) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + return lines +} + +func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] { + var out []gonp.UniHunk[T] + for i, uniHunk := range uniHunks { + var changed bool + for _, e := range uniHunk.GetChanges() { + switch e.GetType() { + case gonp.SesDelete: + changed = true + case gonp.SesAdd: + changed = true + } + } + if changed { + out = append(out, uniHunks[i]) + } + } + return out +} diff --git a/internal/api/generate.go b/internal/api/generate.go new file mode 100644 index 0000000000..e5ebe80e2f --- /dev/null +++ b/internal/api/generate.go @@ -0,0 +1,249 @@ +package api + +import ( + "context" + "errors" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" +) + +// GenerateOptions controls a single Generate invocation. Paths declared in the +// configuration are resolved relative to the current working directory, so +// callers wanting a different base directory should either pass absolute +// paths in the config or os.Chdir before calling. +type GenerateOptions struct { + // Config is the sqlc configuration as a YAML or JSON document. Required. + Config io.Reader + + // Stderr receives diagnostic output. If nil, output is discarded. + Stderr io.Writer + + // Write writes the generated files to disk after a successful generate. + // Failures are reported via GenerateResult.Errors. + Write bool + + // Diff compares each generated file against any existing file on disk and + // writes a unified diff for differences to Stderr. If any differences are + // found, an error is appended to GenerateResult.Errors. + Diff bool + + // BaseDir is the directory relative paths in Config are resolved against, + // and the prefix stripped from file paths shown in parse errors and diff + // labels. When empty, BaseDir defaults to the current working directory. + BaseDir string + + // EnableProcessPlugins controls whether the configuration may invoke + // process-based plugins. When false (the zero value), Generate fails + // before parsing or codegen runs if the configuration declares any + // process plugin. Process plugins execute arbitrary local commands, so + // callers must opt in explicitly. The sqlc CLI populates this from + // SQLCDEBUG: it defaults to true, and SQLCDEBUG=processplugins=0 turns + // it off. + EnableProcessPlugins bool +} + +// GenerateResult is the outcome of a Generate call. +type GenerateResult struct { + // Files maps absolute output paths to generated file contents. + Files map[string]string + + // Errors collects any errors encountered. A non-empty Errors slice means + // generation did not fully succeed. + Errors []error +} + +// Generate parses the sqlc configuration referenced by opts and runs every +// configured codegen target. +func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { + stderr := opts.Stderr + if stderr == nil { + stderr = io.Discard + } + + res := GenerateResult{Files: map[string]string{}} + + if opts.Config == nil { + err := errors.New("GenerateOptions.Config is required") + fmt.Fprintln(stderr, err) + res.Errors = append(res.Errors, err) + return res + } + + conf, err := config.ParseConfig(opts.Config) + if err != nil { + switch err { + case config.ErrMissingVersion: + fmt.Fprint(stderr, errMessageNoVersion) + case config.ErrUnknownVersion: + fmt.Fprint(stderr, errMessageUnknownVersion) + case config.ErrNoPackages: + fmt.Fprint(stderr, errMessageNoPackages) + } + fmt.Fprintf(stderr, "error parsing config: %s\n", err) + res.Errors = append(res.Errors, err) + return res + } + + if err := config.Validate(&conf); err != nil { + fmt.Fprintf(stderr, "error validating config: %s\n", err) + res.Errors = append(res.Errors, err) + return res + } + + if !opts.EnableProcessPlugins { + for _, plug := range conf.Plugins { + if plug.Process != nil { + err := fmt.Errorf("process plugin %q declared but EnableProcessPlugins is false", plug.Name) + fmt.Fprintf(stderr, "error validating config: %s\n", err) + res.Errors = append(res.Errors, err) + return res + } + } + } + + g := &generator{output: map[string]string{}, baseDir: opts.BaseDir} + + if err := processQuerySets(ctx, g, &conf, opts.BaseDir, stderr); err != nil { + res.Errors = append(res.Errors, err) + return res + } + + res.Files = g.output + + if opts.Write { + if err := writeFiles(ctx, res.Files, stderr); err != nil { + res.Errors = append(res.Errors, err) + } + } + + if opts.Diff { + if err := diffFiles(ctx, opts.BaseDir, res.Files, stderr); err != nil { + res.Errors = append(res.Errors, err) + } + } + + return res +} + +const errMessageNoVersion = `The configuration must have a version number. +Set the version to 1 or 2 at the top of the config: + +{ + "version": "1" + ... +} +` + +const errMessageUnknownVersion = `The configuration has an invalid version number. +The supported version can only be "1" or "2". +` + +const errMessageNoPackages = `No packages are configured` + +type generator struct { + m sync.Mutex + baseDir string + output map[string]string +} + +func (g *generator) Pairs(ctx context.Context, conf *config.Config) []outputPair { + var pairs []outputPair + for _, sql := range conf.SQL { + if sql.Gen.Go != nil { + pairs = append(pairs, outputPair{ + SQL: sql, + Gen: config.SQLGen{Go: sql.Gen.Go}, + }) + } + if sql.Gen.JSON != nil { + pairs = append(pairs, outputPair{ + SQL: sql, + Gen: config.SQLGen{JSON: sql.Gen.JSON}, + }) + } + for i := range sql.Codegen { + pairs = append(pairs, outputPair{ + SQL: sql, + Plugin: &sql.Codegen[i], + }) + } + } + return pairs +} + +func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) error { + out, resp, err := codegen(ctx, combo, sql, result) + if err != nil { + return err + } + files := map[string]string{} + for _, file := range resp.Files { + files[file.Name] = string(file.Contents) + } + g.m.Lock() + defer g.m.Unlock() + + absout := resolvePath(g.baseDir, out) + + // When the Go codegen is configured to emit the models file into a + // separate package directory, route that file to its own absolute path. + // This is the only file allowed to live outside of `out`. + var ( + modelsFileName string + modelsAbsout string + modelsAbsfile string + ) + if sql.Gen.Go != nil && sql.Gen.Go.OutputModelsPath != "" && sql.Gen.Go.ModelsEmitEnabled() { + modelsFileName = sql.Gen.Go.OutputModelsFileName + if modelsFileName == "" { + modelsFileName = "models.go" + } + modelsAbsout = resolvePath(g.baseDir, sql.Gen.Go.OutputModelsPath) + modelsAbsfile = filepath.Join(modelsAbsout, modelsFileName) + } + + for n, source := range files { + if modelsFileName != "" && n == modelsFileName { + // Models file routed to a separate package directory. + if strings.Contains(modelsAbsfile, "..") { + return fmt.Errorf("invalid file output path: %s", modelsAbsfile) + } + if !strings.HasPrefix(modelsAbsfile, modelsAbsout) { + return fmt.Errorf("invalid file output path: %s", modelsAbsfile) + } + g.output[modelsAbsfile] = source + continue + } + filename := resolvePath(g.baseDir, filepath.Join(out, n)) + if strings.Contains(filename, "..") { + return fmt.Errorf("invalid file output path: %s", filename) + } + if !strings.HasPrefix(filename, absout) { + return fmt.Errorf("invalid file output path: %s", filename) + } + g.output[filename] = source + } + return nil +} + +// resolvePath joins p with baseDir when p is relative. baseDir is treated as +// the current working directory when empty. +func resolvePath(baseDir, p string) string { + if filepath.IsAbs(p) { + return p + } + if baseDir == "" { + abs, err := filepath.Abs(p) + if err == nil { + return abs + } + return p + } + return filepath.Join(baseDir, p) +} diff --git a/internal/api/parse.go b/internal/api/parse.go new file mode 100644 index 0000000000..960cf5e60a --- /dev/null +++ b/internal/api/parse.go @@ -0,0 +1,74 @@ +package api + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "runtime/trace" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/multierr" + "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/sqlcdebug" +) + +var debugDumpCatalog = sqlcdebug.New("dumpcatalog") + +func printFileErr(stderr io.Writer, baseDir string, fileErr *multierr.FileError) { + if baseDir == "" { + if wd, err := os.Getwd(); err == nil { + baseDir = wd + } + } + filename := fileErr.Filename + if baseDir != "" { + if rel, err := filepath.Rel(baseDir, fileErr.Filename); err == nil { + filename = rel + } + } + fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) +} + +func parse(ctx context.Context, name, baseDir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { + defer trace.StartRegion(ctx, "parse").End() + c, err := compiler.NewCompiler(sql, combo, parserOpts) + defer func() { + if c != nil { + c.Close(ctx) + } + }() + if err != nil { + fmt.Fprintf(stderr, "error creating compiler: %s\n", err) + return nil, true + } + if err := c.ParseCatalog(sql.Schema); err != nil { + fmt.Fprintf(stderr, "# package %s\n", name) + if parserErr, ok := err.(*multierr.Error); ok { + for _, fileErr := range parserErr.Errs() { + printFileErr(stderr, baseDir, fileErr) + } + } else { + fmt.Fprintf(stderr, "error parsing schema: %s\n", err) + } + return nil, true + } + if debugDumpCatalog.Value() == "1" { + debug.Dump(c.Catalog()) + } + if err := c.ParseQueries(sql.Queries, parserOpts); err != nil { + fmt.Fprintf(stderr, "# package %s\n", name) + if parserErr, ok := err.(*multierr.Error); ok { + for _, fileErr := range parserErr.Errs() { + printFileErr(stderr, baseDir, fileErr) + } + } else { + fmt.Fprintf(stderr, "error parsing queries: %s\n", err) + } + return nil, true + } + return c.Result(), false +} diff --git a/internal/api/process.go b/internal/api/process.go new file mode 100644 index 0000000000..86826ec509 --- /dev/null +++ b/internal/api/process.go @@ -0,0 +1,104 @@ +package api + +import ( + "bytes" + "context" + "fmt" + "io" + "runtime" + "runtime/trace" + + "golang.org/x/sync/errgroup" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/opts" +) + +type outputPair struct { + Gen config.SQLGen + Plugin *config.Codegen + + config.SQL +} + +type resultProcessor interface { + Pairs(context.Context, *config.Config) []outputPair + ProcessResult(context.Context, config.CombinedSettings, outputPair, *compiler.Result) error +} + +func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, baseDir string, stderr io.Writer) error { + errored := false + + pairs := rp.Pairs(ctx, conf) + grp, gctx := errgroup.WithContext(ctx) + grp.SetLimit(runtime.GOMAXPROCS(0)) + + stderrs := make([]bytes.Buffer, len(pairs)) + + for i, pair := range pairs { + sql := pair + errout := &stderrs[i] + + grp.Go(func() error { + combo := config.Combine(*conf, sql.SQL) + if sql.Plugin != nil { + combo.Codegen = *sql.Plugin + } + + absSchema := make([]string, 0, len(sql.Schema)) + for _, s := range sql.Schema { + absSchema = append(absSchema, resolvePath(baseDir, s)) + } + sql.Schema = absSchema + + absQueries := make([]string, 0, len(sql.Queries)) + for _, q := range sql.Queries { + absQueries = append(absQueries, resolvePath(baseDir, q)) + } + sql.Queries = absQueries + + var name, lang string + parseOpts := opts.Parser{} + + switch { + case sql.Gen.Go != nil: + name = combo.Go.Package + lang = "golang" + + case sql.Plugin != nil: + lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin) + name = sql.Plugin.Plugin + } + + packageRegion := trace.StartRegion(gctx, "package") + trace.Logf(gctx, "", "name=%s plugin=%s", name, lang) + + result, failed := parse(gctx, name, baseDir, sql.SQL, combo, parseOpts, errout) + if failed { + packageRegion.End() + errored = true + return nil + } + if err := rp.ProcessResult(gctx, combo, sql, result); err != nil { + fmt.Fprintf(errout, "# package %s\n", name) + fmt.Fprintf(errout, "error generating code: %s\n", err) + errored = true + } + packageRegion.End() + return nil + }) + } + if err := grp.Wait(); err != nil { + return err + } + if errored { + for i := range stderrs { + if _, err := io.Copy(stderr, &stderrs[i]); err != nil { + return err + } + } + return fmt.Errorf("errored") + } + return nil +} diff --git a/internal/api/shim.go b/internal/api/shim.go new file mode 100644 index 0000000000..9638d1f126 --- /dev/null +++ b/internal/api/shim.go @@ -0,0 +1,233 @@ +package api + +import ( + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/config/convert" + "github.com/sqlc-dev/sqlc/internal/info" + "github.com/sqlc-dev/sqlc/internal/plugin" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings { + return &plugin.Settings{ + Version: cs.Global.Version, + Engine: string(cs.Package.Engine), + Schema: []string(cs.Package.Schema), + Queries: []string(cs.Package.Queries), + Codegen: pluginCodegen(cs, cs.Codegen), + } +} + +func pluginCodegen(cs config.CombinedSettings, s config.Codegen) *plugin.Codegen { + opts, err := convert.YAMLtoJSON(s.Options) + if err != nil { + panic(err) + } + cg := &plugin.Codegen{ + Out: s.Out, + Plugin: s.Plugin, + Options: opts, + } + for _, p := range cs.Global.Plugins { + if p.Name == s.Plugin { + cg.Env = p.Env + cg.Process = pluginProcess(p) + cg.Wasm = pluginWASM(p) + return cg + } + } + return cg +} + +func pluginProcess(p config.Plugin) *plugin.Codegen_Process { + if p.Process != nil { + return &plugin.Codegen_Process{ + Cmd: p.Process.Cmd, + } + } + return nil +} + +func pluginWASM(p config.Plugin) *plugin.Codegen_WASM { + if p.WASM != nil { + return &plugin.Codegen_WASM{ + Url: p.WASM.URL, + Sha256: p.WASM.SHA256, + } + } + return nil +} + +func pluginCatalog(c *catalog.Catalog) *plugin.Catalog { + var schemas []*plugin.Schema + for _, s := range c.Schemas { + var enums []*plugin.Enum + var cts []*plugin.CompositeType + for _, typ := range s.Types { + switch typ := typ.(type) { + case *catalog.Enum: + enums = append(enums, &plugin.Enum{ + Name: typ.Name, + Comment: typ.Comment, + Vals: typ.Vals, + }) + case *catalog.CompositeType: + cts = append(cts, &plugin.CompositeType{ + Name: typ.Name, + Comment: typ.Comment, + }) + } + } + var tables []*plugin.Table + for _, t := range s.Tables { + var columns []*plugin.Column + for _, c := range t.Columns { + l := -1 + if c.Length != nil { + l = *c.Length + } + columns = append(columns, &plugin.Column{ + Name: c.Name, + Type: &plugin.Identifier{ + Catalog: c.Type.Catalog, + Schema: c.Type.Schema, + Name: c.Type.Name, + }, + Comment: c.Comment, + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + Table: &plugin.Identifier{ + Catalog: t.Rel.Catalog, + Schema: t.Rel.Schema, + Name: t.Rel.Name, + }, + }) + } + tables = append(tables, &plugin.Table{ + Rel: &plugin.Identifier{ + Catalog: t.Rel.Catalog, + Schema: t.Rel.Schema, + Name: t.Rel.Name, + }, + Columns: columns, + Comment: t.Comment, + }) + } + schemas = append(schemas, &plugin.Schema{ + Comment: s.Comment, + Name: s.Name, + Tables: tables, + Enums: enums, + CompositeTypes: cts, + }) + } + return &plugin.Catalog{ + Name: c.Name, + DefaultSchema: c.DefaultSchema, + Comment: c.Comment, + Schemas: schemas, + } +} + +func pluginQueries(r *compiler.Result) []*plugin.Query { + var out []*plugin.Query + for _, q := range r.Queries { + var params []*plugin.Parameter + var columns []*plugin.Column + for _, c := range q.Columns { + columns = append(columns, pluginQueryColumn(c)) + } + for _, p := range q.Params { + params = append(params, pluginQueryParam(p)) + } + var iit *plugin.Identifier + if q.InsertIntoTable != nil { + iit = &plugin.Identifier{ + Catalog: q.InsertIntoTable.Catalog, + Schema: q.InsertIntoTable.Schema, + Name: q.InsertIntoTable.Name, + } + } + out = append(out, &plugin.Query{ + Name: q.Metadata.Name, + Cmd: q.Metadata.Cmd, + Text: q.SQL, + Comments: q.Metadata.Comments, + Columns: columns, + Params: params, + Filename: q.Metadata.Filename, + InsertIntoTable: iit, + }) + } + return out +} + +func pluginQueryColumn(c *compiler.Column) *plugin.Column { + l := -1 + if c.Length != nil { + l = *c.Length + } + out := &plugin.Column{ + Name: c.Name, + OriginalName: c.OriginalName, + Comment: c.Comment, + NotNull: c.NotNull, + Unsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + IsNamedParam: c.IsNamedParam, + IsFuncCall: c.IsFuncCall, + IsSqlcSlice: c.IsSqlcSlice, + } + + if c.Type != nil { + out.Type = &plugin.Identifier{ + Catalog: c.Type.Catalog, + Schema: c.Type.Schema, + Name: c.Type.Name, + } + } else { + out.Type = &plugin.Identifier{ + Name: c.DataType, + } + } + + if c.Table != nil { + out.Table = &plugin.Identifier{ + Catalog: c.Table.Catalog, + Schema: c.Table.Schema, + Name: c.Table.Name, + } + } + + if c.EmbedTable != nil { + out.EmbedTable = &plugin.Identifier{ + Catalog: c.EmbedTable.Catalog, + Schema: c.EmbedTable.Schema, + Name: c.EmbedTable.Name, + } + } + + return out +} + +func pluginQueryParam(p compiler.Parameter) *plugin.Parameter { + return &plugin.Parameter{ + Number: int32(p.Number), + Column: pluginQueryColumn(p.Column), + } +} + +func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.GenerateRequest { + return &plugin.GenerateRequest{ + Settings: pluginSettings(r, settings), + Catalog: pluginCatalog(r.Catalog), + Queries: pluginQueries(r), + SqlcVersion: info.Version, + } +} diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 4079b3c1d3..4fd1b6840a 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -1,7 +1,6 @@ package cmd import ( - "bufio" "bytes" "context" "errors" @@ -12,11 +11,11 @@ import ( "path/filepath" "runtime/trace" - "github.com/cubicdaiya/gonp" "github.com/spf13/cobra" "github.com/spf13/pflag" "gopkg.in/yaml.v3" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/info" "github.com/sqlc-dev/sqlc/internal/opts" @@ -184,6 +183,33 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) { } } +// loadConfig reads the sqlc config into memory and chdirs the process to its +// directory so relative paths in the config resolve correctly. Returns the +// config bytes and the absolute config directory (for +// api.GenerateOptions.BaseDir). +func loadConfig(stderr io.Writer, dir, name string) ([]byte, string) { + configPath, _, err := readConfig(stderr, dir, name) + if err != nil { + os.Exit(1) + } + configPath, err = filepath.Abs(configPath) + if err != nil { + fmt.Fprintf(stderr, "error resolving config path: %s\n", err) + os.Exit(1) + } + data, err := os.ReadFile(configPath) + if err != nil { + fmt.Fprintf(stderr, "error reading %s: %s\n", configPath, err) + os.Exit(1) + } + configDir := filepath.Dir(configPath) + if err := os.Chdir(configDir); err != nil { + fmt.Fprintf(stderr, "error changing directory: %s\n", err) + os.Exit(1) + } + return data, configDir +} + var genCmd = &cobra.Command{ Use: "generate", Short: "Generate source code from SQL", @@ -191,21 +217,17 @@ var genCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - output, err := Generate(cmd.Context(), dir, name, &Options{ - Env: ParseEnv(cmd), - Stderr: stderr, + data, baseDir := loadConfig(stderr, dir, name) + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Config: bytes.NewReader(data), + Stderr: stderr, + Write: true, + BaseDir: baseDir, + EnableProcessPlugins: debugProcessPlugins.Value() != "0", }) - if err != nil { + if len(res.Errors) > 0 { os.Exit(1) } - defer trace.StartRegion(cmd.Context(), "writefiles").End() - for filename, source := range output { - os.MkdirAll(filepath.Dir(filename), 0755) - if err := os.WriteFile(filename, []byte(source), 0644); err != nil { - fmt.Fprintf(stderr, "%s: %s\n", filename, err) - return err - } - } return nil }, } @@ -217,46 +239,20 @@ var checkCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "compile").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - _, err := Generate(cmd.Context(), dir, name, &Options{ - Env: ParseEnv(cmd), - Stderr: stderr, + data, baseDir := loadConfig(stderr, dir, name) + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Config: bytes.NewReader(data), + Stderr: stderr, + BaseDir: baseDir, + EnableProcessPlugins: debugProcessPlugins.Value() != "0", }) - if err != nil { + if len(res.Errors) > 0 { os.Exit(1) } return nil }, } -func getLines(f []byte) []string { - fp := bytes.NewReader(f) - scanner := bufio.NewScanner(fp) - lines := make([]string, 0) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - return lines -} - -func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] { - var out []gonp.UniHunk[T] - for i, uniHunk := range uniHunks { - var changed bool - for _, e := range uniHunk.GetChanges() { - switch e.GetType() { - case gonp.SesDelete: - changed = true - case gonp.SesAdd: - changed = true - } - } - if changed { - out = append(out, uniHunks[i]) - } - } - return out -} - var diffCmd = &cobra.Command{ Use: "diff", Short: "Compare the generated files to the existing files", @@ -264,11 +260,15 @@ var diffCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "diff").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - opts := &Options{ - Env: ParseEnv(cmd), - Stderr: stderr, - } - if err := Diff(cmd.Context(), dir, name, opts); err != nil { + data, baseDir := loadConfig(stderr, dir, name) + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Config: bytes.NewReader(data), + Stderr: stderr, + Diff: true, + BaseDir: baseDir, + EnableProcessPlugins: debugProcessPlugins.Value() != "0", + }) + if len(res.Errors) > 0 { os.Exit(1) } return nil diff --git a/internal/cmd/diff.go b/internal/cmd/diff.go deleted file mode 100644 index 8998971a37..0000000000 --- a/internal/cmd/diff.go +++ /dev/null @@ -1,59 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - "os" - "runtime/trace" - "sort" - "strings" - - "github.com/cubicdaiya/gonp" -) - -func Diff(ctx context.Context, dir, name string, opts *Options) error { - stderr := opts.Stderr - output, err := Generate(ctx, dir, name, opts) - if err != nil { - return err - } - defer trace.StartRegion(ctx, "checkfiles").End() - var errored bool - - keys := make([]string, 0, len(output)) - for k, _ := range output { - kk := k - keys = append(keys, kk) - } - sort.Strings(keys) - - for _, filename := range keys { - source := output[filename] - if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { - errored = true - // stdout message - continue - } - existing, err := os.ReadFile(filename) - if err != nil { - errored = true - fmt.Fprintf(stderr, "%s: %s\n", filename, err) - continue - } - diff := gonp.New(getLines(existing), getLines([]byte(source))) - diff.Compose() - uniHunks := filterHunks(diff.UnifiedHunks()) - - if len(uniHunks) > 0 { - errored = true - fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir)) - fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir)) - diff.FprintUniHunks(stderr, uniHunks) - } - } - if errored { - return errors.New("diff found") - } - return nil -} diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 6d2a0297b6..078b029e04 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -2,30 +2,18 @@ package cmd import ( "context" - "encoding/json" "errors" "fmt" "io" "os" "path/filepath" "runtime/trace" - "strings" - "sync" - "google.golang.org/grpc" - - "github.com/sqlc-dev/sqlc/internal/codegen/golang" - genjson "github.com/sqlc-dev/sqlc/internal/codegen/json" "github.com/sqlc-dev/sqlc/internal/compiler" "github.com/sqlc-dev/sqlc/internal/config" - "github.com/sqlc-dev/sqlc/internal/config/convert" "github.com/sqlc-dev/sqlc/internal/debug" - "github.com/sqlc-dev/sqlc/internal/ext" - "github.com/sqlc-dev/sqlc/internal/ext/process" - "github.com/sqlc-dev/sqlc/internal/ext/wasm" "github.com/sqlc-dev/sqlc/internal/multierr" "github.com/sqlc-dev/sqlc/internal/opts" - "github.com/sqlc-dev/sqlc/internal/plugin" "github.com/sqlc-dev/sqlc/internal/sqlcdebug" ) @@ -54,15 +42,6 @@ func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) { fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) } -func findPlugin(conf config.Config, name string) (*config.Plugin, error) { - for _, plug := range conf.Plugins { - if plug.Name == name { - return &plug, nil - } - } - return nil, fmt.Errorf("plugin not found") -} - func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, error) { configPath := "" if filename != "" { @@ -130,128 +109,6 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, return configPath, &conf, nil } -func Generate(ctx context.Context, dir, filename string, o *Options) (map[string]string, error) { - e := o.Env - stderr := o.Stderr - - configPath, conf, err := o.ReadConfig(dir, filename) - if err != nil { - return nil, err - } - - base := filepath.Base(configPath) - if err := config.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - return nil, err - } - - if err := e.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - return nil, err - } - - g := &generator{ - dir: dir, - output: map[string]string{}, - } - - if err := processQuerySets(ctx, g, conf, dir, o); err != nil { - return nil, err - } - - return g.output, nil -} - -type generator struct { - m sync.Mutex - dir string - output map[string]string -} - -func (g *generator) Pairs(ctx context.Context, conf *config.Config) []OutputPair { - var pairs []OutputPair - for _, sql := range conf.SQL { - if sql.Gen.Go != nil { - pairs = append(pairs, OutputPair{ - SQL: sql, - Gen: config.SQLGen{Go: sql.Gen.Go}, - }) - } - if sql.Gen.JSON != nil { - pairs = append(pairs, OutputPair{ - SQL: sql, - Gen: config.SQLGen{JSON: sql.Gen.JSON}, - }) - } - for i := range sql.Codegen { - pairs = append(pairs, OutputPair{ - SQL: sql, - Plugin: &sql.Codegen[i], - }) - } - } - return pairs -} - -func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) error { - out, resp, err := codegen(ctx, combo, sql, result) - if err != nil { - return err - } - files := map[string]string{} - for _, file := range resp.Files { - files[file.Name] = string(file.Contents) - } - g.m.Lock() - - // out is specified by the user, not a plugin - absout := filepath.Join(g.dir, out) - - // When the Go codegen is configured to emit the models file into a - // separate package directory, route that file to its own absolute path. - // This is the only file allowed to live outside of `out`. - var ( - modelsFileName string - modelsAbsout string - modelsAbsfile string - ) - if sql.Gen.Go != nil && sql.Gen.Go.OutputModelsPath != "" && sql.Gen.Go.ModelsEmitEnabled() { - modelsFileName = sql.Gen.Go.OutputModelsFileName - if modelsFileName == "" { - modelsFileName = "models.go" - } - modelsAbsout = filepath.Join(g.dir, sql.Gen.Go.OutputModelsPath) - modelsAbsfile = filepath.Join(modelsAbsout, modelsFileName) - } - - for n, source := range files { - if modelsFileName != "" && n == modelsFileName { - // Models file routed to a separate package directory. - if strings.Contains(modelsAbsfile, "..") { - return fmt.Errorf("invalid file output path: %s", modelsAbsfile) - } - if !strings.HasPrefix(modelsAbsfile, modelsAbsout) { - return fmt.Errorf("invalid file output path: %s", modelsAbsfile) - } - g.output[modelsAbsfile] = source - continue - } - filename := filepath.Join(g.dir, out, n) - // filepath.Join calls filepath.Clean which should remove all "..", but - // double check to make sure - if strings.Contains(filename, "..") { - return fmt.Errorf("invalid file output path: %s", filename) - } - // The output file must be contained inside the output directory - if !strings.HasPrefix(filename, absout) { - return fmt.Errorf("invalid file output path: %s", filename) - } - g.output[filename] = source - } - g.m.Unlock() - return nil -} - func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() c, err := compiler.NewCompiler(sql, combo, parserOpts) @@ -291,82 +148,3 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C } return c.Result(), false } - -func codegen(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) { - defer trace.StartRegion(ctx, "codegen").End() - req := codeGenRequest(result, combo) - var handler grpc.ClientConnInterface - var out string - switch { - case sql.Plugin != nil: - out = sql.Plugin.Out - plug, err := findPlugin(combo.Global, sql.Plugin.Plugin) - if err != nil { - return "", nil, fmt.Errorf("plugin not found: %s", err) - } - - switch { - case plug.Process != nil: - handler = &process.Runner{ - Cmd: plug.Process.Cmd, - Env: plug.Env, - Format: plug.Process.Format, - } - case plug.WASM != nil: - handler = &wasm.Runner{ - URL: plug.WASM.URL, - SHA256: plug.WASM.SHA256, - Env: plug.Env, - } - default: - return "", nil, fmt.Errorf("unsupported plugin type") - } - - opts, err := convert.YAMLtoJSON(sql.Plugin.Options) - if err != nil { - return "", nil, fmt.Errorf("invalid plugin options: %w", err) - } - req.PluginOptions = opts - - global, found := combo.Global.Options[plug.Name] - if found { - opts, err := convert.YAMLtoJSON(global) - if err != nil { - return "", nil, fmt.Errorf("invalid global options: %w", err) - } - req.GlobalOptions = opts - } - - case sql.Gen.Go != nil: - out = combo.Go.Out - handler = ext.HandleFunc(golang.Generate) - opts, err := json.Marshal(sql.Gen.Go) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.PluginOptions = opts - - if combo.Global.Overrides.Go != nil { - opts, err := json.Marshal(combo.Global.Overrides.Go) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.GlobalOptions = opts - } - - case sql.Gen.JSON != nil: - out = combo.JSON.Out - handler = ext.HandleFunc(genjson.Generate) - opts, err := json.Marshal(sql.Gen.JSON) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.PluginOptions = opts - - default: - return "", nil, fmt.Errorf("missing language backend") - } - client := plugin.NewCodegenServiceClient(handler) - resp, err := client.Generate(ctx, req) - return out, resp, err -} diff --git a/internal/cmd/options.go b/internal/cmd/options.go index 02d3614f4e..7a1fab33f0 100644 --- a/internal/cmd/options.go +++ b/internal/cmd/options.go @@ -12,18 +12,8 @@ type Options struct { // TODO: Move these to a command-specific struct Tags []string Against string - - // Testing only - MutateConfig func(*config.Config) } func (o *Options) ReadConfig(dir, filename string) (string, *config.Config, error) { - path, conf, err := readConfig(o.Stderr, dir, filename) - if err != nil { - return path, conf, err - } - if o.MutateConfig != nil { - o.MutateConfig(conf) - } - return path, conf, nil + return readConfig(o.Stderr, dir, filename) } diff --git a/internal/cmd/process.go b/internal/cmd/process.go index e90450cf5f..363d5683c8 100644 --- a/internal/cmd/process.go +++ b/internal/cmd/process.go @@ -73,15 +73,21 @@ func processQuerySets(ctx context.Context, rp ResultProcessor, conf *config.Conf } // TODO: This feels like a hack that will bite us later + joinDir := func(p string) string { + if filepath.IsAbs(p) { + return p + } + return filepath.Join(dir, p) + } joined := make([]string, 0, len(sql.Schema)) for _, s := range sql.Schema { - joined = append(joined, filepath.Join(dir, s)) + joined = append(joined, joinDir(s)) } sql.Schema = joined joined = make([]string, 0, len(sql.Queries)) for _, q := range sql.Queries { - joined = append(joined, filepath.Join(dir, q)) + joined = append(joined, joinDir(q)) } sql.Queries = joined diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index 38a808fdb1..7fc489a494 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -474,15 +474,22 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { combo := config.Combine(*c.Conf, s) // TODO: This feels like a hack that will bite us later + // TODO: This feels like a hack that will bite us later + joinDir := func(p string) string { + if filepath.IsAbs(p) { + return p + } + return filepath.Join(c.Dir, p) + } joined := make([]string, 0, len(s.Schema)) - for _, s := range s.Schema { - joined = append(joined, filepath.Join(c.Dir, s)) + for _, p := range s.Schema { + joined = append(joined, joinDir(p)) } s.Schema = joined joined = make([]string, 0, len(s.Queries)) for _, q := range s.Queries { - joined = append(joined, filepath.Join(c.Dir, q)) + joined = append(joined, joinDir(q)) } s.Queries = joined diff --git a/internal/config/config.go b/internal/config/config.go index d3e610ef05..19bcda754a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -189,6 +189,30 @@ func (a *AnalyzerDatabase) UnmarshalYAML(unmarshal func(interface{}) error) erro return errors.New("analyzer.database must be true, false, or \"only\"") } +// MarshalJSON serialises the value back to its source form so configs +// produced via UnmarshalJSON round-trip through json.Marshal. +func (a AnalyzerDatabase) MarshalJSON() ([]byte, error) { + if a.isOnly { + return json.Marshal("only") + } + if a.value == nil { + return json.Marshal(true) + } + return json.Marshal(*a.value) +} + +// MarshalYAML serialises the value back to its source form so configs +// produced via UnmarshalYAML round-trip through yaml.Marshal. +func (a AnalyzerDatabase) MarshalYAML() (interface{}, error) { + if a.isOnly { + return "only", nil + } + if a.value == nil { + return true, nil + } + return *a.value, nil +} + type Analyzer struct { Database AnalyzerDatabase `json:"database" yaml:"database"` } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index f8bb5a6e0f..1b1420b2ff 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -3,6 +3,8 @@ package main import ( "bytes" "context" + "fmt" + "io" "os" osexec "os/exec" "path/filepath" @@ -13,30 +15,16 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "gopkg.in/yaml.v3" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/cmd" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/opts" - "github.com/sqlc-dev/sqlc/internal/sqlcdebug" "github.com/sqlc-dev/sqlc/internal/sqltest/docker" "github.com/sqlc-dev/sqlc/internal/sqltest/native" ) -// withSQLCDEBUG installs the given SQLCDEBUG-formatted string for the -// duration of the test and restores the empty default afterwards. -// -// Callers in TestReplay are sequential, so this does not need a mutex: -// Go's test scheduler does not run TestReplay concurrently with the -// parallel top-level tests in this package. -func withSQLCDEBUG(t *testing.T, raw string) func() { - t.Helper() - if raw == "" { - return func() {} - } - sqlcdebug.Update(raw) - return func() { sqlcdebug.Update("") } -} - func lineEndings() cmp.Option { return cmp.Transformer("LineEndings", func(in string) string { // Replace Windows new lines with Unix newlines @@ -74,15 +62,15 @@ func TestExamples(t *testing.T) { t.Parallel() path := filepath.Join(examples, tc) var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, - Stderr: &stderr, - } - output, err := cmd.Generate(ctx, path, "", opts) - if err != nil { + res := api.Generate(ctx, api.GenerateOptions{ + Config: openConfigReader(t, path), + Stderr: &stderr, + BaseDir: path, + }) + if len(res.Errors) > 0 { t.Fatalf("sqlc generate failed: %s", stderr.String()) } - cmpDirectory(t, path, output) + cmpDirectory(t, path, res.Files) }) } } @@ -104,27 +92,116 @@ func BenchmarkExamples(b *testing.B) { tc := replay.Name() b.Run(tc, func(b *testing.B) { path := filepath.Join(examples, tc) + cfg := openConfigBytes(b, path) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, - Stderr: &stderr, - } - cmd.Generate(ctx, path, "", opts) + api.Generate(ctx, api.GenerateOptions{ + Config: bytes.NewReader(cfg), + Stderr: &stderr, + BaseDir: path, + }) } }) } } +// openConfigReader reads the sqlc config in dir, rewrites every relative +// schema/queries/output path to an absolute one (so api.Generate doesn't have +// to know the config's directory), and returns the result as an io.Reader. +func openConfigReader(t testing.TB, dir string) io.Reader { + return bytes.NewReader(openConfigBytes(t, dir)) +} + +func openConfigBytes(t testing.TB, dir string) []byte { + t.Helper() + data, _ := mutatedConfigBytes(t, dir, nil) + return data +} + +// mutatedConfigBytes parses the sqlc config in dir and re-encodes it as YAML. +// When mutate is non-nil, it is applied to the in-memory Config first and the +// encoded bytes are also written to a temp file alongside the original +// (cleaned up at test end). The filename is returned so callers like cmd.Vet +// that still take a config-file path can use it. Parsing v1 configs converts +// them to a v2-shaped Config; we force version "2" so the result can be +// parsed back by api.Generate. +func mutatedConfigBytes(t testing.TB, dir string, mutate func(*config.Config)) ([]byte, string) { + t.Helper() + original, conf, err := readSqlcConfig(dir) + if err != nil { + t.Fatalf("read sqlc config from %s: %s", dir, err) + } + + conf.Version = "2" + if mutate != nil { + mutate(conf) + } + + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + if err := enc.Encode(conf); err != nil { + t.Fatalf("encode config: %s", err) + } + if err := enc.Close(); err != nil { + t.Fatalf("close yaml encoder: %s", err) + } + data := buf.Bytes() + + if mutate == nil { + return data, "" + } + + f, err := os.CreateTemp(dir, "sqlc.test-*"+filepath.Ext(original)) + if err != nil { + t.Fatalf("create temp config in %s: %s", dir, err) + } + t.Cleanup(func() { os.Remove(f.Name()) }) + if _, err := f.Write(data); err != nil { + f.Close() + t.Fatalf("write temp config %s: %s", f.Name(), err) + } + if err := f.Close(); err != nil { + t.Fatalf("close temp config %s: %s", f.Name(), err) + } + return data, filepath.Base(f.Name()) +} + +func readSqlcConfig(dir string) (string, *config.Config, error) { + for _, name := range []string{"sqlc.yaml", "sqlc.yml", "sqlc.json"} { + path := filepath.Join(dir, name) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + continue + } + return path, nil, err + } + defer f.Close() + conf, err := config.ParseConfig(f) + if err != nil { + return path, nil, fmt.Errorf("parse %s: %w", path, err) + } + return path, &conf, nil + } + return "", nil, fmt.Errorf("no sqlc config found in %s", dir) +} + +// configRef is the result of preparing a config for a single TestReplay case. +// reader is for api.Generate; file is for cmd.Vet which still takes a path. +type configRef struct { + reader io.Reader + file string +} + type textContext struct { - Mutate func(*testing.T, string) func(*config.Config) + Config func(*testing.T, string) configRef Enabled func() bool } func TestReplay(t *testing.T) { // Ensure that this environment variable is always set to true when running // end-to-end tests - os.Setenv("SQLC_DUMMY_VALUE", "true") + t.Setenv("SQLC_DUMMY_VALUE", "true") // t.Parallel() ctx := context.Background() @@ -189,45 +266,28 @@ func TestReplay(t *testing.T) { contexts := map[string]textContext{ "base": { - Mutate: func(t *testing.T, path string) func(*config.Config) { return func(c *config.Config) {} }, + Config: func(t *testing.T, dir string) configRef { + data, _ := mutatedConfigBytes(t, dir, nil) + return configRef{reader: bytes.NewReader(data)} + }, Enabled: func() bool { return true }, }, "managed-db": { - Mutate: func(t *testing.T, path string) func(*config.Config) { - return func(c *config.Config) { + Config: func(t *testing.T, dir string) configRef { + data, file := mutatedConfigBytes(t, dir, func(c *config.Config) { // Add all servers - tests will fail if database isn't available c.Servers = []config.Server{ - { - Name: "postgres", - Engine: config.EnginePostgreSQL, - URI: postgresURI, - }, - { - Name: "mysql", - Engine: config.EngineMySQL, - URI: mysqlURI, - }, + {Name: "postgres", Engine: config.EnginePostgreSQL, URI: postgresURI}, + {Name: "mysql", Engine: config.EngineMySQL, URI: mysqlURI}, } - for i := range c.SQL { switch c.SQL[i].Engine { - case config.EnginePostgreSQL: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - case config.EngineMySQL: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - case config.EngineSQLite: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - default: - // pass + case config.EnginePostgreSQL, config.EngineMySQL, config.EngineSQLite: + c.SQL[i].Database = &config.Database{Managed: true} } } - } + }) + return configRef{reader: bytes.NewReader(data), file: file} }, Enabled: func() bool { // Enabled if at least one database URI is available @@ -277,27 +337,43 @@ func TestReplay(t *testing.T) { } } - opts := cmd.Options{ + cfg := testctx.Config(t, path) + for k, v := range args.Env { + t.Setenv(k, v) + } + cmdOpts := cmd.Options{ Env: cmd.Env{ - Experiment: opts.ExperimentFromString(args.Env["SQLCEXPERIMENT"]), + Experiment: opts.ExperimentFromEnv(), }, - Stderr: &stderr, - MutateConfig: testctx.Mutate(t, path), + Stderr: &stderr, } - release := withSQLCDEBUG(t, args.Env["SQLCDEBUG"]) - defer release() - switch args.Command { case "diff": - err = cmd.Diff(ctx, path, "", &opts) + res := api.Generate(ctx, api.GenerateOptions{ + Config: cfg.reader, + Stderr: &stderr, + BaseDir: path, + Diff: true, + }) + if len(res.Errors) > 0 { + err = res.Errors[0] + } case "generate": - output, err = cmd.Generate(ctx, path, "", &opts) + res := api.Generate(ctx, api.GenerateOptions{ + Config: cfg.reader, + Stderr: &stderr, + BaseDir: path, + }) + output = res.Files + if len(res.Errors) > 0 { + err = res.Errors[0] + } if err == nil { cmpDirectory(t, path, output) } case "vet": - err = cmd.Vet(ctx, path, "", &opts) + err = cmd.Vet(ctx, path, cfg.file, &cmdOpts) default: t.Fatalf("unknown command") } @@ -341,6 +417,10 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if filepath.Base(path) == "exec.json" { return nil } + // Mutated configs written by mutatedConfigBytes. + if strings.HasPrefix(filepath.Base(path), "sqlc.test-") { + return nil + } if strings.Contains(path, "/kotlin/build") { return nil } @@ -403,13 +483,14 @@ func BenchmarkReplay(b *testing.B) { tc := replay b.Run(tc, func(b *testing.B) { path, _ := filepath.Abs(tc) + cfg := openConfigBytes(b, path) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, - Stderr: &stderr, - } - cmd.Generate(ctx, path, "", opts) + api.Generate(ctx, api.GenerateOptions{ + Config: bytes.NewReader(cfg), + Stderr: &stderr, + BaseDir: path, + }) } }) } diff --git a/internal/endtoend/vet_test.go b/internal/endtoend/vet_test.go index 011c032c2e..2c8a262f4f 100644 --- a/internal/endtoend/vet_test.go +++ b/internal/endtoend/vet_test.go @@ -60,11 +60,11 @@ func TestExamplesVet(t *testing.T) { } if s, found := findSchema(t, filepath.Join(path, "mysql")); found { uri := local.MySQL(t, []string{s}) - os.Setenv(fmt.Sprintf("VET_TEST_EXAMPLES_MYSQL_%s", strings.ToUpper(tc)), uri) + t.Setenv(fmt.Sprintf("VET_TEST_EXAMPLES_MYSQL_%s", strings.ToUpper(tc)), uri) } if s, found := findSchema(t, filepath.Join(path, "postgresql")); found { uri := local.PostgreSQL(t, []string{s}) - os.Setenv(fmt.Sprintf("VET_TEST_EXAMPLES_POSTGRES_%s", strings.ToUpper(tc)), uri) + t.Setenv(fmt.Sprintf("VET_TEST_EXAMPLES_POSTGRES_%s", strings.ToUpper(tc)), uri) } }