Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,10 @@ func spawnSSHClient(ctx context.Context, client *databricks.WorkspaceClient, use

sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout
sshCmd.Stderr = os.Stderr
// Tee ssh's stderr so the user still sees it while we retain the tail to inspect after exit.
// A host-key-verification failure is reported only on stderr, so we need a copy to detect it.
stderrTail := &tailWriter{maxBytes: hostKeyStderrTailBytes}
sshCmd.Stderr = io.MultiWriter(os.Stderr, stderrTail)

err = sshCmd.Run()
// ssh reserves exit code 255 for its own connection-level failures (a remote command's exit
Expand All @@ -659,7 +662,9 @@ func spawnSSHClient(ctx context.Context, client *databricks.WorkspaceClient, use
// own logs — fetch them from the /logs endpoint and show them instead of leaving the user
// with ssh's opaque "Connection closed" message.
if exitErr, ok := errors.AsType[*exec.ExitError](err); ok && exitErr.ExitCode() == 255 {
if logs := fetchServerErrorLogs(ctx, client, clusterID, serverPort, opts.Liteswap); logs != "" {
if hint := hostKeyChangedHint(stderrTail.String(), hostName, opts.UserKnownHostsFile); hint != "" {
cmdio.LogString(ctx, cmdio.Yellow(ctx, hint))
} else if logs := fetchServerErrorLogs(ctx, client, clusterID, serverPort, opts.Liteswap); logs != "" {
cmdio.LogString(ctx, cmdio.Yellow(ctx, "The SSH connection closed unexpectedly. Recent SSH server errors:"))
cmdio.LogString(ctx, truncateTail(logs, maxRunFailureTraceBytes))
} else {
Expand Down Expand Up @@ -854,6 +859,49 @@ func truncateTail(s string, maxBytes int) string {
return " ...\n" + s[len(s)-maxBytes:]
}

// hostKeyStderrTailBytes bounds how much of ssh's stderr we retain to detect a host-key failure.
// The host-key warning block ssh prints is well under this, so the tail always captures it.
const hostKeyStderrTailBytes = 4096

// tailWriter retains the last maxBytes written to it, so we can inspect an external command's
// recent stderr without buffering an unbounded amount.
type tailWriter struct {
maxBytes int
buf []byte
}

func (w *tailWriter) Write(p []byte) (int, error) {
w.buf = append(w.buf, p...)
if len(w.buf) > w.maxBytes {
w.buf = w.buf[len(w.buf)-w.maxBytes:]
}
return len(p), nil
}

func (w *tailWriter) String() string {
return string(w.buf)
}

// hostKeyChangedHint returns advice for clearing a stale known_hosts entry when ssh's stderr
// shows a host-key-verification failure, or "" if the failure was something else. A cluster that
// has been recreated keeps the same connection name but gets a new host key, so the old entry no
// longer matches and ssh aborts the connection.
func hostKeyChangedHint(stderr, hostName, knownHostsFile string) string {
// "Host key verification failed." is OpenSSH's fixed message for this case; matching it is the
// only signal ssh gives (the "don't branch on err.Error()" rule is about Go errors, not the
// output of an external program).
if !strings.Contains(stderr, "Host key verification failed") {
return ""
}
cmd := "ssh-keygen -R " + hostName
if knownHostsFile != "" {
// ssh-keygen -R defaults to ~/.ssh/known_hosts, so name the custom file explicitly.
cmd += " -f " + knownHostsFile
}
return "The host key for " + hostName + " has changed. " +
"Remove the stale entry and reconnect:\n " + cmd
}

func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) {
sessionID := opts.SessionIdentifier()
// For dedicated clusters, use clusterID; for serverless, it will be read from metadata
Expand Down
68 changes: 68 additions & 0 deletions experimental/ssh/internal/client/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,71 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) {
assert.Contains(t, err.Error(), "ssh server bootstrap job failed")
assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")
}

// hostKeyFailureStderr is the relevant tail of ssh's stderr when strict checking aborts a
// connection because the remote host key changed.
const hostKeyFailureStderr = `@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
Host key for databricks-cpu-6e7644d0 has changed and you have requested strict checking.
Host key verification failed.`

func TestHostKeyChangedHint(t *testing.T) {
tests := []struct {
name string
stderr string
hostName string
knownHostsFile string
wantContains []string
wantEmpty bool
}{
{
name: "host key failure",
stderr: hostKeyFailureStderr,
hostName: "databricks-cpu-6e7644d0",
wantContains: []string{"databricks-cpu-6e7644d0", "ssh-keygen -R databricks-cpu-6e7644d0"},
},
{
name: "host key failure with custom known_hosts file",
stderr: hostKeyFailureStderr,
hostName: "databricks-cpu-6e7644d0",
knownHostsFile: "/tmp/known_hosts",
wantContains: []string{"ssh-keygen -R databricks-cpu-6e7644d0 -f /tmp/known_hosts"},
},
{
name: "unrelated failure",
stderr: "kex_exchange_identification: Connection closed by remote host",
hostName: "databricks-cpu-6e7644d0",
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := hostKeyChangedHint(tt.stderr, tt.hostName, tt.knownHostsFile)
if tt.wantEmpty {
assert.Empty(t, got)
return
}
for _, want := range tt.wantContains {
assert.Contains(t, got, want)
}
})
}
}

func TestTailWriterRetainsTail(t *testing.T) {
t.Run("retains only the tail", func(t *testing.T) {
w := &tailWriter{maxBytes: 4}
n, err := w.Write([]byte("abcdefgh"))
require.NoError(t, err)
assert.Equal(t, 8, n)
assert.Equal(t, "efgh", w.String())
})

t.Run("preserves a short write", func(t *testing.T) {
w := &tailWriter{maxBytes: 4}
_, err := w.Write([]byte("ab"))
require.NoError(t, err)
assert.Equal(t, "ab", w.String())
})
}
Loading