diff --git a/cmd/root/api.go b/cmd/root/api.go index 62bd8141e..5af574d9d 100644 --- a/cmd/root/api.go +++ b/cmd/root/api.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log/slog" + "net" "os" "time" @@ -92,6 +93,7 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) (commandErr defer lnCleanup() out.Println("Listening on", ln.Addr().String()) + warnIfNotLoopback(out, ln.Addr()) slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String()) @@ -123,3 +125,22 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) (commandErr return s.Serve(ctx, ln) } + +// warnIfNotLoopback prints a security warning when the API server is bound to +// an address other than loopback. The default --listen value is 127.0.0.1, so +// reaching this code path means the operator was explicit about exposing the +// API; we just remind them that the API has no authentication. +func warnIfNotLoopback(out *cli.Printer, addr net.Addr) { + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + // Unix sockets and named pipes rely on filesystem permissions. + return + } + if tcpAddr.IP.IsLoopback() { + return + } + out.Println("WARNING: API server is listening on a non-loopback address.") + out.Println(" The API has no authentication; anyone able to reach") + out.Println(" this address can run agents and access all sessions.") + slog.Warn("API server bound to non-loopback address", "addr", tcpAddr.String()) +} diff --git a/pkg/config/sources.go b/pkg/config/sources.go index efdd7734e..9467bc0b7 100644 --- a/pkg/config/sources.go +++ b/pkg/config/sources.go @@ -8,12 +8,15 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" "net/url" "os" "path/filepath" "slices" "strings" + "syscall" + "time" "github.com/docker/docker-agent/pkg/content" "github.com/docker/docker-agent/pkg/environment" @@ -192,6 +195,11 @@ func hasLocalArtifact(store *content.Store, storeKey string) bool { type urlSource struct { url string envProvider environment.Provider + // unsafe disables the HTTPS-only and SSRF dial-time checks. It is set + // only by the test-only constructor newURLSourceForTest (defined in + // sources_test.go), which exists because tests use httptest.NewServer + // (plain HTTP, 127.0.0.1). + unsafe bool } // NewURLSource creates a new URL source. If envProvider is non-nil, it will be used @@ -217,6 +225,12 @@ func getURLCacheDir() string { } func (a urlSource) Read(ctx context.Context) ([]byte, error) { + if !a.unsafe { + if err := validateAgentURL(a.url); err != nil { + return nil, err + } + } + cacheDir := getURLCacheDir() urlHash := hashURL(a.url) cachePath := filepath.Join(cacheDir, urlHash) @@ -241,7 +255,12 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) { // Add GitHub token authorization for GitHub URLs a.addGitHubAuth(ctx, req) - resp, err := httpclient.NewHTTPClient(ctx).Do(req) + client := httpclient.NewHTTPClient(ctx) + if !a.unsafe { + client = ssrfSafeHTTPClient() + } + + resp, err := client.Do(req) if err != nil { // Network error - try to use cached version if cachedData, cacheErr := os.ReadFile(cachePath); cacheErr == nil { @@ -344,3 +363,96 @@ func hashURL(rawURL string) string { func IsURLReference(input string) bool { return strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") } + +// validateAgentURL enforces that an agent URL uses HTTPS. SSRF protection +// (rejecting connections to loopback / private / link-local addresses) is +// done at dial time by ssrfSafeHTTPClient so that DNS rebinding cannot be +// used to bypass it. +func validateAgentURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL %q: %w", rawURL, err) + } + if u.Scheme != "https" { + return fmt.Errorf("refusing to load agent from %q: only https:// URLs are allowed (got scheme %q)", rawURL, u.Scheme) + } + if u.Host == "" { + return fmt.Errorf("invalid URL %q: missing host", rawURL) + } + return nil +} + +// ssrfSafeHTTPClient returns an http.Client whose dialer rejects connections +// to non-public IP ranges (loopback, private, link-local, multicast, +// unspecified). The check happens after DNS resolution and before the TCP +// handshake, so DNS rebinding to a private IP is also blocked. +// +// Redirects are re-validated through CheckRedirect so that an https:// +// origin cannot transparently downgrade to http:// or to a different scheme. +// SSRF protection on the redirect target is provided by the same dialer. +func ssrfSafeHTTPClient() *http.Client { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Control: ssrfDialControl, + } + return &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + CheckRedirect: ssrfCheckRedirect, + } +} + +// ssrfCheckRedirect is the http.Client CheckRedirect hook used by +// ssrfSafeHTTPClient. It rejects redirects to non-https URLs (defeating +// TLS downgrade) and bounds the redirect chain. SSRF on the redirect +// target itself is enforced by the dialer's Control hook. +func ssrfCheckRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + if req.URL.Scheme != "https" { + return fmt.Errorf("refusing redirect to non-https URL %q", req.URL.Redacted()) + } + return nil +} + +// ssrfDialControl is invoked by net.Dialer after DNS resolution but before the +// TCP handshake. It rejects addresses that are not safe to fetch from over +// the public internet. +func ssrfDialControl(_, address string, _ syscall.RawConn) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("parsing dial address %q: %w", address, err) + } + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("refusing to dial %q: not a valid IP", host) + } + if !isPublicIP(ip) { + return fmt.Errorf("refusing to dial non-public address %s", ip) + } + return nil +} + +// isPublicIP reports whether ip is a routable public address. It rejects +// loopback (127/8, ::1), RFC1918 private ranges, link-local (incl. the +// 169.254.169.254 cloud metadata endpoint), multicast and the unspecified +// address (0.0.0.0, ::). +func isPublicIP(ip net.IP) bool { + return !ip.IsLoopback() && + !ip.IsPrivate() && + !ip.IsLinkLocalUnicast() && + !ip.IsLinkLocalMulticast() && + !ip.IsMulticast() && + !ip.IsUnspecified() +} diff --git a/pkg/config/sources_test.go b/pkg/config/sources_test.go index d13706aa1..a6559c52c 100644 --- a/pkg/config/sources_test.go +++ b/pkg/config/sources_test.go @@ -1,6 +1,7 @@ package config import ( + "net" "net/http" "net/http/httptest" "net/url" @@ -21,6 +22,18 @@ import ( "github.com/docker/docker-agent/pkg/remote" ) +// newURLSourceForTest constructs a urlSource that bypasses the HTTPS-only and +// SSRF dial-time checks. It is defined here, in a _test.go file, so it is +// not compiled into release binaries. Tests use it because httptest.NewServer +// binds to 127.0.0.1 over plain HTTP. +func newURLSourceForTest(rawURL string, envProvider environment.Provider) Source { + return &urlSource{ + url: rawURL, + envProvider: envProvider, + unsafe: true, + } +} + func TestOCISource_DigestReference_ServesFromCache(t *testing.T) { t.Parallel() @@ -70,7 +83,7 @@ func TestURLSource_Read(t *testing.T) { })) t.Cleanup(server.Close) - source := NewURLSource(server.URL, nil) + source := newURLSourceForTest(server.URL, nil) assert.Equal(t, server.URL, source.Name()) assert.Empty(t, source.ParentDir()) @@ -108,7 +121,7 @@ func TestURLSource_Read_HTTPError(t *testing.T) { _ = os.Remove(cachePath) _ = os.Remove(etagPath) - _, err := NewURLSource(server.URL, nil).Read(t.Context()) + _, err := newURLSourceForTest(server.URL, nil).Read(t.Context()) require.Error(t, err) }) } @@ -117,7 +130,7 @@ func TestURLSource_Read_HTTPError(t *testing.T) { func TestURLSource_Read_ConnectionError(t *testing.T) { t.Parallel() - _, err := NewURLSource("http://invalid.invalid/config.yaml", nil).Read(t.Context()) + _, err := newURLSourceForTest("http://invalid.invalid/config.yaml", nil).Read(t.Context()) require.Error(t, err) } @@ -130,7 +143,7 @@ func TestURLSource_Read_CachesContent(t *testing.T) { })) t.Cleanup(server.Close) - source := NewURLSource(server.URL, nil) + source := newURLSourceForTest(server.URL, nil) // First read should fetch and cache data, err := source.Read(t.Context()) @@ -188,7 +201,7 @@ func TestURLSource_Read_UsesETagForConditionalRequest(t *testing.T) { _ = os.Remove(etagPath) }) - source := NewURLSource(server.URL, nil) + source := newURLSourceForTest(server.URL, nil) // Read should use cached content via 304 response data, err := source.Read(t.Context()) @@ -213,7 +226,7 @@ func TestURLSource_Read_FallsBackToCacheOnNetworkError(t *testing.T) { _ = os.Remove(cachePath) }) - source := NewURLSource(agentURL, nil) + source := newURLSourceForTest(agentURL, nil) // Read should fall back to cached content data, err := source.Read(t.Context()) @@ -241,7 +254,7 @@ func TestURLSource_Read_FallsBackToCacheOnHTTPError(t *testing.T) { _ = os.Remove(cachePath) }) - source := NewURLSource(server.URL, nil) + source := newURLSourceForTest(server.URL, nil) // Read should fall back to cached content data, err := source.Read(t.Context()) @@ -279,7 +292,7 @@ func TestURLSource_Read_UpdatesCacheWhenContentChanges(t *testing.T) { _ = os.Remove(etagPath) }) - source := NewURLSource(server.URL, nil) + source := newURLSourceForTest(server.URL, nil) // First read data, err := source.Read(t.Context()) @@ -300,6 +313,144 @@ func TestURLSource_Read_UpdatesCacheWhenContentChanges(t *testing.T) { assert.Equal(t, "updated content update", string(cachedData)) } +func TestURLSource_Read_RejectsHTTP(t *testing.T) { + t.Parallel() + + _, err := NewURLSource("http://example.com/agent.yaml", nil).Read(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "only https://") +} + +func TestURLSource_Read_RejectsLocalAddresses(t *testing.T) { + t.Parallel() + + // Hosts whose only resolution is a non-public IP must be refused at + // dial time. We test the SSRF dialer via the HTTPS code path even + // though the TLS handshake will never complete, because the dial is + // aborted before any bytes are sent. + tests := []string{ + "https://127.0.0.1/agent.yaml", // loopback + "https://[::1]/agent.yaml", // IPv6 loopback + "https://10.0.0.1/agent.yaml", // RFC1918 + "https://192.168.1.1/agent.yaml", // RFC1918 + "https://169.254.169.254/agent.yaml", // AWS/GCP/Azure metadata + "https://0.0.0.0/agent.yaml", // unspecified + } + for _, rawURL := range tests { + t.Run(rawURL, func(t *testing.T) { + t.Parallel() + + // Clear any cached content so the dial is actually attempted. + urlCacheDir := getURLCacheDir() + urlHash := hashURL(rawURL) + _ = os.Remove(filepath.Join(urlCacheDir, urlHash)) + _ = os.Remove(filepath.Join(urlCacheDir, urlHash+".etag")) + + _, err := NewURLSource(rawURL, nil).Read(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "non-public address") + }) + } +} + +func TestIsPublicIP(t *testing.T) { + t.Parallel() + + tests := []struct { + ip string + isPublic bool + }{ + {"8.8.8.8", true}, + {"1.1.1.1", true}, + {"2001:4860:4860::8888", true}, // public IPv6 + {"127.0.0.1", false}, + {"::1", false}, + {"10.0.0.1", false}, + {"172.16.0.1", false}, + {"192.168.0.1", false}, + {"169.254.169.254", false}, // AWS/GCP/Azure metadata + {"fe80::1", false}, // link-local IPv6 + {"224.0.0.1", false}, // multicast + {"0.0.0.0", false}, + {"::", false}, + } + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + t.Parallel() + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "failed to parse IP %q", tt.ip) + assert.Equal(t, tt.isPublic, isPublicIP(ip)) + }) + } +} + +func TestSSRFCheckRedirect(t *testing.T) { + t.Parallel() + + mustParse := func(s string) *url.URL { + u, err := url.Parse(s) + require.NoError(t, err) + return u + } + + tests := []struct { + name string + target string + via int // length of the via slice + wantErr string // empty means no error + }{ + {"https redirect allowed", "https://example.com/agent.yaml", 1, ""}, + {"http redirect rejected", "http://example.com/agent.yaml", 1, "non-https"}, + {"file redirect rejected", "file:///etc/passwd", 1, "non-https"}, + {"javascript redirect rejected", "javascript:alert(1)", 1, "non-https"}, + {"ftp redirect rejected", "ftp://example.com/x", 1, "non-https"}, + {"redirect loop bounded", "https://example.com/agent.yaml", 10, "10 redirects"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := &http.Request{URL: mustParse(tt.target)} + via := make([]*http.Request, tt.via) + err := ssrfCheckRedirect(req, via) + if tt.wantErr == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestURLSource_Read_RejectsHTTPRedirect(t *testing.T) { + // Not parallel - clears cache. + + // HTTPS origin that 302s to plain http. We use httptest.NewTLSServer so + // the production ssrfSafeHTTPClient gets to exercise CheckRedirect on a + // real Location header. The dial-time SSRF check would reject 127.0.0.1 + // before the redirect target is fetched, but CheckRedirect runs first + // and gives us the precise downgrade error message. + httpsSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Redirect(w, &http.Request{}, "http://example.com/downgraded", http.StatusFound) + })) + t.Cleanup(httpsSrv.Close) + + // Trust the test server's self-signed cert by injecting it into the + // Go default cert pool would be invasive; instead, exercise the + // CheckRedirect hook directly via ssrfCheckRedirect (covered above) + // and assert that the production fetch path errors out for an https + // origin pointing at a non-trusted CA. Either way, the request must + // not silently follow to http://. + agentURL := httpsSrv.URL + "/agent.yaml" + urlCacheDir := getURLCacheDir() + urlHash := hashURL(agentURL) + _ = os.Remove(filepath.Join(urlCacheDir, urlHash)) + _ = os.Remove(filepath.Join(urlCacheDir, urlHash+".etag")) + + _, err := NewURLSource(agentURL, nil).Read(t.Context()) + require.Error(t, err) +} + func TestIsURLReference(t *testing.T) { t.Parallel() @@ -365,7 +516,7 @@ func TestURLSource_Read_WithGitHubAuth(t *testing.T) { }) // For non-GitHub URLs, auth should not be added even with token available - source := NewURLSource(server.URL, envProvider) + source := newURLSourceForTest(server.URL, envProvider) _, err := source.Read(t.Context()) require.NoError(t, err) assert.Empty(t, receivedAuth, "non-GitHub URLs should not receive auth header") @@ -397,7 +548,7 @@ func TestURLSource_Read_WithGitHubAuth_GitHubURL(t *testing.T) { // URL with GitHub host in path (not hostname) should NOT receive auth // This prevents token leakage to attacker-controlled domains maliciousURL := server.URL + "/" + host + "/path/to/file" - source := NewURLSource(maliciousURL, envProvider) + source := newURLSourceForTest(maliciousURL, envProvider) _, err := source.Read(t.Context()) require.NoError(t, err) @@ -419,7 +570,7 @@ func TestURLSource_Read_WithGitHubAuth_NoToken(t *testing.T) { // Create a mock env provider without a GitHub token envProvider := environment.NewNoEnvProvider() - source := NewURLSource(server.URL, envProvider) + source := newURLSourceForTest(server.URL, envProvider) _, err := source.Read(t.Context()) require.NoError(t, err) assert.Empty(t, receivedAuth, "should not add auth header when token is missing") @@ -436,7 +587,7 @@ func TestURLSource_Read_WithGitHubAuth_NoEnvProvider(t *testing.T) { t.Cleanup(server.Close) // No env provider - source := NewURLSource(server.URL, nil) + source := newURLSourceForTest(server.URL, nil) _, err := source.Read(t.Context()) require.NoError(t, err) assert.Empty(t, receivedAuth, "should not add auth header without env provider") diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 23dbf766c..78d6c079a 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -401,11 +401,11 @@ func (sm *SessionManager) generateTitle(ctx context.Context, sess *session.Sessi } func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.Session, agentFilename, currentAgent string, rc *config.RuntimeConfig) (runtime.Runtime, *sessiontitle.Generator, error) { - rt, exists := sm.runtimeSessions.Load(sess.ID) - if exists && rt.runtime != nil { - return rt.runtime, rt.titleGen, nil - } - + // Caller (RunSession) holds sm.mux and has already verified that no + // active runtime exists for this session. This function is purely a + // constructor: it must not touch sm.runtimeSessions, otherwise it would + // briefly publish a half-initialised activeRuntimes (e.g. without the + // cancel func) that other goroutines could observe. t, err := sm.loadTeam(ctx, agentFilename, rc) if err != nil { return nil, nil, err @@ -433,12 +433,6 @@ func (sm *SessionManager) runtimeForSession(ctx context.Context, sess *session.S titleGen := sessiontitle.New(agt.Model(), agt.FallbackModels()...) - sm.runtimeSessions.Store(sess.ID, &activeRuntimes{ - runtime: run, - session: sess, - titleGen: titleGen, - }) - slog.Debug("Runtime created for session", "session_id", sess.ID) return run, titleGen, nil