diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index e4daa146146..c006ba366b6 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -650,7 +650,10 @@ func spawnSSHClient(ctx context.Context, client *databricks.WorkspaceClient, use sshCmd.Stdin = os.Stdin sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr + // Tee ssh's stderr so the user still sees it while we retain the tail to inspect after exit. + // A host-key-verification failure is reported only on stderr, so we need a copy to detect it. + stderrTail := &tailWriter{maxBytes: hostKeyStderrTailBytes} + sshCmd.Stderr = io.MultiWriter(os.Stderr, stderrTail) err = sshCmd.Run() // ssh reserves exit code 255 for its own connection-level failures (a remote command's exit @@ -659,7 +662,9 @@ func spawnSSHClient(ctx context.Context, client *databricks.WorkspaceClient, use // own logs — fetch them from the /logs endpoint and show them instead of leaving the user // with ssh's opaque "Connection closed" message. if exitErr, ok := errors.AsType[*exec.ExitError](err); ok && exitErr.ExitCode() == 255 { - if logs := fetchServerErrorLogs(ctx, client, clusterID, serverPort, opts.Liteswap); logs != "" { + if hint := hostKeyChangedHint(stderrTail.String(), hostName, opts.UserKnownHostsFile); hint != "" { + cmdio.LogString(ctx, cmdio.Yellow(ctx, hint)) + } else if logs := fetchServerErrorLogs(ctx, client, clusterID, serverPort, opts.Liteswap); logs != "" { cmdio.LogString(ctx, cmdio.Yellow(ctx, "The SSH connection closed unexpectedly. Recent SSH server errors:")) cmdio.LogString(ctx, truncateTail(logs, maxRunFailureTraceBytes)) } else { @@ -854,6 +859,49 @@ func truncateTail(s string, maxBytes int) string { return " ...\n" + s[len(s)-maxBytes:] } +// hostKeyStderrTailBytes bounds how much of ssh's stderr we retain to detect a host-key failure. +// The host-key warning block ssh prints is well under this, so the tail always captures it. +const hostKeyStderrTailBytes = 4096 + +// tailWriter retains the last maxBytes written to it, so we can inspect an external command's +// recent stderr without buffering an unbounded amount. +type tailWriter struct { + maxBytes int + buf []byte +} + +func (w *tailWriter) Write(p []byte) (int, error) { + w.buf = append(w.buf, p...) + if len(w.buf) > w.maxBytes { + w.buf = w.buf[len(w.buf)-w.maxBytes:] + } + return len(p), nil +} + +func (w *tailWriter) String() string { + return string(w.buf) +} + +// hostKeyChangedHint returns advice for clearing a stale known_hosts entry when ssh's stderr +// shows a host-key-verification failure, or "" if the failure was something else. A cluster that +// has been recreated keeps the same connection name but gets a new host key, so the old entry no +// longer matches and ssh aborts the connection. +func hostKeyChangedHint(stderr, hostName, knownHostsFile string) string { + // "Host key verification failed." is OpenSSH's fixed message for this case; matching it is the + // only signal ssh gives (the "don't branch on err.Error()" rule is about Go errors, not the + // output of an external program). + if !strings.Contains(stderr, "Host key verification failed") { + return "" + } + cmd := "ssh-keygen -R " + hostName + if knownHostsFile != "" { + // ssh-keygen -R defaults to ~/.ssh/known_hosts, so name the custom file explicitly. + cmd += " -f " + knownHostsFile + } + return "The host key for " + hostName + " has changed. " + + "Remove the stale entry and reconnect:\n " + cmd +} + func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { sessionID := opts.SessionIdentifier() // For dedicated clusters, use clusterID; for serverless, it will be read from metadata diff --git a/experimental/ssh/internal/client/client_internal_test.go b/experimental/ssh/internal/client/client_internal_test.go index 4e592a0eab3..af6d021c2c1 100644 --- a/experimental/ssh/internal/client/client_internal_test.go +++ b/experimental/ssh/internal/client/client_internal_test.go @@ -146,3 +146,71 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) { assert.Contains(t, err.Error(), "ssh server bootstrap job failed") assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.") } + +// hostKeyFailureStderr is the relevant tail of ssh's stderr when strict checking aborts a +// connection because the remote host key changed. +const hostKeyFailureStderr = `@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +Host key for databricks-cpu-6e7644d0 has changed and you have requested strict checking. +Host key verification failed.` + +func TestHostKeyChangedHint(t *testing.T) { + tests := []struct { + name string + stderr string + hostName string + knownHostsFile string + wantContains []string + wantEmpty bool + }{ + { + name: "host key failure", + stderr: hostKeyFailureStderr, + hostName: "databricks-cpu-6e7644d0", + wantContains: []string{"databricks-cpu-6e7644d0", "ssh-keygen -R databricks-cpu-6e7644d0"}, + }, + { + name: "host key failure with custom known_hosts file", + stderr: hostKeyFailureStderr, + hostName: "databricks-cpu-6e7644d0", + knownHostsFile: "/tmp/known_hosts", + wantContains: []string{"ssh-keygen -R databricks-cpu-6e7644d0 -f /tmp/known_hosts"}, + }, + { + name: "unrelated failure", + stderr: "kex_exchange_identification: Connection closed by remote host", + hostName: "databricks-cpu-6e7644d0", + wantEmpty: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hostKeyChangedHint(tt.stderr, tt.hostName, tt.knownHostsFile) + if tt.wantEmpty { + assert.Empty(t, got) + return + } + for _, want := range tt.wantContains { + assert.Contains(t, got, want) + } + }) + } +} + +func TestTailWriterRetainsTail(t *testing.T) { + t.Run("retains only the tail", func(t *testing.T) { + w := &tailWriter{maxBytes: 4} + n, err := w.Write([]byte("abcdefgh")) + require.NoError(t, err) + assert.Equal(t, 8, n) + assert.Equal(t, "efgh", w.String()) + }) + + t.Run("preserves a short write", func(t *testing.T) { + w := &tailWriter{maxBytes: 4} + _, err := w.Write([]byte("ab")) + require.NoError(t, err) + assert.Equal(t, "ab", w.String()) + }) +}