From 6c80f238285d94fe58aa52a455917d10d2b31b7b Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 11 Jun 2026 18:52:04 +0000 Subject: [PATCH 1/8] experimental/air: scaffold AI runtime CLI command package Add the experimental `air` command group as the Go port surface for the Python `air` CLI. Every subcommand (run, status, list, logs, cancel, register-image) is registered as a stub that returns a not-implemented error; the real implementations land in later milestones. The package lives under experimental/air/cmd (imported as aircmd), matching the layout of the other experimental features (aitools, genie, postgres); cmd/experimental/ keeps only the dispatcher. TEST_PACKAGES in Taskfile.yml gains ./experimental/air/... so the unit tests keep running after the move. Includes unit tests for the command-tree wiring and the not-implemented stubs, plus an acceptance test exercising the stubs end-to-end. Co-authored-by: Isaac --- Taskfile.yml | 2 +- .../experimental/air/help/out.test.toml | 3 ++ acceptance/experimental/air/help/output.txt | 29 ++++++++++++++ acceptance/experimental/air/help/script | 5 +++ acceptance/experimental/air/help/test.toml | 3 ++ .../air/unimplemented/out.test.toml | 3 ++ .../experimental/air/unimplemented/output.txt | 36 +++++++++++++++++ .../experimental/air/unimplemented/script | 19 +++++++++ .../experimental/air/unimplemented/test.toml | 3 ++ cmd/experimental/experimental.go | 2 + experimental/air/cmd/air.go | 36 +++++++++++++++++ experimental/air/cmd/air_test.go | 22 +++++++++++ experimental/air/cmd/cancel.go | 39 +++++++++++++++++++ experimental/air/cmd/list.go | 31 +++++++++++++++ experimental/air/cmd/logs.go | 34 ++++++++++++++++ experimental/air/cmd/register_image.go | 33 ++++++++++++++++ experimental/air/cmd/run.go | 36 +++++++++++++++++ experimental/air/cmd/status.go | 19 +++++++++ experimental/air/cmd/stubs_test.go | 31 +++++++++++++++ 19 files changed, 385 insertions(+), 1 deletion(-) create mode 100644 acceptance/experimental/air/help/out.test.toml create mode 100644 acceptance/experimental/air/help/output.txt create mode 100644 acceptance/experimental/air/help/script create mode 100644 acceptance/experimental/air/help/test.toml create mode 100644 acceptance/experimental/air/unimplemented/out.test.toml create mode 100644 acceptance/experimental/air/unimplemented/output.txt create mode 100644 acceptance/experimental/air/unimplemented/script create mode 100644 acceptance/experimental/air/unimplemented/test.toml create mode 100644 experimental/air/cmd/air.go create mode 100644 experimental/air/cmd/air_test.go create mode 100644 experimental/air/cmd/cancel.go create mode 100644 experimental/air/cmd/list.go create mode 100644 experimental/air/cmd/logs.go create mode 100644 experimental/air/cmd/register_image.go create mode 100644 experimental/air/cmd/run.go create mode 100644 experimental/air/cmd/status.go create mode 100644 experimental/air/cmd/stubs_test.go diff --git a/Taskfile.yml b/Taskfile.yml index d72140290e2..32cb14d0c43 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -4,7 +4,7 @@ vars: # Absolute path so tasks with `dir:` (lint-go-tools, lint-go-codegen) can use it. GO_TOOL: go tool -modfile={{.ROOT_DIR}}/tools/go.mod EXE_EXT: '{{if eq OS "windows"}}.exe{{end}}' - TEST_PACKAGES: ./acceptance/internal ./libs/... ./internal/... ./cmd/... ./bundle/... ./experimental/ssh/... . + TEST_PACKAGES: ./acceptance/internal ./libs/... ./internal/... ./cmd/... ./bundle/... ./experimental/air/... ./experimental/ssh/... . ACCEPTANCE_TEST_FILTER: "" # Single brace-expansion glob covering every //go:embed target in the repo, # computed by grepping `//go:embed` directives. Evaluated lazily by Task so diff --git a/acceptance/experimental/air/help/out.test.toml b/acceptance/experimental/air/help/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/help/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/help/output.txt b/acceptance/experimental/air/help/output.txt new file mode 100644 index 00000000000..cf7e5af634c --- /dev/null +++ b/acceptance/experimental/air/help/output.txt @@ -0,0 +1,29 @@ + +=== help +>>> [CLI] experimental air --help +Run and manage AI runtime training workloads on Databricks serverless GPU compute. + +This command set is the Go port of the standalone Python "air" CLI. It is +experimental and may change in future versions. + +Usage: + databricks experimental air [command] + +Available Commands: + cancel Cancel one or more runs + list List recent runs + logs Stream or fetch logs for a run + register-image Mirror a Docker image into the workspace registry + run Submit a training workload from a YAML config + status Show status and configuration for a run + +Flags: + -h, --help help for air + +Global Flags: + --debug enable debug logging + -o, --output type output type: text or json (default text) + -p, --profile string ~/.databrickscfg profile + -t, --target string bundle target to use (if applicable) + +Use "databricks experimental air [command] --help" for more information about a command. diff --git a/acceptance/experimental/air/help/script b/acceptance/experimental/air/help/script new file mode 100644 index 00000000000..cd67a6fc1b1 --- /dev/null +++ b/acceptance/experimental/air/help/script @@ -0,0 +1,5 @@ +# Pin the command tree so any change to a subcommand or its short description +# shows up as a diff here. + +title "help" +trace $CLI experimental air --help diff --git a/acceptance/experimental/air/help/test.toml b/acceptance/experimental/air/help/test.toml new file mode 100644 index 00000000000..49709b578ef --- /dev/null +++ b/acceptance/experimental/air/help/test.toml @@ -0,0 +1,3 @@ +# --help prints without authenticating, so no server stubs are needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/unimplemented/out.test.toml b/acceptance/experimental/air/unimplemented/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/unimplemented/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt new file mode 100644 index 00000000000..3dc88de3b77 --- /dev/null +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -0,0 +1,36 @@ + +=== run +>>> [CLI] experimental air run +Error: `air run` is not implemented yet + +Exit code: 1 + +=== status +>>> [CLI] experimental air status 123 +Error: `air status` is not implemented yet + +Exit code: 1 + +=== list +>>> [CLI] experimental air list +Error: `air list` is not implemented yet + +Exit code: 1 + +=== logs +>>> [CLI] experimental air logs 123 +Error: `air logs` is not implemented yet + +Exit code: 1 + +=== cancel +>>> [CLI] experimental air cancel 123 +Error: `air cancel` is not implemented yet + +Exit code: 1 + +=== register-image +>>> [CLI] experimental air register-image my-image:latest +Error: `air register-image` is not implemented yet + +Exit code: 1 diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script new file mode 100644 index 00000000000..83397b4b741 --- /dev/null +++ b/acceptance/experimental/air/unimplemented/script @@ -0,0 +1,19 @@ +# Each stub must fail with "not implemented"; errcode records the exit code. + +title "run" +errcode trace $CLI experimental air run + +title "status" +errcode trace $CLI experimental air status 123 + +title "list" +errcode trace $CLI experimental air list + +title "logs" +errcode trace $CLI experimental air logs 123 + +title "cancel" +errcode trace $CLI experimental air cancel 123 + +title "register-image" +errcode trace $CLI experimental air register-image my-image:latest diff --git a/acceptance/experimental/air/unimplemented/test.toml b/acceptance/experimental/air/unimplemented/test.toml new file mode 100644 index 00000000000..c233c30a86c --- /dev/null +++ b/acceptance/experimental/air/unimplemented/test.toml @@ -0,0 +1,3 @@ +# Stubs fail locally before any API call, so no server stubs needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/cmd/experimental/experimental.go b/cmd/experimental/experimental.go index 8d9827c5c94..d87c893abc5 100644 --- a/cmd/experimental/experimental.go +++ b/cmd/experimental/experimental.go @@ -1,6 +1,7 @@ package experimental import ( + aircmd "github.com/databricks/cli/experimental/air/cmd" aitoolscmd "github.com/databricks/cli/experimental/aitools/cmd" geniecmd "github.com/databricks/cli/experimental/genie/cmd" postgrescmd "github.com/databricks/cli/experimental/postgres/cmd" @@ -22,6 +23,7 @@ These commands provide early access to new features that are still under development. They may change or be removed in future versions without notice.`, } + cmd.AddCommand(aircmd.New()) cmd.AddCommand(aitoolscmd.NewAitoolsCmd()) cmd.AddCommand(geniecmd.NewGenieCmd()) cmd.AddCommand(postgrescmd.New()) diff --git a/experimental/air/cmd/air.go b/experimental/air/cmd/air.go new file mode 100644 index 00000000000..3f9122c828c --- /dev/null +++ b/experimental/air/cmd/air.go @@ -0,0 +1,36 @@ +package aircmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +// New returns the root command for the experimental AI runtime CLI. +// +// Milestone 0: scaffolds the command group with every subcommand registered as a +// stub (not yet implemented), pending the port from the Python `air` CLI. +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "air", + Short: "Run and manage AI runtime training workloads", + Long: `Run and manage AI runtime training workloads on Databricks serverless GPU compute. + +This command set is the Go port of the standalone Python "air" CLI. It is +experimental and may change in future versions.`, + } + + cmd.AddCommand(newRunCommand()) + cmd.AddCommand(newStatusCommand()) + cmd.AddCommand(newListCommand()) + cmd.AddCommand(newLogsCommand()) + cmd.AddCommand(newCancelCommand()) + cmd.AddCommand(newRegisterImageCommand()) + + return cmd +} + +// notImplemented returns the placeholder error used by milestone-0 stubs. +func notImplemented(name string) error { + return fmt.Errorf("`air %s` is not implemented yet", name) +} diff --git a/experimental/air/cmd/air_test.go b/experimental/air/cmd/air_test.go new file mode 100644 index 00000000000..26268690850 --- /dev/null +++ b/experimental/air/cmd/air_test.go @@ -0,0 +1,22 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestNewRegistersAllSubcommands asserts the `air` command wires up every +// expected subcommand, so none is accidentally dropped from New. +func TestNewRegistersAllSubcommands(t *testing.T) { + registered := make(map[string]bool) + for _, c := range New().Commands() { + registered[c.Name()] = true + } + + want := []string{"run", "status", "list", "logs", "cancel", "register-image"} + for _, name := range want { + assert.True(t, registered[name], "subcommand %q is not registered", name) + } + assert.Len(t, registered, len(want), "unexpected number of subcommands") +} diff --git a/experimental/air/cmd/cancel.go b/experimental/air/cmd/cancel.go new file mode 100644 index 00000000000..ad7fffc7125 --- /dev/null +++ b/experimental/air/cmd/cancel.go @@ -0,0 +1,39 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newCancelCommand() *cobra.Command { + var ( + all bool + yes bool + ) + + cmd := &cobra.Command{ + Use: "cancel [RUN_ID...]", + Short: "Cancel one or more runs", + Long: `Cancel one or more runs by ID, or cancel all of your active runs with --all.`, + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("cancel") + }, + } + + cmd.Flags().BoolVar(&all, "all", false, "Cancel all of your active runs") + cmd.Flags().BoolVarP(&yes, "yes", "y", false, "Skip the confirmation prompt") + + // Require exactly one of: one or more RUN_IDs, or --all. Cobra parses flags + // before running this, so `all` reflects the user's input. + cmd.Args = func(cmd *cobra.Command, args []string) error { + switch { + case all && len(args) > 0: + return &root.InvalidArgsError{Command: cmd, Message: "cannot combine RUN_ID arguments with --all"} + case !all && len(args) == 0: + return &root.InvalidArgsError{Command: cmd, Message: "provide at least one RUN_ID, or use --all"} + } + return nil + } + + return cmd +} diff --git a/experimental/air/cmd/list.go b/experimental/air/cmd/list.go new file mode 100644 index 00000000000..bf24cff9b23 --- /dev/null +++ b/experimental/air/cmd/list.go @@ -0,0 +1,31 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newListCommand() *cobra.Command { + var ( + limit int + active bool + allUsers bool + filters []string + ) + + cmd := &cobra.Command{ + Use: "list", + Args: root.NoArgs, + Short: "List recent runs", + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("list") + }, + } + + cmd.Flags().IntVar(&limit, "limit", 20, "Maximum number of runs to show") + cmd.Flags().BoolVar(&active, "active", false, "Show only active runs") + cmd.Flags().BoolVar(&allUsers, "all-users", false, "Show runs from all users") + cmd.Flags().StringArrayVar(&filters, "filter", nil, "Filter runs, e.g. experiment=foo* (repeatable)") + + return cmd +} diff --git a/experimental/air/cmd/logs.go b/experimental/air/cmd/logs.go new file mode 100644 index 00000000000..4dbbe41c278 --- /dev/null +++ b/experimental/air/cmd/logs.go @@ -0,0 +1,34 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newLogsCommand() *cobra.Command { + var ( + node int + lines int + retry int + downloadTo string + review bool + ) + + cmd := &cobra.Command{ + Use: "logs RUN_ID", + Args: root.ExactArgs(1), + Short: "Stream or fetch logs for a run", + Long: `Stream logs from an active run, or fetch logs from a completed run.`, + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("logs") + }, + } + + cmd.Flags().IntVar(&node, "node", 0, "Fetch logs from this node") + cmd.Flags().IntVar(&lines, "lines", 10000, "For completed runs, print the last N lines") + cmd.Flags().IntVar(&retry, "retry", -1, "View logs from a specific retry attempt; -1 means latest") + cmd.Flags().StringVar(&downloadTo, "download-to", "", "Download all logs to this directory instead of printing") + cmd.Flags().BoolVar(&review, "review", false, "Download logs from all nodes and filter for error signatures") + + return cmd +} diff --git a/experimental/air/cmd/register_image.go b/experimental/air/cmd/register_image.go new file mode 100644 index 00000000000..a5be3df408b --- /dev/null +++ b/experimental/air/cmd/register_image.go @@ -0,0 +1,33 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newRegisterImageCommand() *cobra.Command { + var ( + scope string + key string + interactiveAuth bool + tagPolicy string + timeoutMinutes int + ) + + cmd := &cobra.Command{ + Use: "register-image IMAGE_URL", + Args: root.ExactArgs(1), + Short: "Mirror a Docker image into the workspace registry", + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("register-image") + }, + } + + cmd.Flags().StringVar(&scope, "scope", "", "Databricks secret scope holding registry credentials") + cmd.Flags().StringVar(&key, "key", "", "Databricks secret key holding registry credentials") + cmd.Flags().BoolVar(&interactiveAuth, "interactive-authenticate", false, "Prompt for registry credentials and store them as a secret") + cmd.Flags().StringVar(&tagPolicy, "tag-policy", "auto", "Image resolution policy: auto or latest") + cmd.Flags().IntVar(&timeoutMinutes, "timeout-minutes", 60, "Timeout to wait for the image to become available") + + return cmd +} diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go new file mode 100644 index 00000000000..0bc3d1fd94b --- /dev/null +++ b/experimental/air/cmd/run.go @@ -0,0 +1,36 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newRunCommand() *cobra.Command { + var ( + file string + watch bool + overrides []string + dryRun bool + idempotencyKey string + ) + + cmd := &cobra.Command{ + Use: "run", + Args: root.NoArgs, + Short: "Submit a training workload from a YAML config", + Long: `Submit a training workload to Databricks serverless GPU compute. + +The workload is described by a YAML config file (see --file).`, + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("run") + }, + } + + cmd.Flags().StringVarP(&file, "file", "f", "", "Path to the workload YAML config") + cmd.Flags().BoolVar(&watch, "watch", false, "Stream logs until the run completes") + cmd.Flags().StringArrayVar(&overrides, "override", nil, "Override a YAML field, e.g. compute.num_accelerators=8 (repeatable)") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Validate the config without submitting") + cmd.Flags().StringVar(&idempotencyKey, "idempotency-key", "", "Return the existing run if this key was already used") + + return cmd +} diff --git a/experimental/air/cmd/status.go b/experimental/air/cmd/status.go new file mode 100644 index 00000000000..a0db0619331 --- /dev/null +++ b/experimental/air/cmd/status.go @@ -0,0 +1,19 @@ +package aircmd + +import ( + "github.com/databricks/cli/cmd/root" + "github.com/spf13/cobra" +) + +func newStatusCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "status RUN_ID", + Args: root.ExactArgs(1), + Short: "Show status and configuration for a run", + RunE: func(cmd *cobra.Command, args []string) error { + return notImplemented("status") + }, + } + + return cmd +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go new file mode 100644 index 00000000000..8ffd197973f --- /dev/null +++ b/experimental/air/cmd/stubs_test.go @@ -0,0 +1,31 @@ +package aircmd + +import ( + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestStubCommandsReturnNotImplemented asserts each unimplemented subcommand +// fails with a "not implemented" error. Drop a command here once it lands. +func TestStubCommandsReturnNotImplemented(t *testing.T) { + stubs := map[string]*cobra.Command{ + "run": newRunCommand(), + "status": newStatusCommand(), + "list": newListCommand(), + "logs": newLogsCommand(), + "cancel": newCancelCommand(), + "register-image": newRegisterImageCommand(), + } + + for name, cmd := range stubs { + t.Run(name, func(t *testing.T) { + require.NotNil(t, cmd.RunE, "command should define RunE") + err := cmd.RunE(cmd, nil) + assert.EqualError(t, err, fmt.Sprintf("`air %s` is not implemented yet", name)) + }) + } +} From 059bd61ca8e04f4e2ecd4f3c6c92c45f98208b99 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Fri, 12 Jun 2026 23:10:23 +0000 Subject: [PATCH 2/8] experimental/air: rename `status` subcommand to `get` Rename the run-details subcommand from `status` to `get`, matching the Python air CLI's current `air get run` naming (it replaced `get status`). Renames the file, constructor, command name, and updates the stub/help/unimplemented tests and goldens accordingly. Co-authored-by: Isaac --- acceptance/experimental/air/help/output.txt | 2 +- acceptance/experimental/air/unimplemented/output.txt | 6 +++--- acceptance/experimental/air/unimplemented/script | 4 ++-- experimental/air/cmd/air.go | 2 +- experimental/air/cmd/air_test.go | 2 +- experimental/air/cmd/{status.go => get.go} | 8 ++++---- experimental/air/cmd/stubs_test.go | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) rename experimental/air/cmd/{status.go => get.go} (59%) diff --git a/acceptance/experimental/air/help/output.txt b/acceptance/experimental/air/help/output.txt index cf7e5af634c..3a0f86e164f 100644 --- a/acceptance/experimental/air/help/output.txt +++ b/acceptance/experimental/air/help/output.txt @@ -11,11 +11,11 @@ Usage: Available Commands: cancel Cancel one or more runs + get Show details for a run list List recent runs logs Stream or fetch logs for a run register-image Mirror a Docker image into the workspace registry run Submit a training workload from a YAML config - status Show status and configuration for a run Flags: -h, --help help for air diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 3dc88de3b77..4a07a38a378 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -5,9 +5,9 @@ Error: `air run` is not implemented yet Exit code: 1 -=== status ->>> [CLI] experimental air status 123 -Error: `air status` is not implemented yet +=== get +>>> [CLI] experimental air get 123 +Error: `air get` is not implemented yet Exit code: 1 diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index 83397b4b741..2ed885c0e66 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -3,8 +3,8 @@ title "run" errcode trace $CLI experimental air run -title "status" -errcode trace $CLI experimental air status 123 +title "get" +errcode trace $CLI experimental air get 123 title "list" errcode trace $CLI experimental air list diff --git a/experimental/air/cmd/air.go b/experimental/air/cmd/air.go index 3f9122c828c..81ffb2dd346 100644 --- a/experimental/air/cmd/air.go +++ b/experimental/air/cmd/air.go @@ -21,7 +21,7 @@ experimental and may change in future versions.`, } cmd.AddCommand(newRunCommand()) - cmd.AddCommand(newStatusCommand()) + cmd.AddCommand(newGetCommand()) cmd.AddCommand(newListCommand()) cmd.AddCommand(newLogsCommand()) cmd.AddCommand(newCancelCommand()) diff --git a/experimental/air/cmd/air_test.go b/experimental/air/cmd/air_test.go index 26268690850..7efac253a2b 100644 --- a/experimental/air/cmd/air_test.go +++ b/experimental/air/cmd/air_test.go @@ -14,7 +14,7 @@ func TestNewRegistersAllSubcommands(t *testing.T) { registered[c.Name()] = true } - want := []string{"run", "status", "list", "logs", "cancel", "register-image"} + want := []string{"run", "get", "list", "logs", "cancel", "register-image"} for _, name := range want { assert.True(t, registered[name], "subcommand %q is not registered", name) } diff --git a/experimental/air/cmd/status.go b/experimental/air/cmd/get.go similarity index 59% rename from experimental/air/cmd/status.go rename to experimental/air/cmd/get.go index a0db0619331..0ab0b8226bf 100644 --- a/experimental/air/cmd/status.go +++ b/experimental/air/cmd/get.go @@ -5,13 +5,13 @@ import ( "github.com/spf13/cobra" ) -func newStatusCommand() *cobra.Command { +func newGetCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "status RUN_ID", + Use: "get RUN_ID", Args: root.ExactArgs(1), - Short: "Show status and configuration for a run", + Short: "Show details for a run", RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("status") + return notImplemented("get") }, } diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index 8ffd197973f..a6e24177f33 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -14,7 +14,7 @@ import ( func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ "run": newRunCommand(), - "status": newStatusCommand(), + "get": newGetCommand(), "list": newListCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), From 2ccd0697c1c4ac9373685286c96eab0024686860 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 21:38:13 +0000 Subject: [PATCH 3/8] experimental/air: implement the `air get` command Implement the read-only run-details command (renamed from `status` to `get`). It fetches a job run via the Jobs API and renders the run's status, start time, duration, retries, experiment, accelerators, dashboard URL, MLflow deep-link, and a foreach/sweep summary. Output is the air-style {v, ts, data} JSON envelope under -o json, or a text view. Renames the command-level identifiers (status -> get) while keeping the run's "status" field/label. Adds format/mlflow/sweep/output helpers with unit tests and an acceptance test, and drops `get` from the not-implemented stub coverage. Co-authored-by: Isaac --- acceptance/experimental/air/get/out.test.toml | 3 + acceptance/experimental/air/get/output.txt | 36 +++ acceptance/experimental/air/get/script | 8 + acceptance/experimental/air/get/test.toml | 40 ++++ .../experimental/air/unimplemented/output.txt | 6 - .../experimental/air/unimplemented/script | 3 - experimental/air/cmd/format.go | 154 +++++++++++++ experimental/air/cmd/format_test.go | 131 +++++++++++ experimental/air/cmd/get.go | 172 +++++++++++++- experimental/air/cmd/get_test.go | 211 ++++++++++++++++++ experimental/air/cmd/mlflow.go | 65 ++++++ experimental/air/cmd/mlflow_test.go | 64 ++++++ experimental/air/cmd/output.go | 39 ++++ experimental/air/cmd/output_test.go | 13 ++ experimental/air/cmd/stubs_test.go | 1 - experimental/air/cmd/sweep.go | 76 +++++++ experimental/air/cmd/sweep_test.go | 81 +++++++ 17 files changed, 1091 insertions(+), 12 deletions(-) create mode 100644 acceptance/experimental/air/get/out.test.toml create mode 100644 acceptance/experimental/air/get/output.txt create mode 100644 acceptance/experimental/air/get/script create mode 100644 acceptance/experimental/air/get/test.toml create mode 100644 experimental/air/cmd/format.go create mode 100644 experimental/air/cmd/format_test.go create mode 100644 experimental/air/cmd/get_test.go create mode 100644 experimental/air/cmd/mlflow.go create mode 100644 experimental/air/cmd/mlflow_test.go create mode 100644 experimental/air/cmd/output.go create mode 100644 experimental/air/cmd/output_test.go create mode 100644 experimental/air/cmd/sweep.go create mode 100644 experimental/air/cmd/sweep_test.go diff --git a/acceptance/experimental/air/get/out.test.toml b/acceptance/experimental/air/get/out.test.toml new file mode 100644 index 00000000000..d6187dcb046 --- /dev/null +++ b/acceptance/experimental/air/get/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/get/output.txt b/acceptance/experimental/air/get/output.txt new file mode 100644 index 00000000000..6ce803659b4 --- /dev/null +++ b/acceptance/experimental/air/get/output.txt @@ -0,0 +1,36 @@ + +=== get (text) +>>> [CLI] experimental air get 123 +Run ID: 123 +Status: SUCCESS +Submitted: [TIMESTAMP] +Duration: 12s +Retries: 0 +Experiment: my-exp +User: user@example.com +Accelerators: 8x H100 +MLflow: [DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0 +Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 + +=== get (json) +>>> [CLI] experimental air get 123 -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "run_id": "123", + "status": "SUCCESS", + "started_at": "[TIMESTAMP]", + "duration_seconds": 12, + "attempt_number": 0, + "experiment_name": "my-exp", + "dashboard_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "mlflow_url": "[DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0" + } +} + +=== invalid run id +>>> [CLI] experimental air get notanumber +Error: invalid RUN_ID "notanumber": must be a positive integer + +Exit code: 1 diff --git a/acceptance/experimental/air/get/script b/acceptance/experimental/air/get/script new file mode 100644 index 00000000000..e0ea8d10f85 --- /dev/null +++ b/acceptance/experimental/air/get/script @@ -0,0 +1,8 @@ +title "get (text)" +trace $CLI experimental air get 123 + +title "get (json)" +trace $CLI experimental air get 123 -o json + +title "invalid run id" +errcode trace $CLI experimental air get notanumber diff --git a/acceptance/experimental/air/get/test.toml b/acceptance/experimental/air/get/test.toml new file mode 100644 index 00000000000..b6219b87f07 --- /dev/null +++ b/acceptance/experimental/air/get/test.toml @@ -0,0 +1,40 @@ +# This command does not deploy a bundle, so no engine matrix is needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] + +# The SDK occasionally probes host reachability with a HEAD request; stub it so +# the test is deterministic. +[[Server]] +Pattern = "HEAD /" +Response.Body = '' + +# A single GenAI-compute run with an experiment, GPUs, and a creator. +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get" +Response.Body = ''' +{ + "run_id": 123, + "run_page_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "creator_user_name": "user@example.com", + "start_time": 1700000000000, + "end_time": 1700000012000, + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}, + "tasks": [ + { + "task_key": "train", + "attempt_number": 0, + "gen_ai_compute_task": { + "mlflow_experiment_name": "/Users/user@example.com/my-exp", + "compute": {"gpu_type": "GPU_8xH100", "num_gpus": 8} + } + } + ] +} +''' + +# MLflow identifiers for the deep-link (runs/get-output is not modeled by the typed SDK). +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get-output" +Response.Body = ''' +{"gen_ai_compute_output": {"run_info": {"mlflow_experiment_id": "exp1", "mlflow_run_id": "run1"}}} +''' diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 4a07a38a378..0a86360c78f 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -5,12 +5,6 @@ Error: `air run` is not implemented yet Exit code: 1 -=== get ->>> [CLI] experimental air get 123 -Error: `air get` is not implemented yet - -Exit code: 1 - === list >>> [CLI] experimental air list Error: `air list` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index 2ed885c0e66..e6e8d33ef9d 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -3,9 +3,6 @@ title "run" errcode trace $CLI experimental air run -title "get" -errcode trace $CLI experimental air get 123 - title "list" errcode trace $CLI experimental air list diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go new file mode 100644 index 00000000000..88f620ee7c3 --- /dev/null +++ b/experimental/air/cmd/format.go @@ -0,0 +1,154 @@ +package aircmd + +import ( + "fmt" + "strings" + "time" + + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// gpuDisplayNames maps the GPU identifiers returned by the backend to the short +// names we show to users. Unknown identifiers are shown unchanged. +var gpuDisplayNames = map[string]string{ + "h100_80gb": "H100", + "a10": "A10", + "GPU_1xA10": "A10", + "GPU_8xH100": "H100", + "GPU_1xH100": "H100", +} + +// runStatus returns the single status word to show for a run. The backend +// reports two values: a lifecycle state (e.g. PENDING, RUNNING) and, once the +// run has finished, a result state (e.g. SUCCESS, FAILED). The result state is +// the more meaningful one, so we prefer it when it is set. +func runStatus(state *jobs.RunState) string { + if state == nil { + return "UNKNOWN" + } + if state.ResultState != "" { + return string(state.ResultState) + } + if state.LifeCycleState != "" { + return string(state.LifeCycleState) + } + return "UNKNOWN" +} + +// startedAt converts the run's start time (epoch milliseconds) to an RFC 3339 +// UTC string, or returns nil if the run has not started yet. +func startedAt(run *jobs.Run) *string { + if run.StartTime == 0 { + return nil + } + s := time.UnixMilli(run.StartTime).UTC().Format(time.RFC3339) + return &s +} + +// durationSeconds returns how long the run has taken, in whole seconds, or nil +// if it has not started. For a finished run this is the elapsed time; for a +// still-running run it is the time since it started. +func durationSeconds(run *jobs.Run) *int64 { + if run.StartTime == 0 { + return nil + } + + var endMillis int64 + switch { + case run.RunDuration > 0: + // The backend already computed the duration for us. + d := run.RunDuration / 1000 + return &d + case run.EndTime > 0: + endMillis = run.EndTime + default: + // Still running: measure against the current time. + endMillis = time.Now().UnixMilli() + } + + d := (endMillis - run.StartTime) / 1000 + return &d +} + +// formatDuration turns a number of seconds into a compact human string such as +// "1h 2m 3s". Trailing zero units are dropped, but a lone "0s" is kept so the +// result is never empty. +func formatDuration(totalSeconds int64) string { + hours := totalSeconds / 3600 + minutes := (totalSeconds % 3600) / 60 + seconds := totalSeconds % 60 + + var parts []string + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + return strings.Join(parts, " ") +} + +// latestAttemptNumber returns the retry count of the run's most recent task. +// Tasks start at attempt 0, so a value of 0 means the run has not been retried. +func latestAttemptNumber(run *jobs.Run) int { + if len(run.Tasks) == 0 { + return 0 + } + return run.Tasks[len(run.Tasks)-1].AttemptNumber +} + +// experimentName returns the MLflow experiment name for the run, or nil if there +// isn't one. Experiment names are often stored under a user's home folder (e.g. +// "/Users/me@example.com/my-experiment"); we strip that prefix so users see just +// the experiment name they chose. +func experimentName(run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.MlflowExperimentName == "" { + return nil + } + name := stripExperimentUserPrefix(task.MlflowExperimentName) + return &name +} + +// stripExperimentUserPrefix removes a leading "/Users//" from an +// experiment name, leaving the remainder. Names without that prefix are returned +// unchanged. +func stripExperimentUserPrefix(name string) string { + if !strings.HasPrefix(name, "/Users/") { + return name + } + // Split into ["", "Users", "", ""]; keep "". + parts := strings.SplitN(name, "/", 4) + if len(parts) == 4 { + return parts[3] + } + return name +} + +// accelerators returns a short description of the GPUs the run uses, such as +// "8x H100", or an empty string if the run has no GPU compute attached. +func accelerators(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.Compute == nil || task.Compute.NumGpus == 0 { + return "" + } + return fmt.Sprintf("%dx %s", task.Compute.NumGpus, gpuDisplayName(task.Compute.GpuType)) +} + +// gpuDisplayName returns the friendly name for a GPU identifier, falling back to +// the identifier itself when it is not one we recognize. +func gpuDisplayName(gpuType string) string { + if name, ok := gpuDisplayNames[gpuType]; ok { + return name + } + return gpuType +} diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go new file mode 100644 index 00000000000..c3e2e865b81 --- /dev/null +++ b/experimental/air/cmd/format_test.go @@ -0,0 +1,131 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatDuration(t *testing.T) { + cases := []struct { + seconds int64 + want string + }{ + {0, "0s"}, + {45, "45s"}, + {60, "1m"}, + {63, "1m 3s"}, + {3600, "1h"}, + {3723, "1h 2m 3s"}, + {7260, "2h 1m"}, + } + for _, c := range cases { + assert.Equal(t, c.want, formatDuration(c.seconds)) + } +} + +func TestStripExperimentUserPrefix(t *testing.T) { + cases := []struct { + name string + want string + }{ + {"/Users/me@example.com/my-experiment", "my-experiment"}, + {"/Users/me@example.com/nested/path", "nested/path"}, + {"my-experiment", "my-experiment"}, + {"/Shared/team-experiment", "/Shared/team-experiment"}, + {"/Users/me@example.com", "/Users/me@example.com"}, + } + for _, c := range cases { + assert.Equal(t, c.want, stripExperimentUserPrefix(c.name)) + } +} + +func TestGpuDisplayName(t *testing.T) { + assert.Equal(t, "H100", gpuDisplayName("h100_80gb")) + assert.Equal(t, "A10", gpuDisplayName("GPU_1xA10")) + assert.Equal(t, "A10", gpuDisplayName("a10")) + assert.Equal(t, "H100", gpuDisplayName("GPU_8xH100")) + assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) + // Unknown identifiers pass through unchanged. + assert.Equal(t, "b200", gpuDisplayName("b200")) + assert.Equal(t, "", gpuDisplayName("")) +} + +func TestRunStatusPrefersResultState(t *testing.T) { + // Result state wins once the run has finished. + assert.Equal(t, "SUCCESS", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateTerminated, + ResultState: jobs.RunResultStateSuccess, + })) + // Before completion only the lifecycle state is set. + assert.Equal(t, "RUNNING", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateRunning, + })) + // Non-nil state with neither field set, and nil state. + assert.Equal(t, "UNKNOWN", runStatus(&jobs.RunState{})) + assert.Equal(t, "UNKNOWN", runStatus(nil)) +} + +func TestStartedAt(t *testing.T) { + // Not started yet. + assert.Nil(t, startedAt(&jobs.Run{})) + // 1700000000000 ms == 2023-11-14T22:13:20Z. + got := startedAt(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, got) + assert.Equal(t, "2023-11-14T22:13:20Z", *got) +} + +func TestDurationSeconds(t *testing.T) { + // Not started yet. + assert.Nil(t, durationSeconds(&jobs.Run{})) + + // Backend-provided duration wins (milliseconds → seconds). + d := durationSeconds(&jobs.Run{StartTime: 1700000000000, RunDuration: 5000}) + require.NotNil(t, d) + assert.Equal(t, int64(5), *d) + + // Finished run with no RunDuration: end - start. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000012000}) + require.NotNil(t, d) + assert.Equal(t, int64(12), *d) + + // Still running: measured against the current time, so positive. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, d) + assert.Positive(t, *d) +} + +func TestLatestAttemptNumber(t *testing.T) { + assert.Equal(t, 0, latestAttemptNumber(&jobs.Run{})) + run := &jobs.Run{Tasks: []jobs.RunTask{{AttemptNumber: 0}, {AttemptNumber: 2}}} + assert.Equal(t, 2, latestAttemptNumber(run)) +} + +func TestExperimentName(t *testing.T) { + assert.Nil(t, experimentName(&jobs.Run{})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: ""}, + }}})) + got := experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}) + require.NotNil(t, got) + assert.Equal(t, "exp", *got) +} + +func TestAccelerators(t *testing.T) { + assert.Equal(t, "", accelerators(&jobs.Run{})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{}, + }}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, + }}})) + assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}}, + }}})) +} diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 0ab0b8226bf..cc486b722f8 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -1,19 +1,187 @@ package aircmd import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/spf13/cobra" ) +// getData is the payload printed by `air get`. The json-tagged fields form +// the machine-readable output; fields tagged `json:"-"` are shown only in the +// human-readable text view. +type getData struct { + RunID string `json:"run_id"` + Status string `json:"status"` + StartedAt *string `json:"started_at"` + DurationSeconds *int64 `json:"duration_seconds"` + AttemptNumber int `json:"attempt_number"` + ExperimentName *string `json:"experiment_name"` + DashboardURL string `json:"dashboard_url"` + MLflowURL *string `json:"mlflow_url"` + + // Duration is the human-readable form of DurationSeconds, e.g. "12m 3s". + Duration string `json:"-"` + // Accelerators describes the run's GPUs, e.g. "8x H100". + Accelerators string `json:"-"` + // User is the run's creator. Text-only; JSON omits it, matching `air get --json`. + User string `json:"-"` + // Sweep replaces the single-run view for foreach runs. Text-only; JSON omits it. + Sweep *sweepInfo `json:"-"` +} + +// getTemplate is the text-mode layout. It reads from the JSON envelope, so +// every field is reached through ".Data". Optional rows are hidden when empty. +const getTemplate = `{{- if .Data.Sweep -}} +Sweep Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +Total: {{.Data.Sweep.Total}} +Completed: {{.Data.Sweep.Completed}} +Succeeded: {{.Data.Sweep.Succeeded}} +Failed: {{.Data.Sweep.Failed}} +Active: {{.Data.Sweep.Active}} +{{- if .Data.Sweep.Tasks}} + +Sweep Tasks: +{{printf " %-24s %-14s %-12s %s" "TASK" "RUN ID" "STATUS" "EXPERIMENT"}} +{{- range .Data.Sweep.Tasks}} +{{printf " %-24s %-14s %-12s %s" .TaskKey .RunID .Status .Experiment}} +{{- end}} +{{- end}} +{{- else -}} +Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +{{- if .Data.StartedAt}} +Submitted: {{.Data.StartedAt}} +{{- end}} +{{- if .Data.Duration}} +Duration: {{.Data.Duration}} +{{- end}} +Retries: {{.Data.AttemptNumber}} +{{- if .Data.ExperimentName}} +Experiment: {{.Data.ExperimentName}} +{{- end}} +{{- if .Data.User}} +User: {{.Data.User}} +{{- end}} +{{- if .Data.Accelerators}} +Accelerators: {{.Data.Accelerators}} +{{- end}} +{{- if .Data.MLflowURL}} +MLflow: {{.Data.MLflowURL}} +{{- end}} +Dashboard: {{.Data.DashboardURL}} +{{- end}} +` + func newGetCommand() *cobra.Command { cmd := &cobra.Command{ Use: "get RUN_ID", Args: root.ExactArgs(1), Short: "Show details for a run", - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("get") + Annotations: map[string]string{ + "template": getTemplate, }, } + cmd.PreRunE = root.MustWorkspaceClient + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + + runID, err := strconv.ParseInt(args[0], 10, 64) + if err != nil || runID <= 0 { + return fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0]) + } + + run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}) + if err != nil { + // The backend returns this when the run ID is unknown to the user. + if errors.Is(err, apierr.ErrResourceDoesNotExist) { + return fmt.Errorf("run %d not found: check the run ID and that it is a job run ID", runID) + } + return fmt.Errorf("failed to get status for run %d: %w", runID, err) + } + + data := buildGetData(run) + data.MLflowURL = mlflowURL(ctx, w, run) + if task := findForEachTask(run); task != nil { + data.Sweep = buildSweepInfo(ctx, w, task) + } + + // Text mode shows the training-config YAML before the status, mirroring + // `air get`. JSON output omits it, matching `air get --json`. + if root.OutputType(cmd) == flags.OutputText { + if path := yamlConfigPath(run); path != "" { + printConfigYAML(ctx, w, path) + } + } + return renderEnvelope(ctx, data) + } + return cmd } + +// buildGetData extracts the fields we display from a run. +func buildGetData(run *jobs.Run) getData { + data := getData{ + RunID: strconv.FormatInt(run.RunId, 10), + Status: runStatus(run.State), + StartedAt: startedAt(run), + DurationSeconds: durationSeconds(run), + AttemptNumber: latestAttemptNumber(run), + ExperimentName: experimentName(run), + DashboardURL: run.RunPageUrl, + Accelerators: accelerators(run), + User: run.CreatorUserName, + } + if data.DurationSeconds != nil { + data.Duration = formatDuration(*data.DurationSeconds) + } + return data +} + +// yamlConfigPath returns the run's training-config YAML path, or "" if none. +func yamlConfigPath(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil { + return "" + } + return task.YamlParametersFilePath +} + +// printConfigYAML downloads the run's training-config YAML and prints it. It is +// best-effort: a failure is surfaced as a warning but does not fail status. +func printConfigYAML(ctx context.Context, w *databricks.WorkspaceClient, path string) { + r, err := w.Workspace.Download(ctx, path) + if err != nil { + log.Warnf(ctx, "air get: could not download training config %s: %v", path, err) + return + } + defer r.Close() + + content, err := io.ReadAll(r) + if err != nil { + log.Warnf(ctx, "air get: could not read training config %s: %v", path, err) + return + } + + cmdio.LogString(ctx, "Training Configuration:") + cmdio.LogString(ctx, string(content)) + cmdio.LogString(ctx, "") +} diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go new file mode 100644 index 00000000000..6dfdc54db7b --- /dev/null +++ b/experimental/air/cmd/get_test.go @@ -0,0 +1,211 @@ +package aircmd + +import ( + "bytes" + "io" + "strings" + "testing" + "text/template" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// renderGet renders the status template against the JSON envelope, exactly as +// the command does, so the test covers the real template branches. +func renderGet(t *testing.T, data getData) string { + t.Helper() + tmpl, err := template.New("status").Parse(getTemplate) + require.NoError(t, err) + var buf bytes.Buffer + require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) + return buf.String() +} + +func TestGetTemplateSingleRun(t *testing.T) { + out := renderGet(t, getData{ + RunID: "123", + Status: "RUNNING", + User: "me@example.com", + DashboardURL: "https://example.test/run/123", + }) + assert.Contains(t, out, "Run ID: 123") + assert.Contains(t, out, "Status: RUNNING") + assert.Contains(t, out, "User:") + assert.Contains(t, out, "me@example.com") + assert.Contains(t, out, "Dashboard: https://example.test/run/123") + assert.NotContains(t, out, "Sweep") +} + +func TestGetRunInvalidID(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"abc"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid RUN_ID") +} + +func TestGetRunNotFound(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( + nil, apierr.ErrResourceDoesNotExist) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"5"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "run 5 not found") +} + +func TestPrintConfigYAML(t *testing.T) { + t.Run("downloads and prints", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + // The mock asserts Download is called with the resolved path. + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/cfg.yaml"). + Return(io.NopCloser(strings.NewReader("epochs: 3\n")), nil) + + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/cfg.yaml") + }) + + t.Run("download failure is non-fatal", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/missing.yaml"). + Return(nil, apierr.ErrResourceDoesNotExist) + + // Must not panic: a failed config fetch is best-effort. + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/missing.yaml") + }) +} + +func TestYAMLConfigPath(t *testing.T) { + // No tasks, or a task without GenAiComputeTask, yields no path. + assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) + assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + + run := &jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, + }}} + assert.Equal(t, "/Workspace/cfg.yaml", yamlConfigPath(run)) +} + +func TestGetTemplateSweep(t *testing.T) { + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{ + Total: 4, Completed: 2, Succeeded: 1, Failed: 1, Active: 2, + Tasks: []sweepTask{ + {TaskKey: "iter_0", RunID: "789", Status: "SUCCESS", Experiment: "my-exp"}, + {TaskKey: "iter_1", RunID: "790", Status: "FAILED", Experiment: "my-exp"}, + }, + }, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.Contains(t, out, "Sweep Tasks:") + assert.Contains(t, out, "iter_0") + assert.Contains(t, out, "iter_1") + assert.Contains(t, out, "FAILED") + assert.Contains(t, out, "my-exp") + // The single-run rows must not appear in the sweep view. + assert.NotContains(t, out, "Dashboard:") +} + +func TestGetTemplateSweepNoTasks(t *testing.T) { + // A sweep whose iterations haven't materialized yet: counts show, but the + // task table header is hidden. + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{Total: 4, Active: 4}, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.NotContains(t, out, "Sweep Tasks:") +} + +func TestGetTemplateMinimal(t *testing.T) { + // Only the always-present rows render; optional rows are hidden when empty. + out := renderGet(t, getData{RunID: "1", Status: "PENDING", DashboardURL: "https://example.test/1"}) + assert.Contains(t, out, "Run ID: 1") + assert.Contains(t, out, "Status: PENDING") + assert.Contains(t, out, "Retries: 0") + assert.Contains(t, out, "Dashboard: https://example.test/1") + for _, hidden := range []string{"Submitted:", "Duration:", "Experiment:", "User:", "Accelerators:", "MLflow:"} { + assert.NotContains(t, out, hidden) + } +} + +func TestGetTemplateAllFields(t *testing.T) { + started := "2023-11-14T22:13:20Z" + exp := "exp" + mlflow := "https://example.test/ml/exp/1" + out := renderGet(t, getData{ + RunID: "1", + Status: "SUCCESS", + StartedAt: &started, + Duration: "12s", + AttemptNumber: 2, + ExperimentName: &exp, + User: "me@example.com", + Accelerators: "8x H100", + MLflowURL: &mlflow, + DashboardURL: "https://example.test/1", + }) + for _, want := range []string{ + "Submitted: 2023-11-14T22:13:20Z", + "Duration: 12s", + "Retries: 2", + "Experiment: exp", + "User: me@example.com", + "Accelerators: 8x H100", + "MLflow: https://example.test/ml/exp/1", + "Dashboard: https://example.test/1", + } { + assert.Contains(t, out, want) + } +} + +func TestBuildStatusData(t *testing.T) { + run := &jobs.Run{ + RunId: 123, + RunPageUrl: "https://example.test/run/123", + CreatorUserName: "me@example.com", + StartTime: 1700000000000, + EndTime: 1700000012000, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + Tasks: []jobs.RunTask{{ + AttemptNumber: 1, + GenAiComputeTask: &jobs.GenAiComputeTask{ + MlflowExperimentName: "/Users/me@example.com/exp", + Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}, + }, + }}, + } + d := buildGetData(run) + assert.Equal(t, "123", d.RunID) + assert.Equal(t, "SUCCESS", d.Status) + assert.Equal(t, 1, d.AttemptNumber) + assert.Equal(t, "https://example.test/run/123", d.DashboardURL) + assert.Equal(t, "me@example.com", d.User) + assert.Equal(t, "8x H100", d.Accelerators) + assert.Equal(t, "12s", d.Duration) + require.NotNil(t, d.ExperimentName) + assert.Equal(t, "exp", *d.ExperimentName) + require.NotNil(t, d.DurationSeconds) + assert.Equal(t, int64(12), *d.DurationSeconds) +} diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go new file mode 100644 index 00000000000..97d085b0128 --- /dev/null +++ b/experimental/air/cmd/mlflow.go @@ -0,0 +1,65 @@ +package aircmd + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// getRunOutputResponse is the slice of the jobs runs/get-output response we care +// about. The MLflow identifiers live under a gen_ai_compute_output field that +// the typed SDK does not model, so we call the endpoint directly and parse just +// these fields. +type getRunOutputResponse struct { + GenAiComputeOutput *struct { + RunInfo *struct { + MlflowExperimentID string `json:"mlflow_experiment_id"` + MlflowRunID string `json:"mlflow_run_id"` + } `json:"run_info"` + } `json:"gen_ai_compute_output"` +} + +// mlflowURL returns a link to the run's MLflow logs, or nil if it can't be +// built. The link is a convenience, so any failure here (missing task, endpoint +// error, run not yet started) is logged and treated as "no link" rather than +// failing the whole command. +func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + // The MLflow output is attached to the task run, not the parent job run. + taskRunID := run.Tasks[0].RunId + + apiClient, err := client.New(w.Config) + if err != nil { + log.Debugf(ctx, "air get: could not build API client for MLflow link: %v", err) + return nil + } + + var out getRunOutputResponse + err = apiClient.Do(ctx, http.MethodGet, "/api/2.2/jobs/runs/get-output", + nil, map[string]any{"run_id": taskRunID}, nil, &out) + if err != nil { + log.Debugf(ctx, "air get: could not fetch run output for MLflow link: %v", err) + return nil + } + + if out.GenAiComputeOutput == nil || out.GenAiComputeOutput.RunInfo == nil { + return nil + } + info := out.GenAiComputeOutput.RunInfo + if info.MlflowExperimentID == "" || info.MlflowRunID == "" { + return nil + } + + host := strings.TrimRight(w.Config.Host, "/") + url := fmt.Sprintf("%s/ml/experiments/%s/runs/%s/artifacts/logs/node_0", + host, info.MlflowExperimentID, info.MlflowRunID) + return &url +} diff --git a/experimental/air/cmd/mlflow_test.go b/experimental/air/cmd/mlflow_test.go new file mode 100644 index 00000000000..bbc4fef9822 --- /dev/null +++ b/experimental/air/cmd/mlflow_test.go @@ -0,0 +1,64 @@ +package aircmd + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestWorkspaceClient builds a WorkspaceClient pointed at a mock HTTP server. +// mlflowURL calls the runs/get-output REST endpoint directly (the field it needs +// is not modeled by the typed SDK), so it must be exercised over HTTP. +func newTestWorkspaceClient(t *testing.T, host string) *databricks.WorkspaceClient { + t.Helper() + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: host, Token: "token"}) + require.NoError(t, err) + return w +} + +// runOutputServer serves the given runs/get-output body and a stub for the SDK's +// well-known config discovery request. *hit is set when get-output is called. +func runOutputServer(t *testing.T, body string, hit *bool) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/2.2/jobs/runs/get-output" { + *hit = true + _, _ = w.Write([]byte(body)) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestMLflowURL(t *testing.T) { + ctx := t.Context() + run := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}}} + + t.Run("builds the deep-link on success", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{"gen_ai_compute_output":{"run_info":{"mlflow_experiment_id":"E1","mlflow_run_id":"R1"}}}`, &hit) + + got := mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run) + require.NotNil(t, got) + assert.True(t, hit, "runs/get-output should have been called") + assert.Equal(t, srv.URL+"/ml/experiments/E1/runs/R1/artifacts/logs/node_0", *got) + }) + + t.Run("nil when the run has no MLflow info", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{}`, &hit) + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run)) + }) + + t.Run("nil when the run has no tasks", func(t *testing.T) { + // Returns before any HTTP call, so the host is never contacted. + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) + }) +} diff --git a/experimental/air/cmd/output.go b/experimental/air/cmd/output.go new file mode 100644 index 00000000000..3da766a7d4f --- /dev/null +++ b/experimental/air/cmd/output.go @@ -0,0 +1,39 @@ +package aircmd + +import ( + "context" + "time" + + "github.com/databricks/cli/libs/cmdio" +) + +// envelopeVersion is the envelope's format-version marker. The Python `air` CLI +// hardcodes it to 1; it lets consumers detect a future incompatible change to +// the envelope shape. +const envelopeVersion = 1 + +// envelope is the JSON shape that the AI runtime CLI prints: +// +// { "v": 1, "ts": "2024-01-15T14:30:45Z", "data": { ... } } +// +// It mirrors the envelope used by the original Python `air` CLI so existing +// consumers keep working after the port to Go. +type envelope struct { + // V is the envelope format-version marker (always 1). + V int `json:"v"` + // TS is the wall-clock time the response was produced, in RFC 3339 UTC. + // It is an absolute timestamp, not an elapsed duration. + TS string `json:"ts"` + // Data is the command-specific payload. + Data any `json:"data"` +} + +// renderEnvelope wraps data in the JSON envelope and prints it. +// Fields that should appear only in text output are tagged `json:"-"` on the payload struct. +func renderEnvelope(ctx context.Context, data any) error { + return cmdio.Render(ctx, envelope{ + V: envelopeVersion, + TS: time.Now().UTC().Format(time.RFC3339), + Data: data, + }) +} diff --git a/experimental/air/cmd/output_test.go b/experimental/air/cmd/output_test.go new file mode 100644 index 00000000000..73a5572c3f5 --- /dev/null +++ b/experimental/air/cmd/output_test.go @@ -0,0 +1,13 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/require" +) + +func TestRenderEnvelope(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + require.NoError(t, renderEnvelope(ctx, getData{RunID: "1", Status: "RUNNING"})) +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index a6e24177f33..5e35bcdcd14 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -14,7 +14,6 @@ import ( func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ "run": newRunCommand(), - "get": newGetCommand(), "list": newListCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), diff --git a/experimental/air/cmd/sweep.go b/experimental/air/cmd/sweep.go new file mode 100644 index 00000000000..b346f43f1b6 --- /dev/null +++ b/experimental/air/cmd/sweep.go @@ -0,0 +1,76 @@ +package aircmd + +import ( + "context" + "strconv" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// sweepInfo summarizes a "foreach" run, which fans a single config out into many +// iterations (a hyperparameter sweep). It is shown only in text output. +type sweepInfo struct { + Total int + Succeeded int + Failed int + Active int + Completed int + Tasks []sweepTask +} + +// sweepTask is one iteration of a sweep. +type sweepTask struct { + TaskKey string + RunID string + Status string + Experiment string +} + +// findForEachTask returns the run's foreach task if it has one, or nil. A run is +// a sweep when one of its tasks fans out into iterations. +func findForEachTask(run *jobs.Run) *jobs.RunTask { + for i := range run.Tasks { + if run.Tasks[i].ForEachTask != nil { + return &run.Tasks[i] + } + } + return nil +} + +// buildSweepInfo gathers the iteration counts and per-iteration rows for a +// sweep. The counts come from the task we already have; the individual +// iterations require a second lookup. If that lookup fails we still return the +// counts (logging the failure) so the user sees the summary. +func buildSweepInfo(ctx context.Context, w *databricks.WorkspaceClient, task *jobs.RunTask) *sweepInfo { + info := &sweepInfo{} + if task.ForEachTask.Stats != nil && task.ForEachTask.Stats.TaskRunStats != nil { + stats := task.ForEachTask.Stats.TaskRunStats + info.Total = stats.TotalIterations + info.Succeeded = stats.SucceededIterations + info.Failed = stats.FailedIterations + info.Active = stats.ActiveIterations + info.Completed = stats.CompletedIterations + } + + // The iterations are returned as part of a run lookup on the foreach task. + iterated, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: task.RunId}) + if err != nil { + log.Debugf(ctx, "air get: could not fetch sweep iterations: %v", err) + return info + } + + for _, it := range iterated.Iterations { + row := sweepTask{ + TaskKey: it.TaskKey, + RunID: strconv.FormatInt(it.RunId, 10), + Status: runStatus(it.State), + } + if it.GenAiComputeTask != nil && it.GenAiComputeTask.MlflowExperimentName != "" { + row.Experiment = stripExperimentUserPrefix(it.GenAiComputeTask.MlflowExperimentName) + } + info.Tasks = append(info.Tasks, row) + } + return info +} diff --git a/experimental/air/cmd/sweep_test.go b/experimental/air/cmd/sweep_test.go new file mode 100644 index 00000000000..10134c0df42 --- /dev/null +++ b/experimental/air/cmd/sweep_test.go @@ -0,0 +1,81 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestFindForEachTask(t *testing.T) { + // No tasks at all. + assert.Nil(t, findForEachTask(&jobs.Run{})) + + // A task that is not a foreach. + assert.Nil(t, findForEachTask(&jobs.Run{Tasks: []jobs.RunTask{{TaskKey: "a"}}})) + + // The foreach task is found even when it isn't first. + run := &jobs.Run{Tasks: []jobs.RunTask{ + {TaskKey: "a"}, + {TaskKey: "sweep", ForEachTask: &jobs.RunForEachTask{}}, + }} + got := findForEachTask(run) + require.NotNil(t, got) + assert.Equal(t, "sweep", got.TaskKey) +} + +func sweepTaskFixture() *jobs.RunTask { + return &jobs.RunTask{ + RunId: 99, + ForEachTask: &jobs.RunForEachTask{ + Stats: &jobs.ForEachStats{TaskRunStats: &jobs.ForEachTaskTaskRunStats{ + TotalIterations: 4, + SucceededIterations: 1, + FailedIterations: 1, + ActiveIterations: 2, + CompletedIterations: 2, + }}, + }, + } +} + +func TestBuildSweepInfo(t *testing.T) { + ctx := t.Context() + + t.Run("counts and iteration rows", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + &jobs.Run{Iterations: []jobs.RunTask{{ + TaskKey: "iter_0", + RunId: 100, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}, nil) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Equal(t, 2, info.Completed) + assert.Equal(t, 1, info.Succeeded) + assert.Equal(t, 1, info.Failed) + assert.Equal(t, 2, info.Active) + require.Len(t, info.Tasks, 1) + assert.Equal(t, "iter_0", info.Tasks[0].TaskKey) + assert.Equal(t, "100", info.Tasks[0].RunID) + assert.Equal(t, "SUCCESS", info.Tasks[0].Status) + assert.Equal(t, "exp", info.Tasks[0].Experiment) + }) + + t.Run("iteration lookup failure still returns counts", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + nil, apierr.ErrResourceDoesNotExist) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Empty(t, info.Tasks) + }) +} From 89042d08bf1ebfb6352abd1078a597e880ec21bb Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:27:54 +0000 Subject: [PATCH 4/8] experimental/air: rename stale TestBuildStatusData to TestBuildGetData Co-authored-by: Isaac --- experimental/air/cmd/get_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 6dfdc54db7b..64fd1aa1f68 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -17,11 +17,11 @@ import ( "github.com/stretchr/testify/require" ) -// renderGet renders the status template against the JSON envelope, exactly as +// renderGet renders the get template against the JSON envelope, exactly as // the command does, so the test covers the real template branches. func renderGet(t *testing.T, data getData) string { t.Helper() - tmpl, err := template.New("status").Parse(getTemplate) + tmpl, err := template.New("get").Parse(getTemplate) require.NoError(t, err) var buf bytes.Buffer require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) @@ -180,7 +180,7 @@ func TestGetTemplateAllFields(t *testing.T) { } } -func TestBuildStatusData(t *testing.T) { +func TestBuildGetData(t *testing.T) { run := &jobs.Run{ RunId: 123, RunPageUrl: "https://example.test/run/123", From c99239ca432a5001ff6ae5f58439399d262916eb Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:30:07 +0000 Subject: [PATCH 5/8] experimental/air: apply testifylint fixes in get/format tests Co-authored-by: Isaac --- experimental/air/cmd/format_test.go | 10 +++++----- experimental/air/cmd/get_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go index c3e2e865b81..583f3cc3111 100644 --- a/experimental/air/cmd/format_test.go +++ b/experimental/air/cmd/format_test.go @@ -50,7 +50,7 @@ func TestGpuDisplayName(t *testing.T) { assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) // Unknown identifiers pass through unchanged. assert.Equal(t, "b200", gpuDisplayName("b200")) - assert.Equal(t, "", gpuDisplayName("")) + assert.Empty(t, gpuDisplayName("")) } func TestRunStatusPrefersResultState(t *testing.T) { @@ -117,12 +117,12 @@ func TestExperimentName(t *testing.T) { } func TestAccelerators(t *testing.T) { - assert.Equal(t, "", accelerators(&jobs.Run{})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + assert.Empty(t, accelerators(&jobs.Run{})) + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{}, }}})) - assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + assert.Empty(t, accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, }}})) assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go index 64fd1aa1f68..b6d6d0baab9 100644 --- a/experimental/air/cmd/get_test.go +++ b/experimental/air/cmd/get_test.go @@ -93,8 +93,8 @@ func TestPrintConfigYAML(t *testing.T) { func TestYAMLConfigPath(t *testing.T) { // No tasks, or a task without GenAiComputeTask, yields no path. - assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) - assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Empty(t, yamlConfigPath(&jobs.Run{})) + assert.Empty(t, yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) run := &jobs.Run{Tasks: []jobs.RunTask{{ GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, From fa3a1a29d33c48695d6591434443fbcb5fe6555a Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Sun, 14 Jun 2026 23:32:03 +0000 Subject: [PATCH 6/8] experimental/air: add GPU accelerator type and compute config model Add compute.go: the gpuType model and compute-block validation the upcoming `air run` config layer depends on. Defines the canonical GPU_* accelerator types, parseGPUType (exact, case-sensitive), gpusPerNode (partition counts), and computeConfig.validate (positive count, multiple-of-per-node, mutually exclusive node_pool_id/pool_name). Co-authored-by: Isaac --- experimental/air/cmd/compute.go | 88 +++++++++++++++++++++++++++ experimental/air/cmd/compute_test.go | 89 ++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 experimental/air/cmd/compute.go create mode 100644 experimental/air/cmd/compute_test.go diff --git a/experimental/air/cmd/compute.go b/experimental/air/cmd/compute.go new file mode 100644 index 00000000000..02e1916520d --- /dev/null +++ b/experimental/air/cmd/compute.go @@ -0,0 +1,88 @@ +package aircmd + +import ( + "errors" + "fmt" + "strings" +) + +// gpuType is a wire-facing accelerator type submitted to the training service. +// The number in the name is the partition count (e.g. GPU_8xH100 is 8 GPUs). +type gpuType string + +const ( + gpuType1xA10 gpuType = "GPU_1xA10" + gpuType8xH100 gpuType = "GPU_8xH100" + gpuType1xH100 gpuType = "GPU_1xH100" +) + +// gpuTypes lists every valid type. Used for validation error messages. +var gpuTypes = []gpuType{gpuType1xA10, gpuType1xH100, gpuType8xH100} + +func validGPUTypesHint() string { + names := make([]string, len(gpuTypes)) + for i, g := range gpuTypes { + names[i] = string(g) + } + return "valid types are: " + strings.Join(names, ", ") +} + +// parseGPUType resolves a YAML accelerator_type string to a gpuType. The match is +// exact: the server's lookup is case-sensitive. +func parseGPUType(value string) (gpuType, error) { + switch gpuType(value) { + case gpuType1xA10, gpuType8xH100, gpuType1xH100: + return gpuType(value), nil + } + return "", fmt.Errorf("invalid GPU type %q: %s", value, validGPUTypesHint()) +} + +// gpusPerNode returns the per-node GPU count, which is the partition count from +// the name (GPU_1xH100 -> 1, GPU_8xH100 -> 8). num_accelerators must be a +// round multiple of this since accelerators are allocated in whole nodes. +func gpusPerNode(g gpuType) (int, error) { + switch g { + case gpuType1xA10, gpuType1xH100: + return 1, nil + case gpuType8xH100: + return 8, nil + } + // Unreachable: callers resolve g through parseGPUType first, which rejects + // unknown types. Kept as a defensive guard. + return 0, fmt.Errorf("invalid GPU type %q", string(g)) +} + +// computeConfig is the `compute` block of the run YAML: which accelerators to +// use and how many. +type computeConfig struct { + NumAccelerators int `yaml:"num_accelerators"` + AcceleratorType string `yaml:"accelerator_type"` + NodePoolID string `yaml:"node_pool_id"` + PoolName string `yaml:"pool_name"` +} + +// validate checks the compute block against the backend's constraints. +func (c computeConfig) validate() error { + g, err := parseGPUType(c.AcceleratorType) + if err != nil { + return fmt.Errorf("compute.accelerator_type: %w", err) + } + + if c.NumAccelerators <= 0 { + return fmt.Errorf("compute.num_accelerators must be positive, got %d", c.NumAccelerators) + } + + perNode, err := gpusPerNode(g) + if err != nil { + return err + } + if c.NumAccelerators%perNode != 0 { + return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators) + } + + if c.NodePoolID != "" && c.PoolName != "" { + return errors.New("compute: cannot specify both node_pool_id and pool_name") + } + + return nil +} diff --git a/experimental/air/cmd/compute_test.go b/experimental/air/cmd/compute_test.go new file mode 100644 index 00000000000..ad91d470861 --- /dev/null +++ b/experimental/air/cmd/compute_test.go @@ -0,0 +1,89 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGPUType(t *testing.T) { + tests := []struct { + in string + want gpuType + }{ + {"GPU_1xA10", gpuType1xA10}, + {"GPU_8xH100", gpuType8xH100}, + {"GPU_1xH100", gpuType1xH100}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + got, err := parseGPUType(tt.in) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseGPUTypeInvalid(t *testing.T) { + // Wrong casing is rejected rather than fixed up; legacy types (h100_80gb, a10) + // can no longer be submitted; unknown types are rejected. + for _, in := range []string{"gpu_1xa10", "GPU_1XA10", "GPU_2xH100", "h100_80gb", "a10", "b200", ""} { + t.Run(in, func(t *testing.T) { + _, err := parseGPUType(in) + require.Error(t, err) + assert.Contains(t, err.Error(), "valid types are") + }) + } +} + +func TestGPUsPerNode(t *testing.T) { + tests := []struct { + in gpuType + want int + }{ + {gpuType1xA10, 1}, + {gpuType1xH100, 1}, + {gpuType8xH100, 8}, + } + for _, tt := range tests { + t.Run(string(tt.in), func(t *testing.T) { + got, err := gpusPerNode(tt.in) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } + + _, err := gpusPerNode(gpuType("nonsense")) + require.Error(t, err) +} + +func TestComputeConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg computeConfig + wantErr string // substring; empty means the config is valid + }{ + {"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""}, + {"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""}, + {"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""}, + {"with node pool", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "pool-123"}, ""}, + {"with pool name", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", PoolName: "my-pool"}, ""}, + {"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"}, + {"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"}, + {"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"}, + {"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"}, + {"both pool fields", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "p", PoolName: "n"}, "both"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.validate() + if tt.wantErr == "" { + require.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} From 62be1a1cd17a5fc94139d0a13abdc66d56c493ba Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Tue, 16 Jun 2026 18:41:30 +0000 Subject: [PATCH 7/8] experimental/air: drop node pool / pool name compute fields The training compute config no longer supports pool placement, so remove the node_pool_id and pool_name fields and the validation that rejected setting both. Co-authored-by: Isaac --- experimental/air/cmd/compute.go | 7 ------- experimental/air/cmd/compute_test.go | 3 --- 2 files changed, 10 deletions(-) diff --git a/experimental/air/cmd/compute.go b/experimental/air/cmd/compute.go index 02e1916520d..07013c53906 100644 --- a/experimental/air/cmd/compute.go +++ b/experimental/air/cmd/compute.go @@ -1,7 +1,6 @@ package aircmd import ( - "errors" "fmt" "strings" ) @@ -57,8 +56,6 @@ func gpusPerNode(g gpuType) (int, error) { type computeConfig struct { NumAccelerators int `yaml:"num_accelerators"` AcceleratorType string `yaml:"accelerator_type"` - NodePoolID string `yaml:"node_pool_id"` - PoolName string `yaml:"pool_name"` } // validate checks the compute block against the backend's constraints. @@ -80,9 +77,5 @@ func (c computeConfig) validate() error { return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators) } - if c.NodePoolID != "" && c.PoolName != "" { - return errors.New("compute: cannot specify both node_pool_id and pool_name") - } - return nil } diff --git a/experimental/air/cmd/compute_test.go b/experimental/air/cmd/compute_test.go index ad91d470861..3464afbe9ea 100644 --- a/experimental/air/cmd/compute_test.go +++ b/experimental/air/cmd/compute_test.go @@ -67,13 +67,10 @@ func TestComputeConfigValidate(t *testing.T) { {"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""}, {"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""}, {"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""}, - {"with node pool", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "pool-123"}, ""}, - {"with pool name", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", PoolName: "my-pool"}, ""}, {"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"}, {"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"}, {"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"}, {"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"}, - {"both pool fields", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "p", PoolName: "n"}, "both"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 39db9d84da7b64bb4b3d2a21d4a9dc8742a667af Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 18 Jun 2026 17:36:11 +0000 Subject: [PATCH 8/8] experimental/air: add run config schema and structural validation Port the run YAML schema and its structural validation from the Python CLI's sdk/config.py: the top-level runConfig plus the environment, docker_image, code_source/snapshot/git, and permission blocks. loadRunConfig decodes a YAML file with KnownFields (mirroring pydantic extra="forbid") and runs the validation pass. "Structural" covers types, required fields, and format/cross-field rules that need no workspace access. Online checks (compute pool resolution, GPU availability), git/filesystem checks, _bases_ composition, and CLI --override handling are deferred to later milestones. Two deliberate divergences from the Python schema, both following from the training-service-only port: the compute pool fields were already dropped, and the top-level priority field is dropped here since it is a node-pool queue-ordering knob with no meaning for serverless workloads. Co-authored-by: Isaac --- experimental/air/cmd/runconfig.go | 446 +++++++++++++++++++++++++ experimental/air/cmd/runconfig_load.go | 40 +++ experimental/air/cmd/runconfig_test.go | 407 ++++++++++++++++++++++ 3 files changed, 893 insertions(+) create mode 100644 experimental/air/cmd/runconfig.go create mode 100644 experimental/air/cmd/runconfig_load.go create mode 100644 experimental/air/cmd/runconfig_test.go diff --git a/experimental/air/cmd/runconfig.go b/experimental/air/cmd/runconfig.go new file mode 100644 index 00000000000..c0678c17cb5 --- /dev/null +++ b/experimental/air/cmd/runconfig.go @@ -0,0 +1,446 @@ +package aircmd + +import ( + "errors" + "fmt" + "regexp" + "slices" + "strings" + + "go.yaml.in/yaml/v3" +) + +// This file ports the run YAML schema and its structural validation from the +// Python CLI's sdk/config.py. "Structural" means types, required fields, and +// format/cross-field rules that need no workspace access. Online checks (compute +// pool resolution, GPU availability) and git/filesystem checks run at launch time +// and are intentionally not ported here. +// +// Divergences from the Python schema, both consequences of the training-service- +// only port: +// - compute.node_pool_id / compute.pool_name are dropped (see compute.go). +// - the top-level `priority` field is dropped: it is a node-pool queue-ordering +// knob with no meaning for serverless (poolless) workloads. + +// REGEX_TASK_KEY_CHARS: ASCII alphanumeric, hyphen, underscore only (no periods). +// Explicit ASCII class, not \w: \w matches Unicode letters that the ASCII-only +// Jobs API task_key rejects. +var taskKeyRe = regexp.MustCompile(`^[A-Za-z0-9_-]+$`) + +// gitRefRe guards branch/remote names against command injection. Only safe ref +// characters are allowed. +var gitRefRe = regexp.MustCompile(`^[\w./-]+$`) + +// runConfig is the top-level run YAML schema: experiment_name + compute / +// environment / code_source plus the command and run options. +type runConfig struct { + ExperimentName string `yaml:"experiment_name"` + Compute *computeConfig `yaml:"compute"` + Environment *environmentConfig `yaml:"environment"` + Command *string `yaml:"command"` + EnvVariables map[string]string `yaml:"env_variables"` + Secrets map[string]string `yaml:"secrets"` + CodeSource *codeSourceConfig `yaml:"code_source"` + // MaxRetries defaults to 3 when unset; default-filling is a normalization + // concern handled at launch, so a nil pointer is left as-is here. + MaxRetries *int `yaml:"max_retries"` + TimeoutMinutes *int `yaml:"timeout_minutes"` + IdempotencyToken *string `yaml:"idempotency_token"` + Parameters map[string]any `yaml:"parameters"` + MLflowRunName *string `yaml:"mlflow_run_name"` + MLflowExperimentDirectory *string `yaml:"mlflow_experiment_directory"` + Permissions []permission `yaml:"permissions"` + UsagePolicyName *string `yaml:"usage_policy_name"` +} + +// validate runs structural validation over the whole config, returning the first +// failure. Fields are checked in declaration order to keep error output stable. +func (c *runConfig) validate() error { + if err := validateExperimentName(c.ExperimentName); err != nil { + return err + } + + if c.Compute == nil { + return errors.New("compute: section is required") + } + if err := c.Compute.validate(); err != nil { + return err + } + + if c.Environment != nil { + if err := c.Environment.validate(); err != nil { + return err + } + } + + // command is optional in the type system but required in practice, matching + // the Python validate_script_fields model validator. + if c.Command == nil { + return errors.New("command is required") + } + if err := validateCommand(*c.Command); err != nil { + return err + } + + if err := validateSecretRefs(c.Secrets); err != nil { + return err + } + + if c.CodeSource != nil { + if err := c.CodeSource.validate(); err != nil { + return err + } + } + + if c.MaxRetries != nil && *c.MaxRetries < 0 { + return fmt.Errorf("max_retries must be >= 0, got %d", *c.MaxRetries) + } + + if c.TimeoutMinutes != nil && *c.TimeoutMinutes < 1 { + return fmt.Errorf("timeout_minutes must be >= 1, got %d", *c.TimeoutMinutes) + } + + if c.IdempotencyToken != nil { + v := strings.TrimSpace(*c.IdempotencyToken) + if v == "" { + return errors.New("idempotency_token cannot be empty") + } + if len(v) > 64 { + return errors.New("idempotency_token must be 64 characters or less") + } + } + + if c.MLflowRunName != nil { + v := strings.TrimSpace(*c.MLflowRunName) + if v == "" { + return errors.New("mlflow_run_name cannot be empty") + } + if !taskKeyRe.MatchString(v) { + return fmt.Errorf("invalid mlflow_run_name %q: only alphanumeric characters, hyphens, and underscores are allowed", v) + } + } + + if c.MLflowExperimentDirectory != nil { + v := strings.TrimSpace(*c.MLflowExperimentDirectory) + if v == "" { + return errors.New("mlflow_experiment_directory cannot be empty") + } + // MLflow experiments live under the workspace tree. + if !strings.HasPrefix(v, "/Workspace") { + return fmt.Errorf("mlflow_experiment_directory must start with '/Workspace', got: %s", v) + } + } + + for i := range c.Permissions { + if err := c.Permissions[i].validate(); err != nil { + return err + } + } + + if c.UsagePolicyName != nil { + v := strings.TrimSpace(*c.UsagePolicyName) + if v == "" { + return errors.New("usage_policy_name must not be empty") + } + // 127 matches the server-side max_length on the policy name filter. + if len(v) > 127 { + return fmt.Errorf("usage_policy_name must be at most 127 characters, got %d", len(v)) + } + } + + return nil +} + +// validateExperimentName enforces the Databricks Jobs API task_key constraints: +// the experiment_name becomes a task key, which caps at 100 characters and allows +// only alphanumerics, hyphens, and underscores. +func validateExperimentName(v string) error { + if v == "" { + return errors.New("experiment_name cannot be empty") + } + if len(v) > 100 { + return fmt.Errorf("experiment_name must be 100 characters or less (got %d); this is the Jobs API task_key length limit", len(v)) + } + if !taskKeyRe.MatchString(v) { + return fmt.Errorf("invalid experiment_name %q: only alphanumeric characters, hyphens (-), and underscores (_) are allowed", v) + } + return nil +} + +// validateCommand enforces command is non-empty and within the line-count cap. +func validateCommand(v string) error { + if strings.TrimSpace(v) == "" { + return errors.New("command cannot be empty") + } + lineCount := strings.Count(v, "\n") + 1 + if lineCount > 1000 { + return fmt.Errorf("command is too long (%d lines); maximum is 1000 lines — move complex logic into a script in your code_source", lineCount) + } + return nil +} + +// validateSecretRefs checks that secret references use the "scope/key" format. +func validateSecretRefs(secrets map[string]string) error { + for varName, ref := range secrets { + parts := strings.Split(ref, "/") + if len(parts) != 2 { + return fmt.Errorf("invalid secret reference %q for variable %q: expected format 'scope/key' (e.g., my_scope/hf_token)", ref, varName) + } + if parts[0] == "" || parts[1] == "" { + return fmt.Errorf("invalid secret reference %q for variable %q: scope and key cannot be empty", ref, varName) + } + } + return nil +} + +// environmentConfig is the `environment` block: dependencies and/or a custom +// docker image. +type environmentConfig struct { + Dependencies dependencies `yaml:"dependencies"` + Version stringOrInt `yaml:"version"` + DockerImage *dockerImageConfig `yaml:"docker_image"` +} + +func (e *environmentConfig) validate() error { + // docker_image is exclusive with dependencies/version: the image already pins + // the full runtime. + if e.DockerImage != nil { + var conflicting []string + if e.Dependencies.set { + conflicting = append(conflicting, "dependencies") + } + if e.Version.set { + conflicting = append(conflicting, "version") + } + if len(conflicting) > 0 { + return fmt.Errorf("when 'docker_image' is specified under 'environment', these fields are not allowed: %s", strings.Join(conflicting, ", ")) + } + return e.DockerImage.validate() + } + + // version pins the client image version, which is only meaningful for an + // inline (list) dependency set — a requirements.yaml file carries its own. + if e.Version.set { + if e.Dependencies.set && !e.Dependencies.isList { + return errors.New("'environment.version' is only valid with inline dependencies (a list); when 'dependencies' points to a requirements.yaml file, set the version inside that file") + } + if !e.Dependencies.set { + return errors.New("'environment.version' requires inline 'dependencies' (a list of packages)") + } + } + + return nil +} + +// dependencies is environment.dependencies, which is polymorphic: a string is a +// path to a requirements.yaml file; a list is an inline package list. +type dependencies struct { + set bool + isList bool + path string + list []string +} + +func (d *dependencies) UnmarshalYAML(node *yaml.Node) error { + switch node.Kind { + case yaml.ScalarNode: + d.set, d.isList = true, false + return node.Decode(&d.path) + case yaml.SequenceNode: + d.set, d.isList = true, true + return node.Decode(&d.list) + default: + return errors.New("environment.dependencies must be a string path or a list of packages") + } +} + +// stringOrInt holds a scalar that may be a string or an integer in YAML +// (environment.version). The raw text is kept; integer-format validation is a +// launch-time concern. +type stringOrInt struct { + set bool + raw string +} + +func (s *stringOrInt) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.ScalarNode { + return errors.New("environment.version must be a string or integer") + } + s.set = true + s.raw = node.Value + return nil +} + +// dockerImageConfig is environment.docker_image. +type dockerImageConfig struct { + URL string `yaml:"url"` +} + +func (d *dockerImageConfig) validate() error { + if strings.TrimSpace(d.URL) == "" { + return errors.New("docker_image.url cannot be empty") + } + return nil +} + +// codeSourceConfig is the `code_source` block. Only the "snapshot" type exists. +type codeSourceConfig struct { + Type string `yaml:"type"` + Snapshot *snapshotSourceConfig `yaml:"snapshot"` +} + +func (c *codeSourceConfig) validate() error { + if c.Type != "snapshot" { + return fmt.Errorf("code_source.type must be 'snapshot', got %q", c.Type) + } + if c.Snapshot == nil { + return errors.New("code_source.type='snapshot' requires a snapshot configuration") + } + return c.Snapshot.validate() +} + +// snapshotSourceConfig describes a local directory to tar and upload. +type snapshotSourceConfig struct { + RootPath string `yaml:"root_path"` + RemoteVolume *string `yaml:"remote_volume"` + Git *gitRef `yaml:"git"` + IncludePaths []string `yaml:"include_paths"` +} + +func (s *snapshotSourceConfig) validate() error { + if strings.TrimSpace(s.RootPath) == "" { + return errors.New("code_source.snapshot.root_path cannot be empty") + } + + if s.RemoteVolume != nil && !strings.HasPrefix(*s.RemoteVolume, "/Volumes/") { + return errors.New("code_source.snapshot.remote_volume must start with '/Volumes/'") + } + + // A non-nil but empty include_paths is an explicit mistake (omit it instead). + if s.IncludePaths != nil && len(s.IncludePaths) == 0 { + return errors.New("code_source.snapshot.include_paths cannot be an empty list; either omit it or provide paths") + } + for _, p := range s.IncludePaths { + p = strings.TrimSpace(p) + if p == "" { + return errors.New("code_source.snapshot.include_paths entry cannot be empty") + } + if strings.HasPrefix(p, "/") { + return fmt.Errorf("code_source.snapshot.include_paths must be relative paths, got: %s", p) + } + // No parent traversal: snapshots must stay within root_path. + if slices.Contains(strings.Split(p, "/"), "..") { + return fmt.Errorf("code_source.snapshot.include_paths cannot contain '..' traversal, got: %s", p) + } + } + + if s.Git != nil { + return s.Git.validate() + } + return nil +} + +// gitRef pins a snapshot to a specific git ref. branch and commit are mutually +// exclusive; remote is only meaningful with branch. +type gitRef struct { + Branch *string `yaml:"branch"` + Commit *string `yaml:"commit"` + Remote gitRemote `yaml:"remote"` +} + +func (g *gitRef) validate() error { + if g.Branch != nil && !gitRefRe.MatchString(*g.Branch) { + return fmt.Errorf("invalid git.branch format %q: only alphanumeric characters, hyphens, dots, slashes, and underscores are allowed", *g.Branch) + } + if g.Remote.isString { + if g.Remote.name == "" { + return errors.New("git.remote string cannot be empty; use 'true' to auto-detect") + } + if !gitRefRe.MatchString(g.Remote.name) { + return fmt.Errorf("invalid git.remote name %q: only alphanumeric characters, hyphens, dots, slashes, and underscores are allowed", g.Remote.name) + } + } + + if g.Branch == nil && g.Commit == nil { + return errors.New("git: must specify either 'branch' or 'commit'") + } + if g.Branch != nil && g.Commit != nil { + return errors.New("git: 'branch' and 'commit' are mutually exclusive — specify only one") + } + if g.Remote.truthy() && g.Branch == nil { + return errors.New("git.remote requires git.branch (only valid with branch refs)") + } + return nil +} + +// gitRemote is git.remote: false (default, use local HEAD), true (auto-detect the +// remote), or a remote name string. +type gitRemote struct { + set bool + isString bool + name string + enabled bool +} + +func (r *gitRemote) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.ScalarNode { + return errors.New("git.remote must be a boolean or a remote name string") + } + r.set = true + if node.Tag == "!!bool" { + return node.Decode(&r.enabled) + } + r.isString = true + r.name = node.Value + return nil +} + +// truthy reports whether remote requests a remote fetch (mirrors Python's +// truthiness of the bool|str union). +func (r *gitRemote) truthy() bool { + if r.isString { + return r.name != "" + } + return r.enabled +} + +// permission is a DABs-compatible permission grant: exactly one principal plus a +// level. +type permission struct { + UserName *string `yaml:"user_name"` + GroupName *string `yaml:"group_name"` + ServicePrincipalName *string `yaml:"service_principal_name"` + // Level is a databricks PermissionLevel (e.g. CAN_VIEW, CAN_MANAGE). Enum + // membership is validated server-side; here we only require it to be set. + Level string `yaml:"level"` +} + +func (p *permission) validate() error { + principals := map[string]*string{ + "user_name": p.UserName, + "group_name": p.GroupName, + "service_principal_name": p.ServicePrincipalName, + } + var set []string + for name, val := range principals { + if val != nil { + set = append(set, name) + } + } + switch len(set) { + case 0: + return errors.New("permissions: one of 'user_name', 'group_name', or 'service_principal_name' must be specified") + case 1: + name := set[0] + if strings.TrimSpace(*principals[name]) == "" { + return fmt.Errorf("permissions: '%s' cannot be empty", name) + } + default: + return errors.New("permissions: only one of 'user_name', 'group_name', or 'service_principal_name' can be specified") + } + + if strings.TrimSpace(p.Level) == "" { + return errors.New("permissions: 'level' is required") + } + return nil +} diff --git a/experimental/air/cmd/runconfig_load.go b/experimental/air/cmd/runconfig_load.go new file mode 100644 index 00000000000..4cdbd283089 --- /dev/null +++ b/experimental/air/cmd/runconfig_load.go @@ -0,0 +1,40 @@ +package aircmd + +import ( + "errors" + "fmt" + "io" + "os" + + "go.yaml.in/yaml/v3" +) + +// loadRunConfig reads a run YAML config file, decodes it into the schema, and +// runs structural validation. Unknown keys are rejected (KnownFields), mirroring +// the Python schema's extra="forbid". +// +// The `_bases_` composition feature and CLI `--override` handling are not yet +// ported; a config using `_bases_` is currently rejected as an unknown field. +func loadRunConfig(path string) (*runConfig, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + dec := yaml.NewDecoder(f) + dec.KnownFields(true) + + var cfg runConfig + if err := dec.Decode(&cfg); err != nil { + if errors.Is(err, io.EOF) { + return nil, fmt.Errorf("config %s is empty", path) + } + return nil, fmt.Errorf("invalid config %s: %w", path, err) + } + + if err := cfg.validate(); err != nil { + return nil, err + } + return &cfg, nil +} diff --git a/experimental/air/cmd/runconfig_test.go b/experimental/air/cmd/runconfig_test.go new file mode 100644 index 00000000000..06501e6ea6f --- /dev/null +++ b/experimental/air/cmd/runconfig_test.go @@ -0,0 +1,407 @@ +package aircmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// writeConfig writes content to a temp YAML file and returns its path. +func writeConfig(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return path +} + +// minimalConfig is the smallest valid config: the three required pieces. +const minimalConfig = ` +experiment_name: my-run +command: python train.py +compute: + accelerator_type: GPU_1xH100 + num_accelerators: 1 +` + +func TestLoadRunConfig_Minimal(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig)) + require.NoError(t, err) + assert.Equal(t, "my-run", cfg.ExperimentName) + require.NotNil(t, cfg.Command) + assert.Equal(t, "python train.py", *cfg.Command) + require.NotNil(t, cfg.Compute) + assert.Equal(t, "GPU_1xH100", cfg.Compute.AcceleratorType) + assert.Equal(t, 1, cfg.Compute.NumAccelerators) +} + +func TestLoadRunConfig_FullFeatured(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, ` +experiment_name: full_run +command: | + python train.py + echo done +compute: + accelerator_type: GPU_8xH100 + num_accelerators: 16 +environment: + dependencies: + - torch==2.3.0 + - numpy + version: 5 +env_variables: + FOO: bar +secrets: + HF_TOKEN: my_scope/hf_token +code_source: + type: snapshot + snapshot: + root_path: project_root/src + remote_volume: /Volumes/main/default/code + git: + branch: main + remote: origin + include_paths: + - src + - configs/train.yaml +max_retries: 5 +timeout_minutes: 120 +idempotency_token: abc-123 +mlflow_run_name: full_run_v2 +mlflow_experiment_directory: /Workspace/Users/me/exp +usage_policy_name: my-policy +permissions: + - group_name: users + level: CAN_VIEW + - user_name: alice@example.com + level: CAN_MANAGE +`)) + require.NoError(t, err) + assert.Equal(t, gpuType8xH100, gpuType(cfg.Compute.AcceleratorType)) + require.NotNil(t, cfg.Environment) + assert.True(t, cfg.Environment.Dependencies.isList) + assert.Equal(t, []string{"torch==2.3.0", "numpy"}, cfg.Environment.Dependencies.list) + assert.True(t, cfg.Environment.Version.set) + assert.Equal(t, "5", cfg.Environment.Version.raw) + require.NotNil(t, cfg.CodeSource) + require.NotNil(t, cfg.CodeSource.Snapshot) + require.NotNil(t, cfg.CodeSource.Snapshot.Git) + require.NotNil(t, cfg.CodeSource.Snapshot.Git.Branch) + assert.Equal(t, "main", *cfg.CodeSource.Snapshot.Git.Branch) + assert.True(t, cfg.CodeSource.Snapshot.Git.Remote.isString) + assert.Equal(t, "origin", cfg.CodeSource.Snapshot.Git.Remote.name) + assert.Len(t, cfg.Permissions, 2) +} + +// TestLoadRunConfig_PolymorphicFields exercises the str|list, str|int, and +// bool|str unions decoded by custom UnmarshalYAML. +func TestLoadRunConfig_PolymorphicFields(t *testing.T) { + t.Run("dependencies as string path", func(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig+` +environment: + dependencies: requirements.yaml +`)) + require.NoError(t, err) + assert.True(t, cfg.Environment.Dependencies.set) + assert.False(t, cfg.Environment.Dependencies.isList) + assert.Equal(t, "requirements.yaml", cfg.Environment.Dependencies.path) + }) + + t.Run("git remote as bool true", func(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig+` +code_source: + type: snapshot + snapshot: + root_path: . + git: + branch: main + remote: true +`)) + require.NoError(t, err) + r := cfg.CodeSource.Snapshot.Git.Remote + assert.False(t, r.isString) + assert.True(t, r.enabled) + assert.True(t, r.truthy()) + }) + + t.Run("git remote defaults to false when unset", func(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig+` +code_source: + type: snapshot + snapshot: + root_path: . + git: + commit: deadbeef +`)) + require.NoError(t, err) + assert.False(t, cfg.CodeSource.Snapshot.Git.Remote.truthy()) + }) +} + +func TestLoadRunConfig_UnknownFieldRejected(t *testing.T) { + tests := []struct { + name string + extra string + errFrag string + }{ + {"top-level typo", "extra_field: nope\n", "extra_field"}, + // priority was intentionally dropped from the schema (pool-only concept). + {"dropped priority field", "priority: 100\n", "priority"}, + // _bases_ composition is not yet ported, so it surfaces as unknown. + {"unported _bases_", "_bases_: [base.yaml]\n", "_bases_"}, + {"nested typo", "environment:\n bogus: 1\n", "bogus"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := loadRunConfig(writeConfig(t, minimalConfig+tt.extra)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestLoadRunConfig_Errors(t *testing.T) { + tests := []struct { + name string + yaml string + errFrag string + }{ + { + "missing experiment_name", + "command: x\ncompute:\n accelerator_type: GPU_1xH100\n num_accelerators: 1\n", + "experiment_name cannot be empty", + }, + { + "experiment_name bad chars", + "experiment_name: my.run\ncommand: x\ncompute:\n accelerator_type: GPU_1xH100\n num_accelerators: 1\n", + "invalid experiment_name", + }, + { + "missing compute", + "experiment_name: r\ncommand: x\n", + "compute: section is required", + }, + { + "missing command", + "experiment_name: r\ncompute:\n accelerator_type: GPU_1xH100\n num_accelerators: 1\n", + "command is required", + }, + { + "bad gpu type", + "experiment_name: r\ncommand: x\ncompute:\n accelerator_type: a100\n num_accelerators: 1\n", + "invalid GPU type", + }, + { + "num_accelerators not a multiple", + "experiment_name: r\ncommand: x\ncompute:\n accelerator_type: GPU_8xH100\n num_accelerators: 3\n", + "must be a multiple of 8", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := loadRunConfig(writeConfig(t, tt.yaml)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +// TestRunConfigValidate_FieldRules unit-tests validation rules directly, away +// from YAML decoding, to keep each rule's failure mode explicit. +func TestRunConfigValidate_FieldRules(t *testing.T) { + str := func(s string) *string { return &s } + intp := func(i int) *int { return &i } + base := func() *runConfig { + return &runConfig{ + ExperimentName: "r", + Command: str("x"), + Compute: &computeConfig{AcceleratorType: "GPU_1xH100", NumAccelerators: 1}, + } + } + + tests := []struct { + name string + mutate func(c *runConfig) + errFrag string + }{ + {"ok baseline", func(c *runConfig) {}, ""}, + {"empty command", func(c *runConfig) { c.Command = str(" ") }, "command cannot be empty"}, + {"negative max_retries", func(c *runConfig) { c.MaxRetries = intp(-1) }, "max_retries must be >= 0"}, + {"zero timeout", func(c *runConfig) { c.TimeoutMinutes = intp(0) }, "timeout_minutes must be >= 1"}, + {"empty idempotency", func(c *runConfig) { c.IdempotencyToken = str(" ") }, "idempotency_token cannot be empty"}, + {"long idempotency", func(c *runConfig) { c.IdempotencyToken = str(string(make([]byte, 65))) }, "64 characters or less"}, + {"bad mlflow_run_name", func(c *runConfig) { c.MLflowRunName = str("bad name") }, "invalid mlflow_run_name"}, + {"bad experiment dir", func(c *runConfig) { c.MLflowExperimentDirectory = str("/Users/me") }, "must start with '/Workspace'"}, + {"empty usage policy", func(c *runConfig) { c.UsagePolicyName = str(" ") }, "usage_policy_name must not be empty"}, + {"bad secret ref", func(c *runConfig) { c.Secrets = map[string]string{"T": "noslash"} }, "expected format 'scope/key'"}, + {"empty secret scope", func(c *runConfig) { c.Secrets = map[string]string{"T": "/key"} }, "scope and key cannot be empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := base() + tt.mutate(c) + err := c.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestEnvironmentConfigValidate(t *testing.T) { + tests := []struct { + name string + env environmentConfig + errFrag string + }{ + { + "docker image alone ok", + environmentConfig{DockerImage: &dockerImageConfig{URL: "org/repo:tag"}}, + "", + }, + { + "docker image with deps conflicts", + environmentConfig{ + DockerImage: &dockerImageConfig{URL: "org/repo:tag"}, + Dependencies: dependencies{set: true, isList: true, list: []string{"torch"}}, + }, + "not allowed: dependencies", + }, + { + "empty docker url", + environmentConfig{DockerImage: &dockerImageConfig{URL: " "}}, + "docker_image.url cannot be empty", + }, + { + "version with file deps", + environmentConfig{ + Version: stringOrInt{set: true, raw: "5"}, + Dependencies: dependencies{set: true, isList: false, path: "req.yaml"}, + }, + "only valid with inline dependencies", + }, + { + "version without deps", + environmentConfig{Version: stringOrInt{set: true, raw: "5"}}, + "requires inline 'dependencies'", + }, + { + "version with inline deps ok", + environmentConfig{ + Version: stringOrInt{set: true, raw: "5"}, + Dependencies: dependencies{set: true, isList: true, list: []string{"torch"}}, + }, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.env.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestGitRefValidate(t *testing.T) { + str := func(s string) *string { return &s } + tests := []struct { + name string + ref gitRef + errFrag string + }{ + {"branch only ok", gitRef{Branch: str("main")}, ""}, + {"commit only ok", gitRef{Commit: str("abc123")}, ""}, + {"branch with remote ok", gitRef{Branch: str("main"), Remote: gitRemote{set: true, enabled: true}}, ""}, + {"neither branch nor commit", gitRef{}, "must specify either 'branch' or 'commit'"}, + {"both branch and commit", gitRef{Branch: str("main"), Commit: str("abc")}, "mutually exclusive"}, + {"remote without branch", gitRef{Commit: str("abc"), Remote: gitRemote{set: true, isString: true, name: "origin"}}, "requires git.branch"}, + {"bad branch chars", gitRef{Branch: str("bad branch")}, "invalid git.branch"}, + {"empty remote string", gitRef{Branch: str("main"), Remote: gitRemote{set: true, isString: true, name: ""}}, "cannot be empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.ref.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestSnapshotSourceConfigValidate(t *testing.T) { + tests := []struct { + name string + snap snapshotSourceConfig + errFrag string + }{ + {"ok", snapshotSourceConfig{RootPath: "src"}, ""}, + {"empty root_path", snapshotSourceConfig{RootPath: " "}, "root_path cannot be empty"}, + {"bad volume", snapshotSourceConfig{RootPath: "src", RemoteVolume: new("/mnt/x")}, "must start with '/Volumes/'"}, + {"empty include list", snapshotSourceConfig{RootPath: "src", IncludePaths: []string{}}, "cannot be an empty list"}, + {"absolute include", snapshotSourceConfig{RootPath: "src", IncludePaths: []string{"/etc"}}, "must be relative"}, + {"traversal include", snapshotSourceConfig{RootPath: "src", IncludePaths: []string{"../x"}}, "'..' traversal"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.snap.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestPermissionValidate(t *testing.T) { + str := func(s string) *string { return &s } + tests := []struct { + name string + perm permission + errFrag string + }{ + {"ok user", permission{UserName: str("alice@example.com"), Level: "CAN_VIEW"}, ""}, + {"no principal", permission{Level: "CAN_VIEW"}, "must be specified"}, + {"two principals", permission{UserName: str("a"), GroupName: str("g"), Level: "CAN_VIEW"}, "only one of"}, + {"empty principal", permission{UserName: str(" "), Level: "CAN_VIEW"}, "cannot be empty"}, + {"missing level", permission{GroupName: str("users")}, "'level' is required"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.perm.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestLoadRunConfig_FileErrors(t *testing.T) { + t.Run("missing file", func(t *testing.T) { + _, err := loadRunConfig(filepath.Join(t.TempDir(), "nope.yaml")) + assert.Error(t, err) + }) + t.Run("empty file", func(t *testing.T) { + _, err := loadRunConfig(writeConfig(t, "")) + require.Error(t, err) + assert.Contains(t, err.Error(), "is empty") + }) +}