From ae0177608673125ac3c90f42a504de767683cf7d Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Fri, 1 May 2026 12:36:05 +0000 Subject: [PATCH] test(proxy): integration tests for session correlation audit and header agreement Add integration tests that verify the core invariants of session correlation across the proxy, auditor, and forwarded request headers working together. These tests fill the gap identified during review of the session correlation PR stack (#196, #197, #198) where unit tests verified each component in isolation but did not verify them in concert. New test file: proxy/proxy_session_correlation_integration_test.go Tests added: - LLMRequestAuditAndHeadersAgree: audit sequence number matches the forwarded header value on inject-target requests. - NonLLMRequestAuditedWithoutHeaders: allowed non-inject-target requests are audited but carry no correlation headers. - DeniedRequestAuditedNeverForwarded: denied requests consume a sequence number but are never forwarded. - MixedRequestsSequenceOrdering: interleaved LLM, non-LLM, and denied requests all advance the counter monotonically. - SequenceGapRevealsAgenticLoop: gap between two LLM sequence numbers precisely equals intermediate tool-use requests. - SpoofedHeadersOverwrittenWithCorrectSequence: client-supplied headers are replaced and the audit event still agrees. - DisabledCorrelationNoHeadersNoPreallocatedSequence: disabled correlation means no headers and no pre-allocated sequence. - ConcurrentRequestsUniqueSequenceNumbers: concurrent requests each get a unique, dense sequence number. --- ...xy_session_correlation_integration_test.go | 578 ++++++++++++++++++ 1 file changed, 578 insertions(+) create mode 100644 proxy/proxy_session_correlation_integration_test.go diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go new file mode 100644 index 0000000..b8035b7 --- /dev/null +++ b/proxy/proxy_session_correlation_integration_test.go @@ -0,0 +1,578 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// multiRequestCapturingBackend records the headers from every request it +// receives, not just the last one. This is needed by integration tests +// that send multiple requests to the same backend and want to verify +// each one independently. +type multiRequestCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + all []http.Header +} + +func newMultiRequestCapturingBackend() *multiRequestCapturingBackend { + mcb := &multiRequestCapturingBackend{} + mcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mcb.mu.Lock() + mcb.all = append(mcb.all, r.Header.Clone()) + mcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return mcb +} + +func (m *multiRequestCapturingBackend) close() { m.server.Close() } + +func (m *multiRequestCapturingBackend) requestCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.all) +} + +func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { + m.mu.Lock() + defer m.mu.Unlock() + return m.all[i].Clone() +} + +// sessionCorrelationIntegrationSetup holds the shared objects for an +// integration test: the proxy, auditor, backend(s), and sequence +// counter. Tests build one via newSessionCorrelationIntegrationSetup +// and tear it down with stop. +type sessionCorrelationIntegrationSetup struct { + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + llmBackend *multiRequestCapturingBackend + otherBackend *multiRequestCapturingBackend +} + +func (s *sessionCorrelationIntegrationSetup) stop() { + s.pt.Stop() + if s.llmBackend != nil { + s.llmBackend.close() + } + if s.otherBackend != nil { + s.otherBackend.close() + } +} + +// newSessionCorrelationIntegrationSetup builds a proxy that allows +// traffic to two httptest backends: one that matches an inject target +// (simulating an LLM provider) and one that does not (simulating a +// generic allowed domain like github.com). Both backends capture all +// received request headers. A capturingAuditor records every audit +// event for later inspection. +func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { + t.Helper() + + llm := newMultiRequestCapturingBackend() + other := newMultiRequestCapturingBackend() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + // Both httptest backends resolve to 127.0.0.1, so a domain-only + // inject target would match both. We use a path glob on the LLM + // paths (/v1/*) to limit header injection to LLM requests. + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // Allow both backends. + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only requests matching the LLM path receive headers. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + + return &sessionCorrelationIntegrationSetup{ + pt: pt, + auditor: aud, + seq: seq, + llmBackend: llm, + otherBackend: other, + } +} + +// ---------- Integration Tests ---------- + +// TestIntegration_LLMRequestAuditAndHeadersAgree verifies the core +// correlation invariant: when an allowed request hits an inject target, +// the sequence number in the audit event equals the sequence number in +// the forwarded header. +func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { + const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Forwarded headers. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // The two must agree. + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + "audit event and forwarded header must carry the same sequence number", + ) +} + +// TestIntegration_NonLLMRequestAuditedWithoutHeaders verifies that an +// allowed request to a domain that is NOT an inject target still gets +// audited (with a sequence number) but does NOT receive correlation +// headers. +func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { + s := newSessionCorrelationIntegrationSetup(t, "test-session") + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event recorded. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // No correlation headers on the backend. + require.Equal(t, 1, s.otherBackend.requestCount()) + hdr := s.otherBackend.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName), + "non-inject-target requests must not carry session ID header") + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName), + "non-inject-target requests must not carry sequence number header") +} + +// TestIntegration_DeniedRequestAuditedNeverForwarded verifies that a +// request denied by the rules engine is audited (consuming a sequence +// number) but is never forwarded to any backend. +func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { + // Create a setup with a custom deny-all proxy, but keep the same + // pattern of shared sequence counter and auditor. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // No allowed domains: deny everything. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session"), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/exfil") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Audit event recorded. + events := aud.getRequests() + require.Len(t, events, 1) + require.False(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Backend never hit. + assert.Equal(t, 0, llm.requestCount(), + "denied requests must not be forwarded to the backend") +} + +// TestIntegration_MixedRequestsSequenceOrdering sends a realistic +// sequence of LLM, non-LLM, and denied requests, then verifies: +// 1. Sequence numbers increase monotonically across all request types. +// 2. Only inject-target requests carry correlation headers. +// 3. The sequence numbers in headers match the audit events. +// 4. The gap between two LLM requests' sequence numbers reveals the +// intermediate non-LLM and denied activity. +func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { + const sessionID = "mixed-test-session" + + // Two allowed backends (LLM and "github"), one denied domain. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only LLM is an inject target. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // Request 0: LLM (allowed, inject target). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 1: non-LLM (allowed, no inject). + resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: denied (nothing is allowed for evil.example.com). + resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Request 3: LLM again. + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // -- Verify audit events -- + events := aud.getRequests() + require.Len(t, events, 4, "expected exactly four audit events") + + expectedSeq := []uint64{0, 1, 2, 3} + expectedAllowed := []bool{true, true, false, true} + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) + assert.Equal(t, expectedSeq[i], *ev.SequenceNumber, + "event %d: wrong sequence number", i) + assert.Equal(t, expectedAllowed[i], ev.Allowed, + "event %d: wrong allowed flag", i) + } + + // -- Verify LLM backend headers -- + require.Equal(t, 2, llm.requestCount(), + "LLM backend should have received exactly two requests") + + firstLLMHdr := llm.headersAt(0) + assert.Equal(t, sessionID, firstLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", firstLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "first LLM request must have sequence 0") + + secondLLMHdr := llm.headersAt(1) + assert.Equal(t, sessionID, secondLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "3", secondLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "second LLM request must have sequence 3") + + // -- Verify non-LLM backend has no correlation headers -- + require.Equal(t, 1, other.requestCount()) + otherHdr := other.headersAt(0) + assert.Empty(t, otherHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, otherHdr.Get(config.DefaultSequenceNumberHeaderName)) + + // -- Verify the gap reveals intermediate activity -- + // The gap between the two LLM sequence numbers (0 and 3) means + // that sequence numbers 1 and 2 were consumed by non-LLM + // activity, matching audit events 1 (non-LLM allowed) and 2 + // (denied). + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[3].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(2), gap, + "gap between LLM requests should reveal 2 intermediate events") +} + +// TestIntegration_SequenceGapRevealsAgenticLoop sends two LLM requests +// with several non-LLM requests in between, simulating an agentic loop +// where the model triggers tool-use HTTP calls between prompts. The +// test verifies that the gap in LLM sequence numbers precisely +// reflects the count of intermediate boundary events. +func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { + const sessionID = "agentic-loop-session" + + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // First LLM prompt (seq 0). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Agentic loop: three tool-use HTTP calls. + for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { + resp, err = pt.proxyClient.Get(other.server.URL + p) + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + } + + // Second LLM prompt (seq 4). + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Verify LLM sequence headers. + require.Equal(t, 2, llm.requestCount()) + assert.Equal(t, "0", llm.headersAt(0).Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, "4", llm.headersAt(1).Get(config.DefaultSequenceNumberHeaderName)) + + // The gap between sequence numbers 0 and 4 is 3, matching the + // three tool-use requests in between. + events := aud.getRequests() + require.Len(t, events, 5) + + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[4].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(3), gap, + "gap between prompts should equal number of tool-use requests") + + // Verify the intermediate events are the tool-use requests. + for i := 1; i <= 3; i++ { + require.NotNil(t, events[i].SequenceNumber) + assert.Equal(t, uint64(i), *events[i].SequenceNumber) + assert.True(t, events[i].Allowed) + } +} + +// TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence +// verifies that when a jailed client sets its own correlation headers, +// the proxy replaces them with the real session ID and the real +// sequence number, and the audit event still agrees with the header. +func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { + const sessionID = "real-session-uuid" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "9999") + + resp, err := s.pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Backend received real values, not spoofed. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event agrees with header. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + ) +} + +// TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence +// verifies that when session correlation is disabled, the proxy does +// not inject headers and does not pre-allocate sequence numbers (the +// auditor falls back to its own counter instead). +func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { + llm := newMultiRequestCapturingBackend() + defer llm.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + // Correlation disabled; no sequence counter. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("should-not-appear"), + // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // No correlation headers. + require.Equal(t, 1, llm.requestCount()) + hdr := llm.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event recorded but without a pre-allocated sequence + // number (nil), because no SequenceCounter was provided. + events := aud.getRequests() + require.Len(t, events, 1) + assert.Nil(t, events[0].SequenceNumber, + "no sequence counter means no pre-allocated sequence number") +} + +// TestIntegration_ConcurrentRequestsUniqueSequenceNumbers sends +// multiple requests concurrently and verifies that every request +// receives a unique sequence number, and that the set of numbers is +// dense (no gaps, no duplicates). +func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { + const sessionID = "concurrent-session" + const numRequests = 10 + + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + var wg sync.WaitGroup + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + assert.NoError(t, err) + if resp != nil { + resp.Body.Close() //nolint:errcheck + } + }() + } + wg.Wait() + + // Every request should have been audited. + events := s.auditor.getRequests() + require.Len(t, events, numRequests) + + // Collect all sequence numbers and verify uniqueness. + seen := make(map[uint64]bool, numRequests) + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, + "event %d: sequence number must not be nil", i) + assert.False(t, seen[*ev.SequenceNumber], + "event %d: duplicate sequence number %d", i, *ev.SequenceNumber) + seen[*ev.SequenceNumber] = true + } + + // The set should be exactly {0, 1, ..., numRequests-1}. + for i := uint64(0); i < numRequests; i++ { + assert.True(t, seen[i], + "sequence number %d is missing from the set", i) + } + + // Every header should also carry a matching sequence number. + require.Equal(t, numRequests, s.llmBackend.requestCount()) + headerSeqs := make(map[string]bool, numRequests) + for i := 0; i < numRequests; i++ { + hdr := s.llmBackend.headersAt(i) + seqStr := hdr.Get(config.DefaultSequenceNumberHeaderName) + assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) + headerSeqs[seqStr] = true + } + for i := uint64(0); i < numRequests; i++ { + assert.True(t, headerSeqs[fmt.Sprintf("%d", i)], + "header sequence number %d is missing", i) + } +}