From b6c6a1e7274796b066766c01068e178bc04ff3c3 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Mon, 20 Apr 2026 16:19:22 +0000 Subject: [PATCH 1/3] prototype header injection --- audit/log_auditor.go | 31 ++++--- audit/multi_auditor.go | 6 +- audit/multi_auditor_test.go | 8 +- audit/socket_auditor.go | 15 +++- audit/socket_auditor_test.go | 46 ++++++++++- cli/cli.go | 23 ++++++ config/config.go | 82 +++++++++++++------ go.mod | 25 +++--- go.sum | 84 +++++++++---------- landjail/manager.go | 16 ++-- landjail/parent.go | 2 +- nsjail_manager/manager.go | 16 ++-- nsjail_manager/parent.go | 2 +- proxy/proxy.go | 54 +++++++----- proxy/proxy_framework_test.go | 34 ++++---- proxy/proxy_session_test.go | 149 ++++++++++++++++++++++++++++++++++ 16 files changed, 437 insertions(+), 156 deletions(-) create mode 100644 proxy/proxy_session_test.go diff --git a/audit/log_auditor.go b/audit/log_auditor.go index 100a58be..5fd953d9 100644 --- a/audit/log_auditor.go +++ b/audit/log_auditor.go @@ -4,7 +4,8 @@ import "log/slog" // LogAuditor implements proxy.Auditor by logging to slog type LogAuditor struct { - logger *slog.Logger + logger *slog.Logger + sessionID string } // NewLogAuditor creates a new LogAuditor @@ -14,19 +15,27 @@ func NewLogAuditor(logger *slog.Logger) *LogAuditor { } } +// NewLogAuditorWithSession creates a new LogAuditor that includes a session ID on every log line. +func NewLogAuditorWithSession(logger *slog.Logger, sessionID string) *LogAuditor { + return &LogAuditor{ + logger: logger, + sessionID: sessionID, + } +} + // AuditRequest logs the request using structured logging func (a *LogAuditor) AuditRequest(req Request) { + fields := []any{ + "method", req.Method, + "url", req.URL, + "host", req.Host, + } + if a.sessionID != "" { + fields = append(fields, "session_id", a.sessionID) + } if req.Allowed { - a.logger.Info("ALLOW", - "method", req.Method, - "url", req.URL, - "host", req.Host, - "rule", req.Rule) + a.logger.Info("ALLOW", append(fields, "rule", req.Rule)...) } else { - a.logger.Warn("DENY", - "method", req.Method, - "url", req.URL, - "host", req.Host, - ) + a.logger.Warn("DENY", fields...) } } diff --git a/audit/multi_auditor.go b/audit/multi_auditor.go index a9be0171..f32cb485 100644 --- a/audit/multi_auditor.go +++ b/audit/multi_auditor.go @@ -28,8 +28,8 @@ func (m *MultiAuditor) AuditRequest(req Request) { // provided configuration. It always includes a LogAuditor for stderr logging, // and conditionally adds a SocketAuditor if audit logs are enabled and the // workspace agent's log proxy socket exists. -func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string) (Auditor, error) { - stderrAuditor := NewLogAuditor(logger) +func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string, sessionID string) (Auditor, error) { + stderrAuditor := NewLogAuditorWithSession(logger, sessionID) auditors := []Auditor{stderrAuditor} if !disableAuditLogs { @@ -48,7 +48,7 @@ func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs boo } agentWillProxy := !os.IsNotExist(err) if agentWillProxy { - socketAuditor := NewSocketAuditor(logger, logProxySocketPath) + socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID) go socketAuditor.Loop(ctx) auditors = append(auditors, socketAuditor) } else { diff --git a/audit/multi_auditor_test.go b/audit/multi_auditor_test.go index e7f8af98..50045cdf 100644 --- a/audit/multi_auditor_test.go +++ b/audit/multi_auditor_test.go @@ -25,7 +25,7 @@ func TestSetupAuditor_DisabledAuditLogs(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) ctx := context.Background() - auditor, err := SetupAuditor(ctx, logger, true, "") + auditor, err := SetupAuditor(ctx, logger, true, "", "") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -50,7 +50,7 @@ func TestSetupAuditor_EmptySocketPath(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) ctx := context.Background() - _, err := SetupAuditor(ctx, logger, false, "") + _, err := SetupAuditor(ctx, logger, false, "", "") if err == nil { t.Fatal("expected error for empty socket path, got nil") } @@ -62,7 +62,7 @@ func TestSetupAuditor_SocketDoesNotExist(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) ctx := context.Background() - auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path") + auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path", "") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -100,7 +100,7 @@ func TestSetupAuditor_SocketExists(t *testing.T) { t.Fatalf("failed to close temp file: %v", err) } - auditor, err := SetupAuditor(ctx, logger, false, socketPath) + auditor, err := SetupAuditor(ctx, logger, false, socketPath, "") if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/audit/socket_auditor.go b/audit/socket_auditor.go index 4c8c5d88..1ec301d2 100644 --- a/audit/socket_auditor.go +++ b/audit/socket_auditor.go @@ -34,6 +34,9 @@ type SocketAuditor struct { batchSize int batchTimerDuration time.Duration socketPath string + // sessionID is included in every batch sent to the workspace agent so that + // all audit events from a single boundary invocation can be correlated. + sessionID string droppedChannelFull atomic.Int64 droppedBatchFull atomic.Int64 @@ -45,7 +48,7 @@ type SocketAuditor struct { // NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's // boundary log proxy socket after SocketAuditor.Loop is called. The socket path // is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath. -func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor { +func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID string) *SocketAuditor { // This channel buffer size intends to allow enough buffering for bursty // AI agent network requests while a batch is being sent to the workspace // agent. @@ -60,6 +63,7 @@ func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor { batchSize: defaultBatchSize, batchTimerDuration: defaultBatchTimerDuration, socketPath: socketPath, + sessionID: sessionID, } } @@ -100,14 +104,17 @@ type flushErr struct { func (e *flushErr) Error() string { return e.err.Error() } // flush sends the current batch of logs to the given connection. -func flush(conn net.Conn, logs []*agentproto.BoundaryLog) *flushErr { +func flush(conn net.Conn, logs []*agentproto.BoundaryLog, sessionID string) *flushErr { if len(logs) == 0 { return nil } msg := &codec.BoundaryMessage{ Msg: &codec.BoundaryMessage_Logs{ - Logs: &agentproto.ReportBoundaryLogsRequest{Logs: logs}, + Logs: &agentproto.ReportBoundaryLogsRequest{ + Logs: logs, + SessionId: sessionID, + }, }, } if err := codec.WriteMessage(conn, codec.TagV2, msg); err != nil { @@ -188,7 +195,7 @@ func (s *SocketAuditor) Loop(ctx context.Context) { return } - if err := flush(conn, batch); err != nil { + if err := flush(conn, batch, s.sessionID); err != nil { if err.permanent { // Data error: discard batch to avoid infinite retries. s.logger.Warn("dropping batch due to data error on flush attempt", diff --git a/audit/socket_auditor_test.go b/audit/socket_auditor_test.go index 67ebbae2..5a486a1f 100644 --- a/audit/socket_auditor_test.go +++ b/audit/socket_auditor_test.go @@ -451,17 +451,59 @@ func TestSocketAuditor_Loop_ShutdownFlushIncludesDrops(t *testing.T) { func TestFlush_EmptyBatch(t *testing.T) { t.Parallel() - err := flush(nil, nil) + err := flush(nil, nil, "") if err != nil { t.Errorf("expected nil error for empty batch, got %v", err) } - err = flush(nil, []*agentproto.BoundaryLog{}) + err = flush(nil, []*agentproto.BoundaryLog{}, "") if err != nil { t.Errorf("expected nil error for empty slice, got %v", err) } } +func TestSocketAuditor_Loop_SessionIDInBatch(t *testing.T) { + t.Parallel() + + const wantSessionID = "test-boundary-session-uuid" + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + }) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + auditor := &SocketAuditor{ + dial: func() (net.Conn, error) { + return clientConn, nil + }, + logger: logger, + logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize), + batchSize: defaultBatchSize, + batchTimerDuration: time.Hour, // disable timer; flush via batch size + sessionID: wantSessionID, + } + + cr := newConnReader() + go readFromConn(t, serverConn, cr) + go auditor.Loop(t.Context()) + + // Fill a full batch so it flushes immediately. + for i := 0; i < auditor.batchSize; i++ { + auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true}) + } + + select { + case req := <-cr.logs: + if req.SessionId != wantSessionID { + t.Errorf("expected session_id=%q, got %q", wantSessionID, req.SessionId) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for flush") + } +} + // setupSocketAuditor creates a SocketAuditor for tests that only exercise // the queueing behavior (no connection needed). func setupSocketAuditor(t *testing.T) *SocketAuditor { diff --git a/cli/cli.go b/cli/cli.go index 7d1567af..556ff496 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -163,6 +163,27 @@ func BaseCommand(version string) *serpent.Command { Value: &cliConfig.LogProxySocketPath, YAML: "", // CLI only, not loaded from YAML }, + { + Flag: "session-id", + Env: "BOUNDARY_SESSION_ID", + Description: "Session ID to use for this boundary invocation. Generated as a UUID if not provided.", + Value: &cliConfig.SessionID, + YAML: "", // CLI only + }, + { + Flag: "session-id-header", + Env: "BOUNDARY_SESSION_ID_HEADER", + Description: fmt.Sprintf("HTTP header name used to inject the session ID into forwarded requests. Default: %q.", config.DefaultSessionIDHeader), + Value: &cliConfig.SessionIDHeader, + YAML: "session_id_header", + }, + { + Flag: "disable-session-id-header", + Env: "BOUNDARY_DISABLE_SESSION_ID_HEADER", + Description: "Disable injection of the session ID header on forwarded requests.", + Value: &cliConfig.DisableSessionIDHeader, + YAML: "disable_session_id_header", + }, { Flag: "version", Description: "Print version information and exit.", @@ -199,6 +220,8 @@ func BaseCommand(version string) *serpent.Command { return fmt.Errorf("could not set up logging: %v", err) } + logger.Info("boundary session started", "session_id", appConfig.SessionID) + appConfigInJSON, err := json.Marshal(appConfig) if err != nil { return err diff --git a/config/config.go b/config/config.go index d6888344..fd859452 100644 --- a/config/config.go +++ b/config/config.go @@ -5,9 +5,14 @@ import ( "strings" "github.com/coder/serpent" + "github.com/google/uuid" "github.com/spf13/pflag" ) +// DefaultSessionIDHeader is the HTTP header injected by boundary on every +// outgoing forwarded request. +const DefaultSessionIDHeader = "X-Agent-Firewall-Session-Id" + // JailType represents the type of jail to use for network isolation type JailType string @@ -56,35 +61,45 @@ func (a AllowStringsArray) Value() []string { } type CliConfig struct { - Config serpent.YAMLConfigPath `yaml:"-"` - AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file - AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only - LogLevel serpent.String `yaml:"log_level"` - LogDir serpent.String `yaml:"log_dir"` - ProxyPort serpent.Int64 `yaml:"proxy_port"` - PprofEnabled serpent.Bool `yaml:"pprof_enabled"` - PprofPort serpent.Int64 `yaml:"pprof_port"` - JailType serpent.String `yaml:"jail_type"` - UseRealDNS serpent.Bool `yaml:"use_real_dns"` - NoUserNamespace serpent.Bool `yaml:"no_user_namespace"` - DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"` - LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"` + Config serpent.YAMLConfigPath `yaml:"-"` + AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file + AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only + LogLevel serpent.String `yaml:"log_level"` + LogDir serpent.String `yaml:"log_dir"` + ProxyPort serpent.Int64 `yaml:"proxy_port"` + PprofEnabled serpent.Bool `yaml:"pprof_enabled"` + PprofPort serpent.Int64 `yaml:"pprof_port"` + JailType serpent.String `yaml:"jail_type"` + UseRealDNS serpent.Bool `yaml:"use_real_dns"` + NoUserNamespace serpent.Bool `yaml:"no_user_namespace"` + DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"` + LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"` + SessionID serpent.String `yaml:"-"` // CLI only; generated if empty + SessionIDHeader serpent.String `yaml:"session_id_header"` + DisableSessionIDHeader serpent.Bool `yaml:"disable_session_id_header"` } type AppConfig struct { - AllowRules []string - LogLevel string - LogDir string - ProxyPort int64 - PprofEnabled bool - PprofPort int64 - JailType JailType - UseRealDNS bool - NoUserNamespace bool - TargetCMD []string - UserInfo *UserInfo - DisableAuditLogs bool - LogProxySocketPath string + AllowRules []string + LogLevel string + LogDir string + ProxyPort int64 + PprofEnabled bool + PprofPort int64 + JailType JailType + UseRealDNS bool + NoUserNamespace bool + TargetCMD []string + UserInfo *UserInfo + DisableAuditLogs bool + LogProxySocketPath string + // SessionID is a UUID generated at startup that identifies this boundary + // invocation. It is injected as a header on all outgoing HTTP requests and + // included in audit batches sent to the workspace agent. + SessionID string + // SessionIDHeader is the HTTP header name used to carry the session ID. + // An empty value means the header is disabled. + SessionIDHeader string } func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) { @@ -102,6 +117,19 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, er userInfo := GetUserInfo() + sessionID := cfg.SessionID.Value() + if sessionID == "" { + sessionID = uuid.NewString() + } + + sessionIDHeader := "" + if !cfg.DisableSessionIDHeader.Value() { + sessionIDHeader = cfg.SessionIDHeader.Value() + if sessionIDHeader == "" { + sessionIDHeader = DefaultSessionIDHeader + } + } + return AppConfig{ AllowRules: allAllowStrings, LogLevel: cfg.LogLevel.Value(), @@ -116,5 +144,7 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, er UserInfo: userInfo, DisableAuditLogs: cfg.DisableAuditLogs.Value(), LogProxySocketPath: cfg.LogProxySocketPath.Value(), + SessionID: sessionID, + SessionIDHeader: sessionIDHeader, }, nil } diff --git a/go.mod b/go.mod index 0188e843..dabdd40e 100644 --- a/go.mod +++ b/go.mod @@ -1,21 +1,22 @@ module github.com/coder/boundary -go 1.25.7 +go 1.25.8 require ( github.com/cenkalti/backoff/v5 v5.0.3 github.com/coder/coder/v2 v2.10.1-0.20260303212958-f758443f44ac github.com/coder/serpent v0.14.0 + github.com/google/uuid v1.6.0 github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c github.com/miekg/dns v1.1.72 github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 - golang.org/x/sys v0.41.0 + golang.org/x/sys v0.42.0 google.golang.org/protobuf v1.36.11 ) require ( - cdr.dev/slog/v3 v3.0.0-rc1 // indirect + cdr.dev/slog/v3 v3.0.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -38,18 +39,20 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect github.com/zeebo/errs v1.4.0 // indirect - go.opentelemetry.io/otel v1.40.0 // indirect - go.opentelemetry.io/otel/trace v1.40.0 // indirect + go.opentelemetry.io/otel v1.42.0 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect - golang.org/x/crypto v0.48.0 // indirect + golang.org/x/crypto v0.49.0 // indirect golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect - golang.org/x/mod v0.33.0 // indirect - golang.org/x/net v0.50.0 // indirect - golang.org/x/sync v0.19.0 // indirect - golang.org/x/term v0.40.0 // indirect - golang.org/x/tools v0.42.0 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/term v0.41.0 // indirect + golang.org/x/tools v0.43.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect gopkg.in/yaml.v3 v3.0.1 // indirect kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect storj.io/drpc v0.0.34 // indirect ) + +replace github.com/coder/coder/v2 => /home/coder/code/coder diff --git a/go.sum b/go.sum index d04ae6c6..c1f9b64e 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cdr.dev/slog/v3 v3.0.0-rc1 h1:EN7Zim6GvTpAeHQjI0ERDEfqKbTyXRvgH4UhlzLpvWM= -cdr.dev/slog/v3 v3.0.0-rc1/go.mod h1:iO/OALX1VxlI03mkodCGdVP7pXzd2bRMvu3ePvlJ9ak= +cdr.dev/slog/v3 v3.0.0 h1:kXFUqAqK7ogRKcvo4BnduQVp+Jh0uV1AUKf3NW5FU74= +cdr.dev/slog/v3 v3.0.0/go.mod h1:iO/OALX1VxlI03mkodCGdVP7pXzd2bRMvu3ePvlJ9ak= cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= @@ -28,14 +28,10 @@ github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMx github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= -github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= -github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= -github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= -github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= -github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= -github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= -github.com/coder/coder/v2 v2.10.1-0.20260303212958-f758443f44ac h1:BmOs70CvLwP7dUt245ffapnSfMME0/tmI2M9nMi6O/w= -github.com/coder/coder/v2 v2.10.1-0.20260303212958-f758443f44ac/go.mod h1:ZyamF6ltx68UIQjbzUchKUTedpUIoyQlwCpOk5c+U4E= +github.com/clipperhouse/displaywidth v0.10.0 h1:GhBG8WuerxjFQQYeuZAeVTuyxuX+UraiZGD4HJQ3Y8g= +github.com/clipperhouse/displaywidth v0.10.0/go.mod h1:XqJajYsaiEwkxOj4bowCTMcT1SgvHo9flfF3jQasdbs= +github.com/clipperhouse/uax29/v2 v2.6.0 h1:z0cDbUV+aPASdFb2/ndFnS9ts/WNXgTNNGFoKXuhpos= +github.com/clipperhouse/uax29/v2 v2.6.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 h1:3A0ES21Ke+FxEM8CXx9n47SZOKOpgSE1bbJzlE4qPVs= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0/go.mod h1:5UuS2Ts+nTToAMeOjNlnHFkPahrtDkmpydBen/3wgZc= github.com/coder/serpent v0.14.0 h1:g7vt2zBMp3nWyAvyhvQduaI53Ku65U3wITMi01+/8pU= @@ -131,14 +127,14 @@ github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= -go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= -go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= -go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= -go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= -go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= -go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= -go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= go.uber.org/goleak v1.3.1-0.20240429205332-517bace7cc29 h1:w0QrHuh0hhUZ++UTQaBM2DMdrWQghZ/UsUb+Wb1+8YE= go.uber.org/goleak v1.3.1-0.20240429205332-517bace7cc29/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= @@ -146,14 +142,14 @@ go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= -golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -161,15 +157,15 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -181,16 +177,16 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -198,25 +194,25 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= -golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= -google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d h1:vsOm753cOAMkt76efriTCDKjpCbK18XGHMJHo0JUKhc= -google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:0oz9d7g9QLSdv9/lgbIjowW1JoxMbxmBVNe8i6tORJI= -google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d h1:EocjzKLywydp5uZ5tJ79iP6Q0UjDnyiHkGRWxuPBP8s= -google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:48U2I+QQUYhsFrg2SY6r+nJzeOtjey7j//WBESw+qyQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d h1:t/LOSXPJ9R0B6fnZNyALBRfZBH0Uy0gT+uR+SJ6syqQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= -google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= -google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 h1:JNfk58HZ8lfmXbYK2vx/UvsqIL59TzByCxPIX4TDmsE= +google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:x5julN69+ED4PcFk/XWayw35O0lf/nGa4aNgODCmNmw= +google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5 h1:CogIeEXn4qWYzzQU0QqvYBM8yDF9cFYzDq9ojSpv0Js= +google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/landjail/manager.go b/landjail/manager.go index b38f54f5..26b5b905 100644 --- a/landjail/manager.go +++ b/landjail/manager.go @@ -33,13 +33,15 @@ func NewLandJail( ) (*LandJail, error) { // Create proxy server proxyServer := proxy.NewProxyServer(proxy.Config{ - HTTPPort: int(config.ProxyPort), - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, - PprofEnabled: config.PprofEnabled, - PprofPort: int(config.PprofPort), + HTTPPort: int(config.ProxyPort), + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + PprofEnabled: config.PprofEnabled, + PprofPort: int(config.PprofPort), + SessionID: config.SessionID, + SessionIDHeader: config.SessionIDHeader, }) return &LandJail{ diff --git a/landjail/parent.go b/landjail/parent.go index 4cd77962..d399d3d8 100644 --- a/landjail/parent.go +++ b/landjail/parent.go @@ -27,7 +27,7 @@ func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) // Create auditor - auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath) + auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath, config.SessionID) if err != nil { return fmt.Errorf("failed to setup auditor: %v", err) } diff --git a/nsjail_manager/manager.go b/nsjail_manager/manager.go index b0ddfda1..a1185771 100644 --- a/nsjail_manager/manager.go +++ b/nsjail_manager/manager.go @@ -36,13 +36,15 @@ func NewNSJailManager( ) (*NSJailManager, error) { // Create proxy server proxyServer := proxy.NewProxyServer(proxy.Config{ - HTTPPort: int(config.ProxyPort), - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, - PprofEnabled: config.PprofEnabled, - PprofPort: int(config.PprofPort), + HTTPPort: int(config.ProxyPort), + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + PprofEnabled: config.PprofEnabled, + PprofPort: int(config.PprofPort), + SessionID: config.SessionID, + SessionIDHeader: config.SessionIDHeader, }) return &NSJailManager{ diff --git a/nsjail_manager/parent.go b/nsjail_manager/parent.go index 10554bff..e5d769b0 100644 --- a/nsjail_manager/parent.go +++ b/nsjail_manager/parent.go @@ -28,7 +28,7 @@ func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) // Create auditor - auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath) + auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath, config.SessionID) if err != nil { return fmt.Errorf("failed to setup auditor: %v", err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index d986d509..e7e2b297 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -24,12 +24,14 @@ import ( // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rulesengine.Engine - auditor audit.Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - started atomic.Bool + ruleEngine rulesengine.Engine + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + started atomic.Bool + sessionID string + sessionIDHeader string // empty means the header is disabled listener net.Listener pprofServer *http.Server @@ -39,25 +41,31 @@ type Server struct { // Config holds configuration for the proxy server type Config struct { - HTTPPort int - RuleEngine rulesengine.Engine - Auditor audit.Auditor - Logger *slog.Logger - TLSConfig *tls.Config - PprofEnabled bool - PprofPort int + HTTPPort int + RuleEngine rulesengine.Engine + Auditor audit.Auditor + Logger *slog.Logger + TLSConfig *tls.Config + PprofEnabled bool + PprofPort int + // SessionID is the per-run UUID injected into forwarded requests. + SessionID string + // SessionIDHeader is the header name to use. Empty disables injection. + SessionIDHeader string } // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, - pprofEnabled: config.PprofEnabled, - pprofPort: config.PprofPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, + sessionID: config.SessionID, + sessionIDHeader: config.SessionIDHeader, } } @@ -335,6 +343,12 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } } + // Stamp the session ID header, overwriting any value the jailed client may + // have set with the same name so that the upstream always sees boundary's ID. + if p.sessionIDHeader != "" && p.sessionID != "" { + newReq.Header.Set(p.sessionIDHeader, p.sessionID) + } + // Make request to destination resp, err := client.Do(newReq) if err != nil { diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index dbfee2bc..d530c834 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -31,16 +31,18 @@ func (m *mockAuditor) AuditRequest(req audit.Request) { // ProxyTest is a high-level test framework for proxy tests type ProxyTest struct { - t *testing.T - server *Server - client *http.Client - proxyClient *http.Client - port int - useCertManager bool - configDir string - startupDelay time.Duration - allowedRules []string - auditor audit.Auditor + t *testing.T + server *Server + client *http.Client + proxyClient *http.Client + port int + useCertManager bool + configDir string + startupDelay time.Duration + allowedRules []string + auditor audit.Auditor + sessionID string + sessionIDHeader string } // ProxyTestOption is a function that configures ProxyTest @@ -152,11 +154,13 @@ func (pt *ProxyTest) Start() *ProxyTest { } pt.server = NewProxyServer(Config{ - HTTPPort: pt.port, - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, + HTTPPort: pt.port, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + SessionID: pt.sessionID, + SessionIDHeader: pt.sessionIDHeader, }) err = pt.server.Start() diff --git a/proxy/proxy_session_test.go b/proxy/proxy_session_test.go new file mode 100644 index 00000000..04f99101 --- /dev/null +++ b/proxy/proxy_session_test.go @@ -0,0 +1,149 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// WithSessionIDHeader sets the session-ID header name on the proxy server under test. +func WithSessionIDHeader(header string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionIDHeader = header + } +} + +// WithSessionID sets the session ID on the proxy server under test. +func WithSessionID(id string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionID = id + } +} + +// recordingHandler captures the headers received by an upstream server. +type recordingHandler struct { + mu sync.Mutex + headers []http.Header +} + +func (r *recordingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.mu.Lock() + r.headers = append(r.headers, req.Header.Clone()) + r.mu.Unlock() + w.WriteHeader(http.StatusOK) +} + +func (r *recordingHandler) last() http.Header { + r.mu.Lock() + defer r.mu.Unlock() + if len(r.headers) == 0 { + return nil + } + return r.headers[len(r.headers)-1] +} + +// TestSessionIDHeader_InjectedOnForwardedRequest verifies that the configured +// session ID header is stamped on requests forwarded to the upstream server, +// and that it overwrites any value the jailed client set. +func TestSessionIDHeader_InjectedOnForwardedRequest(t *testing.T) { + const sessionID = "test-session-uuid-1234" + const headerName = "X-Agent-Firewall-Session-Id" + + recorder := &recordingHandler{} + backend := httptest.NewServer(recorder) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithProxyPort(8085), + WithAllowedDomain(backendURL.Hostname()), + WithSessionID(sessionID), + WithSessionIDHeader(headerName), + ).Start() + defer pt.Stop() + + // Make a request through the proxy; the client does not set the header. + resp, err := pt.proxyClient.Get(backend.URL + "/path") + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := recorder.last() + require.NotNil(t, got) + assert.Equal(t, sessionID, got.Get(headerName), + "upstream should receive the session ID header from boundary") +} + +// TestSessionIDHeader_OverwritesClientValue ensures boundary's value wins even +// when the jailed process has already set the same header. +func TestSessionIDHeader_OverwritesClientValue(t *testing.T) { + const sessionID = "boundary-session-uuid" + const headerName = "X-Agent-Firewall-Session-Id" + const clientValue = "evil-client-forged-value" + + recorder := &recordingHandler{} + backend := httptest.NewServer(recorder) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + require.NoError(t, err) + + pt := NewProxyTest(t, + WithProxyPort(8086), + WithAllowedDomain(backendURL.Hostname()), + WithSessionID(sessionID), + WithSessionIDHeader(headerName), + ).Start() + defer pt.Stop() + + // Build a request that includes the header with a forged value. + req, err := http.NewRequest(http.MethodGet, backend.URL+"/", nil) + require.NoError(t, err) + req.Header.Set(headerName, clientValue) + + resp, err := pt.proxyClient.Do(req) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + got := recorder.last() + require.NotNil(t, got) + assert.Equal(t, sessionID, got.Get(headerName), + "boundary should overwrite the client-supplied header value") +} + +// TestSessionIDHeader_OmittedWhenDisabled verifies that no header is injected +// when SessionIDHeader is empty (i.e. the feature is disabled). +func TestSessionIDHeader_OmittedWhenDisabled(t *testing.T) { + const headerName = "X-Agent-Firewall-Session-Id" + + recorder := &recordingHandler{} + backend := httptest.NewServer(recorder) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + require.NoError(t, err) + + // No WithSessionIDHeader option → sessionIDHeader stays empty → disabled. + pt := NewProxyTest(t, + WithProxyPort(8087), + WithAllowedDomain(backendURL.Hostname()), + WithSessionID("some-id"), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.URL + "/path") + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + got := recorder.last() + require.NotNil(t, got) + assert.Empty(t, got.Get(headerName), + "upstream should not receive the session ID header when the feature is disabled") +} From adf9589239fcff5f99145b03c3ffa92a02ba1d07 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Mon, 20 Apr 2026 16:34:58 +0000 Subject: [PATCH 2/3] rename agent firewall header --- config/config.go | 2 +- proxy/proxy_session_test.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config/config.go b/config/config.go index fd859452..4fd9f191 100644 --- a/config/config.go +++ b/config/config.go @@ -11,7 +11,7 @@ import ( // DefaultSessionIDHeader is the HTTP header injected by boundary on every // outgoing forwarded request. -const DefaultSessionIDHeader = "X-Agent-Firewall-Session-Id" +const DefaultSessionIDHeader = "X-Coder-Agent-Firewall-Session-Id" // JailType represents the type of jail to use for network isolation type JailType string diff --git a/proxy/proxy_session_test.go b/proxy/proxy_session_test.go index 04f99101..1c9f8457 100644 --- a/proxy/proxy_session_test.go +++ b/proxy/proxy_session_test.go @@ -52,7 +52,7 @@ func (r *recordingHandler) last() http.Header { // and that it overwrites any value the jailed client set. func TestSessionIDHeader_InjectedOnForwardedRequest(t *testing.T) { const sessionID = "test-session-uuid-1234" - const headerName = "X-Agent-Firewall-Session-Id" + const headerName = "X-Coder-Agent-Firewall-Session-Id" recorder := &recordingHandler{} backend := httptest.NewServer(recorder) @@ -85,7 +85,7 @@ func TestSessionIDHeader_InjectedOnForwardedRequest(t *testing.T) { // when the jailed process has already set the same header. func TestSessionIDHeader_OverwritesClientValue(t *testing.T) { const sessionID = "boundary-session-uuid" - const headerName = "X-Agent-Firewall-Session-Id" + const headerName = "X-Coder-Agent-Firewall-Session-Id" const clientValue = "evil-client-forged-value" recorder := &recordingHandler{} @@ -121,7 +121,7 @@ func TestSessionIDHeader_OverwritesClientValue(t *testing.T) { // TestSessionIDHeader_OmittedWhenDisabled verifies that no header is injected // when SessionIDHeader is empty (i.e. the feature is disabled). func TestSessionIDHeader_OmittedWhenDisabled(t *testing.T) { - const headerName = "X-Agent-Firewall-Session-Id" + const headerName = "X-Coder-Agent-Firewall-Session-Id" recorder := &recordingHandler{} backend := httptest.NewServer(recorder) From 24e1a0bf0ed7128fb89f4f63e6161ba49d12e7ed Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Mon, 20 Apr 2026 17:48:49 +0000 Subject: [PATCH 3/3] make session id header injection configurable --- cli/cli.go | 13 +++++ config/config.go | 52 +++++++++++-------- landjail/manager.go | 2 + landjail/parent.go | 9 +++- nsjail_manager/manager.go | 2 + nsjail_manager/parent.go | 9 +++- proxy/proxy.go | 29 +++++++++-- proxy/proxy_framework_test.go | 40 ++++++++++----- proxy/proxy_session_test.go | 94 ++++++++++++++++++++++++++++++++++- 9 files changed, 210 insertions(+), 40 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 556ff496..f18efe6d 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -184,6 +184,19 @@ func BaseCommand(version string) *serpent.Command { Value: &cliConfig.DisableSessionIDHeader, YAML: "disable_session_id_header", }, + { + Flag: "session-id-inject-domain", + Env: "BOUNDARY_SESSION_ID_INJECT_DOMAIN", + Description: "Match rule (repeatable) selecting which requests receive the session ID header. Merged with session_id_inject_domains from config file. Uses the same syntax as --allow (e.g. \"domain=dev.coder.com path=/api/v2/aibridge/*\"). If no rules are configured the header is never injected.", + Value: &cliConfig.SessionIDMatch, + YAML: "", // CLI only, not loaded from YAML + }, + { + Flag: "", // No CLI flag, YAML only + Description: "Session ID match rules from config file (YAML only). Merged with --session-id-inject-domain CLI flags.", + Value: &cliConfig.SessionIDMatchList, + YAML: "session_id_inject_domains", + }, { Flag: "version", Description: "Print version information and exit.", diff --git a/config/config.go b/config/config.go index 4fd9f191..e5c2ac95 100644 --- a/config/config.go +++ b/config/config.go @@ -62,8 +62,8 @@ func (a AllowStringsArray) Value() []string { type CliConfig struct { Config serpent.YAMLConfigPath `yaml:"-"` - AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file - AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only + AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file + AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only LogLevel serpent.String `yaml:"log_level"` LogDir serpent.String `yaml:"log_dir"` ProxyPort serpent.Int64 `yaml:"proxy_port"` @@ -74,9 +74,11 @@ type CliConfig struct { NoUserNamespace serpent.Bool `yaml:"no_user_namespace"` DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"` LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"` - SessionID serpent.String `yaml:"-"` // CLI only; generated if empty + SessionID serpent.String `yaml:"-"` // CLI only; generated if empty SessionIDHeader serpent.String `yaml:"session_id_header"` DisableSessionIDHeader serpent.Bool `yaml:"disable_session_id_header"` + SessionIDMatchList serpent.StringArray `yaml:"session_id_inject_domains"` // From config file + SessionIDMatch AllowStringsArray `yaml:"-"` // From CLI flags only } type AppConfig struct { @@ -94,12 +96,17 @@ type AppConfig struct { DisableAuditLogs bool LogProxySocketPath string // SessionID is a UUID generated at startup that identifies this boundary - // invocation. It is injected as a header on all outgoing HTTP requests and - // included in audit batches sent to the workspace agent. + // invocation. It is injected as a header on matching outgoing HTTP requests + // and included in audit batches sent to the workspace agent. SessionID string // SessionIDHeader is the HTTP header name used to carry the session ID. // An empty value means the header is disabled. SessionIDHeader string + // SessionIDMatchRules is the merged list of match-rule strings from the + // YAML session_id_matches key and the --session-id-match CLI flag. The + // session ID header is only injected on requests that match at least one + // rule. Empty means never inject. + SessionIDMatchRules []string } func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) { @@ -130,21 +137,26 @@ func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, er } } + // Merge session-ID match rules: YAML list first, then CLI flags (same + // pattern as allAllowStrings above). + allMatchStrings := append(cfg.SessionIDMatchList.Value(), cfg.SessionIDMatch.Value()...) + return AppConfig{ - AllowRules: allAllowStrings, - LogLevel: cfg.LogLevel.Value(), - LogDir: cfg.LogDir.Value(), - ProxyPort: cfg.ProxyPort.Value(), - PprofEnabled: cfg.PprofEnabled.Value(), - PprofPort: cfg.PprofPort.Value(), - JailType: jailType, - UseRealDNS: cfg.UseRealDNS.Value(), - NoUserNamespace: cfg.NoUserNamespace.Value(), - TargetCMD: targetCMD, - UserInfo: userInfo, - DisableAuditLogs: cfg.DisableAuditLogs.Value(), - LogProxySocketPath: cfg.LogProxySocketPath.Value(), - SessionID: sessionID, - SessionIDHeader: sessionIDHeader, + AllowRules: allAllowStrings, + LogLevel: cfg.LogLevel.Value(), + LogDir: cfg.LogDir.Value(), + ProxyPort: cfg.ProxyPort.Value(), + PprofEnabled: cfg.PprofEnabled.Value(), + PprofPort: cfg.PprofPort.Value(), + JailType: jailType, + UseRealDNS: cfg.UseRealDNS.Value(), + NoUserNamespace: cfg.NoUserNamespace.Value(), + TargetCMD: targetCMD, + UserInfo: userInfo, + DisableAuditLogs: cfg.DisableAuditLogs.Value(), + LogProxySocketPath: cfg.LogProxySocketPath.Value(), + SessionID: sessionID, + SessionIDHeader: sessionIDHeader, + SessionIDMatchRules: allMatchStrings, }, nil } diff --git a/landjail/manager.go b/landjail/manager.go index 26b5b905..d65befa3 100644 --- a/landjail/manager.go +++ b/landjail/manager.go @@ -26,6 +26,7 @@ type LandJail struct { func NewLandJail( ruleEngine rulesengine.Engine, + sessionIDMatch rulesengine.Engine, auditor audit.Auditor, tlsConfig *tls.Config, logger *slog.Logger, @@ -42,6 +43,7 @@ func NewLandJail( PprofPort: int(config.PprofPort), SessionID: config.SessionID, SessionIDHeader: config.SessionIDHeader, + SessionIDMatch: sessionIDMatch, }) return &LandJail{ diff --git a/landjail/parent.go b/landjail/parent.go index d399d3d8..080788b5 100644 --- a/landjail/parent.go +++ b/landjail/parent.go @@ -26,6 +26,13 @@ func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig // Create rule engine ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) + // Parse session-ID match rules and build a second engine that gates header injection. + sessionIDMatchRules, err := rulesengine.ParseAllowSpecs(config.SessionIDMatchRules) + if err != nil { + return fmt.Errorf("failed to parse session-id-match rules: %v", err) + } + sessionIDMatchEngine := rulesengine.NewRuleEngine(sessionIDMatchRules, logger) + // Create auditor auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath, config.SessionID) if err != nil { @@ -50,7 +57,7 @@ func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig return fmt.Errorf("failed to setup TLS and CA certificate: %v", err) } - landjail, err := NewLandJail(ruleEngine, auditor, tlsConfig, logger, config) + landjail, err := NewLandJail(ruleEngine, sessionIDMatchEngine, auditor, tlsConfig, logger, config) if err != nil { return fmt.Errorf("failed to create landjail: %v", err) } diff --git a/nsjail_manager/manager.go b/nsjail_manager/manager.go index a1185771..33e57f1c 100644 --- a/nsjail_manager/manager.go +++ b/nsjail_manager/manager.go @@ -28,6 +28,7 @@ type NSJailManager struct { func NewNSJailManager( ruleEngine rulesengine.Engine, + sessionIDMatch rulesengine.Engine, auditor audit.Auditor, tlsConfig *tls.Config, jailer nsjail.Jailer, @@ -45,6 +46,7 @@ func NewNSJailManager( PprofPort: int(config.PprofPort), SessionID: config.SessionID, SessionIDHeader: config.SessionIDHeader, + SessionIDMatch: sessionIDMatch, }) return &NSJailManager{ diff --git a/nsjail_manager/parent.go b/nsjail_manager/parent.go index e5d769b0..b11ae8ae 100644 --- a/nsjail_manager/parent.go +++ b/nsjail_manager/parent.go @@ -27,6 +27,13 @@ func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig // Create rule engine ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) + // Parse session-ID match rules and build a second engine that gates header injection. + sessionIDMatchRules, err := rulesengine.ParseAllowSpecs(config.SessionIDMatchRules) + if err != nil { + return fmt.Errorf("failed to parse session-id-match rules: %v", err) + } + sessionIDMatchEngine := rulesengine.NewRuleEngine(sessionIDMatchRules, logger) + // Create auditor auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath, config.SessionID) if err != nil { @@ -65,7 +72,7 @@ func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig } // Create boundary instance - nsJailMgr, err := NewNSJailManager(ruleEngine, auditor, tlsConfig, jailer, logger, config) + nsJailMgr, err := NewNSJailManager(ruleEngine, sessionIDMatchEngine, auditor, tlsConfig, jailer, logger, config) if err != nil { return fmt.Errorf("failed to create boundary instance: %v", err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index e7e2b297..a8ffbb80 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -31,7 +31,8 @@ type Server struct { httpPort int started atomic.Bool sessionID string - sessionIDHeader string // empty means the header is disabled + sessionIDHeader string // empty means the header is disabled + sessionIDMatch rulesengine.Engine // gates which requests get the header listener net.Listener pprofServer *http.Server @@ -48,10 +49,15 @@ type Config struct { TLSConfig *tls.Config PprofEnabled bool PprofPort int - // SessionID is the per-run UUID injected into forwarded requests. + // SessionID is the per-run UUID injected into matching forwarded requests. SessionID string // SessionIDHeader is the header name to use. Empty disables injection. SessionIDHeader string + // SessionIDMatch gates which requests receive the session ID header. + // The header is only set when this engine's Evaluate returns Allowed=true. + // An engine with zero rules returns Allowed=false for every request, which + // means an empty SessionIDMatch effectively disables header injection. + SessionIDMatch rulesengine.Engine } // NewProxyServer creates a new proxy server instance @@ -66,6 +72,7 @@ func NewProxyServer(config Config) *Server { pprofPort: config.PprofPort, sessionID: config.SessionID, sessionIDHeader: config.SessionIDHeader, + sessionIDMatch: config.SessionIDMatch, } } @@ -343,9 +350,9 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } } - // Stamp the session ID header, overwriting any value the jailed client may - // have set with the same name so that the upstream always sees boundary's ID. - if p.sessionIDHeader != "" && p.sessionID != "" { + // Stamp the session ID header on matching requests, overwriting any value + // the jailed client may have set so the upstream always sees boundary's ID. + if p.shouldStampSessionID(newReq.Method, targetURL.String()) { newReq.Header.Set(p.sessionIDHeader, p.sessionID) } @@ -405,6 +412,18 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { p.logger.Debug("Successfully wrote to connection") } +// shouldStampSessionID returns true when the session ID header should be set +// on the outgoing request. Injection requires both a header name and a session +// ID to be configured, and the request must match at least one session-ID match +// rule. An engine with zero rules always returns Allowed=false, so leaving +// SessionIDMatch unconfigured effectively disables injection. +func (p *Server) shouldStampSessionID(method, url string) bool { + if p.sessionIDHeader == "" || p.sessionID == "" { + return false + } + return p.sessionIDMatch.Evaluate(method, url).Allowed +} + func (p *Server) writeBlockedResponse(conn net.Conn, req *http.Request) { // Create a response object resp := &http.Response{ diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index d530c834..efdc9656 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -31,18 +31,19 @@ func (m *mockAuditor) AuditRequest(req audit.Request) { // ProxyTest is a high-level test framework for proxy tests type ProxyTest struct { - t *testing.T - server *Server - client *http.Client - proxyClient *http.Client - port int - useCertManager bool - configDir string - startupDelay time.Duration - allowedRules []string - auditor audit.Auditor - sessionID string - sessionIDHeader string + t *testing.T + server *Server + client *http.Client + proxyClient *http.Client + port int + useCertManager bool + configDir string + startupDelay time.Duration + allowedRules []string + auditor audit.Auditor + sessionID string + sessionIDHeader string + sessionIDMatches []string // match rules for session ID header injection } // ProxyTestOption is a function that configures ProxyTest @@ -110,6 +111,15 @@ func WithAuditor(auditor audit.Auditor) ProxyTestOption { } } +// WithSessionIDMatch adds a session-ID match rule. Only requests matching at +// least one rule will receive the session ID header. Omitting this option +// means no injection (same as an empty match list). +func WithSessionIDMatch(rule string) ProxyTestOption { + return func(pt *ProxyTest) { + pt.sessionIDMatches = append(pt.sessionIDMatches, rule) + } +} + // Start starts the proxy server func (pt *ProxyTest) Start() *ProxyTest { pt.t.Helper() @@ -123,6 +133,11 @@ func (pt *ProxyTest) Start() *ProxyTest { ruleEngine := rulesengine.NewRuleEngine(testRules, logger) + matchRules, err := rulesengine.ParseAllowSpecs(pt.sessionIDMatches) + require.NoError(pt.t, err, "Failed to parse session-ID match rules") + + sessionIDMatchEngine := rulesengine.NewRuleEngine(matchRules, logger) + // Use custom auditor if provided, otherwise use no-op mock auditor := pt.auditor if auditor == nil { @@ -161,6 +176,7 @@ func (pt *ProxyTest) Start() *ProxyTest { TLSConfig: tlsConfig, SessionID: pt.sessionID, SessionIDHeader: pt.sessionIDHeader, + SessionIDMatch: sessionIDMatchEngine, }) err = pt.server.Start() diff --git a/proxy/proxy_session_test.go b/proxy/proxy_session_test.go index 1c9f8457..4c47202a 100644 --- a/proxy/proxy_session_test.go +++ b/proxy/proxy_session_test.go @@ -66,6 +66,7 @@ func TestSessionIDHeader_InjectedOnForwardedRequest(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionID(sessionID), WithSessionIDHeader(headerName), + WithSessionIDMatch("domain="+backendURL.Hostname()), ).Start() defer pt.Stop() @@ -100,6 +101,7 @@ func TestSessionIDHeader_OverwritesClientValue(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionID(sessionID), WithSessionIDHeader(headerName), + WithSessionIDMatch("domain="+backendURL.Hostname()), ).Start() defer pt.Stop() @@ -132,7 +134,7 @@ func TestSessionIDHeader_OmittedWhenDisabled(t *testing.T) { // No WithSessionIDHeader option → sessionIDHeader stays empty → disabled. pt := NewProxyTest(t, - WithProxyPort(8087), + WithProxyPort(8083), WithAllowedDomain(backendURL.Hostname()), WithSessionID("some-id"), ).Start() @@ -147,3 +149,93 @@ func TestSessionIDHeader_OmittedWhenDisabled(t *testing.T) { assert.Empty(t, got.Get(headerName), "upstream should not receive the session ID header when the feature is disabled") } + +// TestSessionIDHeader_NotInjectedWithoutMatchRules verifies that no header is +// injected when no session-ID match rules are configured, even if a session ID +// and header name are set. This covers the new "empty rules = never inject" +// default introduced with the match-rule gating feature. +func TestSessionIDHeader_NotInjectedWithoutMatchRules(t *testing.T) { + const headerName = "X-Coder-Agent-Firewall-Session-Id" + + recorder := &recordingHandler{} + backend := httptest.NewServer(recorder) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + require.NoError(t, err) + + // No WithSessionIDMatch → match engine has zero rules → header must not appear. + pt := NewProxyTest(t, + WithProxyPort(8088), + WithAllowedDomain(backendURL.Hostname()), + WithSessionID("some-session-id"), + WithSessionIDHeader(headerName), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(backend.URL + "/path") + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := recorder.last() + require.NotNil(t, got) + assert.Empty(t, got.Get(headerName), + "upstream should not receive the session ID header when no match rules are configured") +} + +// TestSessionIDHeader_InjectedOnlyOnMatchingRequests verifies selective +// injection: the header is stamped on requests whose URL matches a session-ID +// rule but not on requests to a different backend. +func TestSessionIDHeader_InjectedOnlyOnMatchingRequests(t *testing.T) { + const sessionID = "selective-session-id" + const headerName = "X-Coder-Agent-Firewall-Session-Id" + + // Two upstream backends: one that matches the session-ID rule, one that does not. + matchedRecorder := &recordingHandler{} + matchedBackend := httptest.NewServer(matchedRecorder) + defer matchedBackend.Close() + + unmatchedRecorder := &recordingHandler{} + unmatchedBackend := httptest.NewServer(unmatchedRecorder) + defer unmatchedBackend.Close() + + matchedURL, err := url.Parse(matchedBackend.URL) + require.NoError(t, err) + unmatchedURL, err := url.Parse(unmatchedBackend.URL) + require.NoError(t, err) + + // Allow both backends through the proxy, but only inject the header for the + // matched backend's path prefix. + pt := NewProxyTest(t, + WithProxyPort(8089), + WithAllowedDomain(matchedURL.Hostname()), + WithAllowedDomain(unmatchedURL.Hostname()), + WithSessionID(sessionID), + WithSessionIDHeader(headerName), + WithSessionIDMatch("domain="+matchedURL.Hostname()+" path=/api/v2/aibridge/*"), + ).Start() + defer pt.Stop() + + // Request to the matched backend on the matching path. + resp, err := pt.proxyClient.Get(matchedBackend.URL + "/api/v2/aibridge/anthropic/v1/messages") + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode) + + got := matchedRecorder.last() + require.NotNil(t, got) + assert.Equal(t, sessionID, got.Get(headerName), + "matched backend should receive the session ID header") + + // Request to the unmatched backend: no header. + resp, err = pt.proxyClient.Get(unmatchedBackend.URL + "/other/path") + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode) + + got = unmatchedRecorder.last() + require.NotNil(t, got) + assert.Empty(t, got.Get(headerName), + "unmatched backend should not receive the session ID header") +}