Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 12 additions & 60 deletions experimental/aitools/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
Expand All @@ -23,9 +22,6 @@ import (
)

const (
// sqlFileExtension is the file extension used to auto-detect SQL files.
sqlFileExtension = ".sql"

// pollIntervalInitial is the starting interval between status polls.
pollIntervalInitial = 1 * time.Second

Expand Down Expand Up @@ -204,65 +200,21 @@ interactive table browser. Use --output csv to export results as CSV.`,
}

// resolveSQLs collects SQL statements from --file paths, positional args, and
// stdin. The returned slice preserves source order: --file paths first (in flag
// order), then positional args (in arg order), then stdin (only if no other
// source produced anything). Each SQL is run through cleanSQL.
// stdin via sqlcli.Collect, then runs each through cleanSQL (the warehouse
// statement API doesn't care about line comments, so we strip them up front
// to normalise the wire payload). Returns just the SQL strings so the rest of
// this command's flow stays unchanged; the Source labels sqlcli adds are
// dropped on the floor (this command surfaces statement_id, not source).
func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []string) ([]string, error) {
var raws []string

for _, path := range filePaths {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read SQL file %s: %w", path, err)
}
raws = append(raws, string(data))
}

for _, arg := range args {
// If the argument looks like a .sql file, try to read it.
// Only fall through to literal SQL if the file doesn't exist.
// Surface other errors (permission denied, etc.) directly.
if strings.HasSuffix(arg, sqlFileExtension) {
data, err := os.ReadFile(arg)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("read SQL file: %w", err)
}
if err == nil {
raws = append(raws, string(data))
continue
}
}
raws = append(raws, arg)
}

if len(raws) == 0 {
// No --file and no positional args: try reading from stdin if it's piped.
// If stdin was overridden (e.g. cmd.SetIn in tests), always read from it.
// Otherwise, only read if stdin is not a TTY (i.e. piped input).
in := cmd.InOrStdin()
_, isOsFile := in.(*os.File)
if isOsFile && cmdio.IsPromptSupported(ctx) {
return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin")
}
data, err := io.ReadAll(in)
if err != nil {
return nil, fmt.Errorf("read stdin: %w", err)
}
raws = append(raws, string(data))
inputs, err := sqlcli.Collect(ctx, cmd.InOrStdin(), args, filePaths, sqlcli.CollectOptions{Cleaner: cleanSQL})
if err != nil {
return nil, err
}

cleaned := make([]string, 0, len(raws))
for i, raw := range raws {
c := cleanSQL(raw)
if c == "" {
if len(raws) == 1 {
return nil, errors.New("SQL statement is empty after removing comments and blank lines")
}
return nil, fmt.Errorf("SQL statement #%d is empty after removing comments and blank lines", i+1)
}
cleaned = append(cleaned, c)
out := make([]string, len(inputs))
for i, in := range inputs {
out[i] = in.SQL
}
return cleaned, nil
return out, nil
}

// runBatch executes multiple SQL statements in parallel and renders the result
Expand Down
6 changes: 3 additions & 3 deletions experimental/aitools/cmd/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ func TestResolveSQLsUnreadableSQLFileReturnsError(t *testing.T) {
cmd := newTestCmd()
_, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "read SQL file")
assert.Contains(t, err.Error(), "permission denied")
}

func TestResolveSQLsFromStdin(t *testing.T) {
Expand Down Expand Up @@ -579,14 +579,14 @@ func TestResolveSQLsBatchEmptyAtIndexReturnsIndexedError(t *testing.T) {
cmd := newTestCmd()
_, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "-- comment only", "SELECT 3"}, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "SQL statement #2 is empty")
assert.Contains(t, err.Error(), "argv[2] is empty")
}

func TestResolveSQLsMissingFileReturnsError(t *testing.T) {
cmd := newTestCmd()
_, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{"/nonexistent/path/query.sql"})
require.Error(t, err)
assert.Contains(t, err.Error(), "read SQL file")
assert.Contains(t, err.Error(), "no such file")
}

func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) {
Expand Down
123 changes: 123 additions & 0 deletions experimental/libs/sqlcli/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package sqlcli

import (
"context"
"errors"
"fmt"
"io"
"os"
"strings"

"github.com/databricks/cli/libs/cmdio"
)

// SQLFileExtension is the file suffix that triggers the .sql autodetect on a
// positional argument: if `databricks ... query foo.sql` exists on disk, the
// argument is read as a SQL file; otherwise it's treated as literal SQL.
const SQLFileExtension = ".sql"

// Input is one SQL statement to execute, paired with a label identifying its
// origin so multi-input renderers and error messages can refer back to "which
// of the N inputs failed".
type Input struct {
// SQL is the cleaned statement text. Always non-empty (Collect rejects
// inputs that clean to empty).
SQL string
// Source is a human-readable label: "--file PATH", "argv[N]", or "stdin".
Source string
}

// CollectOptions controls per-command behavior. The zero value is fine for
// commands that just want plain trimmed input.
type CollectOptions struct {
// Cleaner is applied to each raw SQL after read (and before the empty
// check). The default is strings.TrimSpace; aitools passes a richer
// cleaner that strips SQL comments and surrounding quotes. Postgres
// passes the default because its multi-statement scanner needs comments
// preserved.
Cleaner func(string) string
}

// Collect assembles the ordered list of inputs from --file paths, positional
// arguments, and stdin.
//
// Order is files-first, then positionals. Cobra/pflag does not preserve the
// user's interleaved CLI spelling: it collects all --file flags into one
// slice and all positionals into another, so callers cannot honour
// `--file q1.sql "SELECT 1" --file q2.sql` as written.
//
// Stdin is read only when neither --file nor positional input was provided,
// and only when stdin is not a prompt-capable TTY (otherwise we'd block
// waiting for input the user did not realise they had to type).
//
// Errors when:
// - A --file path can't be read or cleans to empty.
// - A positional that looks like a .sql file but read fails with a non-
// "does not exist" error (e.g. permission denied).
// - A positional cleans to empty.
// - Stdin is the only source and it's empty / blocked on a TTY.
func Collect(ctx context.Context, in io.Reader, args, files []string, opts CollectOptions) ([]Input, error) {
cleaner := opts.Cleaner
if cleaner == nil {
cleaner = strings.TrimSpace
}

var inputs []Input

for _, path := range files {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read --file %q: %w", path, err)
}
sql := cleaner(string(data))
if sql == "" {
return nil, fmt.Errorf("--file %q is empty", path)
}
inputs = append(inputs, Input{SQL: sql, Source: "--file " + path})
}

for i, arg := range args {
// .sql autodetect: if the positional ends in .sql AND the file
// exists, read it as a SQL file. Other read errors (permission
// denied) surface directly. If the file does not exist, fall
// through and treat the positional as literal SQL — useful when
// the user passes a string that happens to end with ".sql".
if strings.HasSuffix(arg, SQLFileExtension) {
data, err := os.ReadFile(arg)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("read positional %q: %w", arg, err)
}
if err == nil {
sql := cleaner(string(data))
if sql == "" {
return nil, fmt.Errorf("positional %q is empty", arg)
}
inputs = append(inputs, Input{SQL: sql, Source: arg})
continue
}
}
sql := cleaner(arg)
if sql == "" {
return nil, fmt.Errorf("argv[%d] is empty", i+1)
}
inputs = append(inputs, Input{SQL: sql, Source: fmt.Sprintf("argv[%d]", i+1)})
}

if len(inputs) == 0 {
_, isOsFile := in.(*os.File)
if isOsFile && cmdio.IsPromptSupported(ctx) {
return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin")
}
data, err := io.ReadAll(in)
if err != nil {
return nil, fmt.Errorf("read stdin: %w", err)
}
sql := cleaner(string(data))
if sql == "" {
return nil, errors.New("no SQL provided")
}
inputs = append(inputs, Input{SQL: sql, Source: "stdin"})
}

return inputs, nil
}
124 changes: 124 additions & 0 deletions experimental/libs/sqlcli/input_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package sqlcli

import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func writeTemp(t *testing.T, name, contents string) string {
t.Helper()
dir := t.TempDir()
p := filepath.Join(dir, name)
require.NoError(t, os.WriteFile(p, []byte(contents), 0o644))
return p
}

func TestCollect_PositionalOnly(t *testing.T) {
got, err := Collect(t.Context(), strings.NewReader(""), []string{"SELECT 1"}, nil, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "SELECT 1", got[0].SQL)
assert.Equal(t, "argv[1]", got[0].Source)
}

func TestCollect_MultiplePositionals(t *testing.T) {
got, err := Collect(t.Context(), strings.NewReader(""), []string{"SELECT 1", "SELECT 2"}, nil, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 2)
assert.Equal(t, "SELECT 1", got[0].SQL)
assert.Equal(t, "SELECT 2", got[1].SQL)
}

func TestCollect_FileOnly(t *testing.T) {
p := writeTemp(t, "q.sql", "SELECT * FROM t")
got, err := Collect(t.Context(), strings.NewReader(""), nil, []string{p}, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "SELECT * FROM t", got[0].SQL)
assert.Contains(t, got[0].Source, "--file")
}

func TestCollect_FilesFirstThenPositionals(t *testing.T) {
p1 := writeTemp(t, "a.sql", "SELECT 1")
p2 := writeTemp(t, "b.sql", "SELECT 2")
got, err := Collect(t.Context(), strings.NewReader(""), []string{"SELECT 3"}, []string{p1, p2}, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 3)
assert.Equal(t, "SELECT 1", got[0].SQL)
assert.Equal(t, "SELECT 2", got[1].SQL)
assert.Equal(t, "SELECT 3", got[2].SQL)
}

func TestCollect_DotSQLAutoDetect(t *testing.T) {
p := writeTemp(t, "data.sql", "SELECT 42")
got, err := Collect(t.Context(), strings.NewReader(""), []string{p}, nil, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "SELECT 42", got[0].SQL)
}

func TestCollect_DotSQLNotExistingFallsThroughToLiteral(t *testing.T) {
got, err := Collect(t.Context(), strings.NewReader(""), []string{"/nonexistent/path.sql"}, nil, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "/nonexistent/path.sql", got[0].SQL)
}

func TestCollect_StdinOnly(t *testing.T) {
got, err := Collect(t.Context(), strings.NewReader("SELECT 1\n"), nil, nil, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "SELECT 1", got[0].SQL)
assert.Equal(t, "stdin", got[0].Source)
}

func TestCollect_StdinIgnoredWhenPositionalsPresent(t *testing.T) {
got, err := Collect(t.Context(), strings.NewReader("FROM STDIN"), []string{"SELECT 1"}, nil, CollectOptions{})
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "SELECT 1", got[0].SQL)
}

func TestCollect_EmptyStdinErrors(t *testing.T) {
_, err := Collect(t.Context(), strings.NewReader(""), nil, nil, CollectOptions{})
assert.ErrorContains(t, err, "no SQL provided")
}

func TestCollect_EmptyFileErrors(t *testing.T) {
p := writeTemp(t, "empty.sql", "")
_, err := Collect(t.Context(), strings.NewReader(""), nil, []string{p}, CollectOptions{})
assert.ErrorContains(t, err, "is empty")
}

func TestCollect_EmptyPositional(t *testing.T) {
_, err := Collect(t.Context(), strings.NewReader(""), []string{" "}, nil, CollectOptions{})
assert.ErrorContains(t, err, "is empty")
}

func TestCollect_CustomCleanerStripsComments(t *testing.T) {
cleaner := func(s string) string {
// Naive comment stripper: drop lines starting with --
var lines []string
for line := range strings.SplitSeq(s, "\n") {
line = strings.TrimSpace(line)
if line != "" && !strings.HasPrefix(line, "--") {
lines = append(lines, line)
}
}
return strings.Join(lines, "\n")
}
got, err := Collect(
t.Context(), strings.NewReader(""),
[]string{"-- ignored\nSELECT 1\n-- also ignored"},
nil,
CollectOptions{Cleaner: cleaner},
)
require.NoError(t, err)
require.Len(t, got, 1)
assert.Equal(t, "SELECT 1", got[0].SQL)
}
Loading
Loading