From 748fc00533897d7c62ae8ec3d7c7dfa76cf07b9a Mon Sep 17 00:00:00 2001 From: Anton Nekipelov <226657+anton-107@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:40:56 +0200 Subject: [PATCH] Suggest ssh-keygen -R on SSH connect host key mismatch When a Databricks compute is recreated it keeps the same deterministic SSH connection name but gets a new host key, so the stale known_hosts entry trips OpenSSH's strict checking and `ssh connect` exits 255 with "Host key verification failed." Until now that landed in the generic "container is likely missing an OpenSSH server" branch, which is misleading and offers no fix. Tee ssh's stderr through a bounded tail buffer so we can detect the host-key failure after exit, and when it occurs print an actionable hint telling the user to run `ssh-keygen -R ` (with `-f ` when --user-known-hosts-file is set) and reconnect. Co-authored-by: Isaac --- experimental/ssh/internal/client/client.go | 52 +++++++++++++- .../internal/client/client_internal_test.go | 68 +++++++++++++++++++ 2 files changed, 118 insertions(+), 2 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index e4daa146146..c006ba366b6 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -650,7 +650,10 @@ func spawnSSHClient(ctx context.Context, client *databricks.WorkspaceClient, use sshCmd.Stdin = os.Stdin sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr + // Tee ssh's stderr so the user still sees it while we retain the tail to inspect after exit. + // A host-key-verification failure is reported only on stderr, so we need a copy to detect it. + stderrTail := &tailWriter{maxBytes: hostKeyStderrTailBytes} + sshCmd.Stderr = io.MultiWriter(os.Stderr, stderrTail) err = sshCmd.Run() // ssh reserves exit code 255 for its own connection-level failures (a remote command's exit @@ -659,7 +662,9 @@ func spawnSSHClient(ctx context.Context, client *databricks.WorkspaceClient, use // own logs — fetch them from the /logs endpoint and show them instead of leaving the user // with ssh's opaque "Connection closed" message. if exitErr, ok := errors.AsType[*exec.ExitError](err); ok && exitErr.ExitCode() == 255 { - if logs := fetchServerErrorLogs(ctx, client, clusterID, serverPort, opts.Liteswap); logs != "" { + if hint := hostKeyChangedHint(stderrTail.String(), hostName, opts.UserKnownHostsFile); hint != "" { + cmdio.LogString(ctx, cmdio.Yellow(ctx, hint)) + } else if logs := fetchServerErrorLogs(ctx, client, clusterID, serverPort, opts.Liteswap); logs != "" { cmdio.LogString(ctx, cmdio.Yellow(ctx, "The SSH connection closed unexpectedly. Recent SSH server errors:")) cmdio.LogString(ctx, truncateTail(logs, maxRunFailureTraceBytes)) } else { @@ -854,6 +859,49 @@ func truncateTail(s string, maxBytes int) string { return " ...\n" + s[len(s)-maxBytes:] } +// hostKeyStderrTailBytes bounds how much of ssh's stderr we retain to detect a host-key failure. +// The host-key warning block ssh prints is well under this, so the tail always captures it. +const hostKeyStderrTailBytes = 4096 + +// tailWriter retains the last maxBytes written to it, so we can inspect an external command's +// recent stderr without buffering an unbounded amount. +type tailWriter struct { + maxBytes int + buf []byte +} + +func (w *tailWriter) Write(p []byte) (int, error) { + w.buf = append(w.buf, p...) + if len(w.buf) > w.maxBytes { + w.buf = w.buf[len(w.buf)-w.maxBytes:] + } + return len(p), nil +} + +func (w *tailWriter) String() string { + return string(w.buf) +} + +// hostKeyChangedHint returns advice for clearing a stale known_hosts entry when ssh's stderr +// shows a host-key-verification failure, or "" if the failure was something else. A cluster that +// has been recreated keeps the same connection name but gets a new host key, so the old entry no +// longer matches and ssh aborts the connection. +func hostKeyChangedHint(stderr, hostName, knownHostsFile string) string { + // "Host key verification failed." is OpenSSH's fixed message for this case; matching it is the + // only signal ssh gives (the "don't branch on err.Error()" rule is about Go errors, not the + // output of an external program). + if !strings.Contains(stderr, "Host key verification failed") { + return "" + } + cmd := "ssh-keygen -R " + hostName + if knownHostsFile != "" { + // ssh-keygen -R defaults to ~/.ssh/known_hosts, so name the custom file explicitly. + cmd += " -f " + knownHostsFile + } + return "The host key for " + hostName + " has changed. " + + "Remove the stale entry and reconnect:\n " + cmd +} + func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { sessionID := opts.SessionIdentifier() // For dedicated clusters, use clusterID; for serverless, it will be read from metadata diff --git a/experimental/ssh/internal/client/client_internal_test.go b/experimental/ssh/internal/client/client_internal_test.go index 4e592a0eab3..af6d021c2c1 100644 --- a/experimental/ssh/internal/client/client_internal_test.go +++ b/experimental/ssh/internal/client/client_internal_test.go @@ -146,3 +146,71 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) { assert.Contains(t, err.Error(), "ssh server bootstrap job failed") assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.") } + +// hostKeyFailureStderr is the relevant tail of ssh's stderr when strict checking aborts a +// connection because the remote host key changed. +const hostKeyFailureStderr = `@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +Host key for databricks-cpu-6e7644d0 has changed and you have requested strict checking. +Host key verification failed.` + +func TestHostKeyChangedHint(t *testing.T) { + tests := []struct { + name string + stderr string + hostName string + knownHostsFile string + wantContains []string + wantEmpty bool + }{ + { + name: "host key failure", + stderr: hostKeyFailureStderr, + hostName: "databricks-cpu-6e7644d0", + wantContains: []string{"databricks-cpu-6e7644d0", "ssh-keygen -R databricks-cpu-6e7644d0"}, + }, + { + name: "host key failure with custom known_hosts file", + stderr: hostKeyFailureStderr, + hostName: "databricks-cpu-6e7644d0", + knownHostsFile: "/tmp/known_hosts", + wantContains: []string{"ssh-keygen -R databricks-cpu-6e7644d0 -f /tmp/known_hosts"}, + }, + { + name: "unrelated failure", + stderr: "kex_exchange_identification: Connection closed by remote host", + hostName: "databricks-cpu-6e7644d0", + wantEmpty: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hostKeyChangedHint(tt.stderr, tt.hostName, tt.knownHostsFile) + if tt.wantEmpty { + assert.Empty(t, got) + return + } + for _, want := range tt.wantContains { + assert.Contains(t, got, want) + } + }) + } +} + +func TestTailWriterRetainsTail(t *testing.T) { + t.Run("retains only the tail", func(t *testing.T) { + w := &tailWriter{maxBytes: 4} + n, err := w.Write([]byte("abcdefgh")) + require.NoError(t, err) + assert.Equal(t, 8, n) + assert.Equal(t, "efgh", w.String()) + }) + + t.Run("preserves a short write", func(t *testing.T) { + w := &tailWriter{maxBytes: 4} + _, err := w.Write([]byte("ab")) + require.NoError(t, err) + assert.Equal(t, "ab", w.String()) + }) +}