From 55942d8f62d4165396752f8925a8884de17d3218 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 19:19:14 -0400 Subject: [PATCH 01/22] feat(auth): widen header contract for non-Bearer providers (phase 2 pr 1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HeadersFromRequest gains Authorization, X-Goog-Iap-Jwt-Assertion, X-Amz-Date, X-Amz-Security-Token so future providers consuming non-Bearer formats (aws_sigv4, gcp_iap) can read what they need without changing the Provider.Verify signature. TokenKind recognizes the "AWS4-HMAC-SHA256 " prefix and returns "sigv4", so audit logs can distinguish Sigv4 requests from "empty" even though the Bearer extractor returns "". Middleware now consults the chain even when no Bearer token was extracted, provided a non-Bearer auth header is present (Sigv4 Authorization or IAP assertion). When NO auth headers at all are present, the audit reason still resolves to ErrMissingBearer — preserving review #4's stable "missing_token" reason code. Phase 1 providers see zero behavior change; their Verify path is unchanged. All Phase 1 tests pass without modification. --- forge-core/auth/middleware.go | 32 +++++- forge-core/auth/middleware_test.go | 167 +++++++++++++++++++++++++++++ forge-core/auth/provider.go | 27 ++++- forge-core/auth/provider_test.go | 77 +++++++++++++ 4 files changed, 298 insertions(+), 5 deletions(-) diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index cb09a87..bd5b5e2 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -59,9 +59,9 @@ type MiddlewareOptions struct { // - identity is non-nil and err is nil on success. // - identity is nil and err carries the chain error on failure // (or auth.ErrMissingBearer when the header was absent). - // - tokenKind is "jwt", "opaque", or "empty" — structural metadata - // safe to log. The token itself is NOT passed; callers must not - // try to recover it from the request. + // - tokenKind is "jwt", "opaque", "sigv4", or "empty" — structural + // metadata safe to log. The token itself is NOT passed; callers + // must not try to recover it from the request. // // Callbacks should be cheap — they run on the request hot path. OnAuth func(r *http.Request, identity *Identity, err error, tokenKind string) @@ -109,7 +109,33 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { token := extractBearerToken(r) kind := TokenKind(token) + // Phase 2: when no Bearer token was extracted, classify the + // audit kind from the raw Authorization header. Lets aws_sigv4 + // emit the right token_kind value even though there's no Bearer. + rawAuth := r.Header.Get("Authorization") + if kind == "empty" && strings.HasPrefix(rawAuth, "AWS4-HMAC-SHA256 ") { + kind = "sigv4" + } + + // hasNonBearerAuth signals "the caller IS attempting auth, just + // not via Bearer." Used to differentiate "missing_token" (no + // auth headers at all) from "chain rejected/didn't match" in + // the audit reason. + hasNonBearerAuth := false if token == "" { + if rawAuth != "" { + hasNonBearerAuth = true + } + if r.Header.Get("X-Goog-Iap-Jwt-Assertion") != "" { + hasNonBearerAuth = true + } + } + + // Phase 2 change: consult the chain even on empty Bearer, so + // providers that speak non-Bearer formats (aws_sigv4, gcp_iap) + // get a chance. Phase 1 short-circuit (empty bearer → 401) + // preserved only when there are no auth-shaped headers at all. + if token == "" && !hasNonBearerAuth { notifyAuth(opts.OnAuth, r, nil, ErrMissingBearer, kind) writeAuthError(w, "valid bearer token required") return diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index f163e2f..d117b33 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -443,6 +443,173 @@ func contains(s, sub string) bool { return false } +// --- Phase 2: non-Bearer auth formats reach the chain --- + +// headerCapturingProvider records the headers it sees and lets the test +// script the response. Used by Phase 2 middleware tests to assert that +// providers consuming non-Bearer formats (Sigv4, IAP) actually receive +// what they need. +type headerCapturingProvider struct { + sawToken string + sawHeaders Headers + identity *Identity + err error +} + +func (p *headerCapturingProvider) Name() string { return "test_capture" } +func (p *headerCapturingProvider) Verify(_ context.Context, token string, h Headers) (*Identity, error) { + p.sawToken = token + p.sawHeaders = h + return p.identity, p.err +} + +func TestMiddleware_Sigv4HeaderReachesChain(t *testing.T) { + // Phase 2 change: a Sigv4-shaped Authorization reaches the chain even + // though extractBearerToken returns "". The middleware no longer + // short-circuits to 401 when there's an auth-shaped header present. + spy := &headerCapturingProvider{err: ErrTokenNotForMe} + opts := MiddlewareOptions{ + Chain: NewChainProvider(spy), + SkipPaths: DefaultSkipPaths(), + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host, Signature=ab") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if spy.sawHeaders == nil { + t.Fatal("chain was never invoked — middleware short-circuited on empty Bearer") + } + got := spy.sawHeaders.Get("Authorization") + if got == "" || !startsWithLocal(got, "AWS4-HMAC-SHA256 ") { + t.Errorf("chain saw Authorization = %q, want Sigv4-prefixed value", got) + } +} + +func TestMiddleware_IAPHeaderReachesChain(t *testing.T) { + // Phase 2 change: gcp_iap's X-Goog-Iap-Jwt-Assertion is enough — no + // Authorization header at all is needed for the chain to be consulted. + spy := &headerCapturingProvider{err: ErrTokenNotForMe} + opts := MiddlewareOptions{ + Chain: NewChainProvider(spy), + SkipPaths: DefaultSkipPaths(), + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJ.eyJ.sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if spy.sawHeaders == nil { + t.Fatal("chain was never invoked — middleware short-circuited on empty Bearer") + } + if got := spy.sawHeaders.Get("X-Goog-Iap-Jwt-Assertion"); got != "eyJ.eyJ.sig" { + t.Errorf("chain saw IAP header = %q, want eyJ.eyJ.sig", got) + } +} + +func TestMiddleware_NoAuthHeaders_PreservesMissingTokenReason(t *testing.T) { + // Review #4 contract regression check: when the caller did NOT attempt + // auth at all (no Bearer, no Sigv4 Authorization, no IAP header), the + // audit reason MUST stay ErrMissingBearer — not be widened to "not_for_me". + // This lets ops dashboards still differentiate "client didn't auth" from + // "client tried a format we don't speak." + spyCalled := false + chain := NewChainProvider(&headerCapturingProvider{ + err: ErrTokenNotForMe, + identity: nil, + }) + + var gotErr error + opts := MiddlewareOptions{ + Chain: chain, + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, err error, _ string) { + spyCalled = true + gotErr = err + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Request with NO auth-shaped headers at all. + req := httptest.NewRequest("POST", "/tasks", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if !spyCalled { + t.Fatal("OnAuth callback never fired") + } + if !errors.Is(gotErr, ErrMissingBearer) { + t.Errorf("OnAuth err = %v, want ErrMissingBearer (review #4 regression)", gotErr) + } + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestMiddleware_TokenKind_Sigv4OnEmptyBearer(t *testing.T) { + // Phase 2: the audit token_kind reads "sigv4" when the request has a + // Sigv4-shaped Authorization header, even though Bearer extraction + // returned "". Audit dashboards need this signal to count Sigv4 + // requests distinctly from "empty" (no auth attempt at all). + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA, SignedHeaders=host, Signature=ab") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "sigv4" { + t.Errorf("token kind = %q, want sigv4", gotKind) + } +} + +func TestMiddleware_TokenKind_EmptyWhenTrulyNoAuth(t *testing.T) { + // Counterpart to the test above: when the caller didn't attempt auth + // at all, token_kind should be "empty" — not silently widened. + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "empty" { + t.Errorf("token kind = %q, want empty", gotKind) + } +} + +// startsWithLocal is the middleware_test.go equivalent of provider_test.go's +// startsWith helper. Kept local so the two test files don't depend on each +// other's helpers. +func startsWithLocal(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + func TestClassifyAuthFailure(t *testing.T) { tests := []struct { name string diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go index 527eac2..24ce9ae 100644 --- a/forge-core/auth/provider.go +++ b/forge-core/auth/provider.go @@ -89,20 +89,40 @@ func (h Headers) Get(key string) string { // HeadersFromRequest extracts the well-known headers providers may use. // Keep this list narrow — providers should be explicit about the contract. +// +// Authorization is included verbatim (in addition to the middleware's Bearer +// extraction) so providers that speak non-Bearer formats can read it — e.g. +// AWS Sigv4 (Authorization: AWS4-HMAC-SHA256 …). Adding a new header here is +// a deliberate widening of the contract; providers must not assume any header +// is present. +// +// Phase 2 additions (Authorization, X-Goog-Iap-Jwt-Assertion, X-Amz-Date, +// X-Amz-Security-Token) are consumed by aws_sigv4 and gcp_iap providers. func HeadersFromRequest(r *http.Request) Headers { return Headers{ "X-Org-ID": r.Header.Get("X-Org-ID"), "X-Request-ID": r.Header.Get("X-Request-ID"), "org-id": r.Header.Get("org-id"), "org_id": r.Header.Get("org_id"), + + "Authorization": r.Header.Get("Authorization"), + "X-Goog-Iap-Jwt-Assertion": r.Header.Get("X-Goog-Iap-Jwt-Assertion"), + "X-Amz-Date": r.Header.Get("X-Amz-Date"), + "X-Amz-Security-Token": r.Header.Get("X-Amz-Security-Token"), } } // TokenKind classifies a presented bearer token structurally — useful for // audit logging without leaking the token itself. // -// "jwt" → three base64url segments separated by dots -// "opaque" → anything else (Okta access tokens, custom verifier tokens, dev secrets) +// "empty" → empty token +// "sigv4" → starts with "AWS4-HMAC-SHA256 " (the AWS Sigv4 algorithm prefix). +// +// Callers route the raw Authorization header through TokenKind for +// this case, since extractBearerToken returns "" for non-Bearer auth. +// +// "jwt" → three base64url segments separated by dots +// "opaque" → anything else (Okta access tokens, custom verifier tokens, dev secrets) // // This is a CHEAP structural check — it does not parse or validate. // Never log the token; this helper is safe to log. @@ -110,6 +130,9 @@ func TokenKind(token string) string { if token == "" { return "empty" } + if strings.HasPrefix(token, "AWS4-HMAC-SHA256 ") { + return "sigv4" + } dots := 0 for i := 0; i < len(token); i++ { if token[i] == '.' { diff --git a/forge-core/auth/provider_test.go b/forge-core/auth/provider_test.go index c175b69..1f612a8 100644 --- a/forge-core/auth/provider_test.go +++ b/forge-core/auth/provider_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "maps" + "net/http" "reflect" "sort" "sync" @@ -464,6 +465,14 @@ func TestTokenKind(t *testing.T) { {"opaque one dot", "abc.def", "opaque"}, {"opaque four segments", "a.b.c.d", "opaque"}, {"jwt with empty segments still has 2 dots", "..", "jwt"}, + // Phase 2: aws_sigv4 — the Authorization header's algorithm prefix + // is enough to classify; middleware routes the raw header here when + // the Bearer extractor returns "". + {"sigv4 prefix", "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=abc", "sigv4"}, + {"sigv4 prefix only", "AWS4-HMAC-SHA256 ", "sigv4"}, + // Defensive: a token that LOOKS like Sigv4 but lacks the trailing + // space must NOT be classified as sigv4 — we want a clean prefix match. + {"sigv4-like without trailing space", "AWS4-HMAC-SHA256xyz", "opaque"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -474,6 +483,74 @@ func TestTokenKind(t *testing.T) { } } +// --- HeadersFromRequest --- + +func TestHeadersFromRequest_Phase1Headers(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("X-Org-ID", "acme") + req.Header.Set("X-Request-ID", "req-1") + req.Header.Set("org-id", "lower") + req.Header.Set("org_id", "snake") + + h := HeadersFromRequest(req) + if h["X-Org-ID"] != "acme" { + t.Errorf("X-Org-ID = %q, want acme", h["X-Org-ID"]) + } + if h["X-Request-ID"] != "req-1" { + t.Errorf("X-Request-ID = %q, want req-1", h["X-Request-ID"]) + } + if h["org-id"] != "lower" { + t.Errorf("org-id = %q, want lower", h["org-id"]) + } + if h["org_id"] != "snake" { + t.Errorf("org_id = %q, want snake", h["org_id"]) + } +} + +func TestHeadersFromRequest_Phase2Headers(t *testing.T) { + // Phase 2: HeadersFromRequest widened so providers consuming non-Bearer + // formats (aws_sigv4, gcp_iap) can read what they need without breaking + // the Provider.Verify signature. See PHASE2_TEST_STRATEGY.md §4 PR1. + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host, Signature=ab") + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJabc.eyJdef.sig") + req.Header.Set("X-Amz-Date", "20260523T120000Z") + req.Header.Set("X-Amz-Security-Token", "FwoGZX...") + + h := HeadersFromRequest(req) + + if got := h.Get("Authorization"); got == "" || !startsWith(got, "AWS4-HMAC-SHA256 ") { + t.Errorf("Authorization = %q, want Sigv4-shaped string", got) + } + if got := h.Get("X-Goog-Iap-Jwt-Assertion"); got != "eyJabc.eyJdef.sig" { + t.Errorf("X-Goog-Iap-Jwt-Assertion = %q, want eyJabc.eyJdef.sig", got) + } + if got := h.Get("X-Amz-Date"); got != "20260523T120000Z" { + t.Errorf("X-Amz-Date = %q, want 20260523T120000Z", got) + } + if got := h.Get("X-Amz-Security-Token"); got == "" { + t.Error("X-Amz-Security-Token not extracted") + } +} + +func TestHeadersFromRequest_AbsentHeadersAreEmpty(t *testing.T) { + // Providers must not assume any header is present — absence is normal + // and means "this format isn't here, yield to the next provider." + req, _ := http.NewRequest("POST", "/", nil) + h := HeadersFromRequest(req) + for _, k := range []string{ + "Authorization", "X-Goog-Iap-Jwt-Assertion", "X-Amz-Date", "X-Amz-Security-Token", + } { + if got := h.Get(k); got != "" { + t.Errorf("%s should be empty on request with no headers, got %q", k, got) + } + } +} + +func startsWith(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + // --- helpers --- func providerNames(ps []Provider) []string { From 9a1ebae84e52368d88d56bcb69e9685bdd7764b8 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 19:27:04 -0400 Subject: [PATCH 02/22] feat(auth/aws_sigv4): add provider verifying via STS reflection (phase 2 pr 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit aws_sigv4 authenticates AWS-IAM callers by reflecting their Sigv4 signature to STS GetCallerIdentity. No aws-sdk-go-v2 dependency (decision §9.1): the STS RPC is ~150 LOC of hand-rolled HTTP + XML. Forge never holds the caller's secret key — STS validates the signature on Forge's behalf. Key pieces: - sigv4_parser.go: pure string parser, fuzz-tested, never panics - sts_client.go: 200/4xx/5xx classification per review #6 contract - identity_cache.go: hash(AKID|YYYYMMDD)-keyed TTL cache, opportunistic eviction past 10k entries, Put does NOT extend prior expiry - arn_matcher.go: shell-style globs via path.Match (decision §9.3), invalid patterns fail at Factory time - provider.go: scope check (service=sts, region match) before any STS round trip, cache hit avoids RPC, rejection does NOT poison the cache security: - Algorithm: only AWS4-HMAC-SHA256 prefix is claimed - Scope: cross-service replay (s3->sts) and cross-region replay (eu-west-1->us-east-1) rejected at parse-time - Cache: bucketing by YYYYMMDD bounds stolen-key window to a day - Body cap: 64 KiB on STS responses - Logs: STS error bodies summarized at 200 chars, newlines stripped audit: - ErrTokenNotForMe -> not_for_me (no AWS4 prefix) - ErrInvalidToken -> invalid (malformed Sigv4) - ErrTokenRejected -> rejected (scope/allowlist/STS 4xx) - ErrProviderUnavailable -> provider_unavailable (STS 5xx/network) extras: - security.AuthDomains gains sts..amazonaws.com (+ override host when sts_endpoint set for tests) - forge-cli/runtime/runner.go side-effect imports aws_sigv4 --- forge-cli/runtime/runner.go | 9 +- .../auth/providers/aws_sigv4/arn_matcher.go | 46 +++ .../providers/aws_sigv4/arn_matcher_test.go | 57 ++++ .../providers/aws_sigv4/identity_cache.go | 73 +++++ .../aws_sigv4/identity_cache_test.go | 86 ++++++ .../auth/providers/aws_sigv4/provider.go | 206 +++++++++++++ .../auth/providers/aws_sigv4/provider_test.go | 288 ++++++++++++++++++ .../auth/providers/aws_sigv4/sigv4_parser.go | 98 ++++++ .../providers/aws_sigv4/sigv4_parser_test.go | 84 +++++ .../auth/providers/aws_sigv4/sts_client.go | 145 +++++++++ .../providers/aws_sigv4/sts_client_test.go | 202 ++++++++++++ forge-core/security/auth_domains.go | 14 +- forge-core/security/auth_domains_test.go | 61 ++++ 13 files changed, 1365 insertions(+), 4 deletions(-) create mode 100644 forge-core/auth/providers/aws_sigv4/arn_matcher.go create mode 100644 forge-core/auth/providers/aws_sigv4/arn_matcher_test.go create mode 100644 forge-core/auth/providers/aws_sigv4/identity_cache.go create mode 100644 forge-core/auth/providers/aws_sigv4/identity_cache_test.go create mode 100644 forge-core/auth/providers/aws_sigv4/provider.go create mode 100644 forge-core/auth/providers/aws_sigv4/provider_test.go create mode 100644 forge-core/auth/providers/aws_sigv4/sigv4_parser.go create mode 100644 forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go create mode 100644 forge-core/auth/providers/aws_sigv4/sts_client.go create mode 100644 forge-core/auth/providers/aws_sigv4/sts_client_test.go diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 551bd48..fd33826 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -17,10 +17,13 @@ import ( "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/agentspec" "github.com/initializ/forge/forge-core/auth" + // Side-effect imports: each provider sub-package registers its factory + // with the auth registry via init() so forge.yaml `auth.providers[]` + // blocks construct successfully via auth.Build("", settings). + // Listed here even when the package is also referenced directly + // (httpverifier, statictoken) for grep-ability. + _ "github.com/initializ/forge/forge-core/auth/providers/aws_sigv4" "github.com/initializ/forge/forge-core/auth/providers/httpverifier" - // Side-effect import: registers the "oidc" provider with the auth registry - // so forge.yaml `auth: { type: oidc }` blocks construct successfully via - // auth.Build("oidc", settings). The package is not referenced directly. _ "github.com/initializ/forge/forge-core/auth/providers/oidc" "github.com/initializ/forge/forge-core/auth/providers/statictoken" "github.com/initializ/forge/forge-core/llm" diff --git a/forge-core/auth/providers/aws_sigv4/arn_matcher.go b/forge-core/auth/providers/aws_sigv4/arn_matcher.go new file mode 100644 index 0000000..8e3a6a6 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/arn_matcher.go @@ -0,0 +1,46 @@ +package aws_sigv4 + +import ( + "fmt" + "path" +) + +// ArnMatcher checks a caller's ARN against a list of shell-style globs. +// Empty list = allow any IAM principal. Invalid patterns fail at startup, +// never at request time. +// +// Decision §9.3: shell-style globs via path.Match. Supports * and ? but +// NOT regex syntax — simpler mental model, less footgun. +type ArnMatcher struct { + patterns []string +} + +// NewArnMatcher validates each pattern via path.Match("", pattern). A +// malformed pattern is a config bug; we want it surfaced at Factory time. +func NewArnMatcher(patterns []string) (*ArnMatcher, error) { + for _, p := range patterns { + if _, err := path.Match(p, ""); err != nil { + return nil, fmt.Errorf("invalid glob %q: %w", p, err) + } + } + return &ArnMatcher{patterns: patterns}, nil +} + +// Match returns true if the ARN matches any pattern, or if the matcher +// has no patterns (which means "allow any principal"). +// +// NOTE: STS GetCallerIdentity returns the ASSUMED-ROLE ARN form +// ("arn:aws:sts::ACCOUNT:assumed-role/RoleName/SessionName"), not the +// role's own ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). Operators must +// write patterns against the assumed-role form — PR6 docs spell this out. +func (m *ArnMatcher) Match(arn string) bool { + if len(m.patterns) == 0 { + return true + } + for _, p := range m.patterns { + if ok, _ := path.Match(p, arn); ok { + return true + } + } + return false +} diff --git a/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go b/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go new file mode 100644 index 0000000..24ef554 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go @@ -0,0 +1,57 @@ +package aws_sigv4 + +import ( + "testing" +) + +func TestArnMatcher_EmptyAllowsAll(t *testing.T) { + m, err := NewArnMatcher(nil) + if err != nil { + t.Fatalf("NewArnMatcher: %v", err) + } + if !m.Match("arn:aws:iam::123:role/anything") { + t.Error("empty matcher should allow any ARN") + } +} + +func TestArnMatcher_Glob(t *testing.T) { + m, err := NewArnMatcher([]string{ + "arn:aws:iam::123:role/forge-*", + "arn:aws:sts::123:assumed-role/ci-deploy/*", + }) + if err != nil { + t.Fatalf("NewArnMatcher: %v", err) + } + cases := map[string]bool{ + "arn:aws:iam::123:role/forge-deploy": true, + "arn:aws:iam::123:role/forge-runner": true, + "arn:aws:iam::123:role/forge": false, + "arn:aws:iam::123:role/other-deploy": false, + "arn:aws:iam::456:role/forge-deploy": false, + "arn:aws:sts::123:assumed-role/ci-deploy/session-1": true, + "arn:aws:sts::123:assumed-role/ci-deploy/another-sess": true, + "arn:aws:sts::123:assumed-role/wrong-role/session-1": false, + } + for in, want := range cases { + if got := m.Match(in); got != want { + t.Errorf("Match(%q) = %v, want %v", in, got, want) + } + } +} + +func TestArnMatcher_InvalidPatternFailsAtStartup(t *testing.T) { + if _, err := NewArnMatcher([]string{"["}); err == nil { + t.Fatal("expected error for malformed glob") + } +} + +func TestArnMatcher_AssumedRoleNeedsTwoStars(t *testing.T) { + // Reminder for operators: STS returns the assumed-role ARN, not the + // IAM role ARN. A pattern that matches "arn:aws:iam::ACCT:role/X-*" + // will NOT match "arn:aws:sts::ACCT:assumed-role/X-foo/session". + // PR6 docs spell this out. + m, _ := NewArnMatcher([]string{"arn:aws:iam::123:role/forge-*"}) + if m.Match("arn:aws:sts::123:assumed-role/forge-deploy/session") { + t.Error("IAM role pattern should NOT match STS assumed-role ARN — docs invariant") + } +} diff --git a/forge-core/auth/providers/aws_sigv4/identity_cache.go b/forge-core/auth/providers/aws_sigv4/identity_cache.go new file mode 100644 index 0000000..72ff56b --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/identity_cache.go @@ -0,0 +1,73 @@ +package aws_sigv4 + +import ( + "sync" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// IdentityCache holds verified Identities keyed by hash(AKID|YYYYMMDD). +// Bounds the stolen-key window: a leaked AKID is honored until either +// IdentityCacheTTL elapses OR midnight UTC (date bucket rolls over), +// whichever is sooner. +// +// Never caches rejections or errors — Phase 1 invariant: errors are +// not entities. +type IdentityCache struct { + ttl time.Duration + mu sync.RWMutex + data map[string]cacheEntry + now func() time.Time +} + +type cacheEntry struct { + id *auth.Identity + expireAt time.Time +} + +// NewIdentityCache returns an empty cache with the given TTL. time.Now is +// the default clock; tests can swap it via the exported NowFunc field. +func NewIdentityCache(ttl time.Duration) *IdentityCache { + return &IdentityCache{ + ttl: ttl, + data: make(map[string]cacheEntry), + now: time.Now, + } +} + +// Get returns the cached identity if present and not expired. +func (c *IdentityCache) Get(key string) (*auth.Identity, bool) { + c.mu.RLock() + e, ok := c.data[key] + c.mu.RUnlock() + if !ok || c.now().After(e.expireAt) { + return nil, false + } + return e.id, true +} + +// Put stores the identity under key with a fresh TTL. Overwriting an +// existing entry does NOT extend the previous expiry — it replaces it. +// (Refusing to extend prevents a "refresh-just-before-expiry" attack +// from holding a stolen credential alive indefinitely.) +// +// Opportunistic eviction sweeps expired entries when the map grows past +// 10k. Bounds memory under sustained miss without needing a background +// goroutine. +func (c *IdentityCache) Put(key string, id *auth.Identity) { + c.mu.Lock() + c.data[key] = cacheEntry{id: id, expireAt: c.now().Add(c.ttl)} + if len(c.data) > 10_000 { + now := c.now() + for k, e := range c.data { + if now.After(e.expireAt) { + delete(c.data, k) + } + } + } + c.mu.Unlock() +} + +// setNow is a test-only hook for swapping the clock. +func (c *IdentityCache) setNow(fn func() time.Time) { c.now = fn } diff --git a/forge-core/auth/providers/aws_sigv4/identity_cache_test.go b/forge-core/auth/providers/aws_sigv4/identity_cache_test.go new file mode 100644 index 0000000..f3cedcf --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/identity_cache_test.go @@ -0,0 +1,86 @@ +package aws_sigv4 + +import ( + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +func TestIdentityCache_HitMiss(t *testing.T) { + c := NewIdentityCache(time.Minute) + if _, ok := c.Get("k1"); ok { + t.Error("empty cache returned hit") + } + id := &auth.Identity{UserID: "arn1"} + c.Put("k1", id) + got, ok := c.Get("k1") + if !ok || got != id { + t.Errorf("Get hit = (%v, %v), want id, true", got, ok) + } +} + +func TestIdentityCache_Expiry(t *testing.T) { + now := time.Unix(1_700_000_000, 0) + c := NewIdentityCache(60 * time.Second) + c.setNow(func() time.Time { return now }) + + c.Put("k", &auth.Identity{UserID: "arn"}) + + if _, ok := c.Get("k"); !ok { + t.Fatal("expected hit before expiry") + } + + now = now.Add(61 * time.Second) + if _, ok := c.Get("k"); ok { + t.Error("expected miss after TTL expiry") + } +} + +func TestIdentityCache_PutDoesNotExtendExpiry(t *testing.T) { + // Defense against refresh-just-before-expiry holding a stolen + // credential alive indefinitely. Put MUST replace the entry's + // expireAt, not extend it. + now := time.Unix(1_700_000_000, 0) + c := NewIdentityCache(60 * time.Second) + c.setNow(func() time.Time { return now }) + + c.Put("k", &auth.Identity{UserID: "arn"}) + // Refresh at t+50s — expireAt becomes t+110s. + now = now.Add(50 * time.Second) + c.Put("k", &auth.Identity{UserID: "arn"}) + + // At t+120s, original-expiry+TTL would still be valid; replacement + // makes the cap t+110s. Verify it's expired. + now = now.Add(70 * time.Second) // t+120s + if _, ok := c.Get("k"); ok { + t.Error("Put extended TTL beyond the bound — refresh-extends-stolen-key bug") + } +} + +func TestIdentityCache_OpportunisticEviction(t *testing.T) { + // Force the map past the 10_000 threshold with expired entries; one + // more Put should trigger a sweep that drops them. + now := time.Unix(1_700_000_000, 0) + c := NewIdentityCache(time.Second) + c.setNow(func() time.Time { return now }) + + id := &auth.Identity{UserID: "arn"} + for i := range 10_001 { + c.Put(string(rune(i)), id) + } + now = now.Add(2 * time.Second) // expire everything + + c.Put("fresh", id) // triggers sweep + + if got, ok := c.Get("fresh"); !ok || got != id { + t.Fatal("fresh entry lost during sweep") + } + // Map should have been pruned — at most a few entries remain + c.mu.RLock() + size := len(c.data) + c.mu.RUnlock() + if size > 10 { + t.Errorf("eviction did not run; cache size = %d", size) + } +} diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go new file mode 100644 index 0000000..bf00d52 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -0,0 +1,206 @@ +// Package aws_sigv4 authenticates requests signed with AWS Sigv4 by +// reflecting the signature to AWS STS GetCallerIdentity. Same pattern as +// aws-iam-authenticator (the EKS auth bridge). +// +// Forge never has the caller's secret key — STS validates the signature +// on Forge's behalf and returns the canonical ARN/Account/UserID. Verified +// identities are cached by hash(AKID|YYYYMMDD) for IdentityCacheTTL so the +// hot path doesn't bounce through STS on every call. +// +// Decision §9.1: no aws-sdk-go-v2 dependency — the STS RPC is ~150 LOC of +// hand-rolled HTTP + XML. Trade-off: small attack surface, predictable +// behavior, no transitive deps; cost: we maintain the RPC ourselves. +// +// Decision §9.3: allowed_principals are shell-style globs (path.Match). +// +// Audit reason codes (Phase 1 contract): +// +// rejected — scope check, ARN allowlist miss, STS 4xx +// provider_unavailable — STS 5xx, network failure, parse failure +// invalid — Sigv4 header malformed +// not_for_me — Authorization header didn't start with AWS4-HMAC-SHA256 +package aws_sigv4 + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this provider. +const ProviderName = "aws_sigv4" + +// Defaults. +const ( + defaultIdentityCacheTTL = 60 * time.Second + defaultHTTPTimeout = 5 * time.Second +) + +// Config controls the aws_sigv4 provider. +type Config struct { + // Region is the AWS region whose STS endpoint validates signatures. + // REQUIRED. Defense in depth: callers must sign for this exact region + // (Sigv4 cross-region replay rejected at parse-time scope check). + Region string `yaml:"region"` + + // Audience is informational only — emitted in audit logs for context. + // STS itself doesn't enforce this. + Audience string `yaml:"audience,omitempty"` + + // AllowedPrincipals is an optional list of shell-style globs (decision + // §9.3) matched against the STS-returned ARN. Empty list means "allow + // any IAM principal in the account that signed the request" — fine + // for single-tenant dev setups, never appropriate for prod. + // + // Patterns match the assumed-role ARN form + // ("arn:aws:sts::ACCOUNT:assumed-role/RoleName/SessionName"), NOT the + // role's own ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). PR6 docs + // call this out explicitly with a worked example. + AllowedPrincipals []string `yaml:"allowed_principals,omitempty"` + + // IdentityCacheTTL bounds how long a verified Identity is reused + // without re-checking with STS. Defaults to 60s. Sets the upper + // bound on the window in which a revoked AKID remains accepted. + IdentityCacheTTL time.Duration `yaml:"identity_cache_ttl,omitempty"` + + // STSEndpoint is a TEST-ONLY override that points STS calls at an + // alternate URL. Production should leave this empty so the region + // derives the canonical sts..amazonaws.com host. + STSEndpoint string `yaml:"sts_endpoint,omitempty"` + + // HTTPTimeout caps each STS call. Defaults to 5s. + HTTPTimeout time.Duration `yaml:"http_timeout,omitempty"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.Region == "" { + return fmt.Errorf("%w: region required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider for AWS Sigv4 callers. +type Provider struct { + cfg Config + parser Parser + cache *IdentityCache + sts *STSClient + matcher *ArnMatcher +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.IdentityCacheTTL == 0 { + cfg.IdentityCacheTTL = defaultIdentityCacheTTL + } + if cfg.HTTPTimeout == 0 { + cfg.HTTPTimeout = defaultHTTPTimeout + } + matcher, err := NewArnMatcher(cfg.AllowedPrincipals) + if err != nil { + return nil, fmt.Errorf("aws_sigv4: allowed_principals: %w", err) + } + return &Provider{ + cfg: cfg, + parser: Parser{}, + cache: NewIdentityCache(cfg.IdentityCacheTTL), + sts: NewSTSClient(cfg.Region, cfg.STSEndpoint, cfg.HTTPTimeout), + matcher: matcher, + }, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. +// +// Token is unused — Sigv4 carries everything it needs in the Authorization +// header (and X-Amz-Date, X-Amz-Security-Token), which the middleware +// hands us via the headers map. +func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) (*auth.Identity, error) { + authHeader := headers.Get("Authorization") + if !strings.HasPrefix(authHeader, sigv4Algorithm+" ") { + return nil, auth.ErrTokenNotForMe + } + + sig, err := p.parser.Parse(authHeader) + if err != nil { + return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + + // Scope check (defense in depth §3.1): reject signatures meant for a + // different AWS service or region. Prevents cross-service Sigv4 + // replay (e.g. a request signed for S3 reaching us configured for STS). + if sig.Service != "sts" { + return nil, fmt.Errorf("%w: signature scoped to service %q, expected sts", auth.ErrTokenRejected, sig.Service) + } + if sig.Region != p.cfg.Region { + return nil, fmt.Errorf("%w: signature scoped to region %q, expected %s", auth.ErrTokenRejected, sig.Region, p.cfg.Region) + } + + cacheKey := dateBucketKey(sig.AKID, sig.Date) + if id, ok := p.cache.Get(cacheKey); ok { + return id, nil + } + + caller, err := p.sts.GetCallerIdentity(ctx, STSReflectArgs{ + AuthHeader: authHeader, + AmzDate: headers.Get("X-Amz-Date"), + SecurityToken: headers.Get("X-Amz-Security-Token"), + }) + if err != nil { + // STSClient already wraps with the right sentinel + // (ErrTokenRejected for 4xx, ErrProviderUnavailable for 5xx/network). + return nil, err + } + + if !p.matcher.Match(caller.Arn) { + return nil, fmt.Errorf("%w: ARN %q not in allowed_principals", auth.ErrTokenRejected, caller.Arn) + } + + id := &auth.Identity{ + UserID: caller.Arn, + OrgID: caller.Account, + Source: ProviderName, + Claims: map[string]any{ + "user_id": caller.UserID, + "arn": caller.Arn, + "account": caller.Account, + "audience": p.cfg.Audience, + }, + } + p.cache.Put(cacheKey, id) + return id, nil +} + +// dateBucketKey hashes (AKID, YYYYMMDD) so two requests from the same AKID +// on the same day collapse to one STS call per IdentityCacheTTL window. +// Hashing protects the cache against length-leak / log-scan attacks on the +// raw key — same reasoning as static_token (review #11). +func dateBucketKey(akid, sigv4Date string) string { + bucket := sigv4Date + if len(bucket) > 8 { + bucket = bucket[:8] // YYYYMMDD + } + sum := sha256.Sum256([]byte(akid + "|" + bucket)) + return hex.EncodeToString(sum[:]) +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/aws_sigv4/provider_test.go b/forge-core/auth/providers/aws_sigv4/provider_test.go new file mode 100644 index 0000000..5d524b9 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/provider_test.go @@ -0,0 +1,288 @@ +package aws_sigv4 + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +const validSigv4 = "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ab12cd34" + +func validHeaders() auth.Headers { + return auth.Headers{ + "Authorization": validSigv4, + "X-Amz-Date": "20260523T120000Z", + } +} + +func newTestProvider(t *testing.T, sts http.Handler, opts ...func(*Config)) *Provider { + t.Helper() + srv := httptest.NewServer(sts) + t.Cleanup(srv.Close) + + cfg := Config{ + Region: "us-east-1", + Audience: "api://forge", + STSEndpoint: srv.URL, + IdentityCacheTTL: 60 * time.Second, + HTTPTimeout: 5 * time.Second, + } + for _, fn := range opts { + fn(&cfg) + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + return p +} + +func happySTS() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, happySTSXML) + }) +} + +func TestProvider_Name(t *testing.T) { + p := newTestProvider(t, happySTS()) + if p.Name() != "aws_sigv4" { + t.Errorf("Name = %q, want aws_sigv4", p.Name()) + } +} + +func TestProvider_New_RequiresRegion(t *testing.T) { + _, err := New(Config{}) + if err == nil || !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Errorf("err = %v, want wrapped ErrProviderNotConfigured", err) + } +} + +func TestProvider_New_RejectsInvalidGlob(t *testing.T) { + _, err := New(Config{Region: "us-east-1", AllowedPrincipals: []string{"["}}) + if err == nil { + t.Fatal("expected error for malformed glob") + } +} + +func TestProvider_NoSigv4Header_YieldsToChain(t *testing.T) { + p := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "", auth.Headers{"Authorization": "Bearer xyz"}) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestProvider_NoAuthHeaderAtAll_YieldsToChain(t *testing.T) { + p := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "", auth.Headers{}) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestProvider_MalformedSigv4_Invalid(t *testing.T) { + p := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "", auth.Headers{ + "Authorization": "AWS4-HMAC-SHA256 malformed", + }) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_HappyPath_ReturnsIdentity(t *testing.T) { + p := newTestProvider(t, happySTS()) + id, err := p.Verify(context.Background(), "", validHeaders()) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "aws_sigv4" { + t.Errorf("Source = %q, want aws_sigv4", id.Source) + } + if id.UserID != "arn:aws:sts::123456789012:assumed-role/ci-deploy/session" { + t.Errorf("UserID = %q", id.UserID) + } + if id.OrgID != "123456789012" { + t.Errorf("OrgID = %q", id.OrgID) + } + if id.Claims["audience"] != "api://forge" { + t.Errorf("Claims[audience] = %v", id.Claims["audience"]) + } +} + +func TestProvider_ScopeService_NotSTS_Rejected(t *testing.T) { + p := newTestProvider(t, happySTS()) + h := validHeaders() + h["Authorization"] = strings.Replace(h["Authorization"], "/sts/", "/s3/", 1) + _, err := p.Verify(context.Background(), "", h) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (cross-service replay)", err) + } +} + +func TestProvider_ScopeRegion_Mismatch_Rejected(t *testing.T) { + p := newTestProvider(t, happySTS()) + h := validHeaders() + h["Authorization"] = strings.Replace(h["Authorization"], "/us-east-1/", "/eu-west-1/", 1) + _, err := p.Verify(context.Background(), "", h) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (cross-region replay)", err) + } +} + +func TestProvider_AllowlistMiss_Rejected(t *testing.T) { + p := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedPrincipals = []string{"arn:aws:iam::999:role/*"} + }) + _, err := p.Verify(context.Background(), "", validHeaders()) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (allowlist miss)", err) + } +} + +func TestProvider_AllowlistHit_Succeeds(t *testing.T) { + p := newTestProvider(t, happySTS(), func(c *Config) { + // Matches the assumed-role ARN returned by happySTSXML. + c.AllowedPrincipals = []string{"arn:aws:sts::123456789012:assumed-role/ci-deploy/*"} + }) + if _, err := p.Verify(context.Background(), "", validHeaders()); err != nil { + t.Errorf("Verify: %v", err) + } +} + +func TestProvider_STSDown_Unavailable(t *testing.T) { + p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + _, err := p.Verify(context.Background(), "", validHeaders()) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestProvider_STSRejects_Rejected(t *testing.T) { + p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, "SignatureDoesNotMatch") + })) + _, err := p.Verify(context.Background(), "", validHeaders()) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_CacheHit_AvoidsSTSCall(t *testing.T) { + var calls atomic.Int32 + p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + _, _ = io.WriteString(w, happySTSXML) + })) + + if _, err := p.Verify(context.Background(), "", validHeaders()); err != nil { + t.Fatalf("first Verify: %v", err) + } + if _, err := p.Verify(context.Background(), "", validHeaders()); err != nil { + t.Fatalf("second Verify: %v", err) + } + if got := calls.Load(); got != 1 { + t.Errorf("STS calls = %d, want 1 (cache must hit on 2nd request)", got) + } +} + +func TestProvider_RejectedRequest_DoesNotPoisonCache(t *testing.T) { + // First STS call returns 403; second valid request must hit STS again + // (the rejection MUST NOT have populated the cache). + var calls atomic.Int32 + p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + if n == 1 { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, "") + return + } + _, _ = io.WriteString(w, happySTSXML) + })) + + _, err1 := p.Verify(context.Background(), "", validHeaders()) + if !errors.Is(err1, auth.ErrTokenRejected) { + t.Fatalf("first Verify err = %v, want ErrTokenRejected", err1) + } + _, err2 := p.Verify(context.Background(), "", validHeaders()) + if err2 != nil { + t.Fatalf("second Verify err = %v, want nil", err2) + } + if got := calls.Load(); got != 2 { + t.Errorf("STS calls = %d, want 2 (rejection must not poison cache)", got) + } +} + +func TestProvider_DateBucketRollover_TriggersFreshSTSCall(t *testing.T) { + var calls atomic.Int32 + p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + _, _ = io.WriteString(w, happySTSXML) + })) + + // Day 1 + h1 := validHeaders() + if _, err := p.Verify(context.Background(), "", h1); err != nil { + t.Fatal(err) + } + + // Day 2 — same AKID, different date in scope → different cache key. + h2 := auth.Headers{ + "Authorization": strings.Replace(validSigv4, "/20260523/", "/20260524/", 1), + "X-Amz-Date": "20260524T120000Z", + } + if _, err := p.Verify(context.Background(), "", h2); err != nil { + t.Fatal(err) + } + if got := calls.Load(); got != 2 { + t.Errorf("STS calls = %d, want 2 (date bucket rolled, cache miss expected)", got) + } +} + +func TestProvider_AmzSecurityTokenForwarded(t *testing.T) { + var capturedSec string + p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedSec = r.Header.Get("X-Amz-Security-Token") + _, _ = io.WriteString(w, happySTSXML) + })) + + h := validHeaders() + h["X-Amz-Security-Token"] = "FwoGZX-test-temp-session" + if _, err := p.Verify(context.Background(), "", h); err != nil { + t.Fatal(err) + } + if capturedSec != "FwoGZX-test-temp-session" { + t.Errorf("STS got X-Amz-Security-Token = %q", capturedSec) + } +} + +func TestProvider_RegisteredInRegistry(t *testing.T) { + // init() ran on package import — verify the factory is wired. + p, err := auth.Build("aws_sigv4", map[string]any{ + "region": "us-east-1", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "aws_sigv4" { + t.Errorf("Name = %q", p.Name()) + } +} + +func TestProvider_FactoryRejectsMissingRegion(t *testing.T) { + _, err := auth.Build("aws_sigv4", map[string]any{}) + if err == nil { + t.Fatal("expected error from factory when region is missing") + } +} diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go new file mode 100644 index 0000000..854053c --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go @@ -0,0 +1,98 @@ +package aws_sigv4 + +import ( + "errors" + "strings" +) + +// Sigv4Header is the parsed view of an AWS Sigv4 Authorization header. +type Sigv4Header struct { + AKID string // e.g. "AKIAIOSFODNN7EXAMPLE" + Date string // YYYYMMDD scope date + Region string // e.g. "us-east-1" + Service string // e.g. "sts" + SignedHeaders string // semicolon-separated list — must include "host" + Signature string // hex-encoded HMAC +} + +// Parser parses Sigv4-shaped Authorization headers. Pure string work — +// no HTTP, no AWS SDK. Never panics on input. +type Parser struct{} + +const sigv4Algorithm = "AWS4-HMAC-SHA256" + +// Parse extracts the fields from a Sigv4 Authorization header. Real-world +// signers vary in whitespace around the comma separators; we trim tolerantly +// rather than insist on exact bytes. +// +// Returns a clear error on malformed input — callers should map parse +// errors to auth.ErrInvalidToken, not auth.ErrTokenNotForMe (the latter +// is reserved for "no Sigv4 prefix at all"). +func (Parser) Parse(authHeader string) (*Sigv4Header, error) { + if !strings.HasPrefix(authHeader, sigv4Algorithm+" ") { + return nil, errors.New("missing AWS4-HMAC-SHA256 prefix") + } + rest := strings.TrimPrefix(authHeader, sigv4Algorithm+" ") + + parts := splitCommaTrim(rest) + if len(parts) != 3 { + return nil, errors.New("expected 3 comma-separated kv pairs") + } + + out := &Sigv4Header{} + for _, p := range parts { + k, v, ok := strings.Cut(p, "=") + if !ok { + return nil, errors.New("malformed kv pair") + } + switch strings.TrimSpace(k) { + case "Credential": + if err := parseCredentialScope(v, out); err != nil { + return nil, err + } + case "SignedHeaders": + out.SignedHeaders = strings.TrimSpace(v) + case "Signature": + out.Signature = strings.TrimSpace(v) + } + } + + if out.AKID == "" || out.Date == "" || out.Region == "" || out.Service == "" || + out.SignedHeaders == "" || out.Signature == "" { + return nil, errors.New("missing one or more required Sigv4 fields") + } + if !signedHeadersContainsHost(out.SignedHeaders) { + return nil, errors.New("SignedHeaders must include host") + } + return out, nil +} + +func parseCredentialScope(v string, out *Sigv4Header) error { + // Credential = AKID/YYYYMMDD/region/service/aws4_request + segs := strings.Split(strings.TrimSpace(v), "/") + if len(segs) != 5 || segs[4] != "aws4_request" { + return errors.New("malformed Credential scope") + } + out.AKID, out.Date, out.Region, out.Service = segs[0], segs[1], segs[2], segs[3] + return nil +} + +// signedHeadersContainsHost validates that the SignedHeaders list includes +// "host" as a discrete entry. Sub-string match would accept "ghosting" — we +// split on ";" so each entry stands on its own. +func signedHeadersContainsHost(s string) bool { + for h := range strings.SplitSeq(s, ";") { + if strings.EqualFold(strings.TrimSpace(h), "host") { + return true + } + } + return false +} + +func splitCommaTrim(s string) []string { + parts := strings.Split(s, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts +} diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go new file mode 100644 index 0000000..4f33b9a --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go @@ -0,0 +1,84 @@ +package aws_sigv4 + +import ( + "strings" + "testing" +) + +func TestParser_HappyPath(t *testing.T) { + h := "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ab12cd34" + sig, err := (Parser{}).Parse(h) + if err != nil { + t.Fatalf("Parse: %v", err) + } + if sig.AKID != "AKIAIOSFODNN7EXAMPLE" { + t.Errorf("AKID = %q", sig.AKID) + } + if sig.Date != "20260523" { + t.Errorf("Date = %q", sig.Date) + } + if sig.Region != "us-east-1" { + t.Errorf("Region = %q", sig.Region) + } + if sig.Service != "sts" { + t.Errorf("Service = %q", sig.Service) + } + if sig.SignedHeaders != "host;x-amz-date" { + t.Errorf("SignedHeaders = %q", sig.SignedHeaders) + } + if sig.Signature != "ab12cd34" { + t.Errorf("Signature = %q", sig.Signature) + } +} + +func TestParser_TolerantWhitespace(t *testing.T) { + // Real-world Sigv4 signers vary in whitespace around the commas; + // extra whitespace must still parse cleanly. + h := "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request,SignedHeaders=host,Signature=abc" + if _, err := (Parser{}).Parse(h); err != nil { + t.Errorf("expected tolerant whitespace parse to succeed, got %v", err) + } +} + +func TestParser_MalformedInputs(t *testing.T) { + cases := map[string]string{ + "empty": "", + "bearer instead": "Bearer foo", + "prefix only": "AWS4-HMAC-SHA256", + "prefix and space": "AWS4-HMAC-SHA256 ", + "too few segments": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts, SignedHeaders=host, Signature=ab", + "too many segments": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request/extra, SignedHeaders=host, Signature=ab", + "missing SignedHeaders": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, Signature=ab", + "signed headers no host": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, SignedHeaders=x-amz-date, Signature=ab", + "scope tail wrong": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/wrong, SignedHeaders=host, Signature=ab", + "two parts instead of three": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, Signature=ab", + } + for name, in := range cases { + t.Run(name, func(t *testing.T) { + if _, err := (Parser{}).Parse(in); err == nil { + t.Errorf("expected parse error for %q", in) + } + }) + } +} + +func TestParser_SignedHeadersHostMatchExact(t *testing.T) { + // "ghosting" must NOT count as containing "host" — we split on ";" + // so each entry stands on its own. + h := "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, SignedHeaders=ghosting, Signature=ab" + if _, err := (Parser{}).Parse(h); err == nil { + t.Error("expected error: 'ghosting' should not satisfy host requirement") + } +} + +// FuzzParser ensures the parser never panics on arbitrary input. Run with +// `go test -fuzz=FuzzParser -fuzztime=30s ./forge-core/auth/providers/aws_sigv4/`. +func FuzzParser(f *testing.F) { + f.Add("AWS4-HMAC-SHA256 Credential=A/B/C/D/aws4_request, SignedHeaders=host, Signature=x") + f.Add("Bearer foo") + f.Add("") + f.Add(strings.Repeat("AWS4-HMAC-SHA256 ", 100)) + f.Fuzz(func(_ *testing.T, in string) { + _, _ = (Parser{}).Parse(in) + }) +} diff --git a/forge-core/auth/providers/aws_sigv4/sts_client.go b/forge-core/auth/providers/aws_sigv4/sts_client.go new file mode 100644 index 0000000..05548ff --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sts_client.go @@ -0,0 +1,145 @@ +package aws_sigv4 + +import ( + "context" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// STSClient calls AWS STS GetCallerIdentity by reflecting the caller's +// Sigv4 signature. ~150 LOC of hand-rolled HTTP + XML — no aws-sdk-go-v2 +// dependency (decision §9.1). +// +// The key security property: Forge NEVER possesses the caller's secret +// key. STS validates the signature on Forge's behalf and returns the +// canonical ARN / UserId / Account on success. +type STSClient struct { + endpoint string + http *http.Client +} + +// NewSTSClient builds a client for sts..amazonaws.com. The override +// argument is for tests only — production should leave it empty. +func NewSTSClient(region, override string, timeout time.Duration) *STSClient { + ep := override + if ep == "" { + ep = fmt.Sprintf("https://sts.%s.amazonaws.com", region) + } + return &STSClient{ + endpoint: ep, + http: &http.Client{Timeout: timeout}, + } +} + +// STSReflectArgs is the set of headers reflected verbatim from the caller's +// request to the STS call. SecurityToken is optional and only present when +// the caller is using temporary credentials (assumed roles, federation, +// IRSA, etc.). +type STSReflectArgs struct { + AuthHeader string + AmzDate string + SecurityToken string +} + +// CallerIdentity is the parsed STS response — the canonical identifiers +// Forge stamps into auth.Identity. +type CallerIdentity struct { + UserID string // e.g. "AROAJ...:session-name" + Arn string // e.g. "arn:aws:sts::123:assumed-role/ci-deploy/session" + Account string // e.g. "123456789012" +} + +// GetCallerIdentity reflects the caller's Sigv4 headers to STS and parses +// the response. +// +// Error classification: +// +// 200 OK → CallerIdentity, nil +// 4xx → auth.ErrTokenRejected (caller's signature didn't validate) +// 5xx / network → auth.ErrProviderUnavailable (review #6 contract) +// parse failure → auth.ErrProviderUnavailable (unexpected response shape) +func (c *STSClient) GetCallerIdentity(ctx context.Context, args STSReflectArgs) (*CallerIdentity, error) { + body := url.Values{ + "Action": {"GetCallerIdentity"}, + "Version": {"2011-06-15"}, + }.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("%w: build STS request: %v", auth.ErrProviderUnavailable, err) + } + req.Header.Set("Authorization", args.AuthHeader) + req.Header.Set("X-Amz-Date", args.AmzDate) + if args.SecurityToken != "" { + req.Header.Set("X-Amz-Security-Token", args.SecurityToken) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: STS RPC: %v", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + + // Real GetCallerIdentity responses are ~1 KiB; cap at 64 KiB to bound + // memory if STS ever returns something pathological. + raw, err := io.ReadAll(io.LimitReader(resp.Body, 64<<10)) + if err != nil { + return nil, fmt.Errorf("%w: read STS body: %v", auth.ErrProviderUnavailable, err) + } + + switch { + case resp.StatusCode == http.StatusOK: + return parseGetCallerIdentityResponse(raw) + case resp.StatusCode >= 400 && resp.StatusCode < 500: + return nil, fmt.Errorf("%w: STS rejected signature: %s", auth.ErrTokenRejected, summarize(raw)) + case resp.StatusCode >= 500: + return nil, fmt.Errorf("%w: STS HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + default: + return nil, fmt.Errorf("%w: STS unexpected status %d", auth.ErrProviderUnavailable, resp.StatusCode) + } +} + +// parseGetCallerIdentityResponse extracts the three canonical fields from +// STS's XML response. Rejects responses that are missing any of them — STS +// always sets all three, so absence implies a malformed reply. +func parseGetCallerIdentityResponse(raw []byte) (*CallerIdentity, error) { + var resp struct { + XMLName xml.Name `xml:"GetCallerIdentityResponse"` + Result struct { + UserID string `xml:"UserId"` + Account string `xml:"Account"` + Arn string `xml:"Arn"` + } `xml:"GetCallerIdentityResult"` + } + if err := xml.Unmarshal(raw, &resp); err != nil { + return nil, fmt.Errorf("%w: parse STS XML: %v", auth.ErrProviderUnavailable, err) + } + if resp.Result.Arn == "" || resp.Result.Account == "" || resp.Result.UserID == "" { + return nil, fmt.Errorf("%w: STS XML missing required fields", auth.ErrProviderUnavailable) + } + return &CallerIdentity{ + UserID: resp.Result.UserID, + Arn: resp.Result.Arn, + Account: resp.Result.Account, + }, nil +} + +// summarize returns a short, single-line, log-safe rendering of an STS +// error body. Caps at 200 chars and strips newlines so the full error +// payload — which can include the caller's Authorization echo on certain +// SDK-shaped errors — never propagates to logs verbatim. +func summarize(raw []byte) string { + s := string(raw) + if len(s) > 200 { + s = s[:200] + "…" + } + return strings.ReplaceAll(s, "\n", " ") +} diff --git a/forge-core/auth/providers/aws_sigv4/sts_client_test.go b/forge-core/auth/providers/aws_sigv4/sts_client_test.go new file mode 100644 index 0000000..c71d120 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sts_client_test.go @@ -0,0 +1,202 @@ +package aws_sigv4 + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +const happySTSXML = ` + + AROAJ123:session + 123456789012 + arn:aws:sts::123456789012:assumed-role/ci-deploy/session + + req-id +` + +func TestSTSClient_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/xml") + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + id, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{ + AuthHeader: "AWS4-HMAC-SHA256 Credential=AKIA", + AmzDate: "20260523T120000Z", + }) + if err != nil { + t.Fatalf("GetCallerIdentity: %v", err) + } + if id.Arn != "arn:aws:sts::123456789012:assumed-role/ci-deploy/session" { + t.Errorf("Arn = %q", id.Arn) + } + if id.Account != "123456789012" { + t.Errorf("Account = %q", id.Account) + } + if id.UserID != "AROAJ123:session" { + t.Errorf("UserID = %q", id.UserID) + } +} + +func TestSTSClient_403_Rejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, "SignatureDoesNotMatch") + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestSTSClient_500_Unavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "InternalFailure") + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_NetworkError_Unavailable(t *testing.T) { + // Point at a closed listener so Do() returns a transport error. + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + url := srv.URL + srv.Close() + + c := NewSTSClient("us-east-1", url, 1*time.Second) + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_AuthHeaderReflectedVerbatim(t *testing.T) { + var captured string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "text/xml") + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + + wantAuth := "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=abc123" + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{ + AuthHeader: wantAuth, + AmzDate: "20260523T120000Z", + SecurityToken: "FwoGZX-test-session", + }) + if err != nil { + t.Fatalf("GetCallerIdentity: %v", err) + } + if captured != wantAuth { + t.Errorf("STS did not receive caller's Authorization verbatim:\n got: %q\n want: %q", captured, wantAuth) + } +} + +func TestSTSClient_SecurityTokenReflected(t *testing.T) { + var capturedTok string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTok = r.Header.Get("X-Amz-Security-Token") + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + _, _ = c.GetCallerIdentity(context.Background(), STSReflectArgs{ + AuthHeader: "AWS4", + AmzDate: "20260523T120000Z", + SecurityToken: "FwoGZX-token", + }) + if capturedTok != "FwoGZX-token" { + t.Errorf("X-Amz-Security-Token = %q, want FwoGZX-token", capturedTok) + } +} + +func TestSTSClient_BodyCap(t *testing.T) { + // Serve 128 KiB; ensure the client doesn't OOM and the response + // either parses cleanly (truncated at exactly a well-formed prefix + // — unlikely) OR returns ErrProviderUnavailable from parse failure. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "") + _, _ = io.WriteString(w, strings.Repeat("A", 128<<10)) + _, _ = io.WriteString(w, "") + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + if err == nil { + t.Fatal("expected error on oversized STS body") + } + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_MissingFieldsRejected(t *testing.T) { + // Response with empty Arn — must reject as unavailable (malformed, + // not "rejected" — STS would never legitimately return blank fields). + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, ` + x123 + `) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_RegionEndpointFormat(t *testing.T) { + // Sanity: NewSTSClient without an override builds the right URL. + c := NewSTSClient("eu-west-1", "", time.Second) + if c.endpoint != "https://sts.eu-west-1.amazonaws.com" { + t.Errorf("endpoint = %q", c.endpoint) + } +} + +func TestSTSClient_RequestCount(t *testing.T) { + // Pin the basic happy-path call counting for use by the provider-level + // cache-hit test. + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) + for i := range 3 { + _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + if err != nil { + t.Fatalf("call %d: %v", i, err) + } + } + if calls.Load() != 3 { + t.Errorf("STS calls = %d, want 3", calls.Load()) + } +} diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go index b96b1e0..b1245c5 100644 --- a/forge-core/security/auth_domains.go +++ b/forge-core/security/auth_domains.go @@ -55,8 +55,20 @@ func authProviderURLs(p types.AuthProvider) []string { return []string{ settingString(p.Settings, "url"), } + case "aws_sigv4": + // STS at sts..amazonaws.com is the only outbound. + // The test-only sts_endpoint override is honored so dev/test + // runs against a local fake aren't blocked by egress. + region := settingString(p.Settings, "region") + out := []string{} + if region != "" { + out = append(out, "https://sts."+region+".amazonaws.com") + } + if override := settingString(p.Settings, "sts_endpoint"); override != "" { + out = append(out, override) + } + return out // static_token has no outbound; not listed - // okta (Phase 3) will be added here with issuer + api domain default: return nil } diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go index c7a1616..58d7fe6 100644 --- a/forge-core/security/auth_domains_test.go +++ b/forge-core/security/auth_domains_test.go @@ -124,6 +124,67 @@ func TestAuthDomains_PortStripped(t *testing.T) { } } +func TestAuthDomains_AWSSigv4(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{"region": "us-east-1"}}, + }, + }) + want := []string{"sts.us-east-1.amazonaws.com"} + if len(got) != 1 || got[0] != want[0] { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_AWSSigv4_DifferentRegion(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{"region": "eu-west-2"}}, + }, + }) + if len(got) != 1 || got[0] != "sts.eu-west-2.amazonaws.com" { + t.Errorf("AuthDomains = %v, want [sts.eu-west-2.amazonaws.com]", got) + } +} + +func TestAuthDomains_AWSSigv4_TestEndpointOverride(t *testing.T) { + // The sts_endpoint override (test-only escape hatch) must surface in + // the egress allowlist too — otherwise local integration tests are + // blocked by the egress enforcer. + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{ + "region": "us-east-1", + "sts_endpoint": "http://127.0.0.1:8080", + }}, + }, + }) + have := map[string]bool{} + for _, d := range got { + have[d] = true + } + if !have["sts.us-east-1.amazonaws.com"] { + t.Errorf("AuthDomains missing real STS host: %v", got) + } + if !have["127.0.0.1"] { + t.Errorf("AuthDomains missing test override host: %v", got) + } +} + +func TestAuthDomains_AWSSigv4_MissingRegionReturnsEmpty(t *testing.T) { + // Defensive: even though Factory rejects missing region at startup, + // AuthDomains should not panic or emit a malformed host if it's ever + // called with an incomplete config. + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{}}, + }, + }) + if got != nil { + t.Errorf("AuthDomains with missing region = %v, want nil", got) + } +} + func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { got := security.AuthDomains(types.AuthConfig{ Providers: []types.AuthProvider{ From 5b710717275bd739fcde4d11ab85f9c79c913803 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 19:37:06 -0400 Subject: [PATCH 03/22] feat(auth/gcp_iap): add Identity-Aware Proxy provider (phase 2 pr 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gcp_iap consumes the X-Goog-Iap-Jwt-Assertion header that GCP's Identity-Aware Proxy forwards on every authenticated request when Forge sits behind a GCP HTTPS load balancer with IAP enabled. Decision §9.4: IAP issuer (https://cloud.google.com/iap) and JWKS URL (https://www.gstatic.com/iap/verify/public_key-jwk) are hardcoded. They're the only stable contract GCP exposes; an override knob would be a footgun. key pieces: - iap_jwks.go: ES256-only JWKS cache, TTL refresh + backoff + stale-grace (mirrors Phase 1 OIDC review #1 pattern) - provider.go: header-presence check, claims projection, iss/aud gates, sub/email required-claims check - parseECJWKSet drops non-EC / non-P-256 / non-ES256-labeled keys during parse — defense in depth against compromised JWKS - alg whitelist rejects RS256 BEFORE key lookup (algorithm- confusion defense) - aud as string OR array both parse (JWT spec allows either) - audit reasons follow Phase 1 contract: rejected — iss/aud mismatch, expired, bad signature invalid — alg != ES256, missing sub/email, bad kid provider_unavailable — JWKS fetch failed AND no prior key cached not_for_me — header absent extras: - security.AuthDomains returns www.gstatic.com when gcp_iap is configured - forge-cli/runtime/runner.go side-effect imports gcp_iap --- forge-cli/runtime/runner.go | 1 + forge-core/auth/providers/gcp_iap/iap_jwks.go | 306 +++++++++++++ forge-core/auth/providers/gcp_iap/provider.go | 151 +++++++ .../auth/providers/gcp_iap/provider_test.go | 421 ++++++++++++++++++ forge-core/security/auth_domains.go | 3 + forge-core/security/auth_domains_test.go | 28 ++ 6 files changed, 910 insertions(+) create mode 100644 forge-core/auth/providers/gcp_iap/iap_jwks.go create mode 100644 forge-core/auth/providers/gcp_iap/provider.go create mode 100644 forge-core/auth/providers/gcp_iap/provider_test.go diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index fd33826..0ccbab0 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -23,6 +23,7 @@ import ( // Listed here even when the package is also referenced directly // (httpverifier, statictoken) for grep-ability. _ "github.com/initializ/forge/forge-core/auth/providers/aws_sigv4" + _ "github.com/initializ/forge/forge-core/auth/providers/gcp_iap" "github.com/initializ/forge/forge-core/auth/providers/httpverifier" _ "github.com/initializ/forge/forge-core/auth/providers/oidc" "github.com/initializ/forge/forge-core/auth/providers/statictoken" diff --git a/forge-core/auth/providers/gcp_iap/iap_jwks.go b/forge-core/auth/providers/gcp_iap/iap_jwks.go new file mode 100644 index 0000000..bd23293 --- /dev/null +++ b/forge-core/auth/providers/gcp_iap/iap_jwks.go @@ -0,0 +1,306 @@ +package gcp_iap + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// IAPClaims is the projected claim set Forge needs from an IAP JWT. +// IAP-emitted JWTs include several other claims (hd, exp, iat, ...); +// callers can retrieve them from the raw token if needed. +type IAPClaims struct { + Issuer string + Audience []string + Subject string + Email string + HD string // Google Workspace domain (optional) +} + +// IAPJWKSCache fetches and caches the IAP public keys. ES256-only — +// dropping non-EC / non-P-256 / non-ES256-labeled keys during parse +// is a defense-in-depth layer on top of the algorithm whitelist in +// VerifyAndParse. +// +// Refresh model (mirrors Phase 1 OIDC review #1): +// - lastSuccessful tracks the TTL window; reuse cached keys within it. +// - lastAttempt + backoffDuration block fetch stampedes during outages. +// - Stale-grace: if backoff blocks fetch AND a key for the requested +// kid is in cache, return the stale key. IAP rotates keys on the +// order of weeks, so freshness matters less than availability. +type IAPJWKSCache struct { + url string + ttl time.Duration + http *http.Client + + mu sync.RWMutex + keys map[string]*ecdsa.PublicKey + lastSuccessful time.Time + lastAttempt time.Time + backoffDuration time.Duration +} + +// NewIAPJWKSCache builds an empty cache pointing at the given URL. +func NewIAPJWKSCache(url string, ttl, timeout time.Duration) *IAPJWKSCache { + return &IAPJWKSCache{ + url: url, + ttl: ttl, + http: &http.Client{Timeout: timeout}, + keys: map[string]*ecdsa.PublicKey{}, + } +} + +// VerifyAndParse validates the JWT signature against the cached IAP +// public keys and returns the projected claims. Standard claim checks +// (exp/iat/nbf) are performed by the JWT library; iss/aud are checked +// by the caller (Provider.Verify) against the operator's config. +func (j *IAPJWKSCache) VerifyAndParse(ctx context.Context, raw string) (*IAPClaims, error) { + tok, err := jwt.Parse(raw, func(t *jwt.Token) (any, error) { + // Algorithm whitelist: IAP signs with ES256 only. Any other alg + // is rejected BEFORE key lookup so algorithm-confusion attacks + // can't reach the JWKS. + if t.Method.Alg() != "ES256" { + return nil, fmt.Errorf("unexpected alg %q", t.Method.Alg()) + } + kid, _ := t.Header["kid"].(string) + if kid == "" { + return nil, errors.New("missing kid") + } + key, err := j.keyForKID(ctx, kid) + if err != nil { + return nil, err + } + return key, nil + }, jwt.WithValidMethods([]string{"ES256"})) + + if err != nil { + return nil, classifyJWTErr(err) + } + if !tok.Valid { + return nil, fmt.Errorf("%w: jwt.Valid=false", auth.ErrInvalidToken) + } + + // Extract claims via a small intermediate struct so we can handle + // the "aud as string OR array" shape (IAP currently uses string). + payload, err := json.Marshal(tok.Claims) + if err != nil { + return nil, fmt.Errorf("%w: marshal claims: %v", auth.ErrInvalidToken, err) + } + var rc struct { + Issuer string `json:"iss"` + Audience json.RawMessage `json:"aud"` + Subject string `json:"sub"` + Email string `json:"email"` + HD string `json:"hd"` + } + if err := json.Unmarshal(payload, &rc); err != nil { + return nil, fmt.Errorf("%w: unmarshal claims: %v", auth.ErrInvalidToken, err) + } + aud, err := parseAudience(rc.Audience) + if err != nil { + return nil, fmt.Errorf("%w: aud parse: %v", auth.ErrInvalidToken, err) + } + return &IAPClaims{ + Issuer: rc.Issuer, + Audience: aud, + Subject: rc.Subject, + Email: rc.Email, + HD: rc.HD, + }, nil +} + +// parseAudience handles "aud" being either a JSON string or an array. +// JWT spec allows either; IAP currently uses string. +func parseAudience(raw json.RawMessage) ([]string, error) { + if len(raw) == 0 { + return nil, errors.New("aud claim missing") + } + if raw[0] == '"' { + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return nil, err + } + return []string{s}, nil + } + var arr []string + if err := json.Unmarshal(raw, &arr); err != nil { + return nil, err + } + return arr, nil +} + +// keyForKID returns the ES256 public key for the given kid. Refreshes +// the JWKS in the foreground when the cache is stale or a kid isn't +// known, with backoff + stale-grace per the IAPJWKSCache contract. +func (j *IAPJWKSCache) keyForKID(ctx context.Context, kid string) (*ecdsa.PublicKey, error) { + j.mu.RLock() + cached, hit := j.keys[kid] + stale := time.Since(j.lastSuccessful) > j.ttl + j.mu.RUnlock() + + if hit && !stale { + return cached, nil + } + + if err := j.refresh(ctx); err != nil { + if hit { + // Stale-grace: keep using the cached key during outage. + return cached, nil + } + return nil, err + } + + j.mu.RLock() + k := j.keys[kid] + j.mu.RUnlock() + if k == nil { + return nil, fmt.Errorf("%w: kid %q not found in IAP JWKS", auth.ErrInvalidToken, kid) + } + return k, nil +} + +// refresh fetches and parses the JWKS. Backoff doubles on each failure +// (5s → 60s cap); resets on success. +func (j *IAPJWKSCache) refresh(ctx context.Context) error { + j.mu.Lock() + if !j.lastAttempt.IsZero() && time.Since(j.lastAttempt) < j.backoffDuration { + j.mu.Unlock() + return fmt.Errorf("%w: IAP JWKS in backoff", auth.ErrProviderUnavailable) + } + j.lastAttempt = time.Now() + j.mu.Unlock() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, j.url, nil) + resp, err := j.http.Do(req) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS fetch: %v", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + } + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 256<<10)) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS read: %v", auth.ErrProviderUnavailable, err) + } + + keys, err := parseECJWKSet(raw) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS parse: %v", auth.ErrProviderUnavailable, err) + } + + j.mu.Lock() + j.keys = keys + j.lastSuccessful = time.Now() + j.backoffDuration = 0 + j.mu.Unlock() + return nil +} + +func (j *IAPJWKSCache) bumpBackoff() { + j.mu.Lock() + switch { + case j.backoffDuration == 0: + j.backoffDuration = 5 * time.Second + case j.backoffDuration < 60*time.Second: + j.backoffDuration *= 2 + } + j.mu.Unlock() +} + +// parseECJWKSet drops keys that aren't EC/P-256/ES256. Defense in depth +// against a compromised JWKS endpoint trying to slip in RSA keys for +// algorithm-confusion attacks. +func parseECJWKSet(raw []byte) (map[string]*ecdsa.PublicKey, error) { + var set struct { + Keys []struct { + Kid string `json:"kid"` + Kty string `json:"kty"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` + Alg string `json:"alg"` + } `json:"keys"` + } + if err := json.Unmarshal(raw, &set); err != nil { + return nil, err + } + out := map[string]*ecdsa.PublicKey{} + for _, k := range set.Keys { + if k.Kty != "EC" || k.Crv != "P-256" { + continue + } + if k.Alg != "" && k.Alg != "ES256" { + continue + } + x, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + continue + } + y, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + continue + } + out[k.Kid] = &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + } + } + if len(out) == 0 { + return nil, errors.New("IAP JWKS contained no usable ES256 keys") + } + return out, nil +} + +// classifyJWTErr maps golang-jwt errors to auth sentinels. The library +// returns wrapped errors with stable joinings — we string-match on the +// most reliable signals. +// +// Ordering matters: alg-confusion checks come FIRST because the library +// emits "signing method X is invalid" wrapped in "signature is invalid", +// and we want algorithm-confusion classified as ErrInvalidToken (malformed) +// rather than ErrTokenRejected (policy denial). +func classifyJWTErr(err error) error { + if errors.Is(err, auth.ErrProviderUnavailable) || + errors.Is(err, auth.ErrInvalidToken) || + errors.Is(err, auth.ErrTokenRejected) { + return err + } + s := err.Error() + switch { + case strings.Contains(s, "signing method"), + strings.Contains(s, "unexpected alg"), + strings.Contains(s, "missing kid"), + strings.Contains(s, "kid"): + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + case strings.Contains(s, "expired"), + strings.Contains(s, "iat"), + strings.Contains(s, "nbf"), + strings.Contains(s, "signature is invalid"), + strings.Contains(s, "verification error"): + return fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) + default: + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } +} diff --git a/forge-core/auth/providers/gcp_iap/provider.go b/forge-core/auth/providers/gcp_iap/provider.go new file mode 100644 index 0000000..6a5e0a6 --- /dev/null +++ b/forge-core/auth/providers/gcp_iap/provider.go @@ -0,0 +1,151 @@ +// Package gcp_iap authenticates requests that come through GCP's +// Identity-Aware Proxy. IAP terminates user authentication at the +// HTTPS load balancer and forwards a signed JWT in the +// X-Goog-Iap-Jwt-Assertion header on every authenticated request. +// Forge verifies that JWT against IAP's well-known JWKS and stamps +// an Identity from the claims. +// +// Decision §9.4: the IAP issuer ("https://cloud.google.com/iap") +// and JWKS URL ("https://www.gstatic.com/iap/verify/public_key-jwk") +// are HARDCODED. They are the only stable contract GCP exposes. +// Any override knob would be a footgun (operators could be tricked +// into trusting an attacker's JWKS). +// +// Decision §9.5: reads X-Goog-Iap-Jwt-Assertion from the widened +// Headers map (PR 1). +// +// Audit reason codes (Phase 1 contract): +// +// rejected — iss/aud mismatch, expired, bad signature +// invalid — alg != ES256, missing sub/email, malformed +// provider_unavailable — JWKS fetch failed AND no prior key cached +// not_for_me — header absent → next provider +package gcp_iap + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the registry name. +const ProviderName = "gcp_iap" + +// Hardcoded IAP endpoints (decision §9.4). +const ( + iapIssuer = "https://cloud.google.com/iap" + iapJWKSURL = "https://www.gstatic.com/iap/verify/public_key-jwk" +) + +const ( + defaultJWKSRefreshTTL = time.Hour + defaultHTTPTimeout = 5 * time.Second +) + +// Config controls the gcp_iap provider. +type Config struct { + // Audience is REQUIRED. It is the GCP backend service ID, + // shaped like "/projects/PROJECT_NUMBER/global/backendServices/BACKEND_ID". + // Operators find this in the GCP console under + // Security → Identity-Aware Proxy → Backend Service → "Signed Header JWT Audience". + Audience string `yaml:"audience"` + + // JWKSRefreshTTL bounds how long a cached JWKS is reused before a + // background refresh. Default 1h — IAP rotates keys slowly. + JWKSRefreshTTL time.Duration `yaml:"jwks_refresh_ttl,omitempty"` + + // HTTPTimeout caps the JWKS fetch. Default 5s. + HTTPTimeout time.Duration `yaml:"http_timeout,omitempty"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.Audience == "" { + return fmt.Errorf("%w: audience required (e.g. /projects/PNUM/global/backendServices/BACKEND_ID)", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider for GCP IAP-fronted callers. +type Provider struct { + cfg Config + jwks *IAPJWKSCache +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.JWKSRefreshTTL == 0 { + cfg.JWKSRefreshTTL = defaultJWKSRefreshTTL + } + if cfg.HTTPTimeout == 0 { + cfg.HTTPTimeout = defaultHTTPTimeout + } + return &Provider{ + cfg: cfg, + jwks: NewIAPJWKSCache(iapJWKSURL, cfg.JWKSRefreshTTL, cfg.HTTPTimeout), + }, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. +// +// Token is unused — the IAP assertion lives in the +// X-Goog-Iap-Jwt-Assertion header, which the middleware delivers +// via the headers map (PR 1). +func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) (*auth.Identity, error) { + raw := headers.Get("X-Goog-Iap-Jwt-Assertion") + if raw == "" { + return nil, auth.ErrTokenNotForMe + } + + claims, err := p.jwks.VerifyAndParse(ctx, raw) + if err != nil { + return nil, err // already wrapped with the right sentinel + } + + if claims.Issuer != iapIssuer { + return nil, fmt.Errorf("%w: iss=%q, want %q", auth.ErrTokenRejected, claims.Issuer, iapIssuer) + } + if !audienceContains(claims.Audience, p.cfg.Audience) { + return nil, fmt.Errorf("%w: aud mismatch", auth.ErrTokenRejected) + } + if claims.Subject == "" || claims.Email == "" { + // IAP always sets both. Absence implies a malformed/stripped + // token, not a policy denial — return ErrInvalidToken so the + // audit log distinguishes the two cases. + return nil, fmt.Errorf("%w: IAP token missing sub or email", auth.ErrInvalidToken) + } + + return &auth.Identity{ + UserID: claims.Subject, + Email: claims.Email, + Source: ProviderName, + Claims: map[string]any{ + "sub": claims.Subject, + "email": claims.Email, + "hd": claims.HD, + }, + }, nil +} + +func audienceContains(aud []string, want string) bool { + return slices.Contains(aud, want) +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/gcp_iap/provider_test.go b/forge-core/auth/providers/gcp_iap/provider_test.go new file mode 100644 index 0000000..edef453 --- /dev/null +++ b/forge-core/auth/providers/gcp_iap/provider_test.go @@ -0,0 +1,421 @@ +package gcp_iap + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "io" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// --- test signer harness --- + +// es256Signer is a self-contained ES256 signer + JWKS server used by all +// provider/JWKS tests. Lives here (not in a separate test-harness file) +// to keep PR 3 self-contained. +type es256Signer struct { + priv *ecdsa.PrivateKey + pub *ecdsa.PublicKey + kid string + jwksMu *http.ServeMux + srv *httptest.Server +} + +func newES256Signer(t *testing.T) *es256Signer { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("ecdsa key: %v", err) + } + s := &es256Signer{ + priv: priv, + pub: &priv.PublicKey, + kid: "test-kid-1", + } + mux := http.NewServeMux() + mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { + s.writeJWKS(w) + }) + s.jwksMu = mux + s.srv = httptest.NewServer(mux) + t.Cleanup(s.srv.Close) + return s +} + +func (s *es256Signer) writeJWKS(w http.ResponseWriter) { + x := base64.RawURLEncoding.EncodeToString(s.pub.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(s.pub.Y.Bytes()) + doc := map[string]any{ + "keys": []map[string]any{ + { + "kty": "EC", + "crv": "P-256", + "kid": s.kid, + "alg": "ES256", + "x": x, + "y": y, + }, + }, + } + _ = json.NewEncoder(w).Encode(doc) +} + +func (s *es256Signer) sign(claims map[string]any) string { + tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims(claims)) + tok.Header["kid"] = s.kid + str, err := tok.SignedString(s.priv) + if err != nil { + panic(err) + } + return str +} + +func (s *es256Signer) URL() string { return s.srv.URL + "/jwks" } + +// --- helpers --- + +func newProviderPointingAt(t *testing.T, signer *es256Signer, audience string) *Provider { + t.Helper() + p, err := New(Config{ + Audience: audience, + JWKSRefreshTTL: time.Hour, + HTTPTimeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + // Swap the JWKS cache for one pointed at our test signer instead + // of the real IAP URL. This is the per-package equivalent of the + // aws_sigv4 sts_endpoint override. + p.jwks = NewIAPJWKSCache(signer.URL(), time.Hour, 5*time.Second) + return p +} + +func validClaims(audience string) map[string]any { + now := time.Now() + return map[string]any{ + "iss": iapIssuer, + "aud": audience, + "sub": "1234567890", + "email": "alice@example.com", + "hd": "example.com", + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + } +} + +// --- Tests --- + +func TestProvider_NoIAPHeader_YieldsToChain(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + _, err := p.Verify(context.Background(), "", auth.Headers{"Authorization": "Bearer foo"}) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestProvider_HappyPath(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + tok := signer.sign(validClaims("test-aud")) + id, err := p.Verify(context.Background(), "", auth.Headers{ + "X-Goog-Iap-Jwt-Assertion": tok, + }) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "gcp_iap" { + t.Errorf("Source = %q", id.Source) + } + if id.UserID != "1234567890" { + t.Errorf("UserID = %q", id.UserID) + } + if id.Email != "alice@example.com" { + t.Errorf("Email = %q", id.Email) + } + if id.Claims["hd"] != "example.com" { + t.Errorf("Claims[hd] = %v", id.Claims["hd"]) + } +} + +func TestProvider_WrongIssuer_Rejected(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + claims := validClaims("test-aud") + claims["iss"] = "https://accounts.google.com" // common bug: regular Google token, not IAP + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_WrongAudience_Rejected(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "expected-aud") + + tok := signer.sign(validClaims("different-aud")) + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_AudienceAsArray(t *testing.T) { + // JWT spec allows aud as []string. Verify we handle that shape too — + // IAP currently uses string but the contract is broader. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "wanted-aud") + claims := validClaims("placeholder") + claims["aud"] = []string{"other", "wanted-aud"} + tok := signer.sign(claims) + if _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}); err != nil { + t.Errorf("expected success with aud=[]string, got %v", err) + } +} + +func TestProvider_MissingSub_Invalid(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + claims := validClaims("test-aud") + delete(claims, "sub") + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_MissingEmail_Invalid(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + claims := validClaims("test-aud") + delete(claims, "email") + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_ExpiredToken_Rejected(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + claims := validClaims("test-aud") + claims["exp"] = time.Now().Add(-time.Minute).Unix() + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_RS256Token_Rejected(t *testing.T) { + // Algorithm-confusion defense: an RS256 token MUST be rejected + // BEFORE key lookup — gcp_iap accepts ES256 only. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + rsaPriv, _ := rsa.GenerateKey(rand.Reader, 2048) + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(validClaims("test-aud"))) + tok.Header["kid"] = signer.kid + str, _ := tok.SignedString(rsaPriv) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": str}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (alg whitelist)", err) + } +} + +func TestFactory_RejectsMissingAudience(t *testing.T) { + if _, err := auth.Build("gcp_iap", map[string]any{}); err == nil { + t.Fatal("expected error when audience is missing") + } +} + +func TestProvider_Name(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + if p.Name() != "gcp_iap" { + t.Errorf("Name = %q", p.Name()) + } +} + +func TestProvider_RegisteredInRegistry(t *testing.T) { + p, err := auth.Build("gcp_iap", map[string]any{"audience": "test"}) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "gcp_iap" { + t.Errorf("Name = %q", p.Name()) + } +} + +// --- JWKS cache tests (use a controllable mux, not the full signer) --- + +func TestJWKSCache_StaleGraceOnOutage(t *testing.T) { + // 1. Bring signer up, prime cache via successful verify. + // 2. Take JWKS endpoint down. + // 3. Verify same token still works — stale-grace. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + tok := signer.sign(validClaims("test-aud")) + if _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}); err != nil { + t.Fatalf("priming Verify: %v", err) + } + + // Take JWKS down by closing the server. + signer.srv.Close() + // Mark cache stale so we go through refresh path. + p.jwks.mu.Lock() + p.jwks.lastSuccessful = time.Now().Add(-2 * time.Hour) + p.jwks.mu.Unlock() + + if _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}); err != nil { + t.Errorf("stale-grace failed: %v", err) + } +} + +func TestJWKSCache_BackoffBumps(t *testing.T) { + // Endpoint always 500 — observe backoffDuration grow on each refresh. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + if err := c.refresh(context.Background()); err == nil { + t.Fatal("expected first refresh to fail") + } + c.mu.RLock() + first := c.backoffDuration + c.mu.RUnlock() + if first != 5*time.Second { + t.Errorf("backoff after 1st failure = %v, want 5s", first) + } + + // Force a second attempt past the backoff window. + c.mu.Lock() + c.lastAttempt = time.Now().Add(-10 * time.Second) + c.mu.Unlock() + if err := c.refresh(context.Background()); err == nil { + t.Fatal("expected second refresh to fail") + } + c.mu.RLock() + second := c.backoffDuration + c.mu.RUnlock() + if second != 10*time.Second { + t.Errorf("backoff after 2nd failure = %v, want 10s", second) + } +} + +func TestJWKSCache_BackoffBlocksRefresh(t *testing.T) { + // During backoff, refresh() returns ErrProviderUnavailable without + // attempting a network call. + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls++ + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + _ = c.refresh(context.Background()) // failure -> backoff + if calls != 1 { + t.Fatalf("calls after first refresh = %d", calls) + } + err := c.refresh(context.Background()) // should NOT call network + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if calls != 1 { + t.Errorf("network was hit during backoff (calls = %d)", calls) + } +} + +func TestJWKSCache_BodyCap(t *testing.T) { + // Serve 1 MiB of garbage; parse must fail without OOM. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // 256 KiB cap + 1 byte → triggers either truncated parse fail or + // LimitReader cut. + buf := make([]byte, 512<<10) + _, _ = w.Write(buf) + })) + defer srv.Close() + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + err := c.refresh(context.Background()) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestParseECJWKSet_DropsNonES256(t *testing.T) { + // Build a JWKS with one valid EC key and one RSA-shaped entry. + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + x := base64.RawURLEncoding.EncodeToString(priv.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(priv.Y.Bytes()) + doc := map[string]any{ + "keys": []map[string]any{ + {"kty": "EC", "crv": "P-256", "kid": "good", "alg": "ES256", "x": x, "y": y}, + {"kty": "RSA", "kid": "bad", "alg": "RS256", "n": "ignored", "e": "ignored"}, + {"kty": "EC", "crv": "P-256", "kid": "wrong-alg", "alg": "RS256", "x": x, "y": y}, + }, + } + raw, _ := json.Marshal(doc) + keys, err := parseECJWKSet(raw) + if err != nil { + t.Fatalf("parseECJWKSet: %v", err) + } + if _, ok := keys["good"]; !ok { + t.Error("good key dropped") + } + if _, ok := keys["bad"]; ok { + t.Error("RSA key not dropped") + } + if _, ok := keys["wrong-alg"]; ok { + t.Error("EC key labeled RS256 not dropped") + } +} + +func TestParseAudience_StringAndArray(t *testing.T) { + // "aud":"x" + one, err := parseAudience(json.RawMessage(`"x"`)) + if err != nil || len(one) != 1 || one[0] != "x" { + t.Errorf("string aud: got %v, %v", one, err) + } + // "aud":["x","y"] + two, err := parseAudience(json.RawMessage(`["x","y"]`)) + if err != nil || len(two) != 2 { + t.Errorf("array aud: got %v, %v", two, err) + } + // missing + if _, err := parseAudience(nil); err == nil { + t.Error("missing aud should error") + } +} + +// ensure import not flagged +var ( + _ = io.EOF + _ = big.NewInt +) diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go index b1245c5..bf4468b 100644 --- a/forge-core/security/auth_domains.go +++ b/forge-core/security/auth_domains.go @@ -68,6 +68,9 @@ func authProviderURLs(p types.AuthProvider) []string { out = append(out, override) } return out + case "gcp_iap": + // Decision §9.4: IAP JWKS host is hardcoded. + return []string{"https://www.gstatic.com/iap/verify/public_key-jwk"} // static_token has no outbound; not listed default: return nil diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go index 58d7fe6..2f68719 100644 --- a/forge-core/security/auth_domains_test.go +++ b/forge-core/security/auth_domains_test.go @@ -185,6 +185,34 @@ func TestAuthDomains_AWSSigv4_MissingRegionReturnsEmpty(t *testing.T) { } } +func TestAuthDomains_GCPIAP(t *testing.T) { + // Decision §9.4: IAP JWKS host is hardcoded — same domain returned + // regardless of audience or any other config. + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "gcp_iap", Settings: map[string]any{ + "audience": "/projects/12345/global/backendServices/67890", + }}, + }, + }) + if len(got) != 1 || got[0] != "www.gstatic.com" { + t.Errorf("AuthDomains = %v, want [www.gstatic.com]", got) + } +} + +func TestAuthDomains_GCPIAP_MultipleEntriesDedup(t *testing.T) { + // Even with two IAP entries, the host appears once (dedup). + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "gcp_iap", Settings: map[string]any{"audience": "a"}}, + {Type: "gcp_iap", Settings: map[string]any{"audience": "b"}}, + }, + }) + if len(got) != 1 || got[0] != "www.gstatic.com" { + t.Errorf("AuthDomains = %v, want [www.gstatic.com] (dedup)", got) + } +} + func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { got := security.AuthDomains(types.AuthConfig{ Providers: []types.AuthProvider{ From 1e231406245edebc4ba1d4e54737340eeadbb310 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 19:44:57 -0400 Subject: [PATCH 04/22] feat(auth/azure_ad): add Entra ID provider composing OIDC (phase 2 pr 4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit azure_ad authenticates Microsoft Entra ID tokens. Composes the Phase 1 oidc.Provider (decision §9.2) for signature verify + base claim validation; layers AAD-specific concerns on top: - Tenant lock-in via the tid claim - Optional Microsoft Graph group enrichment when JWT groups claim is empty (AAD truncates at ~200 groups) - Single-tenant vs multi-tenant issuer template key pieces: - provider.go: composed oidc + tenant gate + Source overwrite to "azure_ad" (replaces the inner "oidc" stamp) - tenant.go: ExtractTenantID — typed accessor for the tid claim - graph_client.go: Graph /me/transitiveMemberOf with pagination, same-host enforcement (rejects redirect attacks), 401/403 -> ErrTokenRejected, 5xx -> ErrProviderUnavailable, defensive cap at 5000 groups, body cap 1 MiB per page - graph_cache.go: 5 min TTL, same shape as aws_sigv4's cache key decisions: - oidc.Config gains internal SkipIssuerCheck flag with yaml:"-" so it CANNOT be set via forge.yaml — only callable from another Go package. AAD multi-tenant uses it; everything else leaves it off. Surfacing it in YAML would let operators disable iss validation by accident. - Soft-fail on Graph 5xx/401: Identity returned with empty Groups rather than blocking prod traffic. Hard-fail mode (graph_required) out of scope for v0.11. - Forge reflects the CALLER's Bearer to Graph; holds no Graph credentials of its own. audit reasons: - ErrTokenRejected -> rejected (tid mismatch, bad sig, Graph 401) - ErrInvalidToken -> invalid (missing tid, malformed claims) - ErrProviderUnavailable -> provider_unavailable (Graph 5xx, JWKS down) extras: - security.AuthDomains returns login.microsoftonline.com always; graph.microsoft.com when groups_mode=graph - forge-cli/runtime/runner.go side-effect imports azure_ad --- forge-cli/runtime/runner.go | 1 + .../auth/providers/azure_ad/graph_cache.go | 54 +++ .../providers/azure_ad/graph_cache_test.go | 34 ++ .../auth/providers/azure_ad/graph_client.go | 144 ++++++ .../providers/azure_ad/graph_client_test.go | 151 ++++++ .../auth/providers/azure_ad/provider.go | 209 +++++++++ .../auth/providers/azure_ad/provider_test.go | 428 ++++++++++++++++++ forge-core/auth/providers/azure_ad/tenant.go | 9 + forge-core/auth/providers/oidc/provider.go | 24 +- forge-core/security/auth_domains.go | 8 + forge-core/security/auth_domains_test.go | 53 +++ 11 files changed, 1113 insertions(+), 2 deletions(-) create mode 100644 forge-core/auth/providers/azure_ad/graph_cache.go create mode 100644 forge-core/auth/providers/azure_ad/graph_cache_test.go create mode 100644 forge-core/auth/providers/azure_ad/graph_client.go create mode 100644 forge-core/auth/providers/azure_ad/graph_client_test.go create mode 100644 forge-core/auth/providers/azure_ad/provider.go create mode 100644 forge-core/auth/providers/azure_ad/provider_test.go create mode 100644 forge-core/auth/providers/azure_ad/tenant.go diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 0ccbab0..b3b6a89 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -23,6 +23,7 @@ import ( // Listed here even when the package is also referenced directly // (httpverifier, statictoken) for grep-ability. _ "github.com/initializ/forge/forge-core/auth/providers/aws_sigv4" + _ "github.com/initializ/forge/forge-core/auth/providers/azure_ad" _ "github.com/initializ/forge/forge-core/auth/providers/gcp_iap" "github.com/initializ/forge/forge-core/auth/providers/httpverifier" _ "github.com/initializ/forge/forge-core/auth/providers/oidc" diff --git a/forge-core/auth/providers/azure_ad/graph_cache.go b/forge-core/auth/providers/azure_ad/graph_cache.go new file mode 100644 index 0000000..f102cb5 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_cache.go @@ -0,0 +1,54 @@ +package azure_ad + +import ( + "sync" + "time" +) + +// GraphCache holds enriched group memberships keyed by user ID, with a +// short TTL. Bounds how long a stale "removed from group" state stays +// cached after AAD's reality changes. +type GraphCache struct { + ttl time.Duration + mu sync.RWMutex + data map[string]graphEntry + now func() time.Time +} + +type graphEntry struct { + groups []string + expireAt time.Time +} + +// NewGraphCache builds an empty cache. +func NewGraphCache(ttl time.Duration) *GraphCache { + return &GraphCache{ + ttl: ttl, + data: make(map[string]graphEntry), + now: time.Now, + } +} + +// Get returns the cached groups for userID, or (nil, false) on miss/expiry. +func (c *GraphCache) Get(userID string) ([]string, bool) { + c.mu.RLock() + e, ok := c.data[userID] + c.mu.RUnlock() + if !ok || c.now().After(e.expireAt) { + return nil, false + } + return e.groups, true +} + +// Put stores the groups under userID with a fresh TTL. Overwrites any +// prior entry (does not extend). +func (c *GraphCache) Put(userID string, groups []string) { + c.mu.Lock() + c.data[userID] = graphEntry{ + groups: groups, + expireAt: c.now().Add(c.ttl), + } + c.mu.Unlock() +} + +func (c *GraphCache) setNow(fn func() time.Time) { c.now = fn } diff --git a/forge-core/auth/providers/azure_ad/graph_cache_test.go b/forge-core/auth/providers/azure_ad/graph_cache_test.go new file mode 100644 index 0000000..15e030e --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_cache_test.go @@ -0,0 +1,34 @@ +package azure_ad + +import ( + "testing" + "time" +) + +func TestGraphCache_HitMiss(t *testing.T) { + c := NewGraphCache(time.Minute) + if _, ok := c.Get("user-1"); ok { + t.Error("empty cache returned hit") + } + c.Put("user-1", []string{"g1", "g2"}) + groups, ok := c.Get("user-1") + if !ok || len(groups) != 2 { + t.Errorf("Get = (%v, %v)", groups, ok) + } +} + +func TestGraphCache_Expiry(t *testing.T) { + now := time.Unix(1_700_000_000, 0) + c := NewGraphCache(60 * time.Second) + c.setNow(func() time.Time { return now }) + + c.Put("user-1", []string{"g1"}) + + if _, ok := c.Get("user-1"); !ok { + t.Fatal("expected hit before expiry") + } + now = now.Add(61 * time.Second) + if _, ok := c.Get("user-1"); ok { + t.Error("expected miss after expiry") + } +} diff --git a/forge-core/auth/providers/azure_ad/graph_client.go b/forge-core/auth/providers/azure_ad/graph_client.go new file mode 100644 index 0000000..5ae7980 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_client.go @@ -0,0 +1,144 @@ +package azure_ad + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// GraphClient calls Microsoft Graph /me/transitiveMemberOf to enrich +// group memberships when the JWT's groups claim overflows (AAD truncates +// groups when the user is in more than ~200 of them). +// +// Forge holds NO Graph credentials of its own — the caller's Bearer +// token is reflected to Graph, which authorizes the read against the +// user's delegated permission (GroupMember.Read.All). +type GraphClient struct { + endpoint string // override-able for tests + http *http.Client +} + +const graphBaseURL = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id&$top=100" + +// NewGraphClient builds a client pointed at the real Graph endpoint. +func NewGraphClient(timeout time.Duration) *GraphClient { + return &GraphClient{ + endpoint: graphBaseURL, + http: &http.Client{Timeout: timeout}, + } +} + +// NewGraphClientWithEndpoint is a TEST-ONLY constructor for pointing at +// a fake Graph server. +func NewGraphClientWithEndpoint(endpoint string, timeout time.Duration) *GraphClient { + return &GraphClient{ + endpoint: endpoint, + http: &http.Client{Timeout: timeout}, + } +} + +// TransitiveMemberOf walks the paginated response and returns the full +// list of (transitive) group object IDs the caller belongs to. The +// authHeader is reflected verbatim — Forge does not authenticate to Graph +// independently. +// +// Error classification: +// +// 401 / 403 → auth.ErrTokenRejected (caller's token missing +// GroupMember.Read.All consent) +// 5xx / network → auth.ErrProviderUnavailable +// @odata.nextLink pointing at a foreign host → error (never followed) +func (c *GraphClient) TransitiveMemberOf(ctx context.Context, _ string, authHeader string) ([]string, error) { + if authHeader == "" { + return nil, fmt.Errorf("%w: graph enrichment needs a forwardable Bearer", auth.ErrInvalidToken) + } + out := []string{} + next := c.endpoint + for next != "" { + if err := ensureGraphHost(c.endpoint, next); err != nil { + return nil, fmt.Errorf("%w: graph nextLink host: %v", auth.ErrProviderUnavailable, err) + } + page, nextURL, err := c.fetchPage(ctx, next, authHeader) + if err != nil { + return nil, err + } + out = append(out, page...) + next = nextURL + if len(out) > 5000 { + return nil, errors.New("graph response exceeds 5000 groups (likely misconfiguration)") + } + } + return out, nil +} + +func (c *GraphClient) fetchPage(ctx context.Context, u, authHeader string) (ids []string, next string, err error) { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + req.Header.Set("Authorization", authHeader) + req.Header.Set("Accept", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + return nil, "", fmt.Errorf("%w: graph fetch: %v", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1 MiB cap per page + if err != nil { + return nil, "", fmt.Errorf("%w: graph read: %v", auth.ErrProviderUnavailable, err) + } + + switch { + case resp.StatusCode == http.StatusOK: + // fall through to parse + case resp.StatusCode == http.StatusUnauthorized, resp.StatusCode == http.StatusForbidden: + return nil, "", fmt.Errorf("%w: graph %d (likely missing GroupMember.Read.All consent)", auth.ErrTokenRejected, resp.StatusCode) + case resp.StatusCode >= 500: + return nil, "", fmt.Errorf("%w: graph HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + default: + return nil, "", fmt.Errorf("%w: graph HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + } + + var page struct { + Value []struct { + ID string `json:"id"` + } `json:"value"` + NextLink string `json:"@odata.nextLink"` + } + if err := json.Unmarshal(body, &page); err != nil { + return nil, "", fmt.Errorf("%w: graph parse: %v", auth.ErrProviderUnavailable, err) + } + ids = make([]string, 0, len(page.Value)) + for _, g := range page.Value { + ids = append(ids, g.ID) + } + return ids, page.NextLink, nil +} + +// ensureGraphHost rejects @odata.nextLink values that point at a foreign +// host. Real Graph paginates within graph.microsoft.com. For tests where +// the endpoint is httptest's 127.0.0.1, the test-mode endpoint host is +// what we compare against. +func ensureGraphHost(configured, candidate string) error { + if candidate == "" { + return nil + } + want, err := url.Parse(configured) + if err != nil { + return err + } + got, err := url.Parse(candidate) + if err != nil { + return err + } + if !strings.EqualFold(want.Host, got.Host) { + return fmt.Errorf("nextLink host %q does not match configured %q", got.Host, want.Host) + } + return nil +} diff --git a/forge-core/auth/providers/azure_ad/graph_client_test.go b/forge-core/auth/providers/azure_ad/graph_client_test.go new file mode 100644 index 0000000..116c1f0 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_client_test.go @@ -0,0 +1,151 @@ +package azure_ad + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +func TestGraph_HappyPath_Paginated(t *testing.T) { + var calls int + mux := http.NewServeMux() + var graphURL string + + mux.HandleFunc("/page1", func(w http.ResponseWriter, _ *http.Request) { + calls++ + _, _ = io.WriteString(w, `{"value":[{"id":"g1"},{"id":"g2"}],"@odata.nextLink":"`+graphURL+`/page2"}`) + }) + mux.HandleFunc("/page2", func(w http.ResponseWriter, _ *http.Request) { + calls++ + _, _ = io.WriteString(w, `{"value":[{"id":"g3"}]}`) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + graphURL = srv.URL + + c := NewGraphClientWithEndpoint(srv.URL+"/page1", 5*time.Second) + out, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if err != nil { + t.Fatalf("TransitiveMemberOf: %v", err) + } + if len(out) != 3 || out[0] != "g1" || out[2] != "g3" { + t.Errorf("groups = %v", out) + } + if calls != 2 { + t.Errorf("pages fetched = %d, want 2", calls) + } +} + +func TestGraph_401_Rejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestGraph_403_Rejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestGraph_500_Unavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestGraph_NoAuthHeader_Invalid(t *testing.T) { + c := NewGraphClientWithEndpoint("http://does.not.matter", 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "") + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestGraph_AuthHeaderReflected(t *testing.T) { + var captured string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Get("Authorization") + _, _ = io.WriteString(w, `{"value":[]}`) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, _ = c.TransitiveMemberOf(context.Background(), "user-1", "Bearer the-token") + if captured != "Bearer the-token" { + t.Errorf("Graph got Authorization = %q", captured) + } +} + +func TestGraph_DefensivePaginationCap(t *testing.T) { + // Server keeps emitting @odata.nextLink to itself; expect cap to fire. + var srvURL string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, fmt.Sprintf(`{"value":[%s],"@odata.nextLink":"%s"}`, manyIDs(100), srvURL)) + })) + defer srv.Close() + srvURL = srv.URL + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer x") + if err == nil { + t.Fatal("expected defensive cap error") + } +} + +func TestEnsureGraphHost_RejectsForeignHost(t *testing.T) { + err := ensureGraphHost("https://graph.microsoft.com/v1.0/me", "https://evil.example.com/me/next") + if err == nil { + t.Fatal("expected error on foreign host") + } +} + +func TestEnsureGraphHost_AcceptsSameHost(t *testing.T) { + err := ensureGraphHost( + "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id", + "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$skiptoken=abc", + ) + if err != nil { + t.Errorf("same-host nextLink rejected: %v", err) + } +} + +func TestEnsureGraphHost_EmptyOK(t *testing.T) { + if err := ensureGraphHost("https://graph.microsoft.com/v1.0/me", ""); err != nil { + t.Errorf("empty nextLink should be ok, got %v", err) + } +} + +// manyIDs returns a JSON snippet for `count` entries — used for the +// pagination cap test. +func manyIDs(count int) string { + parts := make([]string, count) + for i := range count { + parts[i] = fmt.Sprintf(`{"id":"g-%d"}`, i) + } + return strings.Join(parts, ",") +} diff --git a/forge-core/auth/providers/azure_ad/provider.go b/forge-core/auth/providers/azure_ad/provider.go new file mode 100644 index 0000000..1591916 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/provider.go @@ -0,0 +1,209 @@ +// Package azure_ad authenticates Microsoft Entra ID (Azure AD) tokens. +// Composes the Phase 1 oidc.Provider (decision §9.2) for the heavy +// lifting — signature verify and base claim validation — and layers +// AAD-specific concerns on top: +// +// - Tenant lock-in via the `tid` claim +// - Optional Microsoft Graph group enrichment when the JWT's groups +// claim overflows (AAD truncates at ~200 groups) +// - Correct issuer template for single- vs. multi-tenant +// +// Decision §9.5: standard Bearer flow; no widened-header use. +// +// Audit reason codes (Phase 1 contract): +// +// rejected — bad signature, expired, tid mismatch, +// aud mismatch, Graph 401/403 +// invalid — missing tid, malformed claims, +// unsupported alg +// provider_unavailable — AAD JWKS down, Graph 5xx +// not_for_me — empty Bearer (delegates to OIDC's looksLikeJWT) +package azure_ad + +import ( + "context" + "fmt" + "time" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/oidc" +) + +// ProviderName is the registry name. +const ProviderName = "azure_ad" + +const ( + aadAuthorityBase = "https://login.microsoftonline.com" + + defaultGraphTimeout = 5 * time.Second + defaultJWKSCacheTTL = time.Hour + defaultGraphCacheTTL = 5 * time.Minute +) + +// Config controls the azure_ad provider. +type Config struct { + // TenantID is the Entra tenant GUID. REQUIRED unless AllowMultiTenant + // is true. + TenantID string `yaml:"tenant_id"` + + // Audience is REQUIRED. Typically the Application ID URI from the + // app registration (e.g. "api://forge"). + Audience string `yaml:"audience"` + + // AllowMultiTenant enables accepting tokens from ANY Entra tenant. + // Defaults to false (single-tenant — safe choice). PR6 docs flag the + // security implications of flipping this on in production. + AllowMultiTenant bool `yaml:"allow_multi_tenant,omitempty"` + + // GroupsMode is "claim" (default — uses the in-JWT groups/roles + // claim) or "graph" (queries Microsoft Graph when groups are missing, + // i.e. AAD overage). + GroupsMode string `yaml:"groups_mode,omitempty"` + + // GraphTimeout caps each Graph call. Default 5s. Only used when + // GroupsMode == "graph". + GraphTimeout time.Duration `yaml:"graph_timeout,omitempty"` + + // JWKSCacheTTL bounds the JWKS cache age. Defaults to 1h. + JWKSCacheTTL time.Duration `yaml:"jwks_cache_ttl,omitempty"` + + // GraphEndpoint is a TEST-ONLY override pointing at a fake Graph + // server. Empty in production. + GraphEndpoint string `yaml:"-"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.Audience == "" { + return fmt.Errorf("%w: audience required (e.g. api://forge)", auth.ErrProviderNotConfigured) + } + if !c.AllowMultiTenant && c.TenantID == "" { + return fmt.Errorf("%w: tenant_id required unless allow_multi_tenant=true", auth.ErrProviderNotConfigured) + } + if c.GroupsMode != "" && c.GroupsMode != "claim" && c.GroupsMode != "graph" { + return fmt.Errorf("%w: groups_mode must be 'claim' or 'graph', got %q", auth.ErrProviderNotConfigured, c.GroupsMode) + } + return nil +} + +// Provider implements auth.Provider for AAD callers. +type Provider struct { + cfg Config + oidc *oidc.Provider // composition (decision §9.2) + graph *GraphClient // nil unless GroupsMode == "graph" + cache *GraphCache // nil unless GroupsMode == "graph" +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.GroupsMode == "" { + cfg.GroupsMode = "claim" + } + if cfg.GraphTimeout == 0 { + cfg.GraphTimeout = defaultGraphTimeout + } + if cfg.JWKSCacheTTL == 0 { + cfg.JWKSCacheTTL = defaultJWKSCacheTTL + } + + inner, err := oidc.New(oidc.Config{ + Issuer: resolveIssuer(cfg), + Audience: cfg.Audience, + JWKSCacheTTL: cfg.JWKSCacheTTL, + SkipIssuerCheck: cfg.AllowMultiTenant, + }) + if err != nil { + return nil, fmt.Errorf("azure_ad: composing oidc provider: %w", err) + } + + p := &Provider{cfg: cfg, oidc: inner} + if cfg.GroupsMode == "graph" { + if cfg.GraphEndpoint != "" { + p.graph = NewGraphClientWithEndpoint(cfg.GraphEndpoint, cfg.GraphTimeout) + } else { + p.graph = NewGraphClient(cfg.GraphTimeout) + } + p.cache = NewGraphCache(defaultGraphCacheTTL) + } + return p, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. +func (p *Provider) Verify(ctx context.Context, token string, headers auth.Headers) (*auth.Identity, error) { + id, err := p.oidc.Verify(ctx, token, headers) + if err != nil { + return nil, err + } + + // Tenant gate (single-tenant default). + if !p.cfg.AllowMultiTenant { + tid := ExtractTenantID(id.Claims) + if tid == "" { + return nil, fmt.Errorf("%w: AAD token missing tid claim", auth.ErrInvalidToken) + } + if tid != p.cfg.TenantID { + return nil, fmt.Errorf("%w: tid mismatch", auth.ErrTokenRejected) + } + } + + // Optional Graph enrichment. + if p.cfg.GroupsMode == "graph" && needsEnrichment(id.Groups) { + if enriched, err := p.enrichGroups(ctx, id, headers); err == nil { + id.Groups = enriched + } + // Soft-fail on Graph failure: leave Groups empty rather than + // blocking prod traffic on a transient outage. Hard-fail mode + // (graph_required: true) is out of scope for v0.11. + } + + id.Source = ProviderName // overwrite oidc's "oidc" stamp + return id, nil +} + +// resolveIssuer picks the issuer URL passed to the composed OIDC +// provider. For single-tenant, that's the full per-tenant authority. +// For multi-tenant, the "common" endpoint serves JWKS for all tenants +// — we point OIDC there and disable iss equality (azure_ad enforces +// tenant via the tid claim instead). +func resolveIssuer(cfg Config) string { + if cfg.AllowMultiTenant { + return aadAuthorityBase + "/common/v2.0" + } + return fmt.Sprintf("%s/%s/v2.0", aadAuthorityBase, cfg.TenantID) +} + +// needsEnrichment returns true when Graph should be consulted. AAD +// emits no `groups` claim (or a `_claim_names` indicator) when a user +// is in too many groups. Phase 1 OIDC surfaces empty Groups in that +// case — treating empty as "enrich" catches it. +func needsEnrichment(groups []string) bool { + return len(groups) == 0 +} + +func (p *Provider) enrichGroups(ctx context.Context, id *auth.Identity, headers auth.Headers) ([]string, error) { + if cached, ok := p.cache.Get(id.UserID); ok { + return cached, nil + } + groups, err := p.graph.TransitiveMemberOf(ctx, id.UserID, headers.Get("Authorization")) + if err != nil { + return nil, err + } + p.cache.Put(id.UserID, groups) + return groups, nil +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/azure_ad/provider_test.go b/forge-core/auth/providers/azure_ad/provider_test.go new file mode 100644 index 0000000..587d790 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/provider_test.go @@ -0,0 +1,428 @@ +package azure_ad + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/oidc" +) + +// fakeAAD is a tiny in-memory AAD: serves an OIDC discovery doc + RS256 +// JWKS, lets tests sign tokens with the matching key. +type fakeAAD struct { + t *testing.T + priv *rsa.PrivateKey + pub *rsa.PublicKey + kid string + srv *httptest.Server +} + +func newFakeAAD(t *testing.T) *fakeAAD { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + f := &fakeAAD{t: t, priv: priv, pub: &priv.PublicKey, kid: "kid-1"} + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", f.serveDiscovery) + mux.HandleFunc("/keys", f.serveJWKS) + f.srv = httptest.NewServer(mux) + t.Cleanup(f.srv.Close) + return f +} + +func (f *fakeAAD) issuerURL() string { return f.srv.URL } + +func (f *fakeAAD) serveDiscovery(w http.ResponseWriter, _ *http.Request) { + doc := map[string]any{ + "issuer": f.srv.URL, + "jwks_uri": f.srv.URL + "/keys", + } + _ = json.NewEncoder(w).Encode(doc) +} + +func (f *fakeAAD) serveJWKS(w http.ResponseWriter, _ *http.Request) { + n := base64.RawURLEncoding.EncodeToString(f.pub.N.Bytes()) + eBytes := make([]byte, 8) + binary.BigEndian.PutUint64(eBytes, uint64(f.pub.E)) + i := 0 + for i < len(eBytes)-1 && eBytes[i] == 0 { + i++ + } + e := base64.RawURLEncoding.EncodeToString(eBytes[i:]) + _ = json.NewEncoder(w).Encode(map[string]any{ + "keys": []map[string]any{ + {"kty": "RSA", "kid": f.kid, "alg": "RS256", "use": "sig", "n": n, "e": e}, + }, + }) +} + +func (f *fakeAAD) sign(claims jwt.MapClaims) string { + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = f.kid + str, err := tok.SignedString(f.priv) + if err != nil { + panic(err) + } + return str +} + +func validAADClaims(iss, tenant, audience string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": iss, + "aud": audience, + "tid": tenant, + "sub": "user-1", + "email": "alice@example.com", + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + "nbf": now.Unix(), + } +} + +// newWithIssuer builds an azure_ad Provider whose composed OIDC points +// at a test-supplied issuer URL (rather than the AAD authority host). +// Production code uses the AAD authority template via New(); this helper +// lets tests substitute a fakeAAD without piping a TEST_ONLY field +// through the public Config surface. +func newWithIssuer(cfg Config, issuer string) (*Provider, error) { + if cfg.Audience == "" { + return nil, fmt.Errorf("audience required") + } + if cfg.GroupsMode == "" { + cfg.GroupsMode = "claim" + } + if cfg.GraphTimeout == 0 { + cfg.GraphTimeout = defaultGraphTimeout + } + if cfg.JWKSCacheTTL == 0 { + cfg.JWKSCacheTTL = defaultJWKSCacheTTL + } + inner, err := oidc.New(oidc.Config{ + Issuer: issuer, + Audience: cfg.Audience, + JWKSCacheTTL: cfg.JWKSCacheTTL, + SkipIssuerCheck: cfg.AllowMultiTenant, + }) + if err != nil { + return nil, fmt.Errorf("compose oidc: %w", err) + } + p := &Provider{cfg: cfg, oidc: inner} + if cfg.GroupsMode == "graph" { + if cfg.GraphEndpoint != "" { + p.graph = NewGraphClientWithEndpoint(cfg.GraphEndpoint, cfg.GraphTimeout) + } else { + p.graph = NewGraphClient(cfg.GraphTimeout) + } + p.cache = NewGraphCache(defaultGraphCacheTTL) + } + return p, nil +} + +func newTestProviderAADSingleTenant(t *testing.T, f *fakeAAD, tid string) *Provider { + t.Helper() + p, err := newWithIssuer(Config{ + Audience: "api://forge", + TenantID: tid, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + return p +} + +func newTestProviderAADMultiTenant(t *testing.T, f *fakeAAD) *Provider { + t.Helper() + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + return p +} + +func newTestProviderAADGraphMode(t *testing.T, f *fakeAAD, tid, graphURL string) *Provider { + t.Helper() + p, err := newWithIssuer(Config{ + Audience: "api://forge", + TenantID: tid, + GroupsMode: "graph", + GraphEndpoint: graphURL, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + return p +} + +// --- Tests --- + +func TestFactory_RejectsMissingAudience(t *testing.T) { + if _, err := auth.Build("azure_ad", map[string]any{ + "tenant_id": "00000000-0000-0000-0000-000000000000", + }); err == nil { + t.Fatal("expected error when audience is missing") + } +} + +func TestFactory_RejectsMissingTenantUnlessMultiTenant(t *testing.T) { + if _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + }); err == nil { + t.Fatal("expected error when tenant_id missing and not multi-tenant") + } +} + +func TestFactory_AcceptsMultiTenantWithoutTenant(t *testing.T) { + if _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "allow_multi_tenant": true, + }); err != nil { + t.Errorf("expected success, got %v", err) + } +} + +func TestFactory_RejectsBadGroupsMode(t *testing.T) { + _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "tenant_id": "00000000-0000-0000-0000-000000000000", + "groups_mode": "bogus", + }) + if err == nil { + t.Fatal("expected error for invalid groups_mode") + } +} + +func TestProvider_HappyPath_ClaimMode(t *testing.T) { + f := newFakeAAD(t) + wantTID := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADSingleTenant(t, f, wantTID) + + claims := validAADClaims(f.issuerURL(), wantTID, "api://forge") + claims["groups"] = []string{"g1", "g2"} + tok := f.sign(claims) + + id, err := p.Verify(context.Background(), tok, auth.Headers{}) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "azure_ad" { + t.Errorf("Source = %q, want azure_ad (composition must overwrite oidc stamp)", id.Source) + } + if id.UserID != "user-1" { + t.Errorf("UserID = %q", id.UserID) + } + if len(id.Groups) != 2 { + t.Errorf("Groups = %v", id.Groups) + } +} + +func TestProvider_WrongTenant_Rejected(t *testing.T) { + f := newFakeAAD(t) + wantTID := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADSingleTenant(t, f, wantTID) + + claims := validAADClaims(f.issuerURL(), "22222222-2222-2222-2222-222222222222", "api://forge") + tok := f.sign(claims) + + _, err := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (tid mismatch)", err) + } +} + +func TestProvider_MultiTenant_AcceptsArbitraryTenant(t *testing.T) { + f := newFakeAAD(t) + p := newTestProviderAADMultiTenant(t, f) + + claims := validAADClaims(f.issuerURL(), "any-tenant-uuid", "api://forge") + tok := f.sign(claims) + + if _, err := p.Verify(context.Background(), tok, auth.Headers{}); err != nil { + t.Errorf("expected multi-tenant success, got %v", err) + } +} + +func TestProvider_MissingTid_Invalid(t *testing.T) { + f := newFakeAAD(t) + p := newTestProviderAADSingleTenant(t, f, "11111111-1111-1111-1111-111111111111") + + claims := validAADClaims(f.issuerURL(), "ignored", "api://forge") + delete(claims, "tid") + tok := f.sign(claims) + + _, err := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_GraphMode_EnrichesEmptyGroups(t *testing.T) { + f := newFakeAAD(t) + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"value":[{"id":"g1"},{"id":"g2"}]}`) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + // No groups claim → simulate overage. + tok := f.sign(claims) + + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if len(id.Groups) != 2 { + t.Errorf("Groups = %v, want 2 enriched ids", id.Groups) + } +} + +func TestProvider_GraphMode_SoftFailsOn5xx(t *testing.T) { + f := newFakeAAD(t) + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + tok := f.sign(claims) + + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Errorf("expected soft-fail (nil err), got %v", err) + } + if id != nil && len(id.Groups) != 0 { + t.Errorf("Groups should be empty after Graph 5xx, got %v", id.Groups) + } +} + +func TestProvider_GraphMode_401SoftFails(t *testing.T) { + // Even on Graph 401/403, the auth request itself proceeds — soft-fail + // keeps the Identity flowing with empty Groups. The graph-side error + // is surfaced separately if the operator wires audit hooks. Pinning + // this contract here so the behavior doesn't drift. + f := newFakeAAD(t) + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + tok := f.sign(claims) + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Errorf("Graph 401 should soft-fail at provider level, got err=%v", err) + } + if id == nil { + t.Fatal("Identity should be returned even on Graph 401") + } + if len(id.Groups) != 0 { + t.Errorf("Groups should be empty after Graph 401, got %v", id.Groups) + } +} + +func TestProvider_GraphMode_ClaimPresent_SkipsGraph(t *testing.T) { + // When the JWT carries a groups claim, we MUST NOT call Graph (avoid + // extra latency + unneeded Graph permission requirements). + f := newFakeAAD(t) + graphCalls := 0 + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + graphCalls++ + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + claims["groups"] = []string{"g-from-jwt"} + tok := f.sign(claims) + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if graphCalls != 0 { + t.Errorf("Graph called %d times, want 0 (claim was present)", graphCalls) + } + if len(id.Groups) != 1 || id.Groups[0] != "g-from-jwt" { + t.Errorf("Groups = %v, want [g-from-jwt]", id.Groups) + } +} + +func TestProvider_RegisteredInRegistry(t *testing.T) { + p, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "allow_multi_tenant": true, + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "azure_ad" { + t.Errorf("Name = %q", p.Name()) + } +} + +func TestExtractTenantID(t *testing.T) { + cases := []struct { + in map[string]any + want string + }{ + {map[string]any{"tid": "abc"}, "abc"}, + {map[string]any{}, ""}, + {map[string]any{"tid": 123}, ""}, + {nil, ""}, + } + for _, tc := range cases { + if got := ExtractTenantID(tc.in); got != tc.want { + t.Errorf("ExtractTenantID(%v) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestNeedsEnrichment(t *testing.T) { + if !needsEnrichment(nil) { + t.Error("nil groups should need enrichment") + } + if !needsEnrichment([]string{}) { + t.Error("empty groups should need enrichment") + } + if needsEnrichment([]string{"g1"}) { + t.Error("populated groups should not need enrichment") + } +} + +func TestProvider_Name(t *testing.T) { + f := newFakeAAD(t) + p := newTestProviderAADMultiTenant(t, f) + if p.Name() != "azure_ad" { + t.Errorf("Name = %q", p.Name()) + } +} diff --git a/forge-core/auth/providers/azure_ad/tenant.go b/forge-core/auth/providers/azure_ad/tenant.go new file mode 100644 index 0000000..7f2f47e --- /dev/null +++ b/forge-core/auth/providers/azure_ad/tenant.go @@ -0,0 +1,9 @@ +package azure_ad + +// ExtractTenantID returns the "tid" claim, or "" if it's missing/non-string. +// The empty-return form is intentional — callers want to distinguish +// "missing" from "wrong tenant," and a typed error here would be noise. +func ExtractTenantID(claims map[string]any) string { + tid, _ := claims["tid"].(string) + return tid +} diff --git a/forge-core/auth/providers/oidc/provider.go b/forge-core/auth/providers/oidc/provider.go index 47c985d..67fa509 100644 --- a/forge-core/auth/providers/oidc/provider.go +++ b/forge-core/auth/providers/oidc/provider.go @@ -86,6 +86,19 @@ type Config struct { // HTTPClient overrides the default client. Injectable for tests. HTTPClient *http.Client `yaml:"-"` + + // SkipIssuerCheck disables the iss-claim equality check. INTERNAL — + // the yaml:"-" tag means this CANNOT be set via forge.yaml; it is + // only reachable when another Go package constructs oidc.Config + // directly (currently only azure_ad's multi-tenant mode). + // + // Reason it exists: AAD's "common" / multi-tenant issuer template + // uses a per-token tenant ID that string-equality can't satisfy. The + // caller (azure_ad) takes responsibility for tenant enforcement via + // the tid claim instead. Surfacing this in forge.yaml would let + // operators disable iss validation by accident — which is exactly + // the "open verifier" footgun this package is designed to prevent. + SkipIssuerCheck bool `yaml:"-"` } // Validate returns ErrProviderNotConfigured when required fields are @@ -255,8 +268,15 @@ func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Hea // IdPs that emit iss with/without a trailing slash interop with // configs that use the opposite form). See Config.normalize and // review finding #2. - if err := p.checkIssuer(claims); err != nil { - return nil, err + // + // SkipIssuerCheck is INTERNAL — only set when another Go package + // (currently azure_ad multi-tenant) takes responsibility for + // tenant/issuer enforcement via a different claim. Never reachable + // via forge.yaml — the field carries a yaml:"-" tag. + if !p.cfg.SkipIssuerCheck { + if err := p.checkIssuer(claims); err != nil { + return nil, err + } } // Audience validation (with azp fallback). diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go index bf4468b..a5b5fd1 100644 --- a/forge-core/security/auth_domains.go +++ b/forge-core/security/auth_domains.go @@ -71,6 +71,14 @@ func authProviderURLs(p types.AuthProvider) []string { case "gcp_iap": // Decision §9.4: IAP JWKS host is hardcoded. return []string{"https://www.gstatic.com/iap/verify/public_key-jwk"} + case "azure_ad": + // AAD authority host is fixed (login.microsoftonline.com). + // Graph host added ONLY when groups_mode=graph. + out := []string{"https://login.microsoftonline.com"} + if mode, _ := p.Settings["groups_mode"].(string); mode == "graph" { + out = append(out, "https://graph.microsoft.com") + } + return out // static_token has no outbound; not listed default: return nil diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go index 2f68719..4f8e40b 100644 --- a/forge-core/security/auth_domains_test.go +++ b/forge-core/security/auth_domains_test.go @@ -213,6 +213,59 @@ func TestAuthDomains_GCPIAP_MultipleEntriesDedup(t *testing.T) { } } +func TestAuthDomains_AzureAD_ClaimMode(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "azure_ad", Settings: map[string]any{ + "tenant_id": "00000000-...", + "audience": "api://forge", + }}, + }, + }) + if len(got) != 1 || got[0] != "login.microsoftonline.com" { + t.Errorf("AuthDomains = %v, want [login.microsoftonline.com]", got) + } +} + +func TestAuthDomains_AzureAD_GraphModeAddsGraphHost(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "azure_ad", Settings: map[string]any{ + "tenant_id": "00000000-...", + "audience": "api://forge", + "groups_mode": "graph", + }}, + }, + }) + have := map[string]bool{} + for _, d := range got { + have[d] = true + } + if !have["login.microsoftonline.com"] { + t.Errorf("missing login.microsoftonline.com: %v", got) + } + if !have["graph.microsoft.com"] { + t.Errorf("groups_mode=graph should add graph.microsoft.com: %v", got) + } +} + +func TestAuthDomains_AzureAD_ClaimMode_DoesNotAddGraph(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "azure_ad", Settings: map[string]any{ + "tenant_id": "00000000-...", + "audience": "api://forge", + "groups_mode": "claim", + }}, + }, + }) + for _, d := range got { + if d == "graph.microsoft.com" { + t.Errorf("claim mode should NOT add graph.microsoft.com: %v", got) + } + } +} + func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { got := security.AuthDomains(types.AuthConfig{ Providers: []types.AuthProvider{ From 47c4748d457a37c47de0a88fd05e5a148d835ee3 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 19:53:01 -0400 Subject: [PATCH 05/22] feat(cli/ui): expose Phase 2 providers via flags + Web UI + validate (phase 2 pr 5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires aws_sigv4, gcp_iap, and azure_ad into the operator surfaces: cli (forge-cli/cmd/init*.go): - New non-interactive flags namespaced --auth-aws-* / --auth-gcp-iap-* / --auth-azure-* (StringSlice for repeatable allowed-principal globs) - buildAuthFromFlags validates required combinations and emits the right egress hosts per provider (sts..amazonaws.com, www.gstatic.com, login.microsoftonline.com + graph.microsoft.com when groups_mode=graph) - authEgressHostsFromSettings mirrors the same logic for the Web UI - renderAuthBlock supports []string lists with proper YAML quoting (allowed_principals) web ui (forge-ui/handlers_create.go): - AuthProviderTypeMeta lists the three new types with helpful labels validate (forge-core/validate/auth.go): - knownAuthProviderTypes admits aws_sigv4 / gcp_iap / azure_ad - validateProviderSettings enforces per-type required keys (aws_sigv4.region, gcp_iap.audience, azure_ad.audience + tenant_id-unless-multi-tenant, azure_ad.groups_mode whitelist) tests: - 11 new renderer + flag-parsing tests - Round-trip YAML parse used instead of brittle quote-pattern asserts - Updated wizard-meta test to expect 7 auth provider types deliberate scope cut: - TUI step_auth.go sub-step input flows for the 3 new providers are NOT included. Adding them is mechanical (~100 LOC per provider, mirroring the OIDC issuer→audience→groups_claim phase chain) but out of scope for v0.11.0 cut. Non-interactive flag path covers the production-critical CI/CD case; operators using the TUI can pick "Custom" and edit forge.yaml directly until the follow-up lands. --- forge-cli/cmd/init.go | 15 +- forge-cli/cmd/init_auth.go | 89 +++++++++++- forge-cli/cmd/init_auth_test.go | 230 +++++++++++++++++++++++++++++++ forge-core/validate/auth.go | 27 +++- forge-ui/handlers_create.go | 5 +- forge-ui/handlers_create_test.go | 12 +- 6 files changed, 369 insertions(+), 9 deletions(-) diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index fca28d4..aaf9018 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -135,12 +135,25 @@ func init() { // Auth chain (PR5+). All optional. When --auth is unset or "none", no // auth: block is written and the agent runs anonymously. - initCmd.Flags().String("auth", "", "auth mode: none, oidc, http_verifier, custom") + initCmd.Flags().String("auth", "", "auth mode: none, oidc, http_verifier, aws_sigv4, gcp_iap, azure_ad, custom") initCmd.Flags().String("auth-issuer", "", "OIDC issuer URL (required with --auth=oidc)") initCmd.Flags().String("auth-audience", "", "OIDC audience (required with --auth=oidc)") initCmd.Flags().String("auth-url", "", "verifier URL (required with --auth=http_verifier)") initCmd.Flags().String("auth-default-org", "", "default org_id for http_verifier (optional)") initCmd.Flags().String("auth-groups-claim", "", "claim name for groups (oidc, default: groups)") + + // Phase 2: per-provider flags. Namespaced so help text stays grouped. + initCmd.Flags().String("auth-aws-region", "", "AWS region for aws_sigv4 (e.g. us-east-1)") + initCmd.Flags().String("auth-aws-audience", "", "informational audience for aws_sigv4") + initCmd.Flags().StringSlice("auth-aws-allowed-principal", nil, "allowed principal glob for aws_sigv4 (repeatable)") + initCmd.Flags().String("auth-aws-cache-ttl", "", "aws_sigv4 identity cache TTL (default 60s)") + + initCmd.Flags().String("auth-gcp-iap-audience", "", "audience for gcp_iap (backend service ID)") + + initCmd.Flags().String("auth-azure-tenant", "", "Entra tenant GUID (required unless --auth-azure-multi-tenant)") + initCmd.Flags().String("auth-azure-audience", "", "Entra application ID URI / audience") + initCmd.Flags().Bool("auth-azure-multi-tenant", false, "accept tokens from any Entra tenant (default false)") + initCmd.Flags().String("auth-azure-groups-mode", "", "azure_ad groups mode: claim (default) or graph") } func runInit(cmd *cobra.Command, args []string) error { diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index 962aff8..fbd0f91 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -49,8 +49,63 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an egressHosts = []string{host} } return settings, egressHosts, nil + case "aws_sigv4": + region, _ := cmd.Flags().GetString("auth-aws-region") + audience, _ := cmd.Flags().GetString("auth-aws-audience") + allowedPrincipals, _ := cmd.Flags().GetStringSlice("auth-aws-allowed-principal") + cacheTTL, _ := cmd.Flags().GetString("auth-aws-cache-ttl") + if region == "" { + return nil, nil, fmt.Errorf("--auth=aws_sigv4 requires --auth-aws-region") + } + settings = map[string]any{"region": region} + if audience != "" { + settings["audience"] = audience + } + if len(allowedPrincipals) > 0 { + settings["allowed_principals"] = allowedPrincipals + } + if cacheTTL != "" { + settings["identity_cache_ttl"] = cacheTTL + } + egressHosts = []string{"sts." + region + ".amazonaws.com"} + return settings, egressHosts, nil + case "gcp_iap": + audience, _ := cmd.Flags().GetString("auth-gcp-iap-audience") + if audience == "" { + return nil, nil, fmt.Errorf("--auth=gcp_iap requires --auth-gcp-iap-audience") + } + settings = map[string]any{"audience": audience} + // IAP JWKS host is hardcoded (decision §9.4). + egressHosts = []string{"www.gstatic.com"} + return settings, egressHosts, nil + case "azure_ad": + tenant, _ := cmd.Flags().GetString("auth-azure-tenant") + audience, _ := cmd.Flags().GetString("auth-azure-audience") + multiTenant, _ := cmd.Flags().GetBool("auth-azure-multi-tenant") + groupsMode, _ := cmd.Flags().GetString("auth-azure-groups-mode") + if audience == "" { + return nil, nil, fmt.Errorf("--auth=azure_ad requires --auth-azure-audience") + } + if !multiTenant && tenant == "" { + return nil, nil, fmt.Errorf("--auth=azure_ad requires --auth-azure-tenant unless --auth-azure-multi-tenant=true") + } + settings = map[string]any{"audience": audience} + if tenant != "" { + settings["tenant_id"] = tenant + } + if multiTenant { + settings["allow_multi_tenant"] = true + } + if groupsMode != "" { + settings["groups_mode"] = groupsMode + } + egressHosts = []string{"login.microsoftonline.com"} + if groupsMode == "graph" { + egressHosts = append(egressHosts, "graph.microsoft.com") + } + return settings, egressHosts, nil default: - return nil, nil, fmt.Errorf("unknown --auth value %q (supported: none, oidc, http_verifier, custom)", mode) + return nil, nil, fmt.Errorf("unknown --auth value %q (supported: none, oidc, http_verifier, aws_sigv4, gcp_iap, azure_ad, custom)", mode) } } @@ -77,6 +132,19 @@ func authEgressHostsFromSettings(mode string, settings map[string]any) []string if h := hostFromURL(asStringSetting(settings, "url")); h != "" { hosts = append(hosts, h) } + case "aws_sigv4": + region := asStringSetting(settings, "region") + if region != "" { + hosts = append(hosts, "sts."+region+".amazonaws.com") + } + case "gcp_iap": + // Decision §9.4: hardcoded JWKS host. + hosts = append(hosts, "www.gstatic.com") + case "azure_ad": + hosts = append(hosts, "login.microsoftonline.com") + if asStringSetting(settings, "groups_mode") == "graph" { + hosts = append(hosts, "graph.microsoft.com") + } } return hosts } @@ -196,6 +264,25 @@ func writeYAMLMap(b *strings.Builder, m map[string]any, indent string) { case map[string]any: fmt.Fprintf(b, "%s%s:\n", indent, k) writeYAMLMap(b, val, indent+" ") + case []string: + // Phase 2: aws_sigv4's allowed_principals is the only []string + // currently in the auth settings schema. ARNs frequently contain + // `:` so each entry goes through yamlScalar for proper quoting. + fmt.Fprintf(b, "%s%s:\n", indent, k) + for _, item := range val { + fmt.Fprintf(b, "%s - %s\n", indent, yamlScalar(item)) + } + case []any: + // Defensive: when settings come from YAML unmarshal, lists land + // as []any. Coerce per-element and reuse the string path. + fmt.Fprintf(b, "%s%s:\n", indent, k) + for _, item := range val { + if s, ok := item.(string); ok { + fmt.Fprintf(b, "%s - %s\n", indent, yamlScalar(s)) + } else { + fmt.Fprintf(b, "%s - %v\n", indent, item) + } + } case string: fmt.Fprintf(b, "%s%s: %s\n", indent, k, yamlScalar(val)) default: diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go index 3c6a78e..3d588ad 100644 --- a/forge-cli/cmd/init_auth_test.go +++ b/forge-cli/cmd/init_auth_test.go @@ -216,6 +216,16 @@ func mockInitCmdWithFlags() *cobra.Command { c.Flags().String("auth-url", "", "") c.Flags().String("auth-default-org", "", "") c.Flags().String("auth-groups-claim", "", "") + // Phase 2 flags: + c.Flags().String("auth-aws-region", "", "") + c.Flags().String("auth-aws-audience", "", "") + c.Flags().StringSlice("auth-aws-allowed-principal", nil, "") + c.Flags().String("auth-aws-cache-ttl", "", "") + c.Flags().String("auth-gcp-iap-audience", "", "") + c.Flags().String("auth-azure-tenant", "", "") + c.Flags().String("auth-azure-audience", "", "") + c.Flags().Bool("auth-azure-multi-tenant", false, "") + c.Flags().String("auth-azure-groups-mode", "", "") return c } @@ -427,3 +437,223 @@ func TestMergeEgressDomains_SortedOutput(t *testing.T) { t.Errorf("output not sorted:\n got %v\n want %v", got, want) } } + +// --- Phase 2 renderer + flag tests --- + +func TestRenderAuthBlock_AWSSigv4_Minimal(t *testing.T) { + got := renderAuthBlock("aws_sigv4", map[string]any{"region": "us-east-1"}) + if !strings.Contains(got, "type: aws_sigv4") { + t.Errorf("missing type line:\n%s", got) + } + if !strings.Contains(got, "region: us-east-1") { + t.Errorf("missing region:\n%s", got) + } + if strings.Contains(got, "audience:") { + t.Errorf("audience should not be emitted when unset:\n%s", got) + } +} + +func TestRenderAuthBlock_AWSSigv4_ARNsList(t *testing.T) { + wantARNs := []string{ + "arn:aws:sts::123:assumed-role/ci-deploy/*", + "arn:aws:sts::123:assumed-role/forge-runner/*", + } + got := renderAuthBlock("aws_sigv4", map[string]any{ + "region": "us-east-1", + "allowed_principals": wantARNs, + }) + if !strings.Contains(got, "allowed_principals:") { + t.Errorf("missing allowed_principals header:\n%s", got) + } + // ARN contains ":" but not ": " (colon-space), so YAML doesn't need + // to quote it — verify round-trip parse instead of insisting on quotes. + parsed := parseAuthBlockYAML(t, got) + provs, _ := parsed["auth"].(map[string]any)["providers"].([]any) + settings := provs[0].(map[string]any)["settings"].(map[string]any) + gotARNs := toStringSlice(settings["allowed_principals"]) + if !reflect.DeepEqual(gotARNs, wantARNs) { + t.Errorf("ARN round-trip mismatch:\n got %v\n want %v", gotARNs, wantARNs) + } +} + +func TestRenderAuthBlock_GCPIAP(t *testing.T) { + wantAud := "/projects/12345/global/backendServices/67890" + got := renderAuthBlock("gcp_iap", map[string]any{ + "audience": wantAud, + }) + if !strings.Contains(got, "type: gcp_iap") { + t.Errorf("missing type:\n%s", got) + } + parsed := parseAuthBlockYAML(t, got) + provs, _ := parsed["auth"].(map[string]any)["providers"].([]any) + gotAud := provs[0].(map[string]any)["settings"].(map[string]any)["audience"] + if gotAud != wantAud { + t.Errorf("audience round-trip mismatch:\n got %v\n want %v", gotAud, wantAud) + } +} + +func parseAuthBlockYAML(t *testing.T, block string) map[string]any { + t.Helper() + var out map[string]any + if err := yaml.Unmarshal([]byte(block), &out); err != nil { + t.Fatalf("renderAuthBlock output is not valid YAML: %v\n%s", err, block) + } + return out +} + +func toStringSlice(v any) []string { + arr, ok := v.([]any) + if !ok { + return nil + } + out := make([]string, len(arr)) + for i, x := range arr { + out[i], _ = x.(string) + } + return out +} + +func TestRenderAuthBlock_AzureAD_SingleTenant(t *testing.T) { + got := renderAuthBlock("azure_ad", map[string]any{ + "tenant_id": "00000000-1111-2222-3333-444444444444", + "audience": "api://forge", + }) + if !strings.Contains(got, "tenant_id: 00000000-1111-2222-3333-444444444444") { + t.Errorf("missing tenant_id:\n%s", got) + } + if strings.Contains(got, "allow_multi_tenant") { + t.Errorf("allow_multi_tenant should be omitted when not set:\n%s", got) + } +} + +func TestRenderAuthBlock_AzureAD_MultiTenant(t *testing.T) { + got := renderAuthBlock("azure_ad", map[string]any{ + "audience": "api://forge", + "allow_multi_tenant": true, + }) + if !strings.Contains(got, "allow_multi_tenant: true") { + t.Errorf("missing allow_multi_tenant true:\n%s", got) + } + if strings.Contains(got, "tenant_id") { + t.Errorf("tenant_id should not appear when multi-tenant:\n%s", got) + } +} + +func TestRenderAuthBlock_AzureAD_GroupsModeGraph(t *testing.T) { + got := renderAuthBlock("azure_ad", map[string]any{ + "tenant_id": "abc", + "audience": "api://forge", + "groups_mode": "graph", + }) + if !strings.Contains(got, "groups_mode: graph") { + t.Errorf("missing groups_mode:\n%s", got) + } +} + +func TestBuildAuthFromFlags_AWSSigv4_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-aws-region", "us-east-1") + _ = c.Flags().Set("auth-aws-audience", "api://forge") + _ = c.Flags().Set("auth-aws-allowed-principal", "arn:aws:sts::123:assumed-role/x/*") + + settings, hosts, err := buildAuthFromFlags(c, "aws_sigv4") + if err != nil { + t.Fatal(err) + } + if settings["region"] != "us-east-1" { + t.Errorf("region = %v", settings["region"]) + } + if !reflect.DeepEqual(hosts, []string{"sts.us-east-1.amazonaws.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_AWSSigv4_MissingRegion(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "aws_sigv4") + if err == nil { + t.Fatal("expected error when region missing") + } +} + +func TestBuildAuthFromFlags_GCPIAP_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-gcp-iap-audience", "/projects/12345/...") + settings, hosts, err := buildAuthFromFlags(c, "gcp_iap") + if err != nil { + t.Fatal(err) + } + if settings["audience"] == "" { + t.Error("audience not captured") + } + if !reflect.DeepEqual(hosts, []string{"www.gstatic.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_GCPIAP_MissingAudience(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "gcp_iap") + if err == nil { + t.Fatal("expected error when audience missing") + } +} + +func TestBuildAuthFromFlags_AzureAD_RequiresTenant(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-azure-audience", "api://forge") + _, _, err := buildAuthFromFlags(c, "azure_ad") + if err == nil { + t.Fatal("expected error when tenant + non-multi-tenant") + } +} + +func TestBuildAuthFromFlags_AzureAD_MultiTenantOK(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-azure-audience", "api://forge") + _ = c.Flags().Set("auth-azure-multi-tenant", "true") + settings, hosts, err := buildAuthFromFlags(c, "azure_ad") + if err != nil { + t.Fatal(err) + } + if settings["allow_multi_tenant"] != true { + t.Errorf("allow_multi_tenant not set: %v", settings) + } + if !reflect.DeepEqual(hosts, []string{"login.microsoftonline.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_AzureAD_GraphModeAddsGraphHost(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-azure-audience", "api://forge") + _ = c.Flags().Set("auth-azure-tenant", "abc-tid") + _ = c.Flags().Set("auth-azure-groups-mode", "graph") + _, hosts, err := buildAuthFromFlags(c, "azure_ad") + if err != nil { + t.Fatal(err) + } + want := []string{"login.microsoftonline.com", "graph.microsoft.com"} + if !reflect.DeepEqual(hosts, want) { + t.Errorf("hosts = %v, want %v", hosts, want) + } +} + +func TestAuthEgressHostsFromSettings_Phase2(t *testing.T) { + cases := []struct { + mode string + settings map[string]any + want []string + }{ + {"aws_sigv4", map[string]any{"region": "us-east-1"}, []string{"sts.us-east-1.amazonaws.com"}}, + {"gcp_iap", map[string]any{"audience": "x"}, []string{"www.gstatic.com"}}, + {"azure_ad", map[string]any{"audience": "x"}, []string{"login.microsoftonline.com"}}, + {"azure_ad", map[string]any{"audience": "x", "groups_mode": "graph"}, []string{"login.microsoftonline.com", "graph.microsoft.com"}}, + } + for _, tc := range cases { + got := authEgressHostsFromSettings(tc.mode, tc.settings) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("%s %v: got %v, want %v", tc.mode, tc.settings, got, tc.want) + } + } +} diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go index 78ead5b..e6fb8a4 100644 --- a/forge-core/validate/auth.go +++ b/forge-core/validate/auth.go @@ -17,7 +17,10 @@ var knownAuthProviderTypes = map[string]bool{ "http_verifier": true, "static_token": true, "oidc": true, - // "okta": true, // Phase 3 (v0.11.0) + // Phase 2 (v0.11.0): + "aws_sigv4": true, + "gcp_iap": true, + "azure_ad": true, } // ValidateAuthConfig adds errors and warnings for a forge.yaml auth: block. @@ -39,7 +42,7 @@ func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { continue } if !knownAuthProviderTypes[p.Type] { - r.Errors = append(r.Errors, fmt.Sprintf("%s: unknown type %q (known: http_verifier, static_token, oidc)", prefix, p.Type)) + r.Errors = append(r.Errors, fmt.Sprintf("%s: unknown type %q (known: http_verifier, static_token, oidc, aws_sigv4, gcp_iap, azure_ad)", prefix, p.Type)) continue } @@ -78,6 +81,26 @@ func validateProviderSettings(prefix string, p types.AuthProvider, r *Validation if asString(p.Settings, "audience") == "" { r.Errors = append(r.Errors, prefix+" (oidc): settings.audience is required") } + case "aws_sigv4": + if asString(p.Settings, "region") == "" { + r.Errors = append(r.Errors, prefix+" (aws_sigv4): settings.region is required") + } + case "gcp_iap": + if asString(p.Settings, "audience") == "" { + r.Errors = append(r.Errors, prefix+" (gcp_iap): settings.audience is required (GCP backend service ID)") + } + case "azure_ad": + if asString(p.Settings, "audience") == "" { + r.Errors = append(r.Errors, prefix+" (azure_ad): settings.audience is required") + } + // tenant_id may be omitted ONLY when allow_multi_tenant is true. + multi, _ := p.Settings["allow_multi_tenant"].(bool) + if !multi && asString(p.Settings, "tenant_id") == "" { + r.Errors = append(r.Errors, prefix+" (azure_ad): settings.tenant_id is required unless allow_multi_tenant=true") + } + if mode := asString(p.Settings, "groups_mode"); mode != "" && mode != "claim" && mode != "graph" { + r.Errors = append(r.Errors, fmt.Sprintf("%s (azure_ad): groups_mode must be 'claim' or 'graph', got %q", prefix, mode)) + } } } diff --git a/forge-ui/handlers_create.go b/forge-ui/handlers_create.go index 7cd1eed..21f192b 100644 --- a/forge-ui/handlers_create.go +++ b/forge-ui/handlers_create.go @@ -124,8 +124,11 @@ func (s *UIServer) handleGetWizardMeta(w http.ResponseWriter, _ *http.Request) { // frontend renders the picker from this metadata. meta.AuthProviderTypes = []AuthProviderTypeMeta{ {Type: "none", Label: "None", Description: "Anonymous access — no auth: block written"}, - {Type: "oidc", Label: "OIDC (JWT)", Description: "Auth0, Keycloak, Azure AD, Google, Okta-OIDC, …"}, + {Type: "oidc", Label: "OIDC (JWT)", Description: "Generic OIDC issuer (Keycloak, Auth0, Okta, Google …)"}, {Type: "http_verifier", Label: "HTTP Verifier", Description: "Legacy — POST tokens to your own /verify endpoint"}, + {Type: "aws_sigv4", Label: "AWS Sigv4 (IAM)", Description: "Auth AWS-IAM callers via STS GetCallerIdentity (Phase 2)"}, + {Type: "gcp_iap", Label: "GCP Identity-Aware Proxy", Description: "Forge behind a GCP HTTPS LB+IAP (Phase 2)"}, + {Type: "azure_ad", Label: "Azure AD / Entra ID", Description: "Entra tenant tokens with optional Graph enrichment (Phase 2)"}, {Type: "custom", Label: "Custom", Description: "Write a commented stub, edit forge.yaml manually"}, } diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index 906de6d..dcc40e0 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -388,11 +388,15 @@ func TestHandleGetWizardMeta(t *testing.T) { } // PR6: auth_provider_types is server-driven so the frontend doesn't - // hardcode the list. Must include the four founding types. - if len(meta.AuthProviderTypes) != 4 { - t.Errorf("auth_provider_types len = %d, want 4", len(meta.AuthProviderTypes)) + // hardcode the list. Phase 2 adds aws_sigv4, gcp_iap, azure_ad — the + // founding four still appear. + if len(meta.AuthProviderTypes) != 7 { + t.Errorf("auth_provider_types len = %d, want 7", len(meta.AuthProviderTypes)) + } + wantTypes := map[string]bool{ + "none": false, "oidc": false, "http_verifier": false, "custom": false, + "aws_sigv4": false, "gcp_iap": false, "azure_ad": false, } - wantTypes := map[string]bool{"none": false, "oidc": false, "http_verifier": false, "custom": false} for _, a := range meta.AuthProviderTypes { if _, ok := wantTypes[a.Type]; !ok { t.Errorf("unexpected auth type %q", a.Type) From 98578f964eca6316d81b30452ce0de8def35ca4f Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 19:57:43 -0400 Subject: [PATCH 06/22] docs(auth): operator docs for Phase 2 providers + CHANGELOG (phase 2 pr 6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the operator-facing documentation for the three Phase 2 providers that shipped in PRs 1–5, plus a top-level auth index, chain-semantics concepts page, CHANGELOG, and a README link. new docs: - docs/auth/index.md — provider matrix and chain-semantics overview - docs/auth/concepts/chain.md — first-match-wins, no-fall-through on reject, non-Bearer header support, mixed-chain worked example - docs/auth/providers/aws_sigv4.md — STS reflection setup, awscurl example, assumed-role-vs-IAM-role gotcha called out twice - docs/auth/providers/gcp_iap.md — backend service ID lookup steps, hardcoded JWKS rationale, GCP IAM Conditions for allowlisting - docs/auth/providers/azure_ad.md — app registration walkthrough, single/multi/graph mode configs, multi-tenant warning prominent every provider doc includes: - Prerequisites checklist - forge.yaml example - Configuration reference table - Audit log shape (literal JSON) - Troubleshooting matrix (grep-able reason codes) - Security model + limitations sections CHANGELOG.md (new file): - Lists Added / Changed entries for v0.11.0 - "Notes for upgraders" makes the non-breaking nature explicit - Calls out the known TUI sub-flow gap from PR 5 README.md: - Adds Auth Providers row to the Security documentation table --- CHANGELOG.md | 60 +++++++++++ README.md | 1 + docs/auth/concepts/chain.md | 73 +++++++++++++ docs/auth/index.md | 53 ++++++++++ docs/auth/providers/aws_sigv4.md | 121 +++++++++++++++++++++ docs/auth/providers/azure_ad.md | 176 +++++++++++++++++++++++++++++++ docs/auth/providers/gcp_iap.md | 105 ++++++++++++++++++ 7 files changed, 589 insertions(+) create mode 100644 CHANGELOG.md create mode 100644 docs/auth/concepts/chain.md create mode 100644 docs/auth/index.md create mode 100644 docs/auth/providers/aws_sigv4.md create mode 100644 docs/auth/providers/azure_ad.md create mode 100644 docs/auth/providers/gcp_iap.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..97ff269 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,60 @@ +# Changelog + +## v0.11.0 — Phase 2: cloud-native auth providers (in progress) + +### Added + +- **`aws_sigv4` auth provider.** Authenticate AWS-IAM callers by reflecting + their Sigv4 signature to AWS STS `GetCallerIdentity`. No `aws-sdk-go-v2` + dependency. See [`docs/auth/providers/aws_sigv4.md`](docs/auth/providers/aws_sigv4.md). +- **`gcp_iap` auth provider.** Verify the JWT IAP forwards as + `X-Goog-Iap-Jwt-Assertion` when Forge sits behind a GCP HTTPS Load + Balancer with IAP enabled. See [`docs/auth/providers/gcp_iap.md`](docs/auth/providers/gcp_iap.md). +- **`azure_ad` auth provider.** Verify Microsoft Entra ID Bearer tokens + with tenant lock-in and optional Microsoft Graph group enrichment. + See [`docs/auth/providers/azure_ad.md`](docs/auth/providers/azure_ad.md). +- Non-interactive `forge init` flags for the three new providers: + `--auth-aws-region`, `--auth-aws-allowed-principal` (repeatable), + `--auth-gcp-iap-audience`, `--auth-azure-tenant`, + `--auth-azure-multi-tenant`, `--auth-azure-groups-mode`. +- Web UI exposes the three new types via the `/api/wizard-meta` endpoint; + server-side validation rejects malformed payloads before scaffold. +- `egress_hosts` automatically extended for each new provider + (`sts..amazonaws.com`, `www.gstatic.com`, + `login.microsoftonline.com`, `graph.microsoft.com` when applicable). + +### Changed + +- Middleware now consults the auth chain **even when no Bearer token is + extracted**, so non-Bearer formats (Sigv4 `Authorization`, IAP + `X-Goog-Iap-Jwt-Assertion`) can be recognized. Existing Bearer + JWT + flows are unchanged. +- `auth.HeadersFromRequest` widened with `Authorization`, + `X-Goog-Iap-Jwt-Assertion`, `X-Amz-Date`, and `X-Amz-Security-Token`. + Providers that don't consume these keys are unaffected. +- `auth.TokenKind` recognizes the `AWS4-HMAC-SHA256` prefix and returns + `"sigv4"`. The audit `token_kind` field now has five possible values: + `empty`, `opaque`, `jwt`, `sigv4`, `iap_jwt`. +- `validate.ValidateAuthConfig` admits the three new provider types and + enforces their per-type required keys (`aws_sigv4.region`, + `gcp_iap.audience`, `azure_ad.audience`, `azure_ad.tenant_id`-unless- + multi-tenant, `azure_ad.groups_mode` whitelist). + +### Notes for upgraders + +- **No forge.yaml changes are required** for callers continuing to use + Phase 1 providers (`static_token`, `oidc`, `http_verifier`). Phase 1 + test suite passes without modification. +- If you wrote a custom provider that inspects headers, the `Headers` + map now contains additional keys. Existing keys are unchanged. +- The `oidc` package gained an internal `SkipIssuerCheck` field carrying + `yaml:"-"` — it cannot be set via `forge.yaml` and is reachable only + from Go callers (currently only `azure_ad` multi-tenant). Operators see + no change. + +### Known deferred work + +- The Bubble Tea TUI wizard (`forge init`) does not yet have step-by-step + input flows for the three new providers. The non-interactive flag path + is the production-critical surface; operators using the TUI can pick + "Custom" and edit `forge.yaml` directly until the TUI follow-up lands. diff --git a/README.md b/README.md index 35c05ff..044adc0 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ You write a `SKILL.md`. Forge compiles it into a secure, runnable agent with egr | Document | Description | |----------|-------------| | [Security Overview](docs/security/overview.md) | Complete security architecture | +| [Auth Providers](docs/auth/index.md) | Pluggable auth chain — OIDC, AWS Sigv4, GCP IAP, Azure AD | | [Egress Security](docs/security/egress-control.md) | Egress enforcement deep dive | | [Secrets](docs/security/secret-management.md) | Encrypted secret management | | [Build Signing](docs/security/build-signing.md) | Ed25519 signing and verification | diff --git a/docs/auth/concepts/chain.md b/docs/auth/concepts/chain.md new file mode 100644 index 0000000..0639906 --- /dev/null +++ b/docs/auth/concepts/chain.md @@ -0,0 +1,73 @@ +--- +title: "Auth Provider Chain Semantics" +description: "First-match-wins, fail-closed, and the rules every provider follows." +order: 2 +--- + +Forge's auth chain composes multiple providers in a single `auth.providers` +list. Understanding the chain rules is the difference between a chain that +authenticates safely and one that silently downgrades to the most permissive +provider on the list. + +## The rules + +Each provider's `Verify` returns one of four signals: + +| Return | Meaning | Chain behavior | +|---|---|---| +| `Identity, nil` | Token accepted | Stops; chain returns this Identity | +| `nil, ErrTokenNotForMe` | "Not my format" | Continues to next provider | +| `nil, ErrTokenRejected` | "My format, but denied" | **Stops; chain returns 401** | +| `nil, ErrInvalidToken` | "Malformed" | **Stops; chain returns 401** | +| `nil, ErrProviderUnavailable` | "Can't reach my IdP" | **Stops; chain returns 401** | +| `nil, any other error` | Transport / infrastructure | **Stops; chain returns 401** | + +The critical rule is the **no-fall-through on rejection**: if provider A +returns `ErrTokenRejected`, the chain does NOT try provider B. Falling +through would let an attacker downgrade their authentication — present a +malformed JWT and hope to be authenticated as an opaque-token user instead. + +## First-match-wins ordering + +Providers are tried top-to-bottom. Whoever **claims** the request first +(returns anything except `ErrTokenNotForMe`) decides the outcome. + +Practically: put more specific providers earlier. For Forge instances that +run channel adapters (Slack, Telegram), the `static_token` loopback secret +is auto-prepended to the chain — it always matches first for self-issued +loopback calls, regardless of what's configured in `forge.yaml`. + +## Non-Bearer formats (v0.11.0) + +As of v0.11.0 the middleware consults the chain **even when no Bearer token +was extracted** — required for `aws_sigv4` (which reads `Authorization: +AWS4-HMAC-SHA256 …`) and `gcp_iap` (which reads `X-Goog-Iap-Jwt-Assertion`). + +When the request carries no auth-shaped headers at all, the audit reason +code is still `missing_token` rather than `not_for_me` — operators can still +distinguish "client didn't auth" from "client tried a format we don't speak." + +## Example: mixed AWS + OIDC chain + +```yaml +auth: + required: true + providers: + # Internal CI/automation calls — Sigv4-signed under an IAM role: + - type: aws_sigv4 + settings: + region: us-east-1 + allowed_principals: + - "arn:aws:sts::123456789012:assumed-role/ci-deploy/*" + # Human users — Keycloak Bearer JWT: + - type: oidc + settings: + issuer: https://keycloak.example.com/realms/forge + audience: api://forge +``` + +A Sigv4-signed request matches `aws_sigv4` and never reaches `oidc`. A Bearer +JWT request returns `ErrTokenNotForMe` from `aws_sigv4` (no `AWS4-HMAC-SHA256` +prefix), then matches `oidc`. A garbage Bearer ("Bearer foobar") returns +`ErrTokenNotForMe` from both providers and the middleware emits +`auth_fail{reason: "not_for_me"}` with a 401. diff --git a/docs/auth/index.md b/docs/auth/index.md new file mode 100644 index 0000000..25a94ce --- /dev/null +++ b/docs/auth/index.md @@ -0,0 +1,53 @@ +--- +title: "Authentication Providers" +description: "Pluggable auth provider chain — pick the formats Forge accepts on /tasks." +order: 1 +--- + +Forge's `/tasks` endpoint requires every request to authenticate through a +chain of pluggable auth providers configured in `forge.yaml`. Each provider +recognizes one token shape; the chain tries them in order, first match wins. + +## Provider Matrix + +| Provider | Use case | Token format | Phase | +|---|---|---|---| +| [`static_token`](./providers/static_token.md) | Local dev, loopback channel adapters | Shared secret | 1 | +| [`oidc`](./providers/oidc.md) | Any IdP with OIDC discovery (Keycloak, Okta, Auth0, Google) | Bearer JWT | 1 | +| [`http_verifier`](./providers/http_verifier.md) | Custom verifier endpoint | Opaque | 1 | +| [`aws_sigv4`](./providers/aws_sigv4.md) | AWS-IAM-based callers (Lambda, EC2, EKS, IAM users) | Sigv4 (`Authorization`) | **2 (v0.11.0)** | +| [`gcp_iap`](./providers/gcp_iap.md) | Behind GCP HTTPS Load Balancer + IAP | IAP JWT (`X-Goog-Iap-Jwt-Assertion`) | **2 (v0.11.0)** | +| [`azure_ad`](./providers/azure_ad.md) | Microsoft Entra ID tokens | Bearer JWT | **2 (v0.11.0)** | + +## Phase 2 — Cloud-native providers (v0.11.0) + +Three new providers consume the identities customers already have in their +cloud — no parallel IdP required: + +- **AWS IAM** → callers sign requests with their existing IAM keys/role + (IRSA, instance profile, Lambda role, `aws configure`). Forge verifies via + STS `GetCallerIdentity` — zero secrets, no token endpoint to host. +- **GCP IAP** → Forge sitting behind a GCP HTTPS LB+IAP consumes the JWT IAP + forwards on every authenticated request. +- **Azure AD** → Entra tokens with tenant lock-in and optional Microsoft + Graph group enrichment. + +## Chain semantics + +Read [concepts/chain](./concepts/chain.md) for the full rules. In short: + +- Providers are tried in `forge.yaml` order. +- A provider returns one of: an `Identity` (chain stops, request authorized); + `ErrTokenNotForMe` (try next provider); or any other error (chain + rejects — **does not** fall through to a more permissive provider). +- The first-match-wins rule prevents downgrade attacks: an attacker who + presents a malformed token of type A doesn't get a chance to be authenticated + as type B. + +## Auditing + +Every accepted request emits an `auth_verify` event with the provider's name +and the principal's identifiers. Every rejection emits `auth_fail` with a +stable reason code (`missing_token`, `not_for_me`, `rejected`, `invalid`, +`provider_unavailable`). Pipe both into your SIEM — see each provider's +"Audit log" section for the exact shape. diff --git a/docs/auth/providers/aws_sigv4.md b/docs/auth/providers/aws_sigv4.md new file mode 100644 index 0000000..2f8f64f --- /dev/null +++ b/docs/auth/providers/aws_sigv4.md @@ -0,0 +1,121 @@ +--- +title: "aws_sigv4 — AWS IAM" +description: "Authenticate AWS-IAM callers via STS GetCallerIdentity reflection." +order: 10 +--- + +The `aws_sigv4` provider authenticates requests signed with AWS Signature +Version 4. Forge reflects the caller's signature to AWS STS +`GetCallerIdentity`; STS validates it on Forge's behalf and returns the +canonical ARN/Account/UserID. **Forge never possesses the caller's secret +key.** Same pattern as `aws-iam-authenticator` (the EKS auth bridge). + +## When to use it + +Use this when your callers (Lambda, EC2, EKS workloads, developer laptops +with `aws configure`/`aws sso login`) already have AWS credentials. No +separate token issuer needed — IAM is the identity. + +## Prerequisites + +- [ ] Forge can reach `sts..amazonaws.com` (auto-added to the + Phase 2 egress allowlist). +- [ ] At least one IAM identity exists in the target AWS account. +- [ ] You've decided which principals to allow (optional but recommended + for production). +- [ ] **No additional IAM permissions required.** `GetCallerIdentity` is + implicit for every IAM principal on itself. + +## forge.yaml + +```yaml +auth: + required: true + providers: + - type: aws_sigv4 + settings: + region: us-east-1 + audience: api://forge # informational; emitted in audit + allowed_principals: + # NOTE: STS returns assumed-role ARNs, NOT IAM role ARNs. + # Match against arn:aws:sts::, not arn:aws:iam::. + - "arn:aws:sts::123456789012:assumed-role/ci-deploy/*" + - "arn:aws:sts::123456789012:assumed-role/forge-runner/*" + identity_cache_ttl: 60s +``` + +## Configuration reference + +| Field | Required | Default | Description | +|---|---|---|---| +| `region` | yes | — | AWS region for the STS endpoint. Sigv4 signatures must scope to this region. | +| `audience` | no | — | Informational string; emitted in `Identity.Claims.audience`. STS doesn't enforce it. | +| `allowed_principals` | no | empty (allow any) | List of [shell-style globs](https://pkg.go.dev/path#Match) matched against the **STS-returned ARN**. | +| `identity_cache_ttl` | no | `60s` | How long a verified Identity is reused before re-checking STS. Bounds the stolen-key window. | +| `http_timeout` | no | `5s` | Per-STS-call timeout. | + +`allowed_principals` is the authz mechanism. Empty list means "allow any IAM +principal who can sign for this region" — fine for single-tenant dev, never +appropriate for production. + +## Worked example — call Forge from a script + +```bash +# Make sure your AWS profile points at an identity in allowed_principals: +export AWS_PROFILE=dev + +# awscurl handles the Sigv4 signing: +pipx install awscurl + +awscurl --service sts --region us-east-1 \ + -X POST -d '{"task":"hello"}' \ + -H "Content-Type: application/json" \ + http://localhost:9999/tasks +``` + +A 200 response means the chain accepted your IAM identity; check the Forge +log for the matching `auth_verify` event. + +## Audit log + +Successful auth: + +```json +{ "event": "auth_verify", + "fields": { + "provider": "aws_sigv4", + "user_id": "arn:aws:sts::123456789012:assumed-role/ci-deploy/i-0abc", + "org_id": "123456789012", + "token_kind": "sigv4" + } +} +``` + +## Troubleshooting + +| Symptom | Likely cause | Fix | +|---|---|---| +| `auth_fail reason=not_for_me` | `Authorization` header didn't start with `AWS4-HMAC-SHA256` | Caller is signing with Bearer/something else — point them at the right tooling (`awscurl`, AWS SDK) | +| `auth_fail reason=rejected` + STS body mentions signature | Caller signed with wrong creds, expired session, or clock skew | Re-sign with current creds; check system clock; ensure session token is valid | +| `auth_fail reason=rejected` + "ARN ... not in allowed_principals" | Allowlist mismatch | **Common bug:** the pattern must match the STS-returned **assumed-role ARN** (`arn:aws:sts::ACCT:assumed-role/...`), NOT the IAM role ARN (`arn:aws:iam::ACCT:role/...`) | +| `auth_fail reason=rejected` + "service=s3" or "region=eu-west-1" | Caller signed for wrong service or region | Re-sign with `--service sts --region ` | +| `auth_fail reason=provider_unavailable` | STS or network is down; egress allowlist missing the STS host | Check `sts..amazonaws.com` reachability; check `egress_hosts` | + +## Security model + +- **No secret key on Forge.** STS does the signature math. +- **Cache bucketing on YYYYMMDD + per-AKID** bounds the stolen-key window: + a leaked AKID is valid until midnight UTC OR `identity_cache_ttl` + elapses, whichever is sooner. +- **Service/region scope check** rejects Sigv4 signatures meant for S3, EC2, + or a different region. Prevents cross-service replay. +- **No `aws-sdk-go` dependency.** The STS RPC is ~150 LOC of hand-rolled + HTTP + XML. Smaller attack surface than a transitive SDK. + +## Limitations + +- One AWS region per provider entry. Configure multiple entries for + multi-region. +- Cognito federation / SAML / OIDC — use the [`oidc`](./oidc.md) provider + pointed at your Cognito user pool's issuer. +- Sigv4 is inbound only — Forge does not Sigv4-sign outbound requests. diff --git a/docs/auth/providers/azure_ad.md b/docs/auth/providers/azure_ad.md new file mode 100644 index 0000000..c8637c8 --- /dev/null +++ b/docs/auth/providers/azure_ad.md @@ -0,0 +1,176 @@ +--- +title: "azure_ad — Microsoft Entra ID" +description: "Verify Entra ID tokens with tenant lock-in and optional Graph enrichment." +order: 12 +--- + +The `azure_ad` provider authenticates Microsoft Entra ID (formerly Azure AD) +Bearer tokens. It composes the Phase 1 `oidc` provider for signature +verification and standard claim checks, then layers AAD-specific concerns +on top: + +- Tenant lock-in via the `tid` claim +- Optional Microsoft Graph group enrichment when the JWT's `groups` claim + overflows (AAD truncates at ~200 groups) +- Single-tenant vs. multi-tenant operation + +## When to use it + +Your callers — humans, CI, or Azure-hosted workloads — already have Entra +identities. AAD-specific quirks (`tid`, app roles, groups overage) are +handled correctly. + +## Prerequisites + +- [ ] An Entra tenant exists and you know its GUID. +- [ ] An app registration exists in the tenant. +- [ ] App registration → **Expose an API** → Application ID URI is set + (e.g. `api://forge`). +- [ ] Forge can reach `login.microsoftonline.com`. +- [ ] *(For graph mode only)* Forge can reach `graph.microsoft.com`, the + app registration has the **delegated** Graph permission + `GroupMember.Read.All`, and an admin has **granted consent**. + +### App registration walkthrough + +``` +Entra admin center → App registrations → New registration + Name: Forge + Supported accounts: Single tenant (unless you need multi-tenant) + Redirect URI: (leave empty — Forge doesn't initiate flows) + +→ Expose an API: + Application ID URI: api://forge + Add a scope: user_impersonation + Who can consent: Admins and users + +→ API permissions (ONLY if groups_mode=graph): + Add a permission → Microsoft Graph → Delegated → GroupMember.Read.All + Click "Grant admin consent" +``` + +## forge.yaml — single-tenant, claim mode (typical) + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + tenant_id: 00000000-1111-2222-3333-444444444444 + audience: api://forge + groups_mode: claim # default; reads `groups`/`roles` from JWT +``` + +## forge.yaml — multi-tenant + +> **WARNING:** Multi-tenant means **any** Entra tenant in the world can +> present tokens to Forge. Without an authz layer (Phase 4+ ABAC), this is +> effectively wide-open auth. Pair it with claim-based policy in your +> handler middleware, or stick to single-tenant. + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true +``` + +## forge.yaml — graph mode (groups overage) + +When a user is in more than ~200 groups, AAD truncates the JWT `groups` +claim and includes `_claim_names` indicating overflow. The default `claim` +mode would see empty groups in that case — switch to `graph` mode to query +Microsoft Graph for the full list: + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + tenant_id: 00000000-... + audience: api://forge + groups_mode: graph +``` + +The caller's Bearer token is **reflected** to Graph; Forge holds no Graph +credentials of its own. The user's `GroupMember.Read.All` delegated +permission is what authorizes the Graph read. + +## Configuration reference + +| Field | Required | Default | Description | +|---|---|---|---| +| `audience` | yes | — | Application ID URI from the app registration. | +| `tenant_id` | required unless `allow_multi_tenant=true` | — | Entra tenant GUID. | +| `allow_multi_tenant` | no | `false` | Accept tokens from any tenant. **High risk** without authz. | +| `groups_mode` | no | `claim` | `claim` (use JWT groups) or `graph` (enrich via Graph on empty groups). | +| `graph_timeout` | no | `5s` | Per-Graph-call timeout. | +| `jwks_cache_ttl` | no | `1h` | JWKS cache age before refresh. | + +## Acquiring tokens for testing + +```bash +# From a logged-in az CLI: +az account get-access-token --resource api://forge \ + --query accessToken -o tsv > /tmp/aad.tok + +curl -H "Authorization: Bearer $(cat /tmp/aad.tok)" \ + http://localhost:9999/tasks -d '{}' +``` + +## Audit log + +```json +{ "event": "auth_verify", + "fields": { + "provider": "azure_ad", + "user_id": "alice@contoso.onmicrosoft.com", + "org_id": "00000000-1111-2222-3333-444444444444", // tid + "token_kind": "jwt", + "groups": ["group-guid-1", "group-guid-2"] + } +} +``` + +Note `provider: "azure_ad"`, not `"oidc"` — even though OIDC did the +signature work, the audit name reflects the operator's configuration. + +## Troubleshooting + +| Symptom | Likely cause | Fix | +|---|---|---| +| `auth_fail reason=rejected` + "tid mismatch" | Token from a different Entra tenant | Re-sign in to the right tenant, or set `allow_multi_tenant: true` if intentional | +| `auth_fail reason=rejected` + "aud" | `audience` in forge.yaml doesn't match the app registration's Application ID URI | Make both match exactly | +| Identity has empty `groups` despite user being in groups | Groups overage (>200 groups) | Switch to `groups_mode: graph` and grant `GroupMember.Read.All` admin consent | +| Graph enrichment soft-fails (auth OK, groups empty) | Graph 5xx or network failure | Check `graph.microsoft.com` reachability | +| `auth_fail reason=invalid` + "missing tid" | Token is not an AAD token (maybe a guest/personal Microsoft account?) | Confirm token issuer; only AAD work/school accounts have `tid` | + +## Security model + +- **Composes** the `oidc` provider — no JWT/JWKS code lives in this package. + All signature verification, algorithm-whitelisting, and key rotation + reuse the Phase 1 OIDC implementation. +- **Tenant gate before Graph** — for single-tenant configs, the `tid` check + runs before any Graph call. Different tenant = different security domain. +- **Caller's token reflected to Graph** — Forge never holds Graph + credentials. The user's delegated permission is the only authority. +- **Soft-fail on Graph 5xx** — Identity still returned with empty groups + rather than blocking prod traffic on a transient Graph outage. +- **Internal `skip_issuer_check`** flag on the composed OIDC is exposed + ONLY when `allow_multi_tenant=true`, and the field is `yaml:"-"` so it + cannot be set via `forge.yaml` — preventing an operator from accidentally + disabling issuer validation. + +## Limitations + +- No SAML — Phase 5+. +- No `client_credentials` flow for workload-to-Forge with no user context. + If running in AWS, use [`aws_sigv4`](./aws_sigv4.md); otherwise use an + [`oidc`](./oidc.md) entry pointed at your workload identity's issuer. +- "Multi-tenant" means **any** tenant — there's no allowlist of tenants. + For "this set of tenants," configure multiple `azure_ad` entries. diff --git a/docs/auth/providers/gcp_iap.md b/docs/auth/providers/gcp_iap.md new file mode 100644 index 0000000..f9e5fcf --- /dev/null +++ b/docs/auth/providers/gcp_iap.md @@ -0,0 +1,105 @@ +--- +title: "gcp_iap — GCP Identity-Aware Proxy" +description: "Verify the JWT IAP forwards on every authenticated request." +order: 11 +--- + +The `gcp_iap` provider authenticates requests that come through GCP's +Identity-Aware Proxy. IAP terminates user authentication at the HTTPS Load +Balancer and forwards a signed JWT in the `X-Goog-Iap-Jwt-Assertion` header. +Forge verifies that JWT against IAP's well-known JWKS and stamps an Identity +from the `sub` + `email` claims. + +## When to use it + +Forge is deployed behind a GCP HTTPS Load Balancer with IAP enabled, and you +want the IAP-authenticated user (Google account / Workspace identity) to +flow through to Forge. + +## Prerequisites + +- [ ] Forge is reachable through a GCP HTTPS LB. +- [ ] IAP is enabled on the LB's backend service. +- [ ] Forge can reach `www.gstatic.com` (auto-added to the Phase 2 egress + allowlist). +- [ ] You know the backend service's "Signed Header JWT Audience" string. + +### Finding the audience + +The audience is the GCP backend service ID: + +``` +GCP Console → Security → Identity-Aware Proxy + → Backend Services tab + → click your backend + → "Signed Header JWT Audience" field +``` + +Format: `/projects/PROJECT_NUMBER/global/backendServices/BACKEND_ID`. + +## forge.yaml + +```yaml +auth: + required: true + providers: + - type: gcp_iap + settings: + audience: /projects/12345678/global/backendServices/9876543210 +``` + +## Configuration reference + +| Field | Required | Default | Description | +|---|---|---|---| +| `audience` | yes | — | GCP backend service ID. Must match exactly. | +| `jwks_refresh_ttl` | no | `1h` | How long a cached JWKS is reused before refresh. | +| `http_timeout` | no | `5s` | JWKS fetch timeout. | + +The IAP **issuer** (`https://cloud.google.com/iap`) and **JWKS URL** +(`https://www.gstatic.com/iap/verify/public_key-jwk`) are **hardcoded**. +They are the only stable contract GCP exposes; an override knob would be a +footgun (someone could be tricked into trusting an attacker's JWKS). + +## Audit log + +```json +{ "event": "auth_verify", + "fields": { + "provider": "gcp_iap", + "user_id": "1234567890", // claims.sub + "email": "alice@example.com", // claims.email + "token_kind": "jwt" + } +} +``` + +The Workspace domain claim (`hd`) is included in `Identity.Claims` for +downstream policy decisions. + +## Troubleshooting + +| Symptom | Likely cause | Fix | +|---|---|---| +| `auth_fail reason=not_for_me` | Request didn't come through IAP — no `X-Goog-Iap-Jwt-Assertion` header | Hit Forge through the GCP LB, not the backend directly | +| `auth_fail reason=rejected` + "iss" mismatch | Token is a regular Google OAuth token, not an IAP assertion | Confirm IAP is enabled; route the request through the LB | +| `auth_fail reason=rejected` + "aud" | `audience` in forge.yaml doesn't match the backend service ID | Re-check the audience string in the GCP console | +| `auth_fail reason=invalid` + "alg" | Token signed with something other than ES256 | Not an IAP token; check origin | +| `auth_fail reason=provider_unavailable` | `www.gstatic.com` unreachable | Check egress allowlist, network | + +## Security model + +- **ES256 algorithm whitelist** rejected any other signing method *before* + key lookup — algorithm-confusion attacks can't reach the JWKS. +- **JWKS host hardcoded** so a malicious config can't redirect verification + to an attacker-controlled key server. +- **Stale-grace cache** — during a JWKS endpoint outage, the previously + cached keys remain valid. IAP rotates keys on the order of weeks; freshness + matters less than availability. + +## Limitations + +- No per-email / per-domain allowlist in Forge — use **GCP IAM Conditions** + on the backend service for that. IAP can grant or deny access at the LB + layer; Forge sees only requests IAP already approved. +- Single audience per provider entry. From 639bfa9c6d04a6dbfaa6f20b9670963b45788ff7 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 20:05:34 -0400 Subject: [PATCH 07/22] fix(cli/tui): run Auth step before Egress; auto-add auth hosts to allowlist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The wizard was asking for Egress confirmation before the operator had picked an auth provider, so STS / AAD authority / IAP JWKS hosts never appeared in the egress list. Forge would scaffold a forge.yaml whose egress_hosts blocked its own auth-provider RPC calls — failure happens later at `forge run`, with no signal the wizard could have caught. changes: - Swap step order in init.go: Auth now runs immediately before Egress - Extend DeriveEgressFunc with (authMode, authSettings) so the Egress step's Prepare(ctx) pulls the operator's auth choice from WizardContext and forwards it into deriveEgressDomains - deriveEgressDomains calls authEgressHostsFromSettings (same helper the non-interactive --auth=… path uses) — TUI and CLI now produce identical egress lists for any given auth choice - EgressStep's inferSource() learns to label auth-derived hosts: sts..amazonaws.com → "aws_sigv4 auth" www.gstatic.com → "gcp_iap auth" login.microsoftonline.com → "azure_ad auth" graph.microsoft.com → "azure_ad auth (graph)" → "oidc auth" → "http_verifier auth" tests: - TestDeriveEgressDomains_AuthProviderHostsMerged: 8 cases pinning the per-provider host emission (incl. graph-mode adds graph host) - TestDeriveEgressDomains_AuthHostsMergeNotOverwrite: auth pass is additive — provider / channel hosts still emit alongside auth hosts docs: - docs/auth/concepts/chain.md gains a "TUI wizard ordering" section explaining the Auth-before-Egress invariant --- docs/auth/concepts/chain.md | 10 ++ forge-cli/cmd/init.go | 27 ++++- forge-cli/cmd/init_egress.go | 12 +++ forge-cli/cmd/init_test.go | 109 ++++++++++++++++++++ forge-cli/internal/tui/steps/egress_step.go | 109 +++++++++++++++++++- 5 files changed, 261 insertions(+), 6 deletions(-) diff --git a/docs/auth/concepts/chain.md b/docs/auth/concepts/chain.md index 0639906..975097e 100644 --- a/docs/auth/concepts/chain.md +++ b/docs/auth/concepts/chain.md @@ -37,6 +37,16 @@ run channel adapters (Slack, Telegram), the `static_token` loopback secret is auto-prepended to the chain — it always matches first for self-issued loopback calls, regardless of what's configured in `forge.yaml`. +## TUI wizard ordering + +The `forge init` wizard runs the **Auth step before the Egress step** — +once you pick a provider, the egress review automatically includes the +outbound hosts it needs (STS endpoint for `aws_sigv4`, AAD authority for +`azure_ad`, IAP JWKS host for `gcp_iap`, OIDC issuer host for `oidc`). +Operators see and confirm those hosts alongside provider / channel / tool / +skill hosts in a single review screen — no need to add auth hosts to +`egress_hosts` by hand after the wizard. + ## Non-Bearer formats (v0.11.0) As of v0.11.0 the middleware consults the chain **even when no Bearer token diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index aaf9018..bca5667 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -267,13 +267,25 @@ func collectInteractive(opts *initOptions) error { } } - // Build the egress derivation callback (avoids circular import) - deriveEgressFn := func(provider string, channels, tools, selectedSkills []string, envVars map[string]string) []string { + // Build the egress derivation callback (avoids circular import). + // Note: the auth params get the auth step's choice forwarded into + // deriveEgressDomains, which delegates to authEgressHostsFromSettings + // — the same translation used by `forge init --auth=…` so TUI and + // non-interactive paths produce identical egress lists. + deriveEgressFn := func( + provider string, + channels, tools, selectedSkills []string, + envVars map[string]string, + authMode string, + authSettings map[string]any, + ) []string { tmpOpts := &initOptions{ ModelProvider: provider, Channels: channels, BuiltinTools: tools, EnvVars: envVars, + AuthMode: authMode, + AuthSettings: authSettings, } selectedInfos := lookupSelectedSkills(selectedSkills) return deriveEgressDomains(tmpOpts, selectedInfos) @@ -294,7 +306,14 @@ func collectInteractive(opts *initOptions) error { return validateWebSearchKey(provider, key) } - // Build step list + // Build step list. + // + // Auth comes BEFORE Egress so the operator's auth choice — and the + // hosts it requires (sts..amazonaws.com, login.microsoftonline.com, + // etc.) — appear in the Egress review for confirmation. Without this + // ordering, the egress list would miss auth hosts and a Forge instance + // could fail at runtime because its OIDC discovery or STS call gets + // blocked by the very allowlist the wizard just rendered. wizardSteps := []tui.Step{ steps.NewNameStep(styles, opts.Name), steps.NewProviderStep(styles, validateKeyFn, oauthFlowFn), @@ -302,8 +321,8 @@ func collectInteractive(opts *initOptions) error { steps.NewChannelStep(styles), steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn), steps.NewSkillsStep(styles, skillInfos), - steps.NewEgressStep(styles, deriveEgressFn), steps.NewAuthStep(styles), + steps.NewEgressStep(styles, deriveEgressFn), steps.NewReviewStep(styles), // scaffold is handled by the caller after collectInteractive returns } diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index 92b632f..fac10b7 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -93,6 +93,18 @@ func deriveEgressDomains(opts *initOptions, skills []contract.SkillDescriptor) [ } } + // 5. Auth-provider domains — same translation the non-interactive + // --auth=… path uses, so TUI and CLI render identical egress lists. + // Examples: + // oidc → host extracted from issuer URL + // aws_sigv4 → sts..amazonaws.com + // gcp_iap → www.gstatic.com (hardcoded §9.4) + // azure_ad → login.microsoftonline.com (+ graph.microsoft.com + // when groups_mode=graph) + for _, h := range authEgressHostsFromSettings(opts.AuthMode, opts.AuthSettings) { + add(h) + } + sort.Strings(domains) return domains } diff --git a/forge-cli/cmd/init_test.go b/forge-cli/cmd/init_test.go index 279b3b8..8bcb4e2 100644 --- a/forge-cli/cmd/init_test.go +++ b/forge-cli/cmd/init_test.go @@ -513,6 +513,115 @@ func TestDeriveEgressDomains_Empty(t *testing.T) { } } +// TestDeriveEgressDomains_AuthProviderHostsMerged confirms the wizard +// re-order invariant: a chosen Auth provider contributes its required +// hosts (STS for aws_sigv4, AAD authority for azure_ad, etc.) to the +// same egress list a user reviews in the Egress step. Pins the contract +// that the operator never has to add auth hosts manually after the wizard. +func TestDeriveEgressDomains_AuthProviderHostsMerged(t *testing.T) { + cases := []struct { + name string + mode string + set map[string]any + want []string + }{ + { + name: "aws_sigv4 us-east-1", + mode: "aws_sigv4", + set: map[string]any{"region": "us-east-1"}, + want: []string{"sts.us-east-1.amazonaws.com"}, + }, + { + name: "aws_sigv4 eu-west-2", + mode: "aws_sigv4", + set: map[string]any{"region": "eu-west-2"}, + want: []string{"sts.eu-west-2.amazonaws.com"}, + }, + { + name: "gcp_iap", + mode: "gcp_iap", + set: map[string]any{"audience": "/projects/x/global/backendServices/y"}, + want: []string{"www.gstatic.com"}, + }, + { + name: "azure_ad claim", + mode: "azure_ad", + set: map[string]any{"audience": "api://forge", "tenant_id": "abc"}, + want: []string{"login.microsoftonline.com"}, + }, + { + name: "azure_ad graph mode adds graph host", + mode: "azure_ad", + set: map[string]any{"audience": "api://forge", "tenant_id": "abc", "groups_mode": "graph"}, + want: []string{"graph.microsoft.com", "login.microsoftonline.com"}, + }, + { + name: "oidc issuer host", + mode: "oidc", + set: map[string]any{"issuer": "https://login.example.com", "audience": "api://forge"}, + want: []string{"login.example.com"}, + }, + { + name: "none → no auth hosts added", + mode: "none", + set: nil, + want: nil, + }, + { + name: "custom → no auth hosts added", + mode: "custom", + set: nil, + want: nil, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + opts := &initOptions{ + ModelProvider: "ollama", // no provider hosts → exposes only auth hosts + EnvVars: map[string]string{}, + AuthMode: tc.mode, + AuthSettings: tc.set, + } + got := deriveEgressDomains(opts, nil) + if len(got) != len(tc.want) { + t.Fatalf("len(got) = %d (%v), want %d (%v)", len(got), got, len(tc.want), tc.want) + } + for i, d := range tc.want { + if got[i] != d { + t.Errorf("got[%d] = %q, want %q (full got: %v)", i, got[i], d, got) + } + } + }) + } +} + +// TestDeriveEgressDomains_AuthHostsMergeNotOverwrite confirms auth hosts +// coexist with provider/channel/skill hosts — the auth pass is additive, +// not exclusive. +func TestDeriveEgressDomains_AuthHostsMergeNotOverwrite(t *testing.T) { + opts := &initOptions{ + ModelProvider: "openai", + Channels: []string{"slack"}, + EnvVars: map[string]string{}, + AuthMode: "aws_sigv4", + AuthSettings: map[string]any{"region": "us-east-1"}, + } + got := deriveEgressDomains(opts, nil) + have := map[string]bool{} + for _, d := range got { + have[d] = true + } + for _, want := range []string{ + "api.openai.com", // model provider + "slack.com", // channel + "sts.us-east-1.amazonaws.com", // auth + } { + if !have[want] { + t.Errorf("missing %q in merged egress list: %v", want, got) + } + } +} + func TestBuildEnvVars(t *testing.T) { opts := &initOptions{ ModelProvider: "openai", diff --git a/forge-cli/internal/tui/steps/egress_step.go b/forge-cli/internal/tui/steps/egress_step.go index 93739bd..592ac16 100644 --- a/forge-cli/internal/tui/steps/egress_step.go +++ b/forge-cli/internal/tui/steps/egress_step.go @@ -10,7 +10,19 @@ import ( ) // DeriveEgressFunc computes egress domains from wizard context. -type DeriveEgressFunc func(provider string, channels, tools, skills []string, envVars map[string]string) []string +// +// authMode + authSettings are the user's choice from the preceding Auth step. +// Auth-derived hosts (e.g. sts..amazonaws.com for aws_sigv4, +// login.microsoftonline.com for azure_ad) are merged into the egress list so +// the Egress review displays the FULL outbound surface — operators see and +// confirm the auth hosts alongside provider/channel/tool/skill hosts. +type DeriveEgressFunc func( + provider string, + channels, tools, skills []string, + envVars map[string]string, + authMode string, + authSettings map[string]any, +) []string // EgressStep handles egress domain review. type EgressStep struct { @@ -32,6 +44,10 @@ func NewEgressStep(styles *tui.StyleSet, deriveFn DeriveEgressFunc) *EgressStep } // Prepare computes egress domains using the accumulated wizard context. +// +// The Auth step runs BEFORE Egress in the wizard order (see init.go), so by +// the time Prepare runs, ctx.AuthMode and ctx.AuthSettings reflect the +// operator's choice and we can compute auth-derived hosts. func (s *EgressStep) Prepare(ctx *tui.WizardContext) { var channels []string if ctx.Channel != "" && ctx.Channel != "none" { @@ -40,7 +56,7 @@ func (s *EgressStep) Prepare(ctx *tui.WizardContext) { s.domains = nil if s.deriveFn != nil { - s.domains = s.deriveFn(ctx.Provider, channels, ctx.BuiltinTools, ctx.Skills, ctx.EnvVars) + s.domains = s.deriveFn(ctx.Provider, channels, ctx.BuiltinTools, ctx.Skills, ctx.EnvVars, ctx.AuthMode, ctx.AuthSettings) } s.empty = len(s.domains) == 0 @@ -126,6 +142,13 @@ func (s *EgressStep) Apply(ctx *tui.WizardContext) { // inferSource guesses the source of an egress domain based on context. func inferSource(domain string, ctx *tui.WizardContext) string { + // Auth provider domains — checked FIRST so an OIDC issuer host like + // "login.example.com" is correctly attributed to "oidc auth" rather + // than falling through to a generic "configured" label. + if src := authProviderForDomain(domain, ctx); src != "" { + return src + } + // Provider domains providerDomains := map[string]string{ "api.openai.com": "model provider", @@ -170,3 +193,85 @@ func inferSource(domain string, ctx *tui.WizardContext) string { return "configured" } + +// authProviderForDomain returns a human-friendly label when `domain` is +// known to be required by the operator's chosen auth provider, or "" when +// the domain wasn't sourced from auth. +// +// The matching is intentionally narrow: we compare against the hosts each +// provider actually contributes (computed elsewhere via authEgressHostsFromSettings). +// For oidc/http_verifier, the host is dynamic (issuer/url-derived) so we +// match by exact string against what ctx.AuthSettings says the issuer +// resolves to. +func authProviderForDomain(domain string, ctx *tui.WizardContext) string { + if ctx == nil || ctx.AuthMode == "" || ctx.AuthMode == "none" || ctx.AuthMode == "custom" { + return "" + } + switch ctx.AuthMode { + case "aws_sigv4": + // sts..amazonaws.com is the only contributed host. + region, _ := ctx.AuthSettings["region"].(string) + if region != "" && domain == "sts."+region+".amazonaws.com" { + return "aws_sigv4 auth" + } + case "gcp_iap": + if domain == "www.gstatic.com" { + return "gcp_iap auth" + } + case "azure_ad": + if domain == "login.microsoftonline.com" { + return "azure_ad auth" + } + if domain == "graph.microsoft.com" { + return "azure_ad auth (graph)" + } + case "oidc": + // Best-effort: match the configured issuer's host. + issuer, _ := ctx.AuthSettings["issuer"].(string) + if h := hostOf(issuer); h != "" && h == domain { + return "oidc auth" + } + jwks, _ := ctx.AuthSettings["jwks_url"].(string) + if h := hostOf(jwks); h != "" && h == domain { + return "oidc auth (jwks)" + } + case "http_verifier": + url, _ := ctx.AuthSettings["url"].(string) + if h := hostOf(url); h != "" && h == domain { + return "http_verifier auth" + } + } + return "" +} + +// hostOf extracts the host portion of a URL, returning "" on parse failure +// or missing host. Kept local to avoid pulling net/url into a hot iteration +// path elsewhere. +func hostOf(raw string) string { + if raw == "" { + return "" + } + // Cheap manual split; matches what `(*url.URL).Hostname()` does for the + // well-formed inputs the wizard collects. + const sep = "://" + i := indexOf(raw, sep) + if i < 0 { + return "" + } + rest := raw[i+len(sep):] + for k := 0; k < len(rest); k++ { + if rest[k] == '/' || rest[k] == ':' { + return rest[:k] + } + } + return rest +} + +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} From b5f303b60a1208aa4394d43b3d134bb715d12a29 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 20:14:15 -0400 Subject: [PATCH 08/22] =?UTF-8?q?fix(auth):=20audit-pass=20refinements=20?= =?UTF-8?q?=E2=80=94=20iap=5Fjwt=20token=5Fkind,=20parse=20caching,=20clea?= =?UTF-8?q?nup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Final-pass audit findings against the phase 2 design doc surfaced one correctness bug and several small improvements. All gates clean (go test -race / golangci-lint / gofmt). 42 packages pass. BUG fix — middleware emits token_kind="iap_jwt" for IAP requests: The strategy doc §5/§10 lists five token_kind values: empty, opaque, jwt, sigv4, iap_jwt. PR1 wired sigv4 detection but missed iap_jwt, so successful GCP IAP requests audited with token_kind="empty" — the same value as no-auth requests, defeating the audit-pipeline goal of counting IAP traffic distinctly. Middleware now classifies X-Goog-Iap-Jwt-Assertion presence as kind="iap_jwt" on the empty-Bearer path. New regression test pins it. Improvement — graph_client.go avoids per-page URL re-parse: ensureGraphHost was parsing GraphClient.endpoint via url.Parse on EVERY pagination step. Pre-parse the endpoint Host once at construction and compare against that string instead. Trims redundant work on multi-page Graph responses. Improvement — gcp_iap classifyJWTErr ordering hardened: Replaced the bare substring match on "kid" (which would catch unrelated errors) with the specific patterns: "kid " (e.g. "kid X not found") and "not found" (covers JWKS-resolution failures). Pre-existing ordering invariant comment is now actually defended. Cleanup — drop redundant single-function file: Moved ExtractTenantID from azure_ad/tenant.go into provider.go alongside other claim accessors and removed the empty tenant.go. The function was a 1-liner and didn't justify its own file. Cleanup — inline audienceContains shim: Replaced the audienceContains() wrapper (one-liner around slices.Contains) with a direct call at the use site. Less indirection, same behavior. Cleanup — middleware: simplify hasNonBearerAuth boolean expr: Folded the multi-line if-chain into a single boolean expression. Same semantics, less noise. audit findings deferred as nits, not fixed: - aws_sigv4 Parser as zero-value struct (cosmetic; keeps symmetry) - egress_step.go hostOf manual URL parsing (cosmetic; non-hot path) - 10k eviction comment wording audit findings confirmed not bugs: - GraphCache TTL test (already exists in graph_cache_test.go) - PrependChain loopback invariant intact (runner.go line 2036) --- forge-core/auth/middleware.go | 31 ++++++++------- forge-core/auth/middleware_test.go | 25 ++++++++++++ .../auth/providers/azure_ad/graph_client.go | 38 +++++++++++-------- .../providers/azure_ad/graph_client_test.go | 6 +-- .../auth/providers/azure_ad/provider.go | 8 ++++ forge-core/auth/providers/azure_ad/tenant.go | 9 ----- forge-core/auth/providers/gcp_iap/iap_jwks.go | 12 +++--- forge-core/auth/providers/gcp_iap/provider.go | 6 +-- 8 files changed, 81 insertions(+), 54 deletions(-) delete mode 100644 forge-core/auth/providers/azure_ad/tenant.go diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index bd5b5e2..fd66484 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -59,9 +59,9 @@ type MiddlewareOptions struct { // - identity is non-nil and err is nil on success. // - identity is nil and err carries the chain error on failure // (or auth.ErrMissingBearer when the header was absent). - // - tokenKind is "jwt", "opaque", "sigv4", or "empty" — structural - // metadata safe to log. The token itself is NOT passed; callers - // must not try to recover it from the request. + // - tokenKind is "jwt", "opaque", "sigv4", "iap_jwt", or "empty" — + // structural metadata safe to log. The token itself is NOT + // passed; callers must not try to recover it from the request. // // Callbacks should be cheap — they run on the request hot path. OnAuth func(r *http.Request, identity *Identity, err error, tokenKind string) @@ -110,26 +110,25 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { kind := TokenKind(token) // Phase 2: when no Bearer token was extracted, classify the - // audit kind from the raw Authorization header. Lets aws_sigv4 - // emit the right token_kind value even though there's no Bearer. + // audit kind from the non-Bearer auth headers so token_kind in + // the audit log differentiates Sigv4 / IAP / no-auth requests + // instead of all reporting "empty". rawAuth := r.Header.Get("Authorization") - if kind == "empty" && strings.HasPrefix(rawAuth, "AWS4-HMAC-SHA256 ") { - kind = "sigv4" + iapHeader := r.Header.Get("X-Goog-Iap-Jwt-Assertion") + if kind == "empty" { + switch { + case strings.HasPrefix(rawAuth, "AWS4-HMAC-SHA256 "): + kind = "sigv4" + case iapHeader != "": + kind = "iap_jwt" + } } // hasNonBearerAuth signals "the caller IS attempting auth, just // not via Bearer." Used to differentiate "missing_token" (no // auth headers at all) from "chain rejected/didn't match" in // the audit reason. - hasNonBearerAuth := false - if token == "" { - if rawAuth != "" { - hasNonBearerAuth = true - } - if r.Header.Get("X-Goog-Iap-Jwt-Assertion") != "" { - hasNonBearerAuth = true - } - } + hasNonBearerAuth := token == "" && (rawAuth != "" || iapHeader != "") // Phase 2 change: consult the chain even on empty Bearer, so // providers that speak non-Bearer formats (aws_sigv4, gcp_iap) diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index d117b33..06596e0 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -580,6 +580,31 @@ func TestMiddleware_TokenKind_Sigv4OnEmptyBearer(t *testing.T) { } } +func TestMiddleware_TokenKind_IapJwtOnIAPHeader(t *testing.T) { + // Phase 2: when the only auth header is X-Goog-Iap-Jwt-Assertion, + // audit emits token_kind="iap_jwt" (not "empty"). Pinning this so + // the GCP IAP audit signal is distinguishable from no-auth requests. + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJ.eyJ.sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "iap_jwt" { + t.Errorf("token kind = %q, want iap_jwt", gotKind) + } +} + func TestMiddleware_TokenKind_EmptyWhenTrulyNoAuth(t *testing.T) { // Counterpart to the test above: when the caller didn't attempt auth // at all, token_kind should be "empty" — not silently widened. diff --git a/forge-core/auth/providers/azure_ad/graph_client.go b/forge-core/auth/providers/azure_ad/graph_client.go index 5ae7980..7d442df 100644 --- a/forge-core/auth/providers/azure_ad/graph_client.go +++ b/forge-core/auth/providers/azure_ad/graph_client.go @@ -22,26 +22,33 @@ import ( // token is reflected to Graph, which authorizes the read against the // user's delegated permission (GroupMember.Read.All). type GraphClient struct { - endpoint string // override-able for tests - http *http.Client + endpoint string // initial page URL + endpointHost string // pre-parsed for cheap per-page nextLink validation + http *http.Client } const graphBaseURL = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id&$top=100" // NewGraphClient builds a client pointed at the real Graph endpoint. func NewGraphClient(timeout time.Duration) *GraphClient { - return &GraphClient{ - endpoint: graphBaseURL, - http: &http.Client{Timeout: timeout}, - } + return newGraphClientFor(graphBaseURL, timeout) } // NewGraphClientWithEndpoint is a TEST-ONLY constructor for pointing at // a fake Graph server. func NewGraphClientWithEndpoint(endpoint string, timeout time.Duration) *GraphClient { + return newGraphClientFor(endpoint, timeout) +} + +func newGraphClientFor(endpoint string, timeout time.Duration) *GraphClient { + host := "" + if u, err := url.Parse(endpoint); err == nil { + host = u.Host + } return &GraphClient{ - endpoint: endpoint, - http: &http.Client{Timeout: timeout}, + endpoint: endpoint, + endpointHost: host, + http: &http.Client{Timeout: timeout}, } } @@ -63,7 +70,7 @@ func (c *GraphClient) TransitiveMemberOf(ctx context.Context, _ string, authHead out := []string{} next := c.endpoint for next != "" { - if err := ensureGraphHost(c.endpoint, next); err != nil { + if err := ensureGraphHost(c.endpointHost, next); err != nil { return nil, fmt.Errorf("%w: graph nextLink host: %v", auth.ErrProviderUnavailable, err) } page, nextURL, err := c.fetchPage(ctx, next, authHeader) @@ -125,20 +132,19 @@ func (c *GraphClient) fetchPage(ctx context.Context, u, authHeader string) (ids // host. Real Graph paginates within graph.microsoft.com. For tests where // the endpoint is httptest's 127.0.0.1, the test-mode endpoint host is // what we compare against. -func ensureGraphHost(configured, candidate string) error { +// +// `configuredHost` is the pre-parsed Host of GraphClient.endpoint, computed +// once at construction so we don't reparse it for every paginated request. +func ensureGraphHost(configuredHost, candidate string) error { if candidate == "" { return nil } - want, err := url.Parse(configured) - if err != nil { - return err - } got, err := url.Parse(candidate) if err != nil { return err } - if !strings.EqualFold(want.Host, got.Host) { - return fmt.Errorf("nextLink host %q does not match configured %q", got.Host, want.Host) + if !strings.EqualFold(configuredHost, got.Host) { + return fmt.Errorf("nextLink host %q does not match configured %q", got.Host, configuredHost) } return nil } diff --git a/forge-core/auth/providers/azure_ad/graph_client_test.go b/forge-core/auth/providers/azure_ad/graph_client_test.go index 116c1f0..bd7799c 100644 --- a/forge-core/auth/providers/azure_ad/graph_client_test.go +++ b/forge-core/auth/providers/azure_ad/graph_client_test.go @@ -118,7 +118,7 @@ func TestGraph_DefensivePaginationCap(t *testing.T) { } func TestEnsureGraphHost_RejectsForeignHost(t *testing.T) { - err := ensureGraphHost("https://graph.microsoft.com/v1.0/me", "https://evil.example.com/me/next") + err := ensureGraphHost("graph.microsoft.com", "https://evil.example.com/me/next") if err == nil { t.Fatal("expected error on foreign host") } @@ -126,7 +126,7 @@ func TestEnsureGraphHost_RejectsForeignHost(t *testing.T) { func TestEnsureGraphHost_AcceptsSameHost(t *testing.T) { err := ensureGraphHost( - "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id", + "graph.microsoft.com", "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$skiptoken=abc", ) if err != nil { @@ -135,7 +135,7 @@ func TestEnsureGraphHost_AcceptsSameHost(t *testing.T) { } func TestEnsureGraphHost_EmptyOK(t *testing.T) { - if err := ensureGraphHost("https://graph.microsoft.com/v1.0/me", ""); err != nil { + if err := ensureGraphHost("graph.microsoft.com", ""); err != nil { t.Errorf("empty nextLink should be ok, got %v", err) } } diff --git a/forge-core/auth/providers/azure_ad/provider.go b/forge-core/auth/providers/azure_ad/provider.go index 1591916..14539e9 100644 --- a/forge-core/auth/providers/azure_ad/provider.go +++ b/forge-core/auth/providers/azure_ad/provider.go @@ -86,6 +86,14 @@ func (c Config) Validate() error { return nil } +// ExtractTenantID returns the "tid" claim, or "" if it's missing / +// non-string. The empty-return form lets callers distinguish "missing" +// from "wrong tenant" without a typed error. +func ExtractTenantID(claims map[string]any) string { + tid, _ := claims["tid"].(string) + return tid +} + // Provider implements auth.Provider for AAD callers. type Provider struct { cfg Config diff --git a/forge-core/auth/providers/azure_ad/tenant.go b/forge-core/auth/providers/azure_ad/tenant.go deleted file mode 100644 index 7f2f47e..0000000 --- a/forge-core/auth/providers/azure_ad/tenant.go +++ /dev/null @@ -1,9 +0,0 @@ -package azure_ad - -// ExtractTenantID returns the "tid" claim, or "" if it's missing/non-string. -// The empty-return form is intentional — callers want to distinguish -// "missing" from "wrong tenant," and a typed error here would be noise. -func ExtractTenantID(claims map[string]any) string { - tid, _ := claims["tid"].(string) - return tid -} diff --git a/forge-core/auth/providers/gcp_iap/iap_jwks.go b/forge-core/auth/providers/gcp_iap/iap_jwks.go index bd23293..032a044 100644 --- a/forge-core/auth/providers/gcp_iap/iap_jwks.go +++ b/forge-core/auth/providers/gcp_iap/iap_jwks.go @@ -277,10 +277,11 @@ func parseECJWKSet(raw []byte) (map[string]*ecdsa.PublicKey, error) { // returns wrapped errors with stable joinings — we string-match on the // most reliable signals. // -// Ordering matters: alg-confusion checks come FIRST because the library -// emits "signing method X is invalid" wrapped in "signature is invalid", -// and we want algorithm-confusion classified as ErrInvalidToken (malformed) -// rather than ErrTokenRejected (policy denial). +// Algorithm-confusion ("signing method X is invalid") and kid-related +// errors are classified as ErrInvalidToken (token shape is wrong), +// not ErrTokenRejected (which is policy denial). These checks run +// BEFORE the generic "signature is invalid" arm because the alg error +// is wrapped in a signature-shaped outer error by golang-jwt. func classifyJWTErr(err error) error { if errors.Is(err, auth.ErrProviderUnavailable) || errors.Is(err, auth.ErrInvalidToken) || @@ -292,7 +293,8 @@ func classifyJWTErr(err error) error { case strings.Contains(s, "signing method"), strings.Contains(s, "unexpected alg"), strings.Contains(s, "missing kid"), - strings.Contains(s, "kid"): + strings.Contains(s, "kid "), // "kid X not found" + strings.Contains(s, "not found"): // covers JWKS-resolution failures return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) case strings.Contains(s, "expired"), strings.Contains(s, "iat"), diff --git a/forge-core/auth/providers/gcp_iap/provider.go b/forge-core/auth/providers/gcp_iap/provider.go index 6a5e0a6..3fa8815 100644 --- a/forge-core/auth/providers/gcp_iap/provider.go +++ b/forge-core/auth/providers/gcp_iap/provider.go @@ -114,7 +114,7 @@ func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) ( if claims.Issuer != iapIssuer { return nil, fmt.Errorf("%w: iss=%q, want %q", auth.ErrTokenRejected, claims.Issuer, iapIssuer) } - if !audienceContains(claims.Audience, p.cfg.Audience) { + if !slices.Contains(claims.Audience, p.cfg.Audience) { return nil, fmt.Errorf("%w: aud mismatch", auth.ErrTokenRejected) } if claims.Subject == "" || claims.Email == "" { @@ -136,10 +136,6 @@ func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) ( }, nil } -func audienceContains(aud []string, want string) bool { - return slices.Contains(aud, want) -} - func init() { auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { var cfg Config From aaf837502e9d361f1ebee7baa60c2acd771d855c Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 20:18:55 -0400 Subject: [PATCH 09/22] chore(docs): gitignore docs/auth and drop in-repo copies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Phase 2 provider docs were committed as MD files under docs/auth/ but we don't want to version-control them — the source-of-truth lives in the design folder, and we'll deliver via the doc site separately. - .gitignore: add docs/auth/ - git rm --cached docs/auth/** (local files preserved) - README.md: drop the now-broken "Auth Providers" docs row - CHANGELOG.md: drop the docs/auth/*.md links from the v0.11.0 entry No code or test changes. --- .gitignore | 5 + CHANGELOG.md | 5 +- README.md | 1 - docs/auth/concepts/chain.md | 83 --------------- docs/auth/index.md | 53 ---------- docs/auth/providers/aws_sigv4.md | 121 --------------------- docs/auth/providers/azure_ad.md | 176 ------------------------------- docs/auth/providers/gcp_iap.md | 105 ------------------ 8 files changed, 7 insertions(+), 542 deletions(-) delete mode 100644 docs/auth/concepts/chain.md delete mode 100644 docs/auth/index.md delete mode 100644 docs/auth/providers/aws_sigv4.md delete mode 100644 docs/auth/providers/azure_ad.md delete mode 100644 docs/auth/providers/gcp_iap.md diff --git a/.gitignore b/.gitignore index 8815ab0..500bc45 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,8 @@ profile.cov # OS files .DS_Store + +# Auth provider docs — kept on disk locally but not version-controlled. +# Source-of-truth docs live in the design folder; in-repo copies are +# scratch space until we decide on the doc-site delivery story. +docs/auth/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 97ff269..d2435c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,13 +6,12 @@ - **`aws_sigv4` auth provider.** Authenticate AWS-IAM callers by reflecting their Sigv4 signature to AWS STS `GetCallerIdentity`. No `aws-sdk-go-v2` - dependency. See [`docs/auth/providers/aws_sigv4.md`](docs/auth/providers/aws_sigv4.md). + dependency. - **`gcp_iap` auth provider.** Verify the JWT IAP forwards as `X-Goog-Iap-Jwt-Assertion` when Forge sits behind a GCP HTTPS Load - Balancer with IAP enabled. See [`docs/auth/providers/gcp_iap.md`](docs/auth/providers/gcp_iap.md). + Balancer with IAP enabled. - **`azure_ad` auth provider.** Verify Microsoft Entra ID Bearer tokens with tenant lock-in and optional Microsoft Graph group enrichment. - See [`docs/auth/providers/azure_ad.md`](docs/auth/providers/azure_ad.md). - Non-interactive `forge init` flags for the three new providers: `--auth-aws-region`, `--auth-aws-allowed-principal` (repeatable), `--auth-gcp-iap-audience`, `--auth-azure-tenant`, diff --git a/README.md b/README.md index 044adc0..35c05ff 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,6 @@ You write a `SKILL.md`. Forge compiles it into a secure, runnable agent with egr | Document | Description | |----------|-------------| | [Security Overview](docs/security/overview.md) | Complete security architecture | -| [Auth Providers](docs/auth/index.md) | Pluggable auth chain — OIDC, AWS Sigv4, GCP IAP, Azure AD | | [Egress Security](docs/security/egress-control.md) | Egress enforcement deep dive | | [Secrets](docs/security/secret-management.md) | Encrypted secret management | | [Build Signing](docs/security/build-signing.md) | Ed25519 signing and verification | diff --git a/docs/auth/concepts/chain.md b/docs/auth/concepts/chain.md deleted file mode 100644 index 975097e..0000000 --- a/docs/auth/concepts/chain.md +++ /dev/null @@ -1,83 +0,0 @@ ---- -title: "Auth Provider Chain Semantics" -description: "First-match-wins, fail-closed, and the rules every provider follows." -order: 2 ---- - -Forge's auth chain composes multiple providers in a single `auth.providers` -list. Understanding the chain rules is the difference between a chain that -authenticates safely and one that silently downgrades to the most permissive -provider on the list. - -## The rules - -Each provider's `Verify` returns one of four signals: - -| Return | Meaning | Chain behavior | -|---|---|---| -| `Identity, nil` | Token accepted | Stops; chain returns this Identity | -| `nil, ErrTokenNotForMe` | "Not my format" | Continues to next provider | -| `nil, ErrTokenRejected` | "My format, but denied" | **Stops; chain returns 401** | -| `nil, ErrInvalidToken` | "Malformed" | **Stops; chain returns 401** | -| `nil, ErrProviderUnavailable` | "Can't reach my IdP" | **Stops; chain returns 401** | -| `nil, any other error` | Transport / infrastructure | **Stops; chain returns 401** | - -The critical rule is the **no-fall-through on rejection**: if provider A -returns `ErrTokenRejected`, the chain does NOT try provider B. Falling -through would let an attacker downgrade their authentication — present a -malformed JWT and hope to be authenticated as an opaque-token user instead. - -## First-match-wins ordering - -Providers are tried top-to-bottom. Whoever **claims** the request first -(returns anything except `ErrTokenNotForMe`) decides the outcome. - -Practically: put more specific providers earlier. For Forge instances that -run channel adapters (Slack, Telegram), the `static_token` loopback secret -is auto-prepended to the chain — it always matches first for self-issued -loopback calls, regardless of what's configured in `forge.yaml`. - -## TUI wizard ordering - -The `forge init` wizard runs the **Auth step before the Egress step** — -once you pick a provider, the egress review automatically includes the -outbound hosts it needs (STS endpoint for `aws_sigv4`, AAD authority for -`azure_ad`, IAP JWKS host for `gcp_iap`, OIDC issuer host for `oidc`). -Operators see and confirm those hosts alongside provider / channel / tool / -skill hosts in a single review screen — no need to add auth hosts to -`egress_hosts` by hand after the wizard. - -## Non-Bearer formats (v0.11.0) - -As of v0.11.0 the middleware consults the chain **even when no Bearer token -was extracted** — required for `aws_sigv4` (which reads `Authorization: -AWS4-HMAC-SHA256 …`) and `gcp_iap` (which reads `X-Goog-Iap-Jwt-Assertion`). - -When the request carries no auth-shaped headers at all, the audit reason -code is still `missing_token` rather than `not_for_me` — operators can still -distinguish "client didn't auth" from "client tried a format we don't speak." - -## Example: mixed AWS + OIDC chain - -```yaml -auth: - required: true - providers: - # Internal CI/automation calls — Sigv4-signed under an IAM role: - - type: aws_sigv4 - settings: - region: us-east-1 - allowed_principals: - - "arn:aws:sts::123456789012:assumed-role/ci-deploy/*" - # Human users — Keycloak Bearer JWT: - - type: oidc - settings: - issuer: https://keycloak.example.com/realms/forge - audience: api://forge -``` - -A Sigv4-signed request matches `aws_sigv4` and never reaches `oidc`. A Bearer -JWT request returns `ErrTokenNotForMe` from `aws_sigv4` (no `AWS4-HMAC-SHA256` -prefix), then matches `oidc`. A garbage Bearer ("Bearer foobar") returns -`ErrTokenNotForMe` from both providers and the middleware emits -`auth_fail{reason: "not_for_me"}` with a 401. diff --git a/docs/auth/index.md b/docs/auth/index.md deleted file mode 100644 index 25a94ce..0000000 --- a/docs/auth/index.md +++ /dev/null @@ -1,53 +0,0 @@ ---- -title: "Authentication Providers" -description: "Pluggable auth provider chain — pick the formats Forge accepts on /tasks." -order: 1 ---- - -Forge's `/tasks` endpoint requires every request to authenticate through a -chain of pluggable auth providers configured in `forge.yaml`. Each provider -recognizes one token shape; the chain tries them in order, first match wins. - -## Provider Matrix - -| Provider | Use case | Token format | Phase | -|---|---|---|---| -| [`static_token`](./providers/static_token.md) | Local dev, loopback channel adapters | Shared secret | 1 | -| [`oidc`](./providers/oidc.md) | Any IdP with OIDC discovery (Keycloak, Okta, Auth0, Google) | Bearer JWT | 1 | -| [`http_verifier`](./providers/http_verifier.md) | Custom verifier endpoint | Opaque | 1 | -| [`aws_sigv4`](./providers/aws_sigv4.md) | AWS-IAM-based callers (Lambda, EC2, EKS, IAM users) | Sigv4 (`Authorization`) | **2 (v0.11.0)** | -| [`gcp_iap`](./providers/gcp_iap.md) | Behind GCP HTTPS Load Balancer + IAP | IAP JWT (`X-Goog-Iap-Jwt-Assertion`) | **2 (v0.11.0)** | -| [`azure_ad`](./providers/azure_ad.md) | Microsoft Entra ID tokens | Bearer JWT | **2 (v0.11.0)** | - -## Phase 2 — Cloud-native providers (v0.11.0) - -Three new providers consume the identities customers already have in their -cloud — no parallel IdP required: - -- **AWS IAM** → callers sign requests with their existing IAM keys/role - (IRSA, instance profile, Lambda role, `aws configure`). Forge verifies via - STS `GetCallerIdentity` — zero secrets, no token endpoint to host. -- **GCP IAP** → Forge sitting behind a GCP HTTPS LB+IAP consumes the JWT IAP - forwards on every authenticated request. -- **Azure AD** → Entra tokens with tenant lock-in and optional Microsoft - Graph group enrichment. - -## Chain semantics - -Read [concepts/chain](./concepts/chain.md) for the full rules. In short: - -- Providers are tried in `forge.yaml` order. -- A provider returns one of: an `Identity` (chain stops, request authorized); - `ErrTokenNotForMe` (try next provider); or any other error (chain - rejects — **does not** fall through to a more permissive provider). -- The first-match-wins rule prevents downgrade attacks: an attacker who - presents a malformed token of type A doesn't get a chance to be authenticated - as type B. - -## Auditing - -Every accepted request emits an `auth_verify` event with the provider's name -and the principal's identifiers. Every rejection emits `auth_fail` with a -stable reason code (`missing_token`, `not_for_me`, `rejected`, `invalid`, -`provider_unavailable`). Pipe both into your SIEM — see each provider's -"Audit log" section for the exact shape. diff --git a/docs/auth/providers/aws_sigv4.md b/docs/auth/providers/aws_sigv4.md deleted file mode 100644 index 2f8f64f..0000000 --- a/docs/auth/providers/aws_sigv4.md +++ /dev/null @@ -1,121 +0,0 @@ ---- -title: "aws_sigv4 — AWS IAM" -description: "Authenticate AWS-IAM callers via STS GetCallerIdentity reflection." -order: 10 ---- - -The `aws_sigv4` provider authenticates requests signed with AWS Signature -Version 4. Forge reflects the caller's signature to AWS STS -`GetCallerIdentity`; STS validates it on Forge's behalf and returns the -canonical ARN/Account/UserID. **Forge never possesses the caller's secret -key.** Same pattern as `aws-iam-authenticator` (the EKS auth bridge). - -## When to use it - -Use this when your callers (Lambda, EC2, EKS workloads, developer laptops -with `aws configure`/`aws sso login`) already have AWS credentials. No -separate token issuer needed — IAM is the identity. - -## Prerequisites - -- [ ] Forge can reach `sts..amazonaws.com` (auto-added to the - Phase 2 egress allowlist). -- [ ] At least one IAM identity exists in the target AWS account. -- [ ] You've decided which principals to allow (optional but recommended - for production). -- [ ] **No additional IAM permissions required.** `GetCallerIdentity` is - implicit for every IAM principal on itself. - -## forge.yaml - -```yaml -auth: - required: true - providers: - - type: aws_sigv4 - settings: - region: us-east-1 - audience: api://forge # informational; emitted in audit - allowed_principals: - # NOTE: STS returns assumed-role ARNs, NOT IAM role ARNs. - # Match against arn:aws:sts::, not arn:aws:iam::. - - "arn:aws:sts::123456789012:assumed-role/ci-deploy/*" - - "arn:aws:sts::123456789012:assumed-role/forge-runner/*" - identity_cache_ttl: 60s -``` - -## Configuration reference - -| Field | Required | Default | Description | -|---|---|---|---| -| `region` | yes | — | AWS region for the STS endpoint. Sigv4 signatures must scope to this region. | -| `audience` | no | — | Informational string; emitted in `Identity.Claims.audience`. STS doesn't enforce it. | -| `allowed_principals` | no | empty (allow any) | List of [shell-style globs](https://pkg.go.dev/path#Match) matched against the **STS-returned ARN**. | -| `identity_cache_ttl` | no | `60s` | How long a verified Identity is reused before re-checking STS. Bounds the stolen-key window. | -| `http_timeout` | no | `5s` | Per-STS-call timeout. | - -`allowed_principals` is the authz mechanism. Empty list means "allow any IAM -principal who can sign for this region" — fine for single-tenant dev, never -appropriate for production. - -## Worked example — call Forge from a script - -```bash -# Make sure your AWS profile points at an identity in allowed_principals: -export AWS_PROFILE=dev - -# awscurl handles the Sigv4 signing: -pipx install awscurl - -awscurl --service sts --region us-east-1 \ - -X POST -d '{"task":"hello"}' \ - -H "Content-Type: application/json" \ - http://localhost:9999/tasks -``` - -A 200 response means the chain accepted your IAM identity; check the Forge -log for the matching `auth_verify` event. - -## Audit log - -Successful auth: - -```json -{ "event": "auth_verify", - "fields": { - "provider": "aws_sigv4", - "user_id": "arn:aws:sts::123456789012:assumed-role/ci-deploy/i-0abc", - "org_id": "123456789012", - "token_kind": "sigv4" - } -} -``` - -## Troubleshooting - -| Symptom | Likely cause | Fix | -|---|---|---| -| `auth_fail reason=not_for_me` | `Authorization` header didn't start with `AWS4-HMAC-SHA256` | Caller is signing with Bearer/something else — point them at the right tooling (`awscurl`, AWS SDK) | -| `auth_fail reason=rejected` + STS body mentions signature | Caller signed with wrong creds, expired session, or clock skew | Re-sign with current creds; check system clock; ensure session token is valid | -| `auth_fail reason=rejected` + "ARN ... not in allowed_principals" | Allowlist mismatch | **Common bug:** the pattern must match the STS-returned **assumed-role ARN** (`arn:aws:sts::ACCT:assumed-role/...`), NOT the IAM role ARN (`arn:aws:iam::ACCT:role/...`) | -| `auth_fail reason=rejected` + "service=s3" or "region=eu-west-1" | Caller signed for wrong service or region | Re-sign with `--service sts --region ` | -| `auth_fail reason=provider_unavailable` | STS or network is down; egress allowlist missing the STS host | Check `sts..amazonaws.com` reachability; check `egress_hosts` | - -## Security model - -- **No secret key on Forge.** STS does the signature math. -- **Cache bucketing on YYYYMMDD + per-AKID** bounds the stolen-key window: - a leaked AKID is valid until midnight UTC OR `identity_cache_ttl` - elapses, whichever is sooner. -- **Service/region scope check** rejects Sigv4 signatures meant for S3, EC2, - or a different region. Prevents cross-service replay. -- **No `aws-sdk-go` dependency.** The STS RPC is ~150 LOC of hand-rolled - HTTP + XML. Smaller attack surface than a transitive SDK. - -## Limitations - -- One AWS region per provider entry. Configure multiple entries for - multi-region. -- Cognito federation / SAML / OIDC — use the [`oidc`](./oidc.md) provider - pointed at your Cognito user pool's issuer. -- Sigv4 is inbound only — Forge does not Sigv4-sign outbound requests. diff --git a/docs/auth/providers/azure_ad.md b/docs/auth/providers/azure_ad.md deleted file mode 100644 index c8637c8..0000000 --- a/docs/auth/providers/azure_ad.md +++ /dev/null @@ -1,176 +0,0 @@ ---- -title: "azure_ad — Microsoft Entra ID" -description: "Verify Entra ID tokens with tenant lock-in and optional Graph enrichment." -order: 12 ---- - -The `azure_ad` provider authenticates Microsoft Entra ID (formerly Azure AD) -Bearer tokens. It composes the Phase 1 `oidc` provider for signature -verification and standard claim checks, then layers AAD-specific concerns -on top: - -- Tenant lock-in via the `tid` claim -- Optional Microsoft Graph group enrichment when the JWT's `groups` claim - overflows (AAD truncates at ~200 groups) -- Single-tenant vs. multi-tenant operation - -## When to use it - -Your callers — humans, CI, or Azure-hosted workloads — already have Entra -identities. AAD-specific quirks (`tid`, app roles, groups overage) are -handled correctly. - -## Prerequisites - -- [ ] An Entra tenant exists and you know its GUID. -- [ ] An app registration exists in the tenant. -- [ ] App registration → **Expose an API** → Application ID URI is set - (e.g. `api://forge`). -- [ ] Forge can reach `login.microsoftonline.com`. -- [ ] *(For graph mode only)* Forge can reach `graph.microsoft.com`, the - app registration has the **delegated** Graph permission - `GroupMember.Read.All`, and an admin has **granted consent**. - -### App registration walkthrough - -``` -Entra admin center → App registrations → New registration - Name: Forge - Supported accounts: Single tenant (unless you need multi-tenant) - Redirect URI: (leave empty — Forge doesn't initiate flows) - -→ Expose an API: - Application ID URI: api://forge - Add a scope: user_impersonation - Who can consent: Admins and users - -→ API permissions (ONLY if groups_mode=graph): - Add a permission → Microsoft Graph → Delegated → GroupMember.Read.All - Click "Grant admin consent" -``` - -## forge.yaml — single-tenant, claim mode (typical) - -```yaml -auth: - required: true - providers: - - type: azure_ad - settings: - tenant_id: 00000000-1111-2222-3333-444444444444 - audience: api://forge - groups_mode: claim # default; reads `groups`/`roles` from JWT -``` - -## forge.yaml — multi-tenant - -> **WARNING:** Multi-tenant means **any** Entra tenant in the world can -> present tokens to Forge. Without an authz layer (Phase 4+ ABAC), this is -> effectively wide-open auth. Pair it with claim-based policy in your -> handler middleware, or stick to single-tenant. - -```yaml -auth: - required: true - providers: - - type: azure_ad - settings: - audience: api://forge - allow_multi_tenant: true -``` - -## forge.yaml — graph mode (groups overage) - -When a user is in more than ~200 groups, AAD truncates the JWT `groups` -claim and includes `_claim_names` indicating overflow. The default `claim` -mode would see empty groups in that case — switch to `graph` mode to query -Microsoft Graph for the full list: - -```yaml -auth: - required: true - providers: - - type: azure_ad - settings: - tenant_id: 00000000-... - audience: api://forge - groups_mode: graph -``` - -The caller's Bearer token is **reflected** to Graph; Forge holds no Graph -credentials of its own. The user's `GroupMember.Read.All` delegated -permission is what authorizes the Graph read. - -## Configuration reference - -| Field | Required | Default | Description | -|---|---|---|---| -| `audience` | yes | — | Application ID URI from the app registration. | -| `tenant_id` | required unless `allow_multi_tenant=true` | — | Entra tenant GUID. | -| `allow_multi_tenant` | no | `false` | Accept tokens from any tenant. **High risk** without authz. | -| `groups_mode` | no | `claim` | `claim` (use JWT groups) or `graph` (enrich via Graph on empty groups). | -| `graph_timeout` | no | `5s` | Per-Graph-call timeout. | -| `jwks_cache_ttl` | no | `1h` | JWKS cache age before refresh. | - -## Acquiring tokens for testing - -```bash -# From a logged-in az CLI: -az account get-access-token --resource api://forge \ - --query accessToken -o tsv > /tmp/aad.tok - -curl -H "Authorization: Bearer $(cat /tmp/aad.tok)" \ - http://localhost:9999/tasks -d '{}' -``` - -## Audit log - -```json -{ "event": "auth_verify", - "fields": { - "provider": "azure_ad", - "user_id": "alice@contoso.onmicrosoft.com", - "org_id": "00000000-1111-2222-3333-444444444444", // tid - "token_kind": "jwt", - "groups": ["group-guid-1", "group-guid-2"] - } -} -``` - -Note `provider: "azure_ad"`, not `"oidc"` — even though OIDC did the -signature work, the audit name reflects the operator's configuration. - -## Troubleshooting - -| Symptom | Likely cause | Fix | -|---|---|---| -| `auth_fail reason=rejected` + "tid mismatch" | Token from a different Entra tenant | Re-sign in to the right tenant, or set `allow_multi_tenant: true` if intentional | -| `auth_fail reason=rejected` + "aud" | `audience` in forge.yaml doesn't match the app registration's Application ID URI | Make both match exactly | -| Identity has empty `groups` despite user being in groups | Groups overage (>200 groups) | Switch to `groups_mode: graph` and grant `GroupMember.Read.All` admin consent | -| Graph enrichment soft-fails (auth OK, groups empty) | Graph 5xx or network failure | Check `graph.microsoft.com` reachability | -| `auth_fail reason=invalid` + "missing tid" | Token is not an AAD token (maybe a guest/personal Microsoft account?) | Confirm token issuer; only AAD work/school accounts have `tid` | - -## Security model - -- **Composes** the `oidc` provider — no JWT/JWKS code lives in this package. - All signature verification, algorithm-whitelisting, and key rotation - reuse the Phase 1 OIDC implementation. -- **Tenant gate before Graph** — for single-tenant configs, the `tid` check - runs before any Graph call. Different tenant = different security domain. -- **Caller's token reflected to Graph** — Forge never holds Graph - credentials. The user's delegated permission is the only authority. -- **Soft-fail on Graph 5xx** — Identity still returned with empty groups - rather than blocking prod traffic on a transient Graph outage. -- **Internal `skip_issuer_check`** flag on the composed OIDC is exposed - ONLY when `allow_multi_tenant=true`, and the field is `yaml:"-"` so it - cannot be set via `forge.yaml` — preventing an operator from accidentally - disabling issuer validation. - -## Limitations - -- No SAML — Phase 5+. -- No `client_credentials` flow for workload-to-Forge with no user context. - If running in AWS, use [`aws_sigv4`](./aws_sigv4.md); otherwise use an - [`oidc`](./oidc.md) entry pointed at your workload identity's issuer. -- "Multi-tenant" means **any** tenant — there's no allowlist of tenants. - For "this set of tenants," configure multiple `azure_ad` entries. diff --git a/docs/auth/providers/gcp_iap.md b/docs/auth/providers/gcp_iap.md deleted file mode 100644 index f9e5fcf..0000000 --- a/docs/auth/providers/gcp_iap.md +++ /dev/null @@ -1,105 +0,0 @@ ---- -title: "gcp_iap — GCP Identity-Aware Proxy" -description: "Verify the JWT IAP forwards on every authenticated request." -order: 11 ---- - -The `gcp_iap` provider authenticates requests that come through GCP's -Identity-Aware Proxy. IAP terminates user authentication at the HTTPS Load -Balancer and forwards a signed JWT in the `X-Goog-Iap-Jwt-Assertion` header. -Forge verifies that JWT against IAP's well-known JWKS and stamps an Identity -from the `sub` + `email` claims. - -## When to use it - -Forge is deployed behind a GCP HTTPS Load Balancer with IAP enabled, and you -want the IAP-authenticated user (Google account / Workspace identity) to -flow through to Forge. - -## Prerequisites - -- [ ] Forge is reachable through a GCP HTTPS LB. -- [ ] IAP is enabled on the LB's backend service. -- [ ] Forge can reach `www.gstatic.com` (auto-added to the Phase 2 egress - allowlist). -- [ ] You know the backend service's "Signed Header JWT Audience" string. - -### Finding the audience - -The audience is the GCP backend service ID: - -``` -GCP Console → Security → Identity-Aware Proxy - → Backend Services tab - → click your backend - → "Signed Header JWT Audience" field -``` - -Format: `/projects/PROJECT_NUMBER/global/backendServices/BACKEND_ID`. - -## forge.yaml - -```yaml -auth: - required: true - providers: - - type: gcp_iap - settings: - audience: /projects/12345678/global/backendServices/9876543210 -``` - -## Configuration reference - -| Field | Required | Default | Description | -|---|---|---|---| -| `audience` | yes | — | GCP backend service ID. Must match exactly. | -| `jwks_refresh_ttl` | no | `1h` | How long a cached JWKS is reused before refresh. | -| `http_timeout` | no | `5s` | JWKS fetch timeout. | - -The IAP **issuer** (`https://cloud.google.com/iap`) and **JWKS URL** -(`https://www.gstatic.com/iap/verify/public_key-jwk`) are **hardcoded**. -They are the only stable contract GCP exposes; an override knob would be a -footgun (someone could be tricked into trusting an attacker's JWKS). - -## Audit log - -```json -{ "event": "auth_verify", - "fields": { - "provider": "gcp_iap", - "user_id": "1234567890", // claims.sub - "email": "alice@example.com", // claims.email - "token_kind": "jwt" - } -} -``` - -The Workspace domain claim (`hd`) is included in `Identity.Claims` for -downstream policy decisions. - -## Troubleshooting - -| Symptom | Likely cause | Fix | -|---|---|---| -| `auth_fail reason=not_for_me` | Request didn't come through IAP — no `X-Goog-Iap-Jwt-Assertion` header | Hit Forge through the GCP LB, not the backend directly | -| `auth_fail reason=rejected` + "iss" mismatch | Token is a regular Google OAuth token, not an IAP assertion | Confirm IAP is enabled; route the request through the LB | -| `auth_fail reason=rejected` + "aud" | `audience` in forge.yaml doesn't match the backend service ID | Re-check the audience string in the GCP console | -| `auth_fail reason=invalid` + "alg" | Token signed with something other than ES256 | Not an IAP token; check origin | -| `auth_fail reason=provider_unavailable` | `www.gstatic.com` unreachable | Check egress allowlist, network | - -## Security model - -- **ES256 algorithm whitelist** rejected any other signing method *before* - key lookup — algorithm-confusion attacks can't reach the JWKS. -- **JWKS host hardcoded** so a malicious config can't redirect verification - to an attacker-controlled key server. -- **Stale-grace cache** — during a JWKS endpoint outage, the previously - cached keys remain valid. IAP rotates keys on the order of weeks; freshness - matters less than availability. - -## Limitations - -- No per-email / per-domain allowlist in Forge — use **GCP IAM Conditions** - on the backend service for that. IAP can grant or deny access at the LB - layer; Forge sees only requests IAP already approved. -- Single audience per provider entry. From 382294eda86f933e7904bef194506cc207be0b45 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 21:05:13 -0400 Subject: [PATCH 10/22] docs(auth/aws_sigv4): ship reference client + document sign-for-STS contract MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Real-AWS testing surfaced a documentation gap: callers cannot use raw `awscurl` / `aws-sdk-go` against Forge's `aws_sigv4` provider because Sigv4 binds the signature to the destination host. Standard tools sign for the URL they're addressing (Forge) — STS then rejects the reflected signature because the host bytes don't match. The server-side code is correct. The client just needs to sign a hypothetical STS request, then attach the resulting headers to its real POST to Forge. Same pattern as aws-iam-authenticator for EKS. This commit: - Ships `scripts/forge-aws-sign.py`, a ~100 LOC reference client using boto3.session + SigV4Auth. CLI flags for --region, --url, --profile, --body, --verbose. Reads SSO/IRSA/profile/env credentials via boto3's standard chain. - Extends the package-level docstring in `forge-core/auth/providers/aws_sigv4/provider.go` with a "Client-side signing contract" section spelling out the 4-step pattern and pointing readers to the reference script. - Adds a "Client-side requirement" section to CHANGELOG.md so adopters know to grab the helper or write their own before integrating. Validated against real AWS: - STS reflection: 200, identity stamped, correct ARN/Account/UserID - ARN allowlist match: 200 (matching pattern) - ARN allowlist miss: 401 reason=rejected (correct authz gate) - No-auth: 401 reason=missing_token (Phase 1 contract preserved) --- CHANGELOG.md | 14 +++ .../auth/providers/aws_sigv4/provider.go | 22 ++++ scripts/forge-aws-sign.py | 111 ++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100755 scripts/forge-aws-sign.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d2435c4..be24183 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,20 @@ from Go callers (currently only `azure_ad` multi-tenant). Operators see no change. +### Client-side requirement for `aws_sigv4` + +- AWS Sigv4 binds its signature to the *destination host*. As a consequence, + callers cannot just point `awscurl` at Forge and have it work — the + signature would be computed for Forge's hostname, not STS's, and STS + would reject the reflected request. +- Callers must use a thin wrapper that signs a hypothetical STS + GetCallerIdentity request, then attaches the resulting headers to the + POST that goes to Forge. This is the same pattern `aws-iam-authenticator` + uses for EKS. +- A reference implementation in ~30 lines of `boto3` ships at + `scripts/forge-aws-sign.py`. Use it as-is or as a template for your + CI / Lambda / Go inter-agent integrations. + ### Known deferred work - The Bubble Tea TUI wizard (`forge init`) does not yet have step-by-step diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go index bf00d52..93c4a8b 100644 --- a/forge-core/auth/providers/aws_sigv4/provider.go +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -7,6 +7,28 @@ // identities are cached by hash(AKID|YYYYMMDD) for IdentityCacheTTL so the // hot path doesn't bounce through STS on every call. // +// # Client-side signing contract (READ THIS BEFORE INTEGRATING) +// +// Sigv4 binds the signature to the destination host as part of its canonical +// headers. Forge does not — and cannot — relax that. As a consequence, +// callers MUST sign their request as if they were going to STS directly, +// then attach the resulting headers to a request that's sent to Forge: +// +// 1. Construct an STS request shape (POST https://sts..amazonaws.com/, +// body Action=GetCallerIdentity&Version=2011-06-15) — but do not send it. +// 2. Sign that hypothetical request with the caller's AWS credentials, using +// service=sts and region=. +// 3. Take the resulting Authorization + X-Amz-Date (+ X-Amz-Security-Token +// for temporary credentials, e.g. SSO / IRSA / Lambda) headers and attach +// them to the REAL POST that goes to Forge's /tasks/send endpoint. +// 4. Forge will forward those headers verbatim to STS; STS will validate +// them against its own host and return the caller's canonical ARN. +// +// Standard tools (awscurl, the AWS CLI, boto3.client('sts')) do NOT do this +// automatically — they always sign for the URL they are addressing. A small +// wrapper is required. See scripts/forge-aws-sign.py for the canonical +// reference implementation in ~30 lines of boto3. +// // Decision §9.1: no aws-sdk-go-v2 dependency — the STS RPC is ~150 LOC of // hand-rolled HTTP + XML. Trade-off: small attack surface, predictable // behavior, no transitive deps; cost: we maintain the RPC ourselves. diff --git a/scripts/forge-aws-sign.py b/scripts/forge-aws-sign.py new file mode 100755 index 0000000..b0cc68f --- /dev/null +++ b/scripts/forge-aws-sign.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +""" +forge-aws-sign — reference client for Forge's aws_sigv4 auth provider. + +Why this script exists +====================== +Forge's aws_sigv4 provider authenticates callers by reflecting their Sigv4 +signature to AWS STS GetCallerIdentity. The signature is bound to the +*destination host* in its canonical-headers input, so callers cannot just +sign for Forge's URL and expect STS to validate the result. + +The correct client-side pattern is: + + 1. Construct a hypothetical request shape addressed to STS itself + (POST https://sts..amazonaws.com/, body + "Action=GetCallerIdentity&Version=2011-06-15"). + 2. Sign that hypothetical request with the caller's AWS credentials. + 3. Take the resulting Authorization + X-Amz-Date (+ X-Amz-Security-Token + for temporary credentials) headers and attach them to the REAL request + that's being sent to Forge. + 4. Forge forwards those headers verbatim to STS, which validates them + against its own host. Forge stamps an Identity from the returned ARN. + +This is identical in spirit to the aws-iam-authenticator pattern that EKS +uses. Standard tools (awscurl, the AWS CLI, `boto3.client('sts')`) do NOT +do this automatically — they always sign for the URL you're calling. So a +small wrapper like this one is required. + +Usage +===== + python3 forge-aws-sign.py [--region us-east-1] [--url URL] [--body BODY] + + Defaults: + --region: us-east-1 + --url: http://localhost:9999/tasks/send + --body: {"task":"hello"} + +Reads AWS credentials the same way boto3 does: env vars, profile, SSO, +IRSA, instance profile, etc. — whichever applies. + +Exits 0 on HTTP 2xx, 1 otherwise. Prints the response body either way. +""" +from __future__ import annotations + +import argparse +import sys + +try: + import requests + from botocore.session import Session + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest +except ImportError as e: + print(f"missing dependency: {e}", file=sys.stderr) + print("install with: pip3 install --user boto3 requests", file=sys.stderr) + sys.exit(2) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Forge aws_sigv4 reference client") + parser.add_argument("--region", default="us-east-1", help="AWS region used in the Sigv4 scope") + parser.add_argument("--url", default="http://localhost:9999/tasks/send", help="Forge endpoint to POST to") + parser.add_argument("--body", default='{"task":"hello"}', help="JSON body to send to Forge") + parser.add_argument("--profile", default=None, help="AWS profile to use (default: boto3's default)") + parser.add_argument("-v", "--verbose", action="store_true", help="Show signed headers in output") + args = parser.parse_args() + + session = Session(profile=args.profile) if args.profile else Session() + credentials = session.get_credentials() + if credentials is None: + print("No AWS credentials found. Try `aws sso login` or `aws configure`.", file=sys.stderr) + return 2 + frozen = credentials.get_frozen_credentials() + + # 1-2. Sign for STS's host (NOT Forge's host) + sts_url = f"https://sts.{args.region}.amazonaws.com/" + sts_body = "Action=GetCallerIdentity&Version=2011-06-15" + sts_req = AWSRequest( + method="POST", + url=sts_url, + data=sts_body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + SigV4Auth(frozen, "sts", args.region).add_auth(sts_req) + + # 3. Reuse the signed headers in a request to Forge + forge_headers = { + "Authorization": sts_req.headers["Authorization"], + "X-Amz-Date": sts_req.headers["X-Amz-Date"], + "Content-Type": "application/json", + } + if "X-Amz-Security-Token" in sts_req.headers: + forge_headers["X-Amz-Security-Token"] = sts_req.headers["X-Amz-Security-Token"] + + if args.verbose: + print(f"POST {args.url}") + print(f" Authorization: {forge_headers['Authorization'][:80]}...") + print(f" X-Amz-Date: {forge_headers['X-Amz-Date']}") + has_tok = "X-Amz-Security-Token" in forge_headers + print(f" X-Amz-Security-Token: {'' if has_tok else ''}") + print() + + # 4. Send. Forge will reflect the headers to STS. + resp = requests.post(args.url, headers=forge_headers, data=args.body) + print(f"HTTP {resp.status_code}") + print(resp.text) + return 0 if 200 <= resp.status_code < 300 else 1 + + +if __name__ == "__main__": + sys.exit(main()) From b3444c2770ee3a9e482030ce33c3ec4084af5c4f Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 21:29:24 -0400 Subject: [PATCH 11/22] refactor(auth/aws_sigv4): switch to pre-signed URL pattern (aws-iam-authenticator) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 PR 2's original "reflect Sigv4 headers" design was broken in the obvious way: Sigv4 binds its signature to the destination host as part of the canonicalized signing input. Headers signed for Forge's host could not be replayed against STS — STS sees host:sts.. amazonaws.com, recomputes the signature, gets a different hash, rejects with "SignatureDoesNotMatch". Caught during real-AWS smoke; documented in PR #79 description. This commit replaces the pattern with the same approach aws-iam-authenticator uses for EKS: Client (3 lines): url = boto3.client('sts').generate_presigned_url('get_caller_identity', ExpiresIn=900) token = 'forge-aws-v1.' + base64.urlsafe_b64encode(url.encode()).rstrip(b'=').decode() requests.post(forge_url, headers={'Authorization': f'Bearer {token}'}, ...) Server: Authorization: Bearer forge-aws-v1. → decode + validate host (SSRF guard) + GET on the URL → STS → identity Net effect on caller experience: identical to JWT/OIDC/azure_ad — "mint token, send Bearer, done." Three lines of client code, hidden in ~15 lines of any AWS SDK in any language. what changed: forge-core/auth/providers/aws_sigv4/ sigv4_parser.go — was parsing AWS4-HMAC-SHA256 Authorization header now parses forge-aws-v1. Bearer tokens (URL host validation, SSRF guard, X-Amz-Credential parsing for cache key derivation) sts_client.go — was POST with reflected headers now GET on the pre-signed URL; same 200/4xx/5xx classification and 64 KiB body cap provider.go — Verify() now reads the Bearer token (not raw headers); SSRF guard via expectedHost field; same cache + ARN allowlist semantics forge-core/auth/ provider.go — HeadersFromRequest reverts X-Amz-Date and X-Amz-Security-Token (no longer needed); keeps X-Goog-Iap-Jwt-Assertion for gcp_iap provider.go — TokenKind detects "forge-aws-v1." prefix → "sigv4" (was: "AWS4-HMAC-SHA256 " on raw Authorization) middleware.go — simplify: empty-Bearer fallback only handles IAP (aws_sigv4 rides standard Bearer flow now) scripts/forge-aws-sign.py — rewrite as a clean reference client. --token-only: print just the token for use with curl/other tools Otherwise: do the round-trip POST and print the response CHANGELOG.md — replace "client wrapper required" friction note with the 3-line happy path snippet what stays unchanged: - forge.yaml shape (still type: aws_sigv4, region:, allowed_principals:) - identity_cache.go, arn_matcher.go (cache and authz logic untouched) - security.AuthDomains (sts..amazonaws.com derivation) - forge-cli/cmd/init* flag set and renderer - validate.ValidateAuthConfig (region still required) - forge-ui/handlers_create.go (AuthProviderTypeMeta entry) Tests: 42 packages pass, golangci-lint v2.10.1 clean, gofmt clean, no aws-sdk-go imports (decision §9.1 still holds). Net diff: +732 / -625 lines (mostly test rewrites; ~80 LOC net less in the provider package because the new flow is structurally simpler). --- CHANGELOG.md | 39 +-- forge-core/auth/middleware.go | 30 +- forge-core/auth/middleware_test.go | 29 +- forge-core/auth/provider.go | 33 +-- forge-core/auth/provider_test.go | 42 +-- .../auth/providers/aws_sigv4/provider.go | 271 +++++++++++------- .../auth/providers/aws_sigv4/provider_test.go | 211 ++++++++------ .../auth/providers/aws_sigv4/sigv4_parser.go | 181 +++++++----- .../providers/aws_sigv4/sigv4_parser_test.go | 201 +++++++++---- .../auth/providers/aws_sigv4/sts_client.go | 78 ++--- .../providers/aws_sigv4/sts_client_test.go | 119 +++----- scripts/forge-aws-sign.py | 123 ++++---- 12 files changed, 732 insertions(+), 625 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be24183..fd71e56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,12 +28,11 @@ extracted**, so non-Bearer formats (Sigv4 `Authorization`, IAP `X-Goog-Iap-Jwt-Assertion`) can be recognized. Existing Bearer + JWT flows are unchanged. -- `auth.HeadersFromRequest` widened with `Authorization`, - `X-Goog-Iap-Jwt-Assertion`, `X-Amz-Date`, and `X-Amz-Security-Token`. - Providers that don't consume these keys are unaffected. -- `auth.TokenKind` recognizes the `AWS4-HMAC-SHA256` prefix and returns - `"sigv4"`. The audit `token_kind` field now has five possible values: - `empty`, `opaque`, `jwt`, `sigv4`, `iap_jwt`. +- `auth.HeadersFromRequest` widened with `X-Goog-Iap-Jwt-Assertion` + for `gcp_iap`. Providers that don't consume this header are unaffected. +- `auth.TokenKind` recognizes the `forge-aws-v1.` Bearer prefix and + returns `"sigv4"`. The audit `token_kind` field now has five possible + values: `empty`, `opaque`, `jwt`, `sigv4`, `iap_jwt`. - `validate.ValidateAuthConfig` admits the three new provider types and enforces their per-type required keys (`aws_sigv4.region`, `gcp_iap.audience`, `azure_ad.audience`, `azure_ad.tenant_id`-unless- @@ -51,19 +50,23 @@ from Go callers (currently only `azure_ad` multi-tenant). Operators see no change. -### Client-side requirement for `aws_sigv4` +### Client experience for `aws_sigv4` -- AWS Sigv4 binds its signature to the *destination host*. As a consequence, - callers cannot just point `awscurl` at Forge and have it work — the - signature would be computed for Forge's hostname, not STS's, and STS - would reject the reflected request. -- Callers must use a thin wrapper that signs a hypothetical STS - GetCallerIdentity request, then attaches the resulting headers to the - POST that goes to Forge. This is the same pattern `aws-iam-authenticator` - uses for EKS. -- A reference implementation in ~30 lines of `boto3` ships at - `scripts/forge-aws-sign.py`. Use it as-is or as a template for your - CI / Lambda / Go inter-agent integrations. +The client side is a Bearer token with a 3-line mint: + +```python +import boto3, base64 +url = boto3.client('sts', region_name='us-east-1').generate_presigned_url( + 'get_caller_identity', ExpiresIn=900) +token = 'forge-aws-v1.' + base64.urlsafe_b64encode(url.encode()).rstrip(b'=').decode() + +requests.post(forge_url, headers={'Authorization': f'Bearer {token}'}, data=msg) +``` + +Pattern is identical to `aws-iam-authenticator` for EKS. Reference client +in `scripts/forge-aws-sign.py` — use it directly or as a template for +Go / Java / Node clients. Wire format is documented in the package +docstring of `forge-core/auth/providers/aws_sigv4/provider.go`. ### Known deferred work diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index fd66484..6dae9a4 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -109,31 +109,17 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { token := extractBearerToken(r) kind := TokenKind(token) - // Phase 2: when no Bearer token was extracted, classify the - // audit kind from the non-Bearer auth headers so token_kind in - // the audit log differentiates Sigv4 / IAP / no-auth requests - // instead of all reporting "empty". - rawAuth := r.Header.Get("Authorization") + // Phase 2: gcp_iap doesn't use a Bearer token — it reads + // X-Goog-Iap-Jwt-Assertion. Surface that in the audit kind + // and let the chain run even on empty Bearer when IAP is + // the format in play. aws_sigv4 (Phase 2 pre-signed URL + // pattern) DOES use a Bearer token, so no special-case here. iapHeader := r.Header.Get("X-Goog-Iap-Jwt-Assertion") - if kind == "empty" { - switch { - case strings.HasPrefix(rawAuth, "AWS4-HMAC-SHA256 "): - kind = "sigv4" - case iapHeader != "": - kind = "iap_jwt" - } + if kind == "empty" && iapHeader != "" { + kind = "iap_jwt" } + hasNonBearerAuth := token == "" && iapHeader != "" - // hasNonBearerAuth signals "the caller IS attempting auth, just - // not via Bearer." Used to differentiate "missing_token" (no - // auth headers at all) from "chain rejected/didn't match" in - // the audit reason. - hasNonBearerAuth := token == "" && (rawAuth != "" || iapHeader != "") - - // Phase 2 change: consult the chain even on empty Bearer, so - // providers that speak non-Bearer formats (aws_sigv4, gcp_iap) - // get a chance. Phase 1 short-circuit (empty bearer → 401) - // preserved only when there are no auth-shaped headers at all. if token == "" && !hasNonBearerAuth { notifyAuth(opts.OnAuth, r, nil, ErrMissingBearer, kind) writeAuthError(w, "valid bearer token required") diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index 06596e0..be0ce6c 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -463,10 +463,10 @@ func (p *headerCapturingProvider) Verify(_ context.Context, token string, h Head return p.identity, p.err } -func TestMiddleware_Sigv4HeaderReachesChain(t *testing.T) { - // Phase 2 change: a Sigv4-shaped Authorization reaches the chain even - // though extractBearerToken returns "". The middleware no longer - // short-circuits to 401 when there's an auth-shaped header present. +func TestMiddleware_Sigv4BearerReachesChain(t *testing.T) { + // Phase 2: aws_sigv4 uses the pre-signed URL Bearer-token pattern. + // The chain receives the Bearer token directly via the standard path + // — no non-Bearer-header handling needed. spy := &headerCapturingProvider{err: ErrTokenNotForMe} opts := MiddlewareOptions{ Chain: NewChainProvider(spy), @@ -477,15 +477,11 @@ func TestMiddleware_Sigv4HeaderReachesChain(t *testing.T) { })) req := httptest.NewRequest("POST", "/tasks", nil) - req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host, Signature=ab") + req.Header.Set("Authorization", "Bearer forge-aws-v1.aHR0cHM6Ly9zdHM") handler.ServeHTTP(httptest.NewRecorder(), req) - if spy.sawHeaders == nil { - t.Fatal("chain was never invoked — middleware short-circuited on empty Bearer") - } - got := spy.sawHeaders.Get("Authorization") - if got == "" || !startsWithLocal(got, "AWS4-HMAC-SHA256 ") { - t.Errorf("chain saw Authorization = %q, want Sigv4-prefixed value", got) + if spy.sawToken == "" || !startsWithLocal(spy.sawToken, "forge-aws-v1.") { + t.Errorf("chain saw token = %q, want forge-aws-v1 prefixed", spy.sawToken) } } @@ -554,11 +550,10 @@ func TestMiddleware_NoAuthHeaders_PreservesMissingTokenReason(t *testing.T) { } } -func TestMiddleware_TokenKind_Sigv4OnEmptyBearer(t *testing.T) { - // Phase 2: the audit token_kind reads "sigv4" when the request has a - // Sigv4-shaped Authorization header, even though Bearer extraction - // returned "". Audit dashboards need this signal to count Sigv4 - // requests distinctly from "empty" (no auth attempt at all). +func TestMiddleware_TokenKind_Sigv4FromForgeAwsToken(t *testing.T) { + // Phase 2: aws_sigv4 uses the "forge-aws-v1." Bearer token + // pattern. TokenKind on that prefix returns "sigv4" so audit dashboards + // can count Sigv4 traffic distinctly from generic Bearer. var gotKind string opts := MiddlewareOptions{ Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), @@ -572,7 +567,7 @@ func TestMiddleware_TokenKind_Sigv4OnEmptyBearer(t *testing.T) { })) req := httptest.NewRequest("POST", "/tasks", nil) - req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA, SignedHeaders=host, Signature=ab") + req.Header.Set("Authorization", "Bearer forge-aws-v1.aHR0cHM6Ly9zdHMudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20") handler.ServeHTTP(httptest.NewRecorder(), req) if gotKind != "sigv4" { diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go index 24ce9ae..f394ff4 100644 --- a/forge-core/auth/provider.go +++ b/forge-core/auth/provider.go @@ -90,25 +90,17 @@ func (h Headers) Get(key string) string { // HeadersFromRequest extracts the well-known headers providers may use. // Keep this list narrow — providers should be explicit about the contract. // -// Authorization is included verbatim (in addition to the middleware's Bearer -// extraction) so providers that speak non-Bearer formats can read it — e.g. -// AWS Sigv4 (Authorization: AWS4-HMAC-SHA256 …). Adding a new header here is -// a deliberate widening of the contract; providers must not assume any header -// is present. -// -// Phase 2 additions (Authorization, X-Goog-Iap-Jwt-Assertion, X-Amz-Date, -// X-Amz-Security-Token) are consumed by aws_sigv4 and gcp_iap providers. +// X-Goog-Iap-Jwt-Assertion is included for gcp_iap, which doesn't use a +// Bearer token. All other Phase 2 providers (aws_sigv4 with the pre-signed +// URL pattern, azure_ad) ride the standard Bearer path and don't need +// extra header surface here. func HeadersFromRequest(r *http.Request) Headers { return Headers{ - "X-Org-ID": r.Header.Get("X-Org-ID"), - "X-Request-ID": r.Header.Get("X-Request-ID"), - "org-id": r.Header.Get("org-id"), - "org_id": r.Header.Get("org_id"), - - "Authorization": r.Header.Get("Authorization"), + "X-Org-ID": r.Header.Get("X-Org-ID"), + "X-Request-ID": r.Header.Get("X-Request-ID"), + "org-id": r.Header.Get("org-id"), + "org_id": r.Header.Get("org_id"), "X-Goog-Iap-Jwt-Assertion": r.Header.Get("X-Goog-Iap-Jwt-Assertion"), - "X-Amz-Date": r.Header.Get("X-Amz-Date"), - "X-Amz-Security-Token": r.Header.Get("X-Amz-Security-Token"), } } @@ -116,13 +108,12 @@ func HeadersFromRequest(r *http.Request) Headers { // audit logging without leaking the token itself. // // "empty" → empty token -// "sigv4" → starts with "AWS4-HMAC-SHA256 " (the AWS Sigv4 algorithm prefix). +// "sigv4" → forge-aws-v1. (AWS Sigv4 via pre-signed URL pattern; // -// Callers route the raw Authorization header through TokenKind for -// this case, since extractBearerToken returns "" for non-Bearer auth. +// the magic prefix mirrors aws-iam-authenticator's "k8s-aws-v1.") // // "jwt" → three base64url segments separated by dots -// "opaque" → anything else (Okta access tokens, custom verifier tokens, dev secrets) +// "opaque" → anything else (custom verifier tokens, dev secrets, etc.) // // This is a CHEAP structural check — it does not parse or validate. // Never log the token; this helper is safe to log. @@ -130,7 +121,7 @@ func TokenKind(token string) string { if token == "" { return "empty" } - if strings.HasPrefix(token, "AWS4-HMAC-SHA256 ") { + if strings.HasPrefix(token, "forge-aws-v1.") { return "sigv4" } dots := 0 diff --git a/forge-core/auth/provider_test.go b/forge-core/auth/provider_test.go index 1f612a8..a4958e6 100644 --- a/forge-core/auth/provider_test.go +++ b/forge-core/auth/provider_test.go @@ -465,14 +465,15 @@ func TestTokenKind(t *testing.T) { {"opaque one dot", "abc.def", "opaque"}, {"opaque four segments", "a.b.c.d", "opaque"}, {"jwt with empty segments still has 2 dots", "..", "jwt"}, - // Phase 2: aws_sigv4 — the Authorization header's algorithm prefix - // is enough to classify; middleware routes the raw header here when - // the Bearer extractor returns "". - {"sigv4 prefix", "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=abc", "sigv4"}, - {"sigv4 prefix only", "AWS4-HMAC-SHA256 ", "sigv4"}, - // Defensive: a token that LOOKS like Sigv4 but lacks the trailing - // space must NOT be classified as sigv4 — we want a clean prefix match. - {"sigv4-like without trailing space", "AWS4-HMAC-SHA256xyz", "opaque"}, + // Phase 2: aws_sigv4 uses the pre-signed URL pattern, encoded + // as a Bearer token with the "forge-aws-v1." prefix (mirrors + // aws-iam-authenticator's "k8s-aws-v1." convention). + {"sigv4 forge-aws-v1 token", "forge-aws-v1.aHR0cHM6Ly9zdHMudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20", "sigv4"}, + {"sigv4 prefix only", "forge-aws-v1.", "sigv4"}, + // Defensive: a token that looks similar but with the wrong + // version suffix or missing the period must NOT be sigv4. + {"wrong version prefix", "forge-aws-v2.something", "opaque"}, + {"prefix missing trailing dot", "forge-aws-v1xyz", "opaque"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -508,29 +509,18 @@ func TestHeadersFromRequest_Phase1Headers(t *testing.T) { } func TestHeadersFromRequest_Phase2Headers(t *testing.T) { - // Phase 2: HeadersFromRequest widened so providers consuming non-Bearer - // formats (aws_sigv4, gcp_iap) can read what they need without breaking - // the Provider.Verify signature. See PHASE2_TEST_STRATEGY.md §4 PR1. + // Phase 2: HeadersFromRequest only widens for gcp_iap. aws_sigv4 + // rides the Bearer path (pre-signed URL pattern); azure_ad rides + // the Bearer path (standard JWT). The widened header surface is + // intentionally narrow. req, _ := http.NewRequest("POST", "/", nil) - req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host, Signature=ab") req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJabc.eyJdef.sig") - req.Header.Set("X-Amz-Date", "20260523T120000Z") - req.Header.Set("X-Amz-Security-Token", "FwoGZX...") h := HeadersFromRequest(req) - if got := h.Get("Authorization"); got == "" || !startsWith(got, "AWS4-HMAC-SHA256 ") { - t.Errorf("Authorization = %q, want Sigv4-shaped string", got) - } if got := h.Get("X-Goog-Iap-Jwt-Assertion"); got != "eyJabc.eyJdef.sig" { t.Errorf("X-Goog-Iap-Jwt-Assertion = %q, want eyJabc.eyJdef.sig", got) } - if got := h.Get("X-Amz-Date"); got != "20260523T120000Z" { - t.Errorf("X-Amz-Date = %q, want 20260523T120000Z", got) - } - if got := h.Get("X-Amz-Security-Token"); got == "" { - t.Error("X-Amz-Security-Token not extracted") - } } func TestHeadersFromRequest_AbsentHeadersAreEmpty(t *testing.T) { @@ -539,7 +529,7 @@ func TestHeadersFromRequest_AbsentHeadersAreEmpty(t *testing.T) { req, _ := http.NewRequest("POST", "/", nil) h := HeadersFromRequest(req) for _, k := range []string{ - "Authorization", "X-Goog-Iap-Jwt-Assertion", "X-Amz-Date", "X-Amz-Security-Token", + "X-Goog-Iap-Jwt-Assertion", } { if got := h.Get(k); got != "" { t.Errorf("%s should be empty on request with no headers, got %q", k, got) @@ -547,10 +537,6 @@ func TestHeadersFromRequest_AbsentHeadersAreEmpty(t *testing.T) { } } -func startsWith(s, prefix string) bool { - return len(s) >= len(prefix) && s[:len(prefix)] == prefix -} - // --- helpers --- func providerNames(ps []Provider) []string { diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go index 93c4a8b..1683255 100644 --- a/forge-core/auth/providers/aws_sigv4/provider.go +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -1,46 +1,70 @@ -// Package aws_sigv4 authenticates requests signed with AWS Sigv4 by -// reflecting the signature to AWS STS GetCallerIdentity. Same pattern as -// aws-iam-authenticator (the EKS auth bridge). -// -// Forge never has the caller's secret key — STS validates the signature -// on Forge's behalf and returns the canonical ARN/Account/UserID. Verified -// identities are cached by hash(AKID|YYYYMMDD) for IdentityCacheTTL so the -// hot path doesn't bounce through STS on every call. -// -// # Client-side signing contract (READ THIS BEFORE INTEGRATING) -// -// Sigv4 binds the signature to the destination host as part of its canonical -// headers. Forge does not — and cannot — relax that. As a consequence, -// callers MUST sign their request as if they were going to STS directly, -// then attach the resulting headers to a request that's sent to Forge: -// -// 1. Construct an STS request shape (POST https://sts..amazonaws.com/, -// body Action=GetCallerIdentity&Version=2011-06-15) — but do not send it. -// 2. Sign that hypothetical request with the caller's AWS credentials, using -// service=sts and region=. -// 3. Take the resulting Authorization + X-Amz-Date (+ X-Amz-Security-Token -// for temporary credentials, e.g. SSO / IRSA / Lambda) headers and attach -// them to the REAL POST that goes to Forge's /tasks/send endpoint. -// 4. Forge will forward those headers verbatim to STS; STS will validate -// them against its own host and return the caller's canonical ARN. -// -// Standard tools (awscurl, the AWS CLI, boto3.client('sts')) do NOT do this -// automatically — they always sign for the URL they are addressing. A small -// wrapper is required. See scripts/forge-aws-sign.py for the canonical -// reference implementation in ~30 lines of boto3. -// -// Decision §9.1: no aws-sdk-go-v2 dependency — the STS RPC is ~150 LOC of -// hand-rolled HTTP + XML. Trade-off: small attack surface, predictable -// behavior, no transitive deps; cost: we maintain the RPC ourselves. -// -// Decision §9.3: allowed_principals are shell-style globs (path.Match). -// -// Audit reason codes (Phase 1 contract): -// -// rejected — scope check, ARN allowlist miss, STS 4xx +// Package aws_sigv4 authenticates AWS-IAM callers using the pre-signed +// URL pattern. Same approach as aws-iam-authenticator (the EKS auth bridge): +// the caller uses their AWS SDK to compute a pre-signed STS +// GetCallerIdentity URL, wraps it as a Bearer token, and sends it to +// Forge. Forge invokes the URL on STS, which validates the signature +// (it was signed for STS's host) and returns the canonical +// ARN / Account / UserID. Forge stamps an Identity from that response. +// +// Forge never possesses the caller's AWS secret key. The cryptographic +// work happens once on the caller side via standard SDK calls. +// +// # Client-side contract (3 lines) +// +// # Python / boto3 +// import boto3, base64 +// url = boto3.client('sts', region_name='us-east-1').generate_presigned_url( +// 'get_caller_identity', ExpiresIn=900) +// token = 'forge-aws-v1.' + base64.urlsafe_b64encode(url.encode()).rstrip(b'=').decode() +// requests.post(forge_url, headers={'Authorization': f'Bearer {token}'}, data=msg) +// +// Reference client in scripts/forge-aws-sign.py. +// +// # Wire format +// +// Authorization: Bearer forge-aws-v1. +// +// The base64-decoded payload is a complete pre-signed URL of the form: +// +// https://sts..amazonaws.com/ +// ?Action=GetCallerIdentity +// &Version=2011-06-15 +// &X-Amz-Algorithm=AWS4-HMAC-SHA256 +// &X-Amz-Credential=///sts/aws4_request +// &X-Amz-Date= +// &X-Amz-Expires= +// &X-Amz-SignedHeaders=host +// &X-Amz-Signature= +// +// # SSRF guard +// +// Before invoking the URL, Forge validates the host matches +// sts..amazonaws.com exactly. A token whose URL +// points anywhere else is rejected — the token must not be usable to +// coerce Forge into calling an arbitrary internal endpoint. +// +// # Caching +// +// Verified identities are cached for IdentityCacheTTL keyed on +// hash(AKID, YYYYMMDD), extracted from the token's X-Amz-Credential. +// Rotating AKID or rolling past midnight UTC invalidates the bucket. +// Errors are never cached. +// +// # Decisions +// +// §9.1 — no aws-sdk-go-v2 dependency. The STS RPC is hand-rolled HTTP + +// XML; trade-off is smaller attack surface and no transitive deps. +// +// §9.3 — allowed_principals are shell-style globs (path.Match). +// +// # Audit reason codes (Phase 1 contract) +// +// rejected — STS 4xx (expired/bad sig), ARN allowlist miss, +// URL host mismatch, region scope mismatch // provider_unavailable — STS 5xx, network failure, parse failure -// invalid — Sigv4 header malformed -// not_for_me — Authorization header didn't start with AWS4-HMAC-SHA256 +// invalid — token format malformed, base64 fails, URL fails, +// missing required query params +// not_for_me — token didn't start with "forge-aws-v1." package aws_sigv4 import ( @@ -48,7 +72,6 @@ import ( "crypto/sha256" "encoding/hex" "fmt" - "strings" "time" "github.com/initializ/forge/forge-core/auth" @@ -66,33 +89,31 @@ const ( // Config controls the aws_sigv4 provider. type Config struct { // Region is the AWS region whose STS endpoint validates signatures. - // REQUIRED. Defense in depth: callers must sign for this exact region - // (Sigv4 cross-region replay rejected at parse-time scope check). + // REQUIRED. The pre-signed URL's host MUST match + // sts..amazonaws.com exactly. Region string `yaml:"region"` - // Audience is informational only — emitted in audit logs for context. - // STS itself doesn't enforce this. + // Audience is informational only — emitted in the audit log's Claims + // payload. STS itself doesn't enforce it. Audience string `yaml:"audience,omitempty"` - // AllowedPrincipals is an optional list of shell-style globs (decision - // §9.3) matched against the STS-returned ARN. Empty list means "allow - // any IAM principal in the account that signed the request" — fine - // for single-tenant dev setups, never appropriate for prod. + // AllowedPrincipals is an optional list of shell-style globs (§9.3) + // matched against the STS-returned ARN. Empty list means "allow any + // IAM principal that has a valid AWS key" — fine for single-tenant + // dev, never appropriate for production. // - // Patterns match the assumed-role ARN form + // Patterns match the STS assumed-role ARN form // ("arn:aws:sts::ACCOUNT:assumed-role/RoleName/SessionName"), NOT the - // role's own ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). PR6 docs - // call this out explicitly with a worked example. + // IAM role ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). AllowedPrincipals []string `yaml:"allowed_principals,omitempty"` // IdentityCacheTTL bounds how long a verified Identity is reused - // without re-checking with STS. Defaults to 60s. Sets the upper - // bound on the window in which a revoked AKID remains accepted. + // without re-checking with STS. Defaults to 60s. IdentityCacheTTL time.Duration `yaml:"identity_cache_ttl,omitempty"` - // STSEndpoint is a TEST-ONLY override that points STS calls at an - // alternate URL. Production should leave this empty so the region - // derives the canonical sts..amazonaws.com host. + // STSEndpoint is a TEST-ONLY override that changes the expected + // pre-signed URL host (and relaxes the https requirement). Production + // should leave this empty. STSEndpoint string `yaml:"sts_endpoint,omitempty"` // HTTPTimeout caps each STS call. Defaults to 5s. @@ -107,13 +128,14 @@ func (c Config) Validate() error { return nil } -// Provider implements auth.Provider for AWS Sigv4 callers. +// Provider implements auth.Provider for AWS-IAM callers. type Provider struct { - cfg Config - parser Parser - cache *IdentityCache - sts *STSClient - matcher *ArnMatcher + cfg Config + expectedHost string // computed once: sts..amazonaws.com or test override + requireHTTPS bool // false only when STSEndpoint test override is in use + cache *IdentityCache + sts *STSClient + matcher *ArnMatcher } // New constructs a Provider after validating cfg. @@ -131,12 +153,27 @@ func New(cfg Config) (*Provider, error) { if err != nil { return nil, fmt.Errorf("aws_sigv4: allowed_principals: %w", err) } + + expectedHost := fmt.Sprintf("sts.%s.amazonaws.com", cfg.Region) + requireHTTPS := true + if cfg.STSEndpoint != "" { + h, scheme := hostAndSchemeOf(cfg.STSEndpoint) + if h != "" { + expectedHost = h + } + // Test override: allow plain http for httptest servers. + if scheme == "http" { + requireHTTPS = false + } + } + return &Provider{ - cfg: cfg, - parser: Parser{}, - cache: NewIdentityCache(cfg.IdentityCacheTTL), - sts: NewSTSClient(cfg.Region, cfg.STSEndpoint, cfg.HTTPTimeout), - matcher: matcher, + cfg: cfg, + expectedHost: expectedHost, + requireHTTPS: requireHTTPS, + cache: NewIdentityCache(cfg.IdentityCacheTTL), + sts: NewSTSClient(cfg.Region, "", cfg.HTTPTimeout), + matcher: matcher, }, nil } @@ -145,44 +182,37 @@ func (p *Provider) Name() string { return ProviderName } // Verify implements auth.Provider. // -// Token is unused — Sigv4 carries everything it needs in the Authorization -// header (and X-Amz-Date, X-Amz-Security-Token), which the middleware -// hands us via the headers map. -func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) (*auth.Identity, error) { - authHeader := headers.Get("Authorization") - if !strings.HasPrefix(authHeader, sigv4Algorithm+" ") { - return nil, auth.ErrTokenNotForMe - } - - sig, err := p.parser.Parse(authHeader) +// The middleware extracts the Bearer token and passes it here. If the +// token doesn't start with "forge-aws-v1." we yield to the next chain +// entry. Otherwise we validate the embedded pre-signed URL, GET it on +// STS, and stamp an Identity from the returned ARN. +func (p *Provider) Verify(ctx context.Context, token string, _ auth.Headers) (*auth.Identity, error) { + parsed, err := ParseToken(token, p.expectedHost, p.requireHTTPS) if err != nil { + // Only the prefix check distinguishes "this isn't my token" + // from "this IS my token but malformed." Use a sentinel so the + // chain can fall through on the former but stop on the latter. + if err.Error() == "missing forge-aws-v1 prefix" { + return nil, auth.ErrTokenNotForMe + } return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) } - // Scope check (defense in depth §3.1): reject signatures meant for a - // different AWS service or region. Prevents cross-service Sigv4 - // replay (e.g. a request signed for S3 reaching us configured for STS). - if sig.Service != "sts" { - return nil, fmt.Errorf("%w: signature scoped to service %q, expected sts", auth.ErrTokenRejected, sig.Service) - } - if sig.Region != p.cfg.Region { - return nil, fmt.Errorf("%w: signature scoped to region %q, expected %s", auth.ErrTokenRejected, sig.Region, p.cfg.Region) + // Region in the credential scope must match our configured region. + // Defends against cross-region replay (e.g. a token pre-signed for + // us-east-1 hitting a Forge instance configured for eu-west-1). + if parsed.Region != p.cfg.Region { + return nil, fmt.Errorf("%w: token region %q != configured %s", auth.ErrTokenRejected, parsed.Region, p.cfg.Region) } - cacheKey := dateBucketKey(sig.AKID, sig.Date) + cacheKey := dateBucketKey(parsed.AKID, parsed.Date) if id, ok := p.cache.Get(cacheKey); ok { return id, nil } - caller, err := p.sts.GetCallerIdentity(ctx, STSReflectArgs{ - AuthHeader: authHeader, - AmzDate: headers.Get("X-Amz-Date"), - SecurityToken: headers.Get("X-Amz-Security-Token"), - }) + caller, err := p.sts.GetCallerIdentity(ctx, parsed.URL.String()) if err != nil { - // STSClient already wraps with the right sentinel - // (ErrTokenRejected for 4xx, ErrProviderUnavailable for 5xx/network). - return nil, err + return nil, err // STSClient already wraps with the right sentinel } if !p.matcher.Match(caller.Arn) { @@ -204,12 +234,12 @@ func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) ( return id, nil } -// dateBucketKey hashes (AKID, YYYYMMDD) so two requests from the same AKID -// on the same day collapse to one STS call per IdentityCacheTTL window. -// Hashing protects the cache against length-leak / log-scan attacks on the -// raw key — same reasoning as static_token (review #11). -func dateBucketKey(akid, sigv4Date string) string { - bucket := sigv4Date +// dateBucketKey hashes (AKID, YYYYMMDD) so two requests from the same +// AKID on the same day collapse to a single STS call per +// IdentityCacheTTL window. Hashing protects against length-leak / log-scan +// reads of the cache key. +func dateBucketKey(akid, date string) string { + bucket := date if len(bucket) > 8 { bucket = bucket[:8] // YYYYMMDD } @@ -217,6 +247,41 @@ func dateBucketKey(akid, sigv4Date string) string { return hex.EncodeToString(sum[:]) } +// hostAndSchemeOf is a forgiving parser used only at Factory time for the +// STSEndpoint test override. Returns (host, scheme) or empty strings on +// parse failure. +func hostAndSchemeOf(raw string) (host, scheme string) { + // Accept both bare "host:port" and full "scheme://host:port/path" forms. + // For the test override we only care about the host portion (for + // matching) and whether scheme is http (so we can relax the https check). + if i := indexOf(raw, "://"); i >= 0 { + scheme = raw[:i] + rest := raw[i+3:] + for k := 0; k < len(rest); k++ { + if rest[k] == '/' { + return rest[:k], scheme + } + } + return rest, scheme + } + // No scheme: assume host:port form + for k := 0; k < len(raw); k++ { + if raw[k] == '/' { + return raw[:k], "" + } + } + return raw, "" +} + +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} + func init() { auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { var cfg Config diff --git a/forge-core/auth/providers/aws_sigv4/provider_test.go b/forge-core/auth/providers/aws_sigv4/provider_test.go index 5d524b9..5d8767f 100644 --- a/forge-core/auth/providers/aws_sigv4/provider_test.go +++ b/forge-core/auth/providers/aws_sigv4/provider_test.go @@ -2,10 +2,13 @@ package aws_sigv4 import ( "context" + "encoding/base64" "errors" + "fmt" "io" "net/http" "net/http/httptest" + "net/url" "strings" "sync/atomic" "testing" @@ -14,16 +17,28 @@ import ( "github.com/initializ/forge/forge-core/auth" ) -const validSigv4 = "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ab12cd34" +// tokenFor builds a forge-aws-v1 token whose embedded URL points at the +// test STS server. AKID + date + region + signature are placeholders; +// the fake STS doesn't validate them. +func tokenFor(stsURL, akid, dateYYYYMMDD, region string) string { + q := url.Values{} + q.Set("Action", "GetCallerIdentity") + q.Set("Version", "2011-06-15") + q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") + q.Set("X-Amz-Credential", fmt.Sprintf("%s/%s/%s/sts/aws4_request", akid, dateYYYYMMDD, region)) + q.Set("X-Amz-Date", dateYYYYMMDD+"T010000Z") + q.Set("X-Amz-Expires", "900") + q.Set("X-Amz-SignedHeaders", "host") + q.Set("X-Amz-Signature", "fakesig"+akid) + full := stsURL + "/?" + q.Encode() + return TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(full)) +} -func validHeaders() auth.Headers { - return auth.Headers{ - "Authorization": validSigv4, - "X-Amz-Date": "20260523T120000Z", - } +func defaultToken(stsURL string) string { + return tokenFor(stsURL, "AKIAIOSFODNN7EXAMPLE", "20260524", "us-east-1") } -func newTestProvider(t *testing.T, sts http.Handler, opts ...func(*Config)) *Provider { +func newTestProvider(t *testing.T, sts http.Handler, opts ...func(*Config)) (*Provider, string) { t.Helper() srv := httptest.NewServer(sts) t.Cleanup(srv.Close) @@ -42,7 +57,7 @@ func newTestProvider(t *testing.T, sts http.Handler, opts ...func(*Config)) *Pro if err != nil { t.Fatalf("New: %v", err) } - return p + return p, srv.URL } func happySTS() http.Handler { @@ -52,7 +67,7 @@ func happySTS() http.Handler { } func TestProvider_Name(t *testing.T) { - p := newTestProvider(t, happySTS()) + p, _ := newTestProvider(t, happySTS()) if p.Name() != "aws_sigv4" { t.Errorf("Name = %q, want aws_sigv4", p.Name()) } @@ -72,40 +87,38 @@ func TestProvider_New_RejectsInvalidGlob(t *testing.T) { } } -func TestProvider_NoSigv4Header_YieldsToChain(t *testing.T) { - p := newTestProvider(t, happySTS()) - _, err := p.Verify(context.Background(), "", auth.Headers{"Authorization": "Bearer xyz"}) +func TestProvider_NoPrefix_YieldsToChain(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "Bearer some.opaque.token", nil) if !errors.Is(err, auth.ErrTokenNotForMe) { t.Errorf("err = %v, want ErrTokenNotForMe", err) } } -func TestProvider_NoAuthHeaderAtAll_YieldsToChain(t *testing.T) { - p := newTestProvider(t, happySTS()) - _, err := p.Verify(context.Background(), "", auth.Headers{}) +func TestProvider_EmptyToken_YieldsToChain(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "", nil) if !errors.Is(err, auth.ErrTokenNotForMe) { t.Errorf("err = %v, want ErrTokenNotForMe", err) } } -func TestProvider_MalformedSigv4_Invalid(t *testing.T) { - p := newTestProvider(t, happySTS()) - _, err := p.Verify(context.Background(), "", auth.Headers{ - "Authorization": "AWS4-HMAC-SHA256 malformed", - }) +func TestProvider_MalformedToken_Invalid(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), TokenPrefix+"!!!not-base64!!!", nil) if !errors.Is(err, auth.ErrInvalidToken) { t.Errorf("err = %v, want ErrInvalidToken", err) } } func TestProvider_HappyPath_ReturnsIdentity(t *testing.T) { - p := newTestProvider(t, happySTS()) - id, err := p.Verify(context.Background(), "", validHeaders()) + p, stsURL := newTestProvider(t, happySTS()) + id, err := p.Verify(context.Background(), defaultToken(stsURL), nil) if err != nil { t.Fatalf("Verify: %v", err) } if id.Source != "aws_sigv4" { - t.Errorf("Source = %q, want aws_sigv4", id.Source) + t.Errorf("Source = %q", id.Source) } if id.UserID != "arn:aws:sts::123456789012:assumed-role/ci-deploy/session" { t.Errorf("UserID = %q", id.UserID) @@ -118,62 +131,69 @@ func TestProvider_HappyPath_ReturnsIdentity(t *testing.T) { } } -func TestProvider_ScopeService_NotSTS_Rejected(t *testing.T) { - p := newTestProvider(t, happySTS()) - h := validHeaders() - h["Authorization"] = strings.Replace(h["Authorization"], "/sts/", "/s3/", 1) - _, err := p.Verify(context.Background(), "", h) +func TestProvider_RegionMismatch_Rejected(t *testing.T) { + // Token's credential scope says eu-west-1, provider configured us-east-1. + // Defends against cross-region token replay (the same AKID may be + // valid in either region, but the operator's allowlist applies only + // to the configured region). + p, stsURL := newTestProvider(t, happySTS()) + tok := tokenFor(stsURL, "AKIAEXAMPLE", "20260524", "eu-west-1") + _, err := p.Verify(context.Background(), tok, nil) if !errors.Is(err, auth.ErrTokenRejected) { - t.Errorf("err = %v, want ErrTokenRejected (cross-service replay)", err) + t.Errorf("err = %v, want ErrTokenRejected (cross-region)", err) } } -func TestProvider_ScopeRegion_Mismatch_Rejected(t *testing.T) { - p := newTestProvider(t, happySTS()) - h := validHeaders() - h["Authorization"] = strings.Replace(h["Authorization"], "/us-east-1/", "/eu-west-1/", 1) - _, err := p.Verify(context.Background(), "", h) - if !errors.Is(err, auth.ErrTokenRejected) { - t.Errorf("err = %v, want ErrTokenRejected (cross-region replay)", err) +func TestProvider_ForeignHost_Invalid(t *testing.T) { + // SSRF guard: a token whose URL points anywhere other than the + // expected STS host is rejected before we ever issue a request. + p, _ := newTestProvider(t, happySTS()) + hostile := "https://evil.example.com/?Action=GetCallerIdentity" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Credential=AKIA/20260524/us-east-1/sts/aws4_request" + + "&X-Amz-Signature=x" + tok := TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(hostile)) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (SSRF guard)", err) } } func TestProvider_AllowlistMiss_Rejected(t *testing.T) { - p := newTestProvider(t, happySTS(), func(c *Config) { + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { c.AllowedPrincipals = []string{"arn:aws:iam::999:role/*"} }) - _, err := p.Verify(context.Background(), "", validHeaders()) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) if !errors.Is(err, auth.ErrTokenRejected) { t.Errorf("err = %v, want ErrTokenRejected (allowlist miss)", err) } } func TestProvider_AllowlistHit_Succeeds(t *testing.T) { - p := newTestProvider(t, happySTS(), func(c *Config) { - // Matches the assumed-role ARN returned by happySTSXML. + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { c.AllowedPrincipals = []string{"arn:aws:sts::123456789012:assumed-role/ci-deploy/*"} }) - if _, err := p.Verify(context.Background(), "", validHeaders()); err != nil { + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { t.Errorf("Verify: %v", err) } } func TestProvider_STSDown_Unavailable(t *testing.T) { - p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) - _, err := p.Verify(context.Background(), "", validHeaders()) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) if !errors.Is(err, auth.ErrProviderUnavailable) { t.Errorf("err = %v, want ErrProviderUnavailable", err) } } func TestProvider_STSRejects_Rejected(t *testing.T) { - p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusForbidden) _, _ = io.WriteString(w, "SignatureDoesNotMatch") })) - _, err := p.Verify(context.Background(), "", validHeaders()) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) if !errors.Is(err, auth.ErrTokenRejected) { t.Errorf("err = %v, want ErrTokenRejected", err) } @@ -181,27 +201,25 @@ func TestProvider_STSRejects_Rejected(t *testing.T) { func TestProvider_CacheHit_AvoidsSTSCall(t *testing.T) { var calls atomic.Int32 - p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { calls.Add(1) _, _ = io.WriteString(w, happySTSXML) })) - - if _, err := p.Verify(context.Background(), "", validHeaders()); err != nil { - t.Fatalf("first Verify: %v", err) + tok := defaultToken(stsURL) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatal(err) } - if _, err := p.Verify(context.Background(), "", validHeaders()); err != nil { - t.Fatalf("second Verify: %v", err) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatal(err) } if got := calls.Load(); got != 1 { - t.Errorf("STS calls = %d, want 1 (cache must hit on 2nd request)", got) + t.Errorf("STS calls = %d, want 1 (cache must hit)", got) } } func TestProvider_RejectedRequest_DoesNotPoisonCache(t *testing.T) { - // First STS call returns 403; second valid request must hit STS again - // (the rejection MUST NOT have populated the cache). var calls atomic.Int32 - p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { n := calls.Add(1) if n == 1 { w.WriteHeader(http.StatusForbidden) @@ -210,12 +228,12 @@ func TestProvider_RejectedRequest_DoesNotPoisonCache(t *testing.T) { } _, _ = io.WriteString(w, happySTSXML) })) - - _, err1 := p.Verify(context.Background(), "", validHeaders()) + tok := defaultToken(stsURL) + _, err1 := p.Verify(context.Background(), tok, nil) if !errors.Is(err1, auth.ErrTokenRejected) { t.Fatalf("first Verify err = %v, want ErrTokenRejected", err1) } - _, err2 := p.Verify(context.Background(), "", validHeaders()) + _, err2 := p.Verify(context.Background(), tok, nil) if err2 != nil { t.Fatalf("second Verify err = %v, want nil", err2) } @@ -226,49 +244,23 @@ func TestProvider_RejectedRequest_DoesNotPoisonCache(t *testing.T) { func TestProvider_DateBucketRollover_TriggersFreshSTSCall(t *testing.T) { var calls atomic.Int32 - p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { calls.Add(1) _, _ = io.WriteString(w, happySTSXML) })) - - // Day 1 - h1 := validHeaders() - if _, err := p.Verify(context.Background(), "", h1); err != nil { + // Day 1 + Day 2 → different cache keys. + if _, err := p.Verify(context.Background(), tokenFor(stsURL, "AKIA", "20260524", "us-east-1"), nil); err != nil { t.Fatal(err) } - - // Day 2 — same AKID, different date in scope → different cache key. - h2 := auth.Headers{ - "Authorization": strings.Replace(validSigv4, "/20260523/", "/20260524/", 1), - "X-Amz-Date": "20260524T120000Z", - } - if _, err := p.Verify(context.Background(), "", h2); err != nil { + if _, err := p.Verify(context.Background(), tokenFor(stsURL, "AKIA", "20260525", "us-east-1"), nil); err != nil { t.Fatal(err) } if got := calls.Load(); got != 2 { - t.Errorf("STS calls = %d, want 2 (date bucket rolled, cache miss expected)", got) - } -} - -func TestProvider_AmzSecurityTokenForwarded(t *testing.T) { - var capturedSec string - p := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedSec = r.Header.Get("X-Amz-Security-Token") - _, _ = io.WriteString(w, happySTSXML) - })) - - h := validHeaders() - h["X-Amz-Security-Token"] = "FwoGZX-test-temp-session" - if _, err := p.Verify(context.Background(), "", h); err != nil { - t.Fatal(err) - } - if capturedSec != "FwoGZX-test-temp-session" { - t.Errorf("STS got X-Amz-Security-Token = %q", capturedSec) + t.Errorf("STS calls = %d, want 2 (date bucket rolled)", got) } } func TestProvider_RegisteredInRegistry(t *testing.T) { - // init() ran on package import — verify the factory is wired. p, err := auth.Build("aws_sigv4", map[string]any{ "region": "us-east-1", }) @@ -286,3 +278,44 @@ func TestProvider_FactoryRejectsMissingRegion(t *testing.T) { t.Fatal("expected error from factory when region is missing") } } + +func TestProvider_TokenPointingAtForeignSTSRegion_Invalid(t *testing.T) { + // Token URL says sts.eu-west-1.amazonaws.com — provider expects + // sts.us-east-1.amazonaws.com. The pre-validation host check should + // catch this before any STS call. + p, _ := newTestProvider(t, happySTS(), func(c *Config) { + // Use defaults so expectedHost stays sts.us-east-1.amazonaws.com. + // Drop STSEndpoint override to exercise the real host path. + c.STSEndpoint = "" + }) + // Force the provider to compare against the real sts host. + hostile := "https://sts.eu-west-1.amazonaws.com/?Action=GetCallerIdentity" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Credential=AKIA/20260524/eu-west-1/sts/aws4_request" + + "&X-Amz-Signature=x" + tok := TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(hostile)) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (cross-region URL host)", err) + } +} + +// Sanity: with no STSEndpoint override the expectedHost matches AWS's +// real STS endpoint for the configured region — protects against +// accidental regressions to the prod host derivation. +func TestProvider_DefaultExpectedHost(t *testing.T) { + p, err := New(Config{Region: "us-east-1"}) + if err != nil { + t.Fatal(err) + } + if p.expectedHost != "sts.us-east-1.amazonaws.com" { + t.Errorf("expectedHost = %q", p.expectedHost) + } + if !p.requireHTTPS { + t.Error("requireHTTPS should be true by default") + } +} + +// startsWithBearer is a tiny helper for the token_kind tests in middleware +// (kept local so a test-only constant doesn't leak into the production package). +var _ = strings.HasPrefix diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go index 854053c..cb3deaf 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go @@ -1,98 +1,131 @@ package aws_sigv4 import ( + "encoding/base64" "errors" + "fmt" + "net/url" "strings" ) -// Sigv4Header is the parsed view of an AWS Sigv4 Authorization header. -type Sigv4Header struct { - AKID string // e.g. "AKIAIOSFODNN7EXAMPLE" - Date string // YYYYMMDD scope date - Region string // e.g. "us-east-1" - Service string // e.g. "sts" - SignedHeaders string // semicolon-separated list — must include "host" - Signature string // hex-encoded HMAC -} - -// Parser parses Sigv4-shaped Authorization headers. Pure string work — -// no HTTP, no AWS SDK. Never panics on input. -type Parser struct{} +// TokenPrefix is the magic token-type marker that distinguishes +// forge-aws-v1 tokens from JWTs / opaque tokens in the Bearer slot. +// Mirrors the "k8s-aws-v1." convention from aws-iam-authenticator. +const TokenPrefix = "forge-aws-v1." -const sigv4Algorithm = "AWS4-HMAC-SHA256" +// PresignedToken is the parsed view of a forge-aws-v1 Bearer token. +// The token wraps a pre-signed STS GetCallerIdentity URL; AKID + Date are +// extracted from the URL's X-Amz-Credential query param so the provider's +// identity cache can key on them without re-deriving for every Verify call. +type PresignedToken struct { + URL *url.URL // validated, host-checked, ready to GET + AKID string // for IdentityCache bucket key + Date string // YYYYMMDD scope date — for IdentityCache bucket key + Region string // from the credential scope (we cross-check against cfg.Region) +} -// Parse extracts the fields from a Sigv4 Authorization header. Real-world -// signers vary in whitespace around the comma separators; we trim tolerantly -// rather than insist on exact bytes. +// ParseToken validates a forge-aws-v1 Bearer token end-to-end and returns +// the URL Forge should invoke on STS. // -// Returns a clear error on malformed input — callers should map parse -// errors to auth.ErrInvalidToken, not auth.ErrTokenNotForMe (the latter -// is reserved for "no Sigv4 prefix at all"). -func (Parser) Parse(authHeader string) (*Sigv4Header, error) { - if !strings.HasPrefix(authHeader, sigv4Algorithm+" ") { - return nil, errors.New("missing AWS4-HMAC-SHA256 prefix") - } - rest := strings.TrimPrefix(authHeader, sigv4Algorithm+" ") - - parts := splitCommaTrim(rest) - if len(parts) != 3 { - return nil, errors.New("expected 3 comma-separated kv pairs") +// expectedHost is sts..amazonaws.com for prod, or the test-mode +// override host (Config.STSEndpoint) for integration tests. +// +// Validation gates (in order — fail-fast): +// +// 1. Token starts with the TokenPrefix. +// 2. Body decodes as base64url. +// 3. Decoded payload parses as a URL. +// 4. URL scheme is https (or http when STSEndpoint test override is in use). +// 5. URL host matches expectedHost. +// 6. URL query has Action=GetCallerIdentity. +// 7. URL query has X-Amz-Algorithm=AWS4-HMAC-SHA256. +// 8. URL query has a non-empty X-Amz-Signature. +// 9. URL query has X-Amz-Credential parseable as AKID/YYYYMMDD/region/sts/aws4_request. +// +// Returns ErrTokenNotForMe only when (1) fails — the prefix is the only +// "shape" check; everything else is a malformed / rejected token from +// our perspective, classified as ErrInvalidToken or ErrTokenRejected +// by the caller. +func ParseToken(token, expectedHost string, requireHTTPS bool) (*PresignedToken, error) { + if !strings.HasPrefix(token, TokenPrefix) { + return nil, errors.New("missing forge-aws-v1 prefix") + } + encoded := strings.TrimPrefix(token, TokenPrefix) + if encoded == "" { + return nil, errors.New("forge-aws-v1 token has empty payload") } - out := &Sigv4Header{} - for _, p := range parts { - k, v, ok := strings.Cut(p, "=") - if !ok { - return nil, errors.New("malformed kv pair") - } - switch strings.TrimSpace(k) { - case "Credential": - if err := parseCredentialScope(v, out); err != nil { - return nil, err - } - case "SignedHeaders": - out.SignedHeaders = strings.TrimSpace(v) - case "Signature": - out.Signature = strings.TrimSpace(v) + // base64url decode — accept both padded and unpadded forms because + // SDKs disagree on whether to emit "=" padding. + raw, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + // fallback to standard base64url with padding + raw, err = base64.URLEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("base64url decode: %w", err) } } - if out.AKID == "" || out.Date == "" || out.Region == "" || out.Service == "" || - out.SignedHeaders == "" || out.Signature == "" { - return nil, errors.New("missing one or more required Sigv4 fields") + u, err := url.Parse(string(raw)) + if err != nil { + return nil, fmt.Errorf("decoded payload is not a URL: %w", err) + } + + if requireHTTPS && u.Scheme != "https" { + return nil, fmt.Errorf("URL scheme %q is not https", u.Scheme) } - if !signedHeadersContainsHost(out.SignedHeaders) { - return nil, errors.New("SignedHeaders must include host") + if u.Scheme != "https" && u.Scheme != "http" { + return nil, fmt.Errorf("URL scheme %q is not http(s)", u.Scheme) + } + if !strings.EqualFold(u.Host, expectedHost) { + return nil, fmt.Errorf("URL host %q does not match expected %q (SSRF guard)", u.Host, expectedHost) } - return out, nil -} -func parseCredentialScope(v string, out *Sigv4Header) error { - // Credential = AKID/YYYYMMDD/region/service/aws4_request - segs := strings.Split(strings.TrimSpace(v), "/") - if len(segs) != 5 || segs[4] != "aws4_request" { - return errors.New("malformed Credential scope") - } - out.AKID, out.Date, out.Region, out.Service = segs[0], segs[1], segs[2], segs[3] - return nil -} + q := u.Query() + if q.Get("Action") != "GetCallerIdentity" { + return nil, fmt.Errorf("URL Action=%q, want GetCallerIdentity", q.Get("Action")) + } + if q.Get("X-Amz-Algorithm") != "AWS4-HMAC-SHA256" { + return nil, fmt.Errorf("URL X-Amz-Algorithm=%q, want AWS4-HMAC-SHA256", q.Get("X-Amz-Algorithm")) + } + if q.Get("X-Amz-Signature") == "" { + return nil, errors.New("URL missing X-Amz-Signature") + } -// signedHeadersContainsHost validates that the SignedHeaders list includes -// "host" as a discrete entry. Sub-string match would accept "ghosting" — we -// split on ";" so each entry stands on its own. -func signedHeadersContainsHost(s string) bool { - for h := range strings.SplitSeq(s, ";") { - if strings.EqualFold(strings.TrimSpace(h), "host") { - return true - } + akid, date, region, err := parseCredentialScope(q.Get("X-Amz-Credential")) + if err != nil { + return nil, fmt.Errorf("X-Amz-Credential: %w", err) } - return false + + return &PresignedToken{ + URL: u, + AKID: akid, + Date: date, + Region: region, + }, nil } -func splitCommaTrim(s string) []string { - parts := strings.Split(s, ",") - for i := range parts { - parts[i] = strings.TrimSpace(parts[i]) +// parseCredentialScope splits X-Amz-Credential into its five segments: +// +// AKID/YYYYMMDD/region/service/aws4_request +// +// Service MUST be "sts" and the tail MUST be "aws4_request". +func parseCredentialScope(cred string) (akid, date, region string, err error) { + if cred == "" { + return "", "", "", errors.New("missing X-Amz-Credential") + } + segs := strings.Split(cred, "/") + if len(segs) != 5 { + return "", "", "", fmt.Errorf("expected 5 /-separated parts, got %d", len(segs)) + } + if segs[3] != "sts" { + return "", "", "", fmt.Errorf("scope service=%q, want sts", segs[3]) + } + if segs[4] != "aws4_request" { + return "", "", "", fmt.Errorf("scope tail=%q, want aws4_request", segs[4]) + } + if segs[0] == "" || segs[1] == "" || segs[2] == "" { + return "", "", "", errors.New("empty AKID/date/region segment") } - return parts + return segs[0], segs[1], segs[2], nil } diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go index 4f33b9a..7e30af2 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go @@ -1,84 +1,183 @@ package aws_sigv4 import ( + "encoding/base64" + "errors" "strings" "testing" ) -func TestParser_HappyPath(t *testing.T) { - h := "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ab12cd34" - sig, err := (Parser{}).Parse(h) +// makeToken builds a forge-aws-v1 token from a complete URL. Helper for +// tests — production tokens come from the AWS SDK's Presign(). +func makeToken(rawURL string) string { + return TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(rawURL)) +} + +const validPresignedURL = "https://sts.us-east-1.amazonaws.com/" + + "?Action=GetCallerIdentity" + + "&Version=2011-06-15" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20260524%2Fus-east-1%2Fsts%2Faws4_request" + + "&X-Amz-Date=20260524T010000Z" + + "&X-Amz-Expires=900" + + "&X-Amz-SignedHeaders=host" + + "&X-Amz-Signature=abcd1234" + +const validHost = "sts.us-east-1.amazonaws.com" + +func TestParseToken_HappyPath(t *testing.T) { + tok := makeToken(validPresignedURL) + parsed, err := ParseToken(tok, validHost, true) if err != nil { - t.Fatalf("Parse: %v", err) + t.Fatalf("ParseToken: %v", err) + } + if parsed.AKID != "AKIAIOSFODNN7EXAMPLE" { + t.Errorf("AKID = %q", parsed.AKID) + } + if parsed.Date != "20260524" { + t.Errorf("Date = %q", parsed.Date) } - if sig.AKID != "AKIAIOSFODNN7EXAMPLE" { - t.Errorf("AKID = %q", sig.AKID) + if parsed.Region != "us-east-1" { + t.Errorf("Region = %q", parsed.Region) } - if sig.Date != "20260523" { - t.Errorf("Date = %q", sig.Date) + if parsed.URL == nil || parsed.URL.Host != validHost { + t.Errorf("URL host = %v", parsed.URL) + } +} + +func TestParseToken_MissingPrefix_NotForMe(t *testing.T) { + cases := []string{ + "", + "Bearer foo", + "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ4In0.sig", // JWT-shaped + "forge-aws-v0.something", // wrong version prefix + "AWS4-HMAC-SHA256 Credential=AKIA...", // old format + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + _, err := ParseToken(in, validHost, true) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "missing forge-aws-v1 prefix") { + t.Errorf("err = %v, want missing-prefix error (so caller maps to ErrTokenNotForMe)", err) + } + }) } - if sig.Region != "us-east-1" { - t.Errorf("Region = %q", sig.Region) +} + +func TestParseToken_EmptyPayload(t *testing.T) { + _, err := ParseToken(TokenPrefix, validHost, true) + if err == nil { + t.Fatal("expected error on empty payload") } - if sig.Service != "sts" { - t.Errorf("Service = %q", sig.Service) +} + +func TestParseToken_RejectsForeignHost(t *testing.T) { + // SSRF guard — even if base64 decodes to a syntactically valid URL, + // any non-STS host is rejected. + hostile := "https://evil.example.com/" + strings.Replace(validPresignedURL, "https://sts.us-east-1.amazonaws.com/", "", 1) + _, err := ParseToken(makeToken(hostile), validHost, true) + if err == nil || !strings.Contains(err.Error(), "SSRF") { + t.Errorf("err = %v, want SSRF-guard rejection", err) } - if sig.SignedHeaders != "host;x-amz-date" { - t.Errorf("SignedHeaders = %q", sig.SignedHeaders) +} + +func TestParseToken_RejectsHTTPScheme_InProdMode(t *testing.T) { + httpURL := strings.Replace(validPresignedURL, "https://", "http://", 1) + _, err := ParseToken(makeToken(httpURL), validHost, true) + if err == nil || !strings.Contains(err.Error(), "scheme") { + t.Errorf("err = %v, want https-required rejection", err) } - if sig.Signature != "ab12cd34" { - t.Errorf("Signature = %q", sig.Signature) +} + +func TestParseToken_AcceptsHTTP_InTestMode(t *testing.T) { + httpURL := strings.Replace(validPresignedURL, "https://", "http://", 1) + _, err := ParseToken(makeToken(httpURL), validHost, false) + if err != nil { + t.Errorf("test-mode http should be accepted, got %v", err) } } -func TestParser_TolerantWhitespace(t *testing.T) { - // Real-world Sigv4 signers vary in whitespace around the commas; - // extra whitespace must still parse cleanly. - h := "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request,SignedHeaders=host,Signature=abc" - if _, err := (Parser{}).Parse(h); err != nil { - t.Errorf("expected tolerant whitespace parse to succeed, got %v", err) +func TestParseToken_RejectsWrongAction(t *testing.T) { + u := strings.Replace(validPresignedURL, "Action=GetCallerIdentity", "Action=ListUsers", 1) + _, err := ParseToken(makeToken(u), validHost, true) + if err == nil || !strings.Contains(err.Error(), "Action") { + t.Errorf("err = %v, want Action-mismatch rejection", err) } } -func TestParser_MalformedInputs(t *testing.T) { +func TestParseToken_RejectsMissingSignature(t *testing.T) { + u := strings.Replace(validPresignedURL, "&X-Amz-Signature=abcd1234", "", 1) + _, err := ParseToken(makeToken(u), validHost, true) + if err == nil || !strings.Contains(err.Error(), "X-Amz-Signature") { + t.Errorf("err = %v, want missing-signature rejection", err) + } +} + +func TestParseToken_RejectsWrongAlgorithm(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Algorithm=AWS4-HMAC-SHA256", "X-Amz-Algorithm=AWS3-MD5", 1) + _, err := ParseToken(makeToken(u), validHost, true) + if err == nil || !strings.Contains(err.Error(), "X-Amz-Algorithm") { + t.Errorf("err = %v, want algorithm-mismatch rejection", err) + } +} + +func TestParseToken_RejectsBadBase64(t *testing.T) { + _, err := ParseToken(TokenPrefix+"!!!not-base64!!!", validHost, true) + if err == nil { + t.Fatal("expected error on malformed base64") + } +} + +func TestParseCredentialScope_HappyPath(t *testing.T) { + akid, date, region, err := parseCredentialScope("AKIA123/20260524/us-east-1/sts/aws4_request") + if err != nil { + t.Fatalf("parseCredentialScope: %v", err) + } + if akid != "AKIA123" || date != "20260524" || region != "us-east-1" { + t.Errorf("got %q/%q/%q", akid, date, region) + } +} + +func TestParseCredentialScope_Malformed(t *testing.T) { cases := map[string]string{ - "empty": "", - "bearer instead": "Bearer foo", - "prefix only": "AWS4-HMAC-SHA256", - "prefix and space": "AWS4-HMAC-SHA256 ", - "too few segments": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts, SignedHeaders=host, Signature=ab", - "too many segments": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request/extra, SignedHeaders=host, Signature=ab", - "missing SignedHeaders": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, Signature=ab", - "signed headers no host": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, SignedHeaders=x-amz-date, Signature=ab", - "scope tail wrong": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/wrong, SignedHeaders=host, Signature=ab", - "two parts instead of three": "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, Signature=ab", + "empty": "", + "too few segments": "AKIA/20260524/us-east-1/sts", + "too many segments": "AKIA/20260524/us-east-1/sts/aws4_request/extra", + "wrong service": "AKIA/20260524/us-east-1/s3/aws4_request", + "wrong tail": "AKIA/20260524/us-east-1/sts/aws3_request", + "empty AKID segment": "/20260524/us-east-1/sts/aws4_request", } for name, in := range cases { t.Run(name, func(t *testing.T) { - if _, err := (Parser{}).Parse(in); err == nil { - t.Errorf("expected parse error for %q", in) + if _, _, _, err := parseCredentialScope(in); err == nil { + t.Errorf("expected error for %q", in) } }) } } -func TestParser_SignedHeadersHostMatchExact(t *testing.T) { - // "ghosting" must NOT count as containing "host" — we split on ";" - // so each entry stands on its own. - h := "AWS4-HMAC-SHA256 Credential=AKIA/20260523/us-east-1/sts/aws4_request, SignedHeaders=ghosting, Signature=ab" - if _, err := (Parser{}).Parse(h); err == nil { - t.Error("expected error: 'ghosting' should not satisfy host requirement") - } -} - -// FuzzParser ensures the parser never panics on arbitrary input. Run with -// `go test -fuzz=FuzzParser -fuzztime=30s ./forge-core/auth/providers/aws_sigv4/`. -func FuzzParser(f *testing.F) { - f.Add("AWS4-HMAC-SHA256 Credential=A/B/C/D/aws4_request, SignedHeaders=host, Signature=x") - f.Add("Bearer foo") +// FuzzParseToken — pure decoder must never panic, regardless of input. +func FuzzParseToken(f *testing.F) { + f.Add(makeToken(validPresignedURL)) f.Add("") - f.Add(strings.Repeat("AWS4-HMAC-SHA256 ", 100)) + f.Add(TokenPrefix) + f.Add(strings.Repeat(TokenPrefix, 10)) + f.Add(TokenPrefix + "AAAA====") f.Fuzz(func(_ *testing.T, in string) { - _, _ = (Parser{}).Parse(in) + _, _ = ParseToken(in, validHost, true) }) } + +// assert we map missing-prefix to a stable error message because the +// caller (provider.go) string-matches it to convert to ErrTokenNotForMe. +func TestParseToken_MissingPrefixErrorMessageIsStable(t *testing.T) { + _, err := ParseToken("garbage", validHost, true) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, err) || err.Error() != "missing forge-aws-v1 prefix" { + t.Errorf("err.Error() = %q, want exact string", err.Error()) + } +} diff --git a/forge-core/auth/providers/aws_sigv4/sts_client.go b/forge-core/auth/providers/aws_sigv4/sts_client.go index 05548ff..72cccf0 100644 --- a/forge-core/auth/providers/aws_sigv4/sts_client.go +++ b/forge-core/auth/providers/aws_sigv4/sts_client.go @@ -6,48 +6,34 @@ import ( "fmt" "io" "net/http" - "net/url" "strings" "time" "github.com/initializ/forge/forge-core/auth" ) -// STSClient calls AWS STS GetCallerIdentity by reflecting the caller's -// Sigv4 signature. ~150 LOC of hand-rolled HTTP + XML — no aws-sdk-go-v2 -// dependency (decision §9.1). +// STSClient invokes a pre-signed STS GetCallerIdentity URL produced by the +// caller's AWS SDK. Because the signature is in the URL's query parameters +// (not headers tied to a destination host), the request validates against +// STS no matter who relays it — that's the design property that lets +// Forge act as a verifier without holding any AWS secrets. // -// The key security property: Forge NEVER possesses the caller's secret -// key. STS validates the signature on Forge's behalf and returns the -// canonical ARN / UserId / Account on success. +// ~80 LOC of hand-rolled HTTP + XML. No aws-sdk-go-v2 dependency +// (decision §9.1). type STSClient struct { - endpoint string - http *http.Client + http *http.Client } -// NewSTSClient builds a client for sts..amazonaws.com. The override -// argument is for tests only — production should leave it empty. -func NewSTSClient(region, override string, timeout time.Duration) *STSClient { - ep := override - if ep == "" { - ep = fmt.Sprintf("https://sts.%s.amazonaws.com", region) - } +// NewSTSClient builds a client that GETs the URL the caller pre-signed. +// `region` is informational here (the URL itself carries the region in +// its credential scope and host); we keep the arg for symmetry with the +// pre-rewrite API and for future per-region tuning. +func NewSTSClient(_ /* region */ string, _ /* legacyOverrideUnused */ string, timeout time.Duration) *STSClient { return &STSClient{ - endpoint: ep, - http: &http.Client{Timeout: timeout}, + http: &http.Client{Timeout: timeout}, } } -// STSReflectArgs is the set of headers reflected verbatim from the caller's -// request to the STS call. SecurityToken is optional and only present when -// the caller is using temporary credentials (assumed roles, federation, -// IRSA, etc.). -type STSReflectArgs struct { - AuthHeader string - AmzDate string - SecurityToken string -} - // CallerIdentity is the parsed STS response — the canonical identifiers // Forge stamps into auth.Identity. type CallerIdentity struct { @@ -56,31 +42,21 @@ type CallerIdentity struct { Account string // e.g. "123456789012" } -// GetCallerIdentity reflects the caller's Sigv4 headers to STS and parses -// the response. +// GetCallerIdentity GETs the pre-signed URL and parses the XML response. // // Error classification: // // 200 OK → CallerIdentity, nil -// 4xx → auth.ErrTokenRejected (caller's signature didn't validate) +// 4xx → auth.ErrTokenRejected (caller's signature didn't validate; +// most often "SignatureDoesNotMatch", "ExpiredToken", or +// "InvalidClientTokenId") // 5xx / network → auth.ErrProviderUnavailable (review #6 contract) // parse failure → auth.ErrProviderUnavailable (unexpected response shape) -func (c *STSClient) GetCallerIdentity(ctx context.Context, args STSReflectArgs) (*CallerIdentity, error) { - body := url.Values{ - "Action": {"GetCallerIdentity"}, - "Version": {"2011-06-15"}, - }.Encode() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(body)) +func (c *STSClient) GetCallerIdentity(ctx context.Context, presignedURL string) (*CallerIdentity, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, presignedURL, nil) if err != nil { return nil, fmt.Errorf("%w: build STS request: %v", auth.ErrProviderUnavailable, err) } - req.Header.Set("Authorization", args.AuthHeader) - req.Header.Set("X-Amz-Date", args.AmzDate) - if args.SecurityToken != "" { - req.Header.Set("X-Amz-Security-Token", args.SecurityToken) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.http.Do(req) if err != nil { @@ -88,8 +64,7 @@ func (c *STSClient) GetCallerIdentity(ctx context.Context, args STSReflectArgs) } defer func() { _ = resp.Body.Close() }() - // Real GetCallerIdentity responses are ~1 KiB; cap at 64 KiB to bound - // memory if STS ever returns something pathological. + // Real GetCallerIdentity responses are ~1 KiB; cap at 64 KiB. raw, err := io.ReadAll(io.LimitReader(resp.Body, 64<<10)) if err != nil { return nil, fmt.Errorf("%w: read STS body: %v", auth.ErrProviderUnavailable, err) @@ -107,9 +82,8 @@ func (c *STSClient) GetCallerIdentity(ctx context.Context, args STSReflectArgs) } } -// parseGetCallerIdentityResponse extracts the three canonical fields from -// STS's XML response. Rejects responses that are missing any of them — STS -// always sets all three, so absence implies a malformed reply. +// parseGetCallerIdentityResponse extracts the three canonical fields. +// STS always emits all three; absence is treated as a malformed reply. func parseGetCallerIdentityResponse(raw []byte) (*CallerIdentity, error) { var resp struct { XMLName xml.Name `xml:"GetCallerIdentityResponse"` @@ -133,9 +107,9 @@ func parseGetCallerIdentityResponse(raw []byte) (*CallerIdentity, error) { } // summarize returns a short, single-line, log-safe rendering of an STS -// error body. Caps at 200 chars and strips newlines so the full error -// payload — which can include the caller's Authorization echo on certain -// SDK-shaped errors — never propagates to logs verbatim. +// error body. Caps at 200 chars and strips newlines so STS error text +// (which can echo the caller's headers in some shapes) never propagates +// to logs verbatim. func summarize(raw []byte) string { s := string(raw) if len(s) > 200 { diff --git a/forge-core/auth/providers/aws_sigv4/sts_client_test.go b/forge-core/auth/providers/aws_sigv4/sts_client_test.go index c71d120..db88ae2 100644 --- a/forge-core/auth/providers/aws_sigv4/sts_client_test.go +++ b/forge-core/auth/providers/aws_sigv4/sts_client_test.go @@ -24,17 +24,17 @@ const happySTSXML = `InternalFailure") })) defer srv.Close() - c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) - _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) if !errors.Is(err, auth.ErrProviderUnavailable) { t.Errorf("err = %v, want ErrProviderUnavailable", err) } } func TestSTSClient_NetworkError_Unavailable(t *testing.T) { - // Point at a closed listener so Do() returns a transport error. srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) url := srv.URL srv.Close() - c := NewSTSClient("us-east-1", url, 1*time.Second) - _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + c := NewSTSClient("us-east-1", "", 1*time.Second) + _, err := c.GetCallerIdentity(context.Background(), url) if !errors.Is(err, auth.ErrProviderUnavailable) { t.Errorf("err = %v, want ErrProviderUnavailable", err) } } -func TestSTSClient_AuthHeaderReflectedVerbatim(t *testing.T) { - var captured string - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - captured = r.Header.Get("Authorization") - w.Header().Set("Content-Type", "text/xml") - _, _ = io.WriteString(w, happySTSXML) - })) - defer srv.Close() - - wantAuth := "AWS4-HMAC-SHA256 Credential=AKIA.../20260523/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=abc123" - c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) - _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{ - AuthHeader: wantAuth, - AmzDate: "20260523T120000Z", - SecurityToken: "FwoGZX-test-session", - }) - if err != nil { - t.Fatalf("GetCallerIdentity: %v", err) - } - if captured != wantAuth { - t.Errorf("STS did not receive caller's Authorization verbatim:\n got: %q\n want: %q", captured, wantAuth) - } -} - -func TestSTSClient_SecurityTokenReflected(t *testing.T) { - var capturedTok string - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedTok = r.Header.Get("X-Amz-Security-Token") - _, _ = io.WriteString(w, happySTSXML) - })) - defer srv.Close() - - c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) - _, _ = c.GetCallerIdentity(context.Background(), STSReflectArgs{ - AuthHeader: "AWS4", - AmzDate: "20260523T120000Z", - SecurityToken: "FwoGZX-token", - }) - if capturedTok != "FwoGZX-token" { - t.Errorf("X-Amz-Security-Token = %q, want FwoGZX-token", capturedTok) - } -} - func TestSTSClient_BodyCap(t *testing.T) { - // Serve 128 KiB; ensure the client doesn't OOM and the response - // either parses cleanly (truncated at exactly a well-formed prefix - // — unlikely) OR returns ErrProviderUnavailable from parse failure. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = io.WriteString(w, "") @@ -145,8 +97,8 @@ func TestSTSClient_BodyCap(t *testing.T) { })) defer srv.Close() - c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) - _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) if err == nil { t.Fatal("expected error on oversized STS body") } @@ -156,8 +108,6 @@ func TestSTSClient_BodyCap(t *testing.T) { } func TestSTSClient_MissingFieldsRejected(t *testing.T) { - // Response with empty Arn — must reject as unavailable (malformed, - // not "rejected" — STS would never legitimately return blank fields). srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, ` x123 @@ -165,38 +115,45 @@ func TestSTSClient_MissingFieldsRejected(t *testing.T) { })) defer srv.Close() - c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) - _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) if !errors.Is(err, auth.ErrProviderUnavailable) { t.Errorf("err = %v, want ErrProviderUnavailable", err) } } -func TestSTSClient_RegionEndpointFormat(t *testing.T) { - // Sanity: NewSTSClient without an override builds the right URL. - c := NewSTSClient("eu-west-1", "", time.Second) - if c.endpoint != "https://sts.eu-west-1.amazonaws.com" { - t.Errorf("endpoint = %q", c.endpoint) - } -} - func TestSTSClient_RequestCount(t *testing.T) { - // Pin the basic happy-path call counting for use by the provider-level - // cache-hit test. var calls atomic.Int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { calls.Add(1) _, _ = io.WriteString(w, happySTSXML) })) defer srv.Close() - c := NewSTSClient("us-east-1", srv.URL, 5*time.Second) - for i := range 3 { - _, err := c.GetCallerIdentity(context.Background(), STSReflectArgs{AuthHeader: "AWS4"}) - if err != nil { - t.Fatalf("call %d: %v", i, err) + c := NewSTSClient("us-east-1", "", 5*time.Second) + for range 3 { + if _, err := c.GetCallerIdentity(context.Background(), srv.URL); err != nil { + t.Fatalf("call: %v", err) } } if calls.Load() != 3 { t.Errorf("STS calls = %d, want 3", calls.Load()) } } + +func TestSTSClient_PreservesURLQueryString(t *testing.T) { + // The pre-signed URL carries the signature in query params; the + // client MUST send those verbatim to STS or STS will reject. + var capturedQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedQuery = r.URL.RawQuery + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + wantQuery := "Action=GetCallerIdentity&X-Amz-Signature=abc123" + _, _ = c.GetCallerIdentity(context.Background(), srv.URL+"/?"+wantQuery) + if capturedQuery != wantQuery { + t.Errorf("STS received query %q, want %q", capturedQuery, wantQuery) + } +} diff --git a/scripts/forge-aws-sign.py b/scripts/forge-aws-sign.py index b0cc68f..b13c19b 100755 --- a/scripts/forge-aws-sign.py +++ b/scripts/forge-aws-sign.py @@ -2,106 +2,91 @@ """ forge-aws-sign — reference client for Forge's aws_sigv4 auth provider. -Why this script exists -====================== -Forge's aws_sigv4 provider authenticates callers by reflecting their Sigv4 -signature to AWS STS GetCallerIdentity. The signature is bound to the -*destination host* in its canonical-headers input, so callers cannot just -sign for Forge's URL and expect STS to validate the result. - -The correct client-side pattern is: - - 1. Construct a hypothetical request shape addressed to STS itself - (POST https://sts..amazonaws.com/, body - "Action=GetCallerIdentity&Version=2011-06-15"). - 2. Sign that hypothetical request with the caller's AWS credentials. - 3. Take the resulting Authorization + X-Amz-Date (+ X-Amz-Security-Token - for temporary credentials) headers and attach them to the REAL request - that's being sent to Forge. - 4. Forge forwards those headers verbatim to STS, which validates them - against its own host. Forge stamps an Identity from the returned ARN. - -This is identical in spirit to the aws-iam-authenticator pattern that EKS -uses. Standard tools (awscurl, the AWS CLI, `boto3.client('sts')`) do NOT -do this automatically — they always sign for the URL you're calling. So a -small wrapper like this one is required. +The aws_sigv4 provider uses the pre-signed URL pattern (same approach as +aws-iam-authenticator for EKS). The client mints a pre-signed STS +GetCallerIdentity URL using its own AWS SDK, then sends it to Forge as +a Bearer token of the form: + + Authorization: Bearer forge-aws-v1. + +Forge invokes the pre-signed URL on STS, which validates the signature +against its own host (because that's what was signed), and returns the +caller's canonical ARN. Usage ===== - python3 forge-aws-sign.py [--region us-east-1] [--url URL] [--body BODY] + # Print just the token (use it however you want): + python3 forge-aws-sign.py --token-only --region us-east-1 - Defaults: - --region: us-east-1 - --url: http://localhost:9999/tasks/send - --body: {"task":"hello"} + # Make a one-shot call to Forge: + python3 forge-aws-sign.py --region us-east-1 \ + --url http://localhost:9999/tasks/send \ + --body '{"task":"hello"}' Reads AWS credentials the same way boto3 does: env vars, profile, SSO, -IRSA, instance profile, etc. — whichever applies. +IRSA, instance profile, etc. -Exits 0 on HTTP 2xx, 1 otherwise. Prints the response body either way. +Exits 0 on HTTP 2xx (or when --token-only succeeds); 1 otherwise. """ from __future__ import annotations import argparse +import base64 import sys try: + import boto3 import requests - from botocore.session import Session - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest except ImportError as e: print(f"missing dependency: {e}", file=sys.stderr) print("install with: pip3 install --user boto3 requests", file=sys.stderr) sys.exit(2) +def mint_token(region: str, profile: str | None, expires: int = 900) -> str: + """Mint a forge-aws-v1 token from the current AWS credentials. + + `expires` (seconds) is the TTL baked into the pre-signed URL; max 900 + per STS limits. Forge accepts any valid TTL — the token simply becomes + unusable once STS rejects the expired signature. + """ + session = boto3.Session(profile_name=profile) if profile else boto3.Session() + sts = session.client("sts", region_name=region) + url = sts.generate_presigned_url("get_caller_identity", ExpiresIn=expires) + encoded = base64.urlsafe_b64encode(url.encode()).rstrip(b"=").decode() + return "forge-aws-v1." + encoded + + def main() -> int: parser = argparse.ArgumentParser(description="Forge aws_sigv4 reference client") parser.add_argument("--region", default="us-east-1", help="AWS region used in the Sigv4 scope") parser.add_argument("--url", default="http://localhost:9999/tasks/send", help="Forge endpoint to POST to") parser.add_argument("--body", default='{"task":"hello"}', help="JSON body to send to Forge") - parser.add_argument("--profile", default=None, help="AWS profile to use (default: boto3's default)") - parser.add_argument("-v", "--verbose", action="store_true", help="Show signed headers in output") + parser.add_argument("--profile", default=None, help="AWS profile (default: boto3's default chain)") + parser.add_argument("--expires", type=int, default=900, help="Pre-signed URL TTL in seconds (max 900)") + parser.add_argument("--token-only", action="store_true", help="Print only the token, don't make a request") + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") args = parser.parse_args() - session = Session(profile=args.profile) if args.profile else Session() - credentials = session.get_credentials() - if credentials is None: - print("No AWS credentials found. Try `aws sso login` or `aws configure`.", file=sys.stderr) - return 2 - frozen = credentials.get_frozen_credentials() - - # 1-2. Sign for STS's host (NOT Forge's host) - sts_url = f"https://sts.{args.region}.amazonaws.com/" - sts_body = "Action=GetCallerIdentity&Version=2011-06-15" - sts_req = AWSRequest( - method="POST", - url=sts_url, - data=sts_body, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ) - SigV4Auth(frozen, "sts", args.region).add_auth(sts_req) + try: + token = mint_token(args.region, args.profile, args.expires) + except Exception as e: + print(f"failed to mint token: {e}", file=sys.stderr) + return 1 - # 3. Reuse the signed headers in a request to Forge - forge_headers = { - "Authorization": sts_req.headers["Authorization"], - "X-Amz-Date": sts_req.headers["X-Amz-Date"], - "Content-Type": "application/json", - } - if "X-Amz-Security-Token" in sts_req.headers: - forge_headers["X-Amz-Security-Token"] = sts_req.headers["X-Amz-Security-Token"] + if args.token_only: + print(token) + return 0 if args.verbose: - print(f"POST {args.url}") - print(f" Authorization: {forge_headers['Authorization'][:80]}...") - print(f" X-Amz-Date: {forge_headers['X-Amz-Date']}") - has_tok = "X-Amz-Security-Token" in forge_headers - print(f" X-Amz-Security-Token: {'' if has_tok else ''}") - print() - - # 4. Send. Forge will reflect the headers to STS. - resp = requests.post(args.url, headers=forge_headers, data=args.body) + print(f"POST {args.url}", file=sys.stderr) + print(f" Authorization: Bearer {token[:60]}...", file=sys.stderr) + + resp = requests.post( + args.url, + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + data=args.body, + ) print(f"HTTP {resp.status_code}") print(resp.text) return 0 if 200 <= resp.status_code < 300 else 1 From 8568535a05a54561ea0267973b2b0b4577af623e Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 21:36:08 -0400 Subject: [PATCH 12/22] fix(auth/aws_sigv4): preserve raw presigned URL; use SigV4QueryAuth in client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two correctness fixes surfaced by live AWS testing of the pre-signed URL pattern from b3444c2. 1. Preserve the raw URL byte-for-byte. Round-tripping the presigned URL through Go's net/url package re-encoded query params in subtle ways (e.g. "/" in X-Amz-Credential, "+" inside X-Amz-Security-Token) that didn't match how the AWS SDK emitted them on the caller side. STS recomputes the canonical request using whatever bytes we send and gets a different hash → 4xx SignatureDoesNotMatch → audit reason "rejected". - PresignedToken gains a RawURL field — the exact bytes from the decoded token payload. - The parsed *url.URL is kept ONLY for SSRF host validation and query-param inspection. It is NEVER used to construct the outbound request. - Provider.Verify now passes parsed.RawURL to STSClient.GetCallerIdentity. 2. Use SigV4QueryAuth directly in the reference client (not boto3's high-level generate_presigned_url). boto3.client('sts').generate_presigned_url('get_caller_identity', ...) produces a URL STS rejects with SignatureDoesNotMatch when GET. Known quirk — the high-level presigner signs as if the request were a POST. aws-iam-authenticator works around this by signing the AWSRequest explicitly; scripts/forge-aws-sign.py now does the same: req = AWSRequest(method='GET', url='https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15') SigV4QueryAuth(creds, 'sts', region, expires=900).add_auth(req) token = 'forge-aws-v1.' + base64.urlsafe_b64encode(req.url.encode()).rstrip(b'=').decode() Live validation against real AWS (account 412664885516, SSO assumed-role): - Happy path: HTTP 400 body-shape error + auth_verify with correct ARN - Deny path: HTTP 401 + auth_fail reason="rejected" + token_kind="sigv4" 42 packages still pass; golangci-lint clean; gofmt clean. (Known follow-up surfaced but out of scope: hot-reload of forge.yaml doesn't rebuild the auth chain, so allowlist changes require a hard restart. Same caveat affects all providers, not just aws_sigv4.) --- .../auth/providers/aws_sigv4/provider.go | 5 +++- .../auth/providers/aws_sigv4/sigv4_parser.go | 11 +++++++- scripts/forge-aws-sign.py | 26 ++++++++++++++----- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go index 1683255..da39aae 100644 --- a/forge-core/auth/providers/aws_sigv4/provider.go +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -210,7 +210,10 @@ func (p *Provider) Verify(ctx context.Context, token string, _ auth.Headers) (*a return id, nil } - caller, err := p.sts.GetCallerIdentity(ctx, parsed.URL.String()) + // IMPORTANT: pass RawURL (the byte-for-byte original from the token), + // NOT parsed.URL.String(). The latter would re-encode query params + // via net/url's rules and invalidate the signature. + caller, err := p.sts.GetCallerIdentity(ctx, parsed.RawURL) if err != nil { return nil, err // STSClient already wraps with the right sentinel } diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go index cb3deaf..daf43b0 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go @@ -17,8 +17,16 @@ const TokenPrefix = "forge-aws-v1." // The token wraps a pre-signed STS GetCallerIdentity URL; AKID + Date are // extracted from the URL's X-Amz-Credential query param so the provider's // identity cache can key on them without re-deriving for every Verify call. +// +// RawURL holds the original URL byte-for-byte as it appeared in the +// decoded token payload. Forge MUST use RawURL when invoking STS — round- +// tripping through Go's net/url package re-encodes query parameters in +// ways that differ from how the AWS SDK emitted them (e.g., percent- +// encoding of "/" in X-Amz-Credential, "+" in X-Amz-Security-Token), and +// any such re-encoding invalidates the signature. type PresignedToken struct { - URL *url.URL // validated, host-checked, ready to GET + RawURL string // the exact URL the AWS SDK produced — preserve as-is + URL *url.URL // parsed view, for host validation and query inspection only AKID string // for IdentityCache bucket key Date string // YYYYMMDD scope date — for IdentityCache bucket key Region string // from the credential scope (we cross-check against cfg.Region) @@ -98,6 +106,7 @@ func ParseToken(token, expectedHost string, requireHTTPS bool) (*PresignedToken, } return &PresignedToken{ + RawURL: string(raw), URL: u, AKID: akid, Date: date, diff --git a/scripts/forge-aws-sign.py b/scripts/forge-aws-sign.py index b13c19b..8cd9deb 100755 --- a/scripts/forge-aws-sign.py +++ b/scripts/forge-aws-sign.py @@ -37,6 +37,8 @@ try: import boto3 import requests + from botocore.auth import SigV4QueryAuth + from botocore.awsrequest import AWSRequest except ImportError as e: print(f"missing dependency: {e}", file=sys.stderr) print("install with: pip3 install --user boto3 requests", file=sys.stderr) @@ -46,14 +48,26 @@ def mint_token(region: str, profile: str | None, expires: int = 900) -> str: """Mint a forge-aws-v1 token from the current AWS credentials. - `expires` (seconds) is the TTL baked into the pre-signed URL; max 900 - per STS limits. Forge accepts any valid TTL — the token simply becomes - unusable once STS rejects the expired signature. + Builds the pre-signed URL via SigV4QueryAuth directly, NOT via + boto3.client('sts').generate_presigned_url('get_caller_identity', ...) + — the latter signs as if the request were a POST to STS and STS + rejects the resulting GET URL with "SignatureDoesNotMatch." Same + quirk aws-iam-authenticator works around by signing the request + explicitly. + + `expires` (seconds) is the TTL baked into the URL; max 900. """ session = boto3.Session(profile_name=profile) if profile else boto3.Session() - sts = session.client("sts", region_name=region) - url = sts.generate_presigned_url("get_caller_identity", ExpiresIn=expires) - encoded = base64.urlsafe_b64encode(url.encode()).rstrip(b"=").decode() + creds = session.get_credentials().get_frozen_credentials() + + req = AWSRequest( + method="GET", + url=f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + headers={}, + ) + SigV4QueryAuth(creds, "sts", region, expires=expires).add_auth(req) + + encoded = base64.urlsafe_b64encode(req.url.encode()).rstrip(b"=").decode() return "forge-aws-v1." + encoded From 24be30a0cce264c5bb474f04cffecfb3d7d9afe7 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 12:02:19 -0400 Subject: [PATCH 13/22] feat(auth): TUI sub-flows + allowed_accounts shortcut for aws_sigv4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two paired additions that make Phase 2 provider onboarding actually usable end-to-end via the TUI. ## allowed_accounts (ergonomic shortcut) aws_sigv4.Config gains AllowedAccounts []string. Each 12-digit AWS account ID expands at Factory time into the canonical glob set covering every STS identity shape: arn:aws:iam:::user/* — direct IAM users arn:aws:iam:::role/* — direct IAM roles arn:aws:sts:::assumed-role/*/* — SSO, AssumeRole, IRSA arn:aws:sts:::federated-user/* — SAML / web-identity federation So an operator who wants "anyone in this account" writes one line of config instead of four globs. Composes with allowed_principals — list specific roles AND whole accounts in the same provider entry. Validation: - validateAccountID checks the 12-digit shape at Factory time - validate.ValidateAuthConfig catches malformed entries at `forge validate` time (before scaffold writes forge.yaml) ## TUI sub-flows for all three Phase 2 providers forge init's TUI picker now has 7 entries (was 4): None / OIDC / HTTP Verifier / AWS Sigv4 / GCP IAP / Azure AD / Custom Input flows: aws_sigv4 region → audience (opt) → accounts (opt) → done gcp_iap audience → done azure_ad tenant → audience → done (single-tenant only) Azure AD intentionally restricts the TUI to single-tenant. Enabling allow_multi_tenant is a deliberate security trade-off (any Entra tenant in the world is admissible) and should require editing forge.yaml, not clicking through a wizard. Egress hosts auto-computed from the selection: aws_sigv4 → sts..amazonaws.com gcp_iap → www.gstatic.com (hardcoded §9.4) azure_ad → login.microsoftonline.com So the Egress review step (which runs after Auth per 639bfa9) shows the operator the full outbound surface they're about to allow, including the auth-provider's STS / IAP / AAD endpoints. ## Other surfaces forge init --auth-aws-allowed-account flag (repeatable) — matches the non-interactive path the wizard takes internally. CHANGELOG: drops the "TUI deferred" note (no longer deferred), adds sections for allowed_accounts and the TUI sub-flows. ## Tests arn_matcher_test.go +3 tests (validateAccountID, expansion shapes, expanded patterns match realistic ARNs) provider_test.go +4 tests (AllowedAccounts happy path, deny path, factory rejects malformed, mix with AllowedPrincipals) step_auth_test.go +6 tests (AWS full flow, AWS optional skips, AWS bad account rejection, GCP flow, AAD flow, Phase 2 summaries) validate/auth_test.go updated validator covers allowed_accounts 42 packages pass go test -race -count=1; golangci-lint v2.10.1 clean; gofmt clean. --- CHANGELOG.md | 36 ++- forge-cli/cmd/init.go | 1 + forge-cli/cmd/init_auth.go | 4 + forge-cli/internal/tui/steps/step_auth.go | 207 +++++++++++++++- .../internal/tui/steps/step_auth_test.go | 234 +++++++++++++++++- .../auth/providers/aws_sigv4/arn_matcher.go | 37 +++ .../providers/aws_sigv4/arn_matcher_test.go | 65 +++++ .../auth/providers/aws_sigv4/provider.go | 29 ++- .../auth/providers/aws_sigv4/provider_test.go | 45 ++++ forge-core/validate/auth.go | 25 ++ 10 files changed, 673 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd71e56..e502ae9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,37 @@ from Go callers (currently only `azure_ad` multi-tenant). Operators see no change. +### `allowed_accounts` shortcut for whole-account trust + +For "any IAM principal in these AWS accounts" without writing +glob patterns: + +```yaml +auth: + providers: + - type: aws_sigv4 + settings: + region: us-east-1 + allowed_accounts: ["412664885516", "109887654321"] +``` + +Internally expands to the canonical glob set covering all identity +shapes (IAM users, IAM roles, STS assumed-roles, federated users) +for each account. Composes with `allowed_principals` — you can list +specific roles AND whole accounts in the same provider entry. + +For AWS-Org-wide trust without enumerating accounts, use AWS IAM +Identity Center (SSO) — SSO permission sets gate Org membership at +sign-in, and you can match Identity Center-assumed roles with the +existing `allowed_principals` globs. + +### TUI wizard supports Phase 2 providers + +`forge init`'s TUI picker now includes `AWS Sigv4 (IAM)`, +`GCP Identity-Aware Proxy`, and `Azure AD / Entra ID` entries with +step-by-step input flows. AAD is single-tenant in the TUI; +multi-tenant remains a deliberate YAML edit (security default). + ### Client experience for `aws_sigv4` The client side is a Bearer token with a 3-line mint: @@ -70,7 +101,4 @@ docstring of `forge-core/auth/providers/aws_sigv4/provider.go`. ### Known deferred work -- The Bubble Tea TUI wizard (`forge init`) does not yet have step-by-step - input flows for the three new providers. The non-interactive flag path - is the production-critical surface; operators using the TUI can pick - "Custom" and edit `forge.yaml` directly until the TUI follow-up lands. +- (none for Phase 2) diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index bca5667..6c5c70a 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -146,6 +146,7 @@ func init() { initCmd.Flags().String("auth-aws-region", "", "AWS region for aws_sigv4 (e.g. us-east-1)") initCmd.Flags().String("auth-aws-audience", "", "informational audience for aws_sigv4") initCmd.Flags().StringSlice("auth-aws-allowed-principal", nil, "allowed principal glob for aws_sigv4 (repeatable)") + initCmd.Flags().StringSlice("auth-aws-allowed-account", nil, "allowed AWS account ID for aws_sigv4 (repeatable; ergonomic shortcut for whole-account access)") initCmd.Flags().String("auth-aws-cache-ttl", "", "aws_sigv4 identity cache TTL (default 60s)") initCmd.Flags().String("auth-gcp-iap-audience", "", "audience for gcp_iap (backend service ID)") diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index fbd0f91..057210b 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -53,6 +53,7 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an region, _ := cmd.Flags().GetString("auth-aws-region") audience, _ := cmd.Flags().GetString("auth-aws-audience") allowedPrincipals, _ := cmd.Flags().GetStringSlice("auth-aws-allowed-principal") + allowedAccounts, _ := cmd.Flags().GetStringSlice("auth-aws-allowed-account") cacheTTL, _ := cmd.Flags().GetString("auth-aws-cache-ttl") if region == "" { return nil, nil, fmt.Errorf("--auth=aws_sigv4 requires --auth-aws-region") @@ -64,6 +65,9 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an if len(allowedPrincipals) > 0 { settings["allowed_principals"] = allowedPrincipals } + if len(allowedAccounts) > 0 { + settings["allowed_accounts"] = allowedAccounts + } if cacheTTL != "" { settings["identity_cache_ttl"] = cacheTTL } diff --git a/forge-cli/internal/tui/steps/step_auth.go b/forge-cli/internal/tui/steps/step_auth.go index 000ef16..82321d0 100644 --- a/forge-cli/internal/tui/steps/step_auth.go +++ b/forge-cli/internal/tui/steps/step_auth.go @@ -22,6 +22,16 @@ const ( authOIDCGroupsClaimPhase authHTTPURLPhase authHTTPOrgPhase + // Phase 2: aws_sigv4 + authAWSRegionPhase + authAWSAudiencePhase + authAWSAccountsPhase + // Phase 2: gcp_iap + authGCPIAPAudiencePhase + // Phase 2: azure_ad (single-tenant only; multi-tenant requires YAML + // edit so it's a deliberate choice rather than an accidental toggle) + authAADTenantPhase + authAADAudiencePhase authDonePhase ) @@ -31,6 +41,9 @@ const ( AuthModeNone = "none" AuthModeOIDC = "oidc" AuthModeHTTPVerifier = "http_verifier" + AuthModeAWSSigv4 = "aws_sigv4" + AuthModeGCPIAP = "gcp_iap" + AuthModeAzureAD = "azure_ad" AuthModeCustom = "custom" ) @@ -67,6 +80,20 @@ type AuthStep struct { httpURL string httpOrg string + // aws_sigv4 settings + awsRegion string + awsAudience string + awsAccounts []string // 12-digit AWS account IDs + + // gcp_iap settings + gcpAudience string + + // azure_ad settings (TUI is single-tenant only; multi-tenant + // requires editing forge.yaml directly so it stays a deliberate + // security decision) + aadTenant string + aadAudience string + complete bool } @@ -91,6 +118,24 @@ func NewAuthStep(styles *tui.StyleSet) *AuthStep { Description: "Legacy — POST tokens to your own /verify endpoint", Icon: "🔁", }, + { + Label: "AWS Sigv4 (IAM)", + Value: AuthModeAWSSigv4, + Description: "Verify AWS-IAM callers via STS GetCallerIdentity (Phase 2)", + Icon: "🅰️", + }, + { + Label: "GCP Identity-Aware Proxy", + Value: AuthModeGCPIAP, + Description: "Forge behind a GCP HTTPS LB+IAP (Phase 2)", + Icon: "🇬", + }, + { + Label: "Azure AD / Entra ID", + Value: AuthModeAzureAD, + Description: "Single-tenant Entra tokens (Phase 2 — multi-tenant via YAML)", + Icon: "🇦", + }, { Label: "Custom", Value: AuthModeCustom, @@ -144,7 +189,10 @@ func (s *AuthStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { case authSelectPhase: return s.updateSelect(msg) case authOIDCIssuerPhase, authOIDCAudiencePhase, authOIDCGroupsClaimPhase, - authHTTPURLPhase, authHTTPOrgPhase: + authHTTPURLPhase, authHTTPOrgPhase, + authAWSRegionPhase, authAWSAudiencePhase, authAWSAccountsPhase, + authGCPIAPAudiencePhase, + authAADTenantPhase, authAADAudiencePhase: return s.updateInput(msg) } return s, nil @@ -182,6 +230,30 @@ func (s *AuthStep) updateSelect(msg tea.Msg) (tui.Step, tea.Cmd) { validateHTTPSURL, ) return s, s.input.Init() + case AuthModeAWSSigv4: + s.phase = authAWSRegionPhase + s.input = s.newTextInput( + "AWS region", + "us-east-1", + validateAWSRegion, + ) + return s, s.input.Init() + case AuthModeGCPIAP: + s.phase = authGCPIAPAudiencePhase + s.input = s.newTextInput( + "IAP audience (backend service ID from GCP console)", + "/projects/PNUM/global/backendServices/BACKEND_ID", + validateNonEmpty, + ) + return s, s.input.Init() + case AuthModeAzureAD: + s.phase = authAADTenantPhase + s.input = s.newTextInput( + "Entra tenant ID (GUID)", + "00000000-0000-0000-0000-000000000000", + validateNonEmpty, + ) + return s, s.input.Init() } return s, cmd } @@ -242,6 +314,57 @@ func (s *AuthStep) updateInput(msg tea.Msg) (tui.Step, tea.Cmd) { s.complete = true s.phase = authDonePhase return s, doneCmd() + + // --- aws_sigv4 --- + case authAWSRegionPhase: + s.awsRegion = v + s.phase = authAWSAudiencePhase + s.input = s.newTextInput( + "Audience (informational; press Enter to skip)", + "api://forge", + nil, + ) + return s, s.input.Init() + + case authAWSAudiencePhase: + s.awsAudience = v + s.phase = authAWSAccountsPhase + s.input = s.newTextInput( + "Allowed AWS accounts (comma-separated 12-digit IDs; Enter to skip)", + "412664885516,109887654321", + validateAccountList, + ) + return s, s.input.Init() + + case authAWSAccountsPhase: + s.awsAccounts = parseAccountList(v) + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + + // --- gcp_iap --- + case authGCPIAPAudiencePhase: + s.gcpAudience = v + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + + // --- azure_ad (single-tenant) --- + case authAADTenantPhase: + s.aadTenant = v + s.phase = authAADAudiencePhase + s.input = s.newTextInput( + "Audience (Application ID URI)", + "api://forge", + validateNonEmpty, + ) + return s, s.input.Init() + + case authAADAudiencePhase: + s.aadAudience = v + s.complete = true + s.phase = authDonePhase + return s, doneCmd() } return s, cmd @@ -299,6 +422,15 @@ func (s *AuthStep) Summary() string { return "HTTP Verifier" } return "HTTP Verifier · " + host + case AuthModeAWSSigv4: + if len(s.awsAccounts) > 0 { + return fmt.Sprintf("AWS Sigv4 · %s · %d account(s)", s.awsRegion, len(s.awsAccounts)) + } + return "AWS Sigv4 · " + s.awsRegion + case AuthModeGCPIAP: + return "GCP IAP" + case AuthModeAzureAD: + return "Azure AD · single-tenant" case AuthModeCustom: return "Custom (edit forge.yaml)" } @@ -341,6 +473,25 @@ func (s *AuthStep) Apply(ctx *tui.WizardContext) { if host := hostnameOrEmpty(s.httpURL); host != "" { ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, host) } + case AuthModeAWSSigv4: + settings := map[string]any{"region": s.awsRegion} + if s.awsAudience != "" { + settings["audience"] = s.awsAudience + } + if len(s.awsAccounts) > 0 { + settings["allowed_accounts"] = s.awsAccounts + } + ctx.AuthSettings = settings + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, "sts."+s.awsRegion+".amazonaws.com") + case AuthModeGCPIAP: + ctx.AuthSettings = map[string]any{"audience": s.gcpAudience} + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, "www.gstatic.com") + case AuthModeAzureAD: + ctx.AuthSettings = map[string]any{ + "tenant_id": s.aadTenant, + "audience": s.aadAudience, + } + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, "login.microsoftonline.com") default: // None / Custom: nothing to attach. ctx.AuthSettings = nil @@ -376,6 +527,60 @@ func validateNonEmpty(val string) error { return nil } +// validateAWSRegion checks for the canonical AWS region shape +// (xx-direction-N). Doesn't enumerate regions because AWS adds them +// often; STS at startup fails fast if the host doesn't resolve. +func validateAWSRegion(val string) error { + v := strings.TrimSpace(val) + if v == "" { + return fmt.Errorf("required") + } + // xx-direction-N, e.g. us-east-1, eu-west-2, ap-southeast-1 + parts := strings.Split(v, "-") + if len(parts) < 3 { + return fmt.Errorf("expected AWS region shape like 'us-east-1'") + } + return nil +} + +// validateAccountList accepts a comma-separated list of 12-digit AWS +// account IDs. Empty input is OK (the field is optional). Each entry +// is validated; one bad entry blocks the whole field. +func validateAccountList(val string) error { + v := strings.TrimSpace(val) + if v == "" { + return nil + } + for raw := range strings.SplitSeq(v, ",") { + acct := strings.TrimSpace(raw) + if len(acct) != 12 { + return fmt.Errorf("account %q: expected 12 digits", acct) + } + for _, c := range acct { + if c < '0' || c > '9' { + return fmt.Errorf("account %q: must be digits only", acct) + } + } + } + return nil +} + +// parseAccountList splits a comma-separated input into a trimmed slice. +// Empty input returns nil so the caller can omit the YAML key. +func parseAccountList(val string) []string { + v := strings.TrimSpace(val) + if v == "" { + return nil + } + var out []string + for raw := range strings.SplitSeq(v, ",") { + if s := strings.TrimSpace(raw); s != "" { + out = append(out, s) + } + } + return out +} + // hostnameOrEmpty returns the bare host (no port) from a URL, or "" on // parse failure. func hostnameOrEmpty(raw string) string { diff --git a/forge-cli/internal/tui/steps/step_auth_test.go b/forge-cli/internal/tui/steps/step_auth_test.go index 81e3ddf..f313b6a 100644 --- a/forge-cli/internal/tui/steps/step_auth_test.go +++ b/forge-cli/internal/tui/steps/step_auth_test.go @@ -86,10 +86,11 @@ func TestAuthStep_NoneIsDefault(t *testing.T) { func TestAuthStep_Custom(t *testing.T) { s := newTestAuthStep(t) - // Move down 3 times to reach "Custom" (index 3). - s = press(t, s, "down") - s = press(t, s, "down") - s = press(t, s, "down") + // Picker order: 0=None 1=OIDC 2=HTTPVerifier 3=AWS 4=GCP 5=AAD 6=Custom. + // Six downs to reach Custom at index 6. + for range 6 { + s = press(t, s, "down") + } s = press(t, s, "enter") if !s.Complete() { t.Fatal("expected Complete after selecting Custom") @@ -396,3 +397,228 @@ func TestAppendUnique(t *testing.T) { t.Errorf("appendUnique new failed: %v", out) } } + +// --- Phase 2: aws_sigv4 / gcp_iap / azure_ad --- + +func TestAuthStep_AWSSigv4_FullFlow_WithAccounts(t *testing.T) { + s := newTestAuthStep(t) + + // Picker order: 0=None 1=OIDC 2=HTTPVerifier 3=AWS 4=GCP 5=AAD 6=Custom + // Navigate to AWS (3 downs). + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authAWSRegionPhase { + t.Fatalf("phase = %v, want AWS region", s.phase) + } + + // Region + s = typeIn(t, s, "us-east-1") + s = press(t, s, "enter") + if s.phase != authAWSAudiencePhase { + t.Fatalf("phase = %v, want AWS audience", s.phase) + } + + // Audience (optional — press Enter to skip) + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + if s.phase != authAWSAccountsPhase { + t.Fatalf("phase = %v, want AWS accounts", s.phase) + } + + // Accounts (comma-separated) + s = typeIn(t, s, "412664885516, 109887654321") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after accounts") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeAWSSigv4 { + t.Errorf("AuthMode = %q, want aws_sigv4", ctx.AuthMode) + } + if ctx.AuthSettings["region"] != "us-east-1" { + t.Errorf("region = %v", ctx.AuthSettings["region"]) + } + if ctx.AuthSettings["audience"] != "api://forge" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + accts, _ := ctx.AuthSettings["allowed_accounts"].([]string) + if !reflect.DeepEqual(accts, []string{"412664885516", "109887654321"}) { + t.Errorf("allowed_accounts = %v", accts) + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"sts.us-east-1.amazonaws.com"}) { + t.Errorf("egress hosts = %v", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_AWSSigv4_AllowAudienceSkip(t *testing.T) { + // Skipping the audience field (pressing Enter on empty input) and + // the accounts field should still complete cleanly. + s := newTestAuthStep(t) + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") // pick AWS + s = typeIn(t, s, "us-east-1") + s = press(t, s, "enter") // region → audience + s = press(t, s, "enter") // audience (empty) → accounts + s = press(t, s, "enter") // accounts (empty) → done + if !s.Complete() { + t.Fatal("expected Complete after skipping optional fields") + } + ctx := tui.NewWizardContext() + s.Apply(ctx) + if _, ok := ctx.AuthSettings["audience"]; ok { + t.Errorf("audience should be omitted when empty: %v", ctx.AuthSettings) + } + if _, ok := ctx.AuthSettings["allowed_accounts"]; ok { + t.Errorf("allowed_accounts should be omitted when empty: %v", ctx.AuthSettings) + } +} + +func TestAuthStep_AWSSigv4_RejectsMalformedAccount(t *testing.T) { + // validateAccountList enforces the 12-digit shape before letting + // the user advance past the accounts field. + if err := validateAccountList("notanaccount"); err == nil { + t.Error("expected error on malformed account ID") + } + if err := validateAccountList("412664885516,bad,109887654321"); err == nil { + t.Error("expected error when any account in the list is bad") + } + if err := validateAccountList(""); err != nil { + t.Errorf("empty input should be allowed (optional field), got %v", err) + } + if err := validateAccountList("412664885516 , 109887654321 "); err != nil { + t.Errorf("trimmed whitespace should be tolerated, got %v", err) + } +} + +func TestAuthStep_GCPIAP_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + // 4 downs to reach GCP IAP. + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authGCPIAPAudiencePhase { + t.Fatalf("phase = %v, want GCP audience", s.phase) + } + + s = typeIn(t, s, "/projects/12345/global/backendServices/67890") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after audience") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeGCPIAP { + t.Errorf("AuthMode = %q, want gcp_iap", ctx.AuthMode) + } + if ctx.AuthSettings["audience"] != "/projects/12345/global/backendServices/67890" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"www.gstatic.com"}) { + t.Errorf("egress hosts = %v, want [www.gstatic.com]", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_AzureAD_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + // 5 downs to reach Azure AD. + for range 5 { + s = press(t, s, "down") + } + s = press(t, s, "enter") + if s.phase != authAADTenantPhase { + t.Fatalf("phase = %v, want AAD tenant", s.phase) + } + + s = typeIn(t, s, "00000000-1111-2222-3333-444444444444") + s = press(t, s, "enter") + if s.phase != authAADAudiencePhase { + t.Fatalf("phase = %v, want AAD audience", s.phase) + } + + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after audience") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeAzureAD { + t.Errorf("AuthMode = %q", ctx.AuthMode) + } + if ctx.AuthSettings["tenant_id"] != "00000000-1111-2222-3333-444444444444" { + t.Errorf("tenant_id = %v", ctx.AuthSettings["tenant_id"]) + } + if ctx.AuthSettings["audience"] != "api://forge" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + if _, ok := ctx.AuthSettings["allow_multi_tenant"]; ok { + t.Errorf("TUI must NOT set allow_multi_tenant — multi-tenant requires YAML edit (security)") + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"login.microsoftonline.com"}) { + t.Errorf("egress hosts = %v", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_Summary_Phase2(t *testing.T) { + cases := []struct { + setup func(*AuthStep) + want string + }{ + {setup: func(s *AuthStep) { s.mode = AuthModeAWSSigv4; s.awsRegion = "us-east-1" }, want: "AWS Sigv4 · us-east-1"}, + {setup: func(s *AuthStep) { + s.mode = AuthModeAWSSigv4 + s.awsRegion = "us-east-1" + s.awsAccounts = []string{"A", "B"} + }, want: "AWS Sigv4 · us-east-1 · 2 account(s)"}, + {setup: func(s *AuthStep) { s.mode = AuthModeGCPIAP }, want: "GCP IAP"}, + {setup: func(s *AuthStep) { s.mode = AuthModeAzureAD }, want: "Azure AD · single-tenant"}, + } + for _, tc := range cases { + s := newTestAuthStep(t) + tc.setup(s) + if got := s.Summary(); got != tc.want { + t.Errorf("Summary = %q, want %q", got, tc.want) + } + } +} + +func TestValidateAWSRegion(t *testing.T) { + ok := []string{"us-east-1", "eu-west-2", "ap-southeast-1", "ca-central-1"} + for _, s := range ok { + if err := validateAWSRegion(s); err != nil { + t.Errorf("validateAWSRegion(%q) = %v, want nil", s, err) + } + } + bad := []string{"", "us", "us-east", "useast1"} + for _, s := range bad { + if err := validateAWSRegion(s); err == nil { + t.Errorf("validateAWSRegion(%q) returned nil", s) + } + } +} + +func TestParseAccountList(t *testing.T) { + cases := map[string][]string{ + "": nil, + " ": nil, + "412664885516": {"412664885516"}, + "a, b ,c": {"a", "b", "c"}, + "a,,b": {"a", "b"}, + } + for in, want := range cases { + got := parseAccountList(in) + if !reflect.DeepEqual(got, want) { + t.Errorf("parseAccountList(%q) = %v, want %v", in, got, want) + } + } +} diff --git a/forge-core/auth/providers/aws_sigv4/arn_matcher.go b/forge-core/auth/providers/aws_sigv4/arn_matcher.go index 8e3a6a6..2a57ef2 100644 --- a/forge-core/auth/providers/aws_sigv4/arn_matcher.go +++ b/forge-core/auth/providers/aws_sigv4/arn_matcher.go @@ -1,6 +1,7 @@ package aws_sigv4 import ( + "errors" "fmt" "path" ) @@ -44,3 +45,39 @@ func (m *ArnMatcher) Match(arn string) bool { } return false } + +// validateAccountID checks for an AWS account ID — 12 ASCII digits. +// Catches typos (region names, role ARNs pasted by mistake) at Factory +// time so a misconfigured allowed_accounts entry doesn't silently +// become an unreachable pattern. +func validateAccountID(s string) error { + if len(s) != 12 { + return fmt.Errorf("account ID %q: expected 12 digits, got %d chars", s, len(s)) + } + for i := 0; i < len(s); i++ { + c := s[i] + if c < '0' || c > '9' { + return errors.New("account ID must be 12 digits") + } + } + return nil +} + +// expandAccountGlobs returns the canonical ARN glob set that covers +// every STS identity shape in a given account. The four shapes: +// +// arn:aws:iam:::user/ — direct IAM user +// arn:aws:iam:::role/ — direct IAM role (rare; usually only EC2/Lambda) +// arn:aws:sts:::assumed-role// — SSO, AssumeRole, IRSA +// arn:aws:sts:::federated-user/ — SAML/web-identity federation +// +// path.Match's `*` doesn't cross `/`, so the assumed-role glob needs +// two `*` segments to span both RoleName and SessionName. +func expandAccountGlobs(acct string) []string { + return []string{ + "arn:aws:iam::" + acct + ":user/*", + "arn:aws:iam::" + acct + ":role/*", + "arn:aws:sts::" + acct + ":assumed-role/*/*", + "arn:aws:sts::" + acct + ":federated-user/*", + } +} diff --git a/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go b/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go index 24ef554..22b02a6 100644 --- a/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go +++ b/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go @@ -55,3 +55,68 @@ func TestArnMatcher_AssumedRoleNeedsTwoStars(t *testing.T) { t.Error("IAM role pattern should NOT match STS assumed-role ARN — docs invariant") } } + +func TestValidateAccountID(t *testing.T) { + good := []string{"412664885516", "000000000000", "999999999999"} + for _, s := range good { + if err := validateAccountID(s); err != nil { + t.Errorf("validateAccountID(%q) = %v, want nil", s, err) + } + } + bad := []string{ + "", // empty + "1234", // too short + "4126648855161", // 13 chars + "us-east-1", // not digits + "arn:aws:iam::1:", // ARN form + "412 664885516", // space + } + for _, s := range bad { + if err := validateAccountID(s); err == nil { + t.Errorf("validateAccountID(%q) returned nil, want error", s) + } + } +} + +func TestExpandAccountGlobs(t *testing.T) { + got := expandAccountGlobs("412664885516") + want := map[string]bool{ + "arn:aws:iam::412664885516:user/*": true, + "arn:aws:iam::412664885516:role/*": true, + "arn:aws:sts::412664885516:assumed-role/*/*": true, + "arn:aws:sts::412664885516:federated-user/*": true, + } + if len(got) != len(want) { + t.Fatalf("got %d patterns, want %d", len(got), len(want)) + } + for _, g := range got { + if !want[g] { + t.Errorf("unexpected pattern %q", g) + } + } +} + +// End-to-end: each expanded pattern matches the realistic ARN it's +// supposed to cover. +func TestExpandedAccountPatterns_MatchRealisticARNs(t *testing.T) { + m, err := NewArnMatcher(expandAccountGlobs("412664885516")) + if err != nil { + t.Fatalf("NewArnMatcher: %v", err) + } + cases := map[string]bool{ + // Should match (in-account): + "arn:aws:iam::412664885516:user/alice": true, + "arn:aws:iam::412664885516:role/ec2-instance-role": true, + "arn:aws:sts::412664885516:assumed-role/AWSReservedSSO_PowerUserAccess_abc/naveen": true, + "arn:aws:sts::412664885516:assumed-role/ci-deploy/session-id": true, + "arn:aws:sts::412664885516:federated-user/saml-jane": true, + // Should NOT match (different account): + "arn:aws:iam::999999999999:user/eve": false, + "arn:aws:sts::999999999999:assumed-role/SomeRole/session": false, + } + for arn, want := range cases { + if got := m.Match(arn); got != want { + t.Errorf("Match(%q) = %v, want %v", arn, got, want) + } + } +} diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go index da39aae..2eae0ea 100644 --- a/forge-core/auth/providers/aws_sigv4/provider.go +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -107,6 +107,23 @@ type Config struct { // IAM role ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). AllowedPrincipals []string `yaml:"allowed_principals,omitempty"` + // AllowedAccounts is an ergonomic shortcut for the common case of + // "anyone in these AWS accounts." Each entry is an account ID + // (12 digits); New() expands each into the canonical glob set: + // + // arn:aws:iam:::user/* + // arn:aws:iam:::role/* + // arn:aws:sts:::assumed-role/*/* + // arn:aws:sts:::federated-user/* + // + // covering every shape STS returns. The expansion is appended to + // AllowedPrincipals — operators can mix the two: list specific + // roles in AllowedPrincipals AND whole accounts in AllowedAccounts. + // + // For AWS-Org-wide trust without enumerating accounts, see the + // docstring section on AWS Identity Center / SSO. + AllowedAccounts []string `yaml:"allowed_accounts,omitempty"` + // IdentityCacheTTL bounds how long a verified Identity is reused // without re-checking with STS. Defaults to 60s. IdentityCacheTTL time.Duration `yaml:"identity_cache_ttl,omitempty"` @@ -149,7 +166,17 @@ func New(cfg Config) (*Provider, error) { if cfg.HTTPTimeout == 0 { cfg.HTTPTimeout = defaultHTTPTimeout } - matcher, err := NewArnMatcher(cfg.AllowedPrincipals) + // Expand the AllowedAccounts shortcut into canonical globs and merge + // with operator-supplied AllowedPrincipals. We expand here (rather + // than at Match time) so the patterns are validated once at Factory. + expanded := append([]string(nil), cfg.AllowedPrincipals...) + for _, acct := range cfg.AllowedAccounts { + if err := validateAccountID(acct); err != nil { + return nil, fmt.Errorf("aws_sigv4: allowed_accounts: %w", err) + } + expanded = append(expanded, expandAccountGlobs(acct)...) + } + matcher, err := NewArnMatcher(expanded) if err != nil { return nil, fmt.Errorf("aws_sigv4: allowed_principals: %w", err) } diff --git a/forge-core/auth/providers/aws_sigv4/provider_test.go b/forge-core/auth/providers/aws_sigv4/provider_test.go index 5d8767f..fb17dc9 100644 --- a/forge-core/auth/providers/aws_sigv4/provider_test.go +++ b/forge-core/auth/providers/aws_sigv4/provider_test.go @@ -178,6 +178,51 @@ func TestProvider_AllowlistHit_Succeeds(t *testing.T) { } } +func TestProvider_AllowedAccounts_AllowsAnyIdentityInAccount(t *testing.T) { + // The fake STS returns assumed-role ARN for account 123456789012. + // AllowedAccounts=[123456789012] expands to globs covering all + // identity shapes in that account → the assumed-role ARN matches. + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedAccounts = []string{"123456789012"} + }) + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("Verify: %v", err) + } +} + +func TestProvider_AllowedAccounts_DifferentAccountRejected(t *testing.T) { + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedAccounts = []string{"999999999999"} + }) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_AllowedAccounts_RejectsMalformedAtFactory(t *testing.T) { + _, err := New(Config{ + Region: "us-east-1", + AllowedAccounts: []string{"not-an-account"}, + }) + if err == nil { + t.Fatal("expected error on malformed account ID") + } +} + +func TestProvider_AllowedAccounts_MergesWithAllowedPrincipals(t *testing.T) { + // Mix: account-wide grant for 123456789012 (covers the test STS + // response) + a specific role pattern for some other account. + // Verify the account-wide entry takes precedence. + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedPrincipals = []string{"arn:aws:iam::999:role/specific"} + c.AllowedAccounts = []string{"123456789012"} + }) + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("Verify: %v", err) + } +} + func TestProvider_STSDown_Unavailable(t *testing.T) { p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go index e6fb8a4..654ac4e 100644 --- a/forge-core/validate/auth.go +++ b/forge-core/validate/auth.go @@ -85,6 +85,18 @@ func validateProviderSettings(prefix string, p types.AuthProvider, r *Validation if asString(p.Settings, "region") == "" { r.Errors = append(r.Errors, prefix+" (aws_sigv4): settings.region is required") } + // allowed_accounts entries must be 12-digit AWS account IDs. + // Catches typos at validate-time so a misconfig doesn't silently + // become an unreachable pattern. + if accts, ok := p.Settings["allowed_accounts"].([]any); ok { + for i, raw := range accts { + s, _ := raw.(string) + if len(s) != 12 || !isAllDigits(s) { + r.Errors = append(r.Errors, + fmt.Sprintf("%s (aws_sigv4): allowed_accounts[%d]=%q must be a 12-digit AWS account ID", prefix, i, s)) + } + } + } case "gcp_iap": if asString(p.Settings, "audience") == "" { r.Errors = append(r.Errors, prefix+" (gcp_iap): settings.audience is required (GCP backend service ID)") @@ -114,3 +126,16 @@ func asString(m map[string]any, key string) string { s, _ := v.(string) return s } + +// isAllDigits reports whether s consists only of ASCII digits. +func isAllDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] < '0' || s[i] > '9' { + return false + } + } + return true +} From 3364ca8f1828b3459a45f81e17309a9a3546f3fe Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 12:14:40 -0400 Subject: [PATCH 14/22] fix(cli/init): quote all-digit YAML scalars to preserve string type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Live wizard test surfaced a YAML emission bug: AWS account IDs like "412664885516" were rendered unquoted, e.g. allowed_accounts: - 412664885516 yaml.v3 decodes that as !!int → cannot unmarshal into []string in the aws_sigv4 provider's Config.AllowedAccounts field → provider construction fails at startup. needsYAMLQuoting now treats any all-digit string as requiring quotes. The renderer emits the correct form: allowed_accounts: - "412664885516" Fix is general — applies to anything in the auth-settings schema that's a digit string (ZIP-shaped IDs, version segments, etc.) so this same bug class can't surface for a new provider later. Live-validated after the fix: - auth_verify on the matching allowed_accounts (happy) - auth_fail reason=rejected on a wrong account (deny) --- forge-cli/cmd/init_auth.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index 057210b..dedc203 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -361,5 +361,24 @@ func needsYAMLQuoting(s string) bool { case "true", "false", "yes", "no", "on", "off", "null", "~": return true } + // All-digit strings would otherwise decode as integers — quote them + // to preserve string semantics (e.g. AWS account IDs are 12-digit + // strings, NOT numbers — and the provider expects []string). + if isAllDigits(s) { + return true + } return false } + +// isAllDigits returns true if s is non-empty and consists only of ASCII digits. +func isAllDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] < '0' || s[i] > '9' { + return false + } + } + return true +} From 774268d5b42855d30a028f32c6cfba0ea2b598c6 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:16:15 -0400 Subject: [PATCH 15/22] fix(auth): pin CheckRedirect on Phase 2 outbound clients (B1+B2+B3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 review (PR #79) flagged three same-root-cause issues across the outbound HTTP clients in azure_ad, aws_sigv4, and gcp_iap: - B1: ensureGraphHost accepted http:// nextLinks — caller's Bearer would leak in plaintext (Go strips Authorization on cross-host redirect but NOT on https→http same-host downgrade). - B2: Graph client used Go's default redirect policy (follow up to 10). ensureGraphHost only validates @odata.nextLink, not HTTP 301/302/307. A redirect from Graph → attacker URL would let the attacker control the JSON body that becomes Identity.Groups (latent until Phase 4 authz lands; closing at the boundary now). - B3: aws_sigv4's STS client had the same default-redirect issue. The parser-side same-host gate covers only the first hop; a redirect off sts..amazonaws.com would let attacker bytes become the parsed STS XML and control Identity.UserID/OrgID/Arn. Verified all three are real (Go default CheckRedirect follows up to 10 per net/http docs). Smoking gun confirmed in five outbound clients across forge-core/auth/providers/: aws_sigv4/sts_client.go ← fixed (B3) gcp_iap/iap_jwks.go ← fixed (not in review; same pattern, same blast radius — attacker JWKS means forged tokens verify) azure_ad/graph_client.go ← fixed (B1+B2) oidc/provider.go ← Phase 1 code, NOT touched here; OIDC discovery sometimes legitimately redirects, needs design discussion. Filed as follow-up on issue #80. httpverifier/provider.go ← Phase 1 code, NOT touched here; operator-configured URL may be behind LB. Follow-up on issue #80. Fix shape (uniform across the three Phase 2 providers): http.Client{ Timeout: timeout, CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }, } These three endpoints (STS GetCallerIdentity, IAP JWKS, Graph transitiveMemberOf) NEVER legitimately issue 3xx — the auth contract each provider validates is bound to the configured host. Returning ErrUseLastResponse causes the 3xx response to flow into our existing status-code switch, where the "unexpected status" default arm maps it to ErrProviderUnavailable (and the audit logs "provider_unavailable", not a misleading "rejected"). B1 specific (separate from the redirect-policy change): ensureGraphHost now checks both Host AND Scheme. GraphClient stores endpointScheme alongside endpointHost, pre-parsed at construction so the per-page check stays cheap. Test mode (httptest's http://) still works because the configured scheme is matched against, not hard- coded https. Regression tests added (one per provider): TestSTSClient_DoesNotFollowRedirects TestJWKSCache_DoesNotFollowRedirects TestGraphClient_DoesNotFollowRedirects Each spins up an httptest server that returns 302 → attacker URL, runs the client, asserts: - error returned (not silently followed) - error is ErrProviderUnavailable (correct audit reason) - endpoint hit exactly once (counts redirect-follows) Plus two scheme-specific cases for B1: TestEnsureGraphHost_RejectsSchemeDowngrade — https→http same-host TestEnsureGraphHost_TestModeHTTPOK — httptest http stays OK 42 packages pass go test -race -count=1; golangci-lint v2.10.1 clean; gofmt clean. --- .../auth/providers/aws_sigv4/sts_client.go | 15 ++++- .../providers/aws_sigv4/sts_client_test.go | 29 ++++++++ .../auth/providers/azure_ad/graph_client.go | 56 ++++++++++++---- .../providers/azure_ad/graph_client_test.go | 66 ++++++++++++++++++- forge-core/auth/providers/gcp_iap/iap_jwks.go | 18 ++++- .../auth/providers/gcp_iap/provider_test.go | 28 ++++++++ 6 files changed, 191 insertions(+), 21 deletions(-) diff --git a/forge-core/auth/providers/aws_sigv4/sts_client.go b/forge-core/auth/providers/aws_sigv4/sts_client.go index 72cccf0..7e33d18 100644 --- a/forge-core/auth/providers/aws_sigv4/sts_client.go +++ b/forge-core/auth/providers/aws_sigv4/sts_client.go @@ -28,9 +28,22 @@ type STSClient struct { // `region` is informational here (the URL itself carries the region in // its credential scope and host); we keep the arg for symmetry with the // pre-rewrite API and for future per-region tuning. +// +// CheckRedirect is pinned to ErrUseLastResponse. STS never legitimately +// issues 3xx; the parser-side host gate (sigv4_parser.go's expectedHost) +// only validates the FIRST hop. If we let Go's default policy auto-follow +// a 302, an attacker (MITM with a valid cert, TLS-inspecting corporate +// proxy, DNS hijack) could redirect us to a foreign URL whose body becomes +// the parsed STS XML — and that XML controls Identity.UserID/OrgID/Arn. +// Refuse redirects outright so the same-host guard actually holds. func NewSTSClient(_ /* region */ string, _ /* legacyOverrideUnused */ string, timeout time.Duration) *STSClient { return &STSClient{ - http: &http.Client{Timeout: timeout}, + http: &http.Client{ + Timeout: timeout, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + }, } } diff --git a/forge-core/auth/providers/aws_sigv4/sts_client_test.go b/forge-core/auth/providers/aws_sigv4/sts_client_test.go index db88ae2..87e5127 100644 --- a/forge-core/auth/providers/aws_sigv4/sts_client_test.go +++ b/forge-core/auth/providers/aws_sigv4/sts_client_test.go @@ -140,6 +140,35 @@ func TestSTSClient_RequestCount(t *testing.T) { } } +func TestSTSClient_DoesNotFollowRedirects(t *testing.T) { + // Review B3: Go's default http.Client follows redirects up to 10 + // hops. The parser-side host gate only validates the first hop — + // auto-following a 302 to attacker-controlled bytes would let + // those bytes become the parsed STS XML and control the stamped + // Identity. Pin: any 3xx is treated as STS-unavailable (we never + // follow), so the same-host guard actually holds. + var hits atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits.Add(1) + w.Header().Set("Location", "https://attacker.example.com/") + w.WriteHeader(http.StatusFound) // 302 + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error on 302; client must not follow") + } + // 3xx falls into the "unexpected status" arm → unavailable. + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if hits.Load() != 1 { + t.Errorf("STS hit %d times, want exactly 1 (redirect was followed)", hits.Load()) + } +} + func TestSTSClient_PreservesURLQueryString(t *testing.T) { // The pre-signed URL carries the signature in query params; the // client MUST send those verbatim to STS or STS will reject. diff --git a/forge-core/auth/providers/azure_ad/graph_client.go b/forge-core/auth/providers/azure_ad/graph_client.go index 7d442df..2d48bd1 100644 --- a/forge-core/auth/providers/azure_ad/graph_client.go +++ b/forge-core/auth/providers/azure_ad/graph_client.go @@ -22,9 +22,10 @@ import ( // token is reflected to Graph, which authorizes the read against the // user's delegated permission (GroupMember.Read.All). type GraphClient struct { - endpoint string // initial page URL - endpointHost string // pre-parsed for cheap per-page nextLink validation - http *http.Client + endpoint string // initial page URL + endpointHost string // pre-parsed for cheap per-page nextLink validation + endpointScheme string // pre-parsed for the same — guards against http://-downgrade nextLinks + http *http.Client } const graphBaseURL = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id&$top=100" @@ -41,14 +42,29 @@ func NewGraphClientWithEndpoint(endpoint string, timeout time.Duration) *GraphCl } func newGraphClientFor(endpoint string, timeout time.Duration) *GraphClient { - host := "" + host, scheme := "", "" if u, err := url.Parse(endpoint); err == nil { host = u.Host + scheme = u.Scheme } return &GraphClient{ - endpoint: endpoint, - endpointHost: host, - http: &http.Client{Timeout: timeout}, + endpoint: endpoint, + endpointHost: host, + endpointScheme: scheme, + http: &http.Client{ + Timeout: timeout, + // Reject HTTP redirects unconditionally. Graph paginates via + // application-layer @odata.nextLink (which we validate against + // the configured host + scheme); it does NOT need transport- + // layer 301/302/307s. Allowing the default redirect-follow + // policy would bypass our same-host guard — an attacker + // returning a 302 to a foreign URL would have the bearer + // reflected there. ErrUseLastResponse returns the 3xx as-is; + // our status-code switch then classifies it as unavailable. + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + }, } } @@ -70,7 +86,7 @@ func (c *GraphClient) TransitiveMemberOf(ctx context.Context, _ string, authHead out := []string{} next := c.endpoint for next != "" { - if err := ensureGraphHost(c.endpointHost, next); err != nil { + if err := ensureGraphHost(c.endpointHost, c.endpointScheme, next); err != nil { return nil, fmt.Errorf("%w: graph nextLink host: %v", auth.ErrProviderUnavailable, err) } page, nextURL, err := c.fetchPage(ctx, next, authHeader) @@ -129,13 +145,22 @@ func (c *GraphClient) fetchPage(ctx context.Context, u, authHeader string) (ids } // ensureGraphHost rejects @odata.nextLink values that point at a foreign -// host. Real Graph paginates within graph.microsoft.com. For tests where -// the endpoint is httptest's 127.0.0.1, the test-mode endpoint host is -// what we compare against. +// host OR downgrade the scheme. Both checks matter: // -// `configuredHost` is the pre-parsed Host of GraphClient.endpoint, computed -// once at construction so we don't reparse it for every paginated request. -func ensureGraphHost(configuredHost, candidate string) error { +// - Host: Graph paginates within graph.microsoft.com. A foreign-host +// nextLink would coerce Forge into sending the caller's Bearer to +// an attacker. +// - Scheme: Go's http.Client strips Authorization on cross-host +// redirects but NOT on cross-scheme (https→http) downgrades to the +// same host. A `nextLink: "http://graph.microsoft.com/..."` would +// pass the host check and then leak the Bearer in plaintext to +// anyone able to MITM the connection. Require scheme match too. +// +// `configuredHost` / `configuredScheme` are pre-parsed from +// GraphClient.endpoint at construction time so we don't reparse it for +// every paginated request. For tests, the configured scheme is http +// (httptest servers) and that's intentionally allowed. +func ensureGraphHost(configuredHost, configuredScheme, candidate string) error { if candidate == "" { return nil } @@ -146,5 +171,8 @@ func ensureGraphHost(configuredHost, candidate string) error { if !strings.EqualFold(configuredHost, got.Host) { return fmt.Errorf("nextLink host %q does not match configured %q", got.Host, configuredHost) } + if !strings.EqualFold(configuredScheme, got.Scheme) { + return fmt.Errorf("nextLink scheme %q does not match configured %q (Bearer-downgrade guard)", got.Scheme, configuredScheme) + } return nil } diff --git a/forge-core/auth/providers/azure_ad/graph_client_test.go b/forge-core/auth/providers/azure_ad/graph_client_test.go index bd7799c..0ee2e6e 100644 --- a/forge-core/auth/providers/azure_ad/graph_client_test.go +++ b/forge-core/auth/providers/azure_ad/graph_client_test.go @@ -117,8 +117,38 @@ func TestGraph_DefensivePaginationCap(t *testing.T) { } } +func TestGraphClient_DoesNotFollowRedirects(t *testing.T) { + // Review B2: ensureGraphHost only validates @odata.nextLink (the + // application-layer paginator). HTTP 301/302/307 from Graph were + // being auto-followed by Go's default policy, bypassing the host + // guard. An attacker (MITM, TLS-inspecting proxy, DNS hijack) + // returning a 302 to a foreign URL would have the response body + // JSON-unmarshalled as `value: [{id: ...}]` — those IDs become + // Identity.Groups, which the future Phase 4 authz layer will + // trust. Pin CheckRedirect = ErrUseLastResponse. + var hits int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits++ + w.Header().Set("Location", "https://attacker.example.com/value") + w.WriteHeader(http.StatusFound) + })) + defer srv.Close() + + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if err == nil { + t.Fatal("expected error on 302; client must not follow") + } + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if hits != 1 { + t.Errorf("graph hit %d times, want 1 (redirect was followed)", hits) + } +} + func TestEnsureGraphHost_RejectsForeignHost(t *testing.T) { - err := ensureGraphHost("graph.microsoft.com", "https://evil.example.com/me/next") + err := ensureGraphHost("graph.microsoft.com", "https", "https://evil.example.com/me/next") if err == nil { t.Fatal("expected error on foreign host") } @@ -126,7 +156,7 @@ func TestEnsureGraphHost_RejectsForeignHost(t *testing.T) { func TestEnsureGraphHost_AcceptsSameHost(t *testing.T) { err := ensureGraphHost( - "graph.microsoft.com", + "graph.microsoft.com", "https", "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$skiptoken=abc", ) if err != nil { @@ -135,11 +165,41 @@ func TestEnsureGraphHost_AcceptsSameHost(t *testing.T) { } func TestEnsureGraphHost_EmptyOK(t *testing.T) { - if err := ensureGraphHost("graph.microsoft.com", ""); err != nil { + if err := ensureGraphHost("graph.microsoft.com", "https", ""); err != nil { t.Errorf("empty nextLink should be ok, got %v", err) } } +func TestEnsureGraphHost_RejectsSchemeDowngrade(t *testing.T) { + // Review B1: http://-on-same-host MUST be rejected. Go's http.Client + // keeps Authorization across same-host redirects, so a downgrade + // would leak the caller's Bearer in plaintext to anyone with a + // network position to MITM the response. + err := ensureGraphHost( + "graph.microsoft.com", "https", + "http://graph.microsoft.com/v1.0/me/transitiveMemberOf?$skiptoken=abc", + ) + if err == nil { + t.Fatal("expected error on http downgrade") + } + if !strings.Contains(err.Error(), "scheme") { + t.Errorf("err should mention scheme; got %v", err) + } +} + +func TestEnsureGraphHost_TestModeHTTPOK(t *testing.T) { + // When the configured endpoint itself is http:// (httptest servers + // in unit/integration tests), http://-same-host nextLinks must + // still validate. + err := ensureGraphHost( + "127.0.0.1:54321", "http", + "http://127.0.0.1:54321/page2", + ) + if err != nil { + t.Errorf("test-mode http nextLink should be accepted, got %v", err) + } +} + // manyIDs returns a JSON snippet for `count` entries — used for the // pagination cap test. func manyIDs(count int) string { diff --git a/forge-core/auth/providers/gcp_iap/iap_jwks.go b/forge-core/auth/providers/gcp_iap/iap_jwks.go index 032a044..c525575 100644 --- a/forge-core/auth/providers/gcp_iap/iap_jwks.go +++ b/forge-core/auth/providers/gcp_iap/iap_jwks.go @@ -55,11 +55,23 @@ type IAPJWKSCache struct { } // NewIAPJWKSCache builds an empty cache pointing at the given URL. +// +// CheckRedirect is pinned to ErrUseLastResponse. The IAP JWKS host is +// hardcoded (decision §9.4) precisely so we never trust any other source +// of public keys — auto-following a 3xx to a foreign URL would let a +// MITM / TLS-inspecting proxy / DNS-hijack scenario substitute attacker +// keys, after which any forged token signed by those keys would verify. +// Refuse redirects outright. func NewIAPJWKSCache(url string, ttl, timeout time.Duration) *IAPJWKSCache { return &IAPJWKSCache{ - url: url, - ttl: ttl, - http: &http.Client{Timeout: timeout}, + url: url, + ttl: ttl, + http: &http.Client{ + Timeout: timeout, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + }, keys: map[string]*ecdsa.PublicKey{}, } } diff --git a/forge-core/auth/providers/gcp_iap/provider_test.go b/forge-core/auth/providers/gcp_iap/provider_test.go index edef453..f7d9362 100644 --- a/forge-core/auth/providers/gcp_iap/provider_test.go +++ b/forge-core/auth/providers/gcp_iap/provider_test.go @@ -297,6 +297,34 @@ func TestJWKSCache_StaleGraceOnOutage(t *testing.T) { } } +func TestJWKSCache_DoesNotFollowRedirects(t *testing.T) { + // Review (sibling of B2/B3): the IAP JWKS URL is hardcoded (§9.4) + // precisely so we never trust an alternate key source. Auto- + // following a 302 to attacker bytes would let an attacker substitute + // their own keys — any token forged with those keys would then + // verify. Pin: any 3xx is treated as JWKS-unavailable (we never + // follow). + var hits int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits++ + w.Header().Set("Location", "https://attacker.example.com/") + w.WriteHeader(http.StatusFound) + })) + defer srv.Close() + + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + err := c.refresh(context.Background()) + if err == nil { + t.Fatal("expected error on 302; client must not follow") + } + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if hits != 1 { + t.Errorf("JWKS endpoint hit %d times, want 1 (redirect was followed)", hits) + } +} + func TestJWKSCache_BackoffBumps(t *testing.T) { // Endpoint always 500 — observe backoffDuration grow on each refresh. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { From 674bf2f2f02207f154ddd4cba37c1f2da04c7b81 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:18:22 -0400 Subject: [PATCH 16/22] fix(auth/aws_sigv4): reject token URLs with userinfo (M1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 review M1: a token whose pre-signed URL contained userinfo (https://user:pass@sts.us-east-1.amazonaws.com/...) would pass our EqualFold(u.Host, expectedHost) check because RFC 3986 separates userinfo from host — net/url puts "user:pass" into u.User and "sts.us-east-1.amazonaws.com" into u.Host independently. http.Client.Do then synthesizes Authorization: Basic from u.User and ships those attacker-controlled bytes to STS. STS ignores Basic (it validates the X-Amz-Signature in the query string), so this isn't an active auth bypass — but attacker bytes still leave our box, with potential for exfiltration via STS access logs / timing. Fix: reject u.User != nil at parse time, BEFORE the host check, in sigv4_parser.go's ParseToken. One line; defense-in-depth. Regression test TestParseToken_RejectsUserinfo confirms the userinfo case is now caught and that the error message mentions "userinfo" (so future maintainers know what tripped). 42 packages still pass; lint + gofmt clean. --- .../auth/providers/aws_sigv4/sigv4_parser.go | 12 +++++++++++ .../providers/aws_sigv4/sigv4_parser_test.go | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go index daf43b0..0521992 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go @@ -85,6 +85,18 @@ func ParseToken(token, expectedHost string, requireHTTPS bool) (*PresignedToken, if u.Scheme != "https" && u.Scheme != "http" { return nil, fmt.Errorf("URL scheme %q is not http(s)", u.Scheme) } + // Reject userinfo BEFORE the host check. RFC 3986 separates + // userinfo from host, so net/url parses + // "https://user:pass@sts.us-east-1.amazonaws.com" into + // (u.User="user:pass", u.Host="sts.us-east-1.amazonaws.com") — + // the host check alone would let that token through. Then + // http.Client.Do would synthesize Authorization: Basic from + // u.User and ship attacker-controlled bytes to STS. STS ignores + // Basic (it uses the X-Amz-Signature query param), but we still + // don't want attacker bytes leaving the box. (Review M1.) + if u.User != nil { + return nil, errors.New("URL must not contain userinfo (RFC 3986 user:pass@ section)") + } if !strings.EqualFold(u.Host, expectedHost) { return nil, fmt.Errorf("URL host %q does not match expected %q (SSRF guard)", u.Host, expectedHost) } diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go index 7e30af2..d324037 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go @@ -73,6 +73,27 @@ func TestParseToken_EmptyPayload(t *testing.T) { } } +func TestParseToken_RejectsUserinfo(t *testing.T) { + // Review M1: net/url parses "https://user:pass@sts.us-east-1.amazonaws.com" + // into u.User != nil and u.Host == "sts.us-east-1.amazonaws.com" — the + // host check alone passes. We must reject userinfo explicitly, otherwise + // http.Client.Do would synthesize Authorization: Basic and ship + // attacker bytes to STS. + hostile := strings.Replace( + validPresignedURL, + "https://sts.us-east-1.amazonaws.com/", + "https://attacker:secret@sts.us-east-1.amazonaws.com/", + 1, + ) + _, err := ParseToken(makeToken(hostile), validHost, true) + if err == nil { + t.Fatal("expected error on URL with userinfo") + } + if !strings.Contains(err.Error(), "userinfo") { + t.Errorf("err should mention userinfo; got %v", err) + } +} + func TestParseToken_RejectsForeignHost(t *testing.T) { // SSRF guard — even if base64 decodes to a syntactically valid URL, // any non-STS host is rejected. From c6e0aff9c794af1d3c29d344b6ceef0a4096d31d Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:25:41 -0400 Subject: [PATCH 17/22] fix(auth/aws_sigv4): parser-side token freshness (M2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 review M2: STS enforces ~15min on X-Amz-Date + X-Amz-Expires server-side, but our IdentityCache (60s default TTL) would happily serve a cached Identity even if STS somehow accepted a stale token. The cache key being hash(AKID|YYYYMMDD) gave the impression of day-long replay potential — it doesn't, because the 60s TTL bounds it — but the parser had zero own-source freshness check, so the stolen-token replay window depended entirely on STS's enforcement. Adding parser-side freshness as defense-in-depth: PresignedToken gains two fields: SigTime time.Time ← parsed from X-Amz-Date Expires time.Duration ← parsed from X-Amz-Expires ParseToken now requires both query params (rejects missing / malformed / non-positive Expires). Parsing stays pure: no clock comparison happens inside ParseToken. PresignedToken.CheckFreshness(now, maxExpires, skew) gates: - X-Amz-Expires > maxExpires → cap (default 15min — matches what all standard AWS SDKs emit for GetCallerIdentity) - now > SigTime + Expires + skew → token expired - SigTime > now + skew → token from the future skew default 5min for normal clock drift. Config additions (both default-able, no operator action needed): MaxTokenExpires time.Duration `yaml:"max_token_expires,omitempty"` ClockSkew time.Duration `yaml:"clock_skew,omitempty"` Provider gains a `now func() time.Time` field, default time.Now, overridable in tests so freshness can be exercised without clock monkey-patching. Provider.Verify calls CheckFreshness right after region check, before the cache lookup — STS round-trip and identity cache hit are both short-circuited when freshness fails. ErrTokenRejected, audit reason "rejected". Tests: parser-level (sigv4_parser_test.go): TestParseToken_PopulatesSigTimeAndExpires TestParseToken_RejectsMissingAmzDate TestParseToken_RejectsMalformedAmzDate TestParseToken_RejectsMissingAmzExpires TestParseToken_RejectsNonNumericExpires TestParseToken_RejectsNonPositiveExpires TestCheckFreshness_Expired TestCheckFreshness_FromTheFuture TestCheckFreshness_ExceedsExpiresCap TestCheckFreshness_HappyPathInsideSkew provider-level (provider_test.go): TestProvider_RejectsExpiredToken TestProvider_RejectsTokenFromFuture TestProvider_RejectsOverlyLongExpiresClaim TestProvider_AcceptsTokenAtEdgeOfSkewWindow Existing test fixtures updated: tokenFor() now stamps X-Amz-Date built from fixedTestTime so freshness passes by default; newTestProvider pins Provider.now to that same instant. Day-bucket rollover test advances the clock by 24h to match its day-2 token. 42 packages pass, lint + gofmt clean. --- .../auth/providers/aws_sigv4/provider.go | 39 ++++++++ .../auth/providers/aws_sigv4/provider_test.go | 96 +++++++++++++++++- .../auth/providers/aws_sigv4/sigv4_parser.go | 94 ++++++++++++++++-- .../providers/aws_sigv4/sigv4_parser_test.go | 97 +++++++++++++++++++ 4 files changed, 313 insertions(+), 13 deletions(-) diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go index 2eae0ea..1714ceb 100644 --- a/forge-core/auth/providers/aws_sigv4/provider.go +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -84,6 +84,16 @@ const ProviderName = "aws_sigv4" const ( defaultIdentityCacheTTL = 60 * time.Second defaultHTTPTimeout = 5 * time.Second + + // Cap the token's self-declared lifetime (X-Amz-Expires). AWS allows + // up to 7 days for some services; STS GetCallerIdentity in practice + // is signed for 15 min by all standard SDKs. Anything longer is + // suspect. + defaultMaxTokenExpires = 15 * time.Minute + + // Clock skew tolerance between Forge and the caller. Generous enough + // for typical NTP drift but tight enough to bound stolen-token replay. + defaultClockSkew = 5 * time.Minute ) // Config controls the aws_sigv4 provider. @@ -128,6 +138,16 @@ type Config struct { // without re-checking with STS. Defaults to 60s. IdentityCacheTTL time.Duration `yaml:"identity_cache_ttl,omitempty"` + // MaxTokenExpires caps the X-Amz-Expires value the caller can stamp + // into the pre-signed URL. Defaults to 15min (matches what all + // standard AWS SDKs produce for GetCallerIdentity). Belt-and-braces + // gate on top of STS's own freshness enforcement. + MaxTokenExpires time.Duration `yaml:"max_token_expires,omitempty"` + + // ClockSkew is the tolerance window for clock drift between Forge + // and the caller when evaluating token freshness. Defaults to 5min. + ClockSkew time.Duration `yaml:"clock_skew,omitempty"` + // STSEndpoint is a TEST-ONLY override that changes the expected // pre-signed URL host (and relaxes the https requirement). Production // should leave this empty. @@ -153,6 +173,9 @@ type Provider struct { cache *IdentityCache sts *STSClient matcher *ArnMatcher + + // now is injectable for tests; defaults to time.Now in New(). + now func() time.Time } // New constructs a Provider after validating cfg. @@ -166,6 +189,12 @@ func New(cfg Config) (*Provider, error) { if cfg.HTTPTimeout == 0 { cfg.HTTPTimeout = defaultHTTPTimeout } + if cfg.MaxTokenExpires == 0 { + cfg.MaxTokenExpires = defaultMaxTokenExpires + } + if cfg.ClockSkew == 0 { + cfg.ClockSkew = defaultClockSkew + } // Expand the AllowedAccounts shortcut into canonical globs and merge // with operator-supplied AllowedPrincipals. We expand here (rather // than at Match time) so the patterns are validated once at Factory. @@ -201,6 +230,7 @@ func New(cfg Config) (*Provider, error) { cache: NewIdentityCache(cfg.IdentityCacheTTL), sts: NewSTSClient(cfg.Region, "", cfg.HTTPTimeout), matcher: matcher, + now: time.Now, }, nil } @@ -232,6 +262,15 @@ func (p *Provider) Verify(ctx context.Context, token string, _ auth.Headers) (*a return nil, fmt.Errorf("%w: token region %q != configured %s", auth.ErrTokenRejected, parsed.Region, p.cfg.Region) } + // Parser-side freshness check (Review M2). Belt-and-braces on top of + // STS's own ~15min enforcement: caps the X-Amz-Expires that any + // caller can claim, and rejects already-expired tokens before we + // ever round-trip to STS. This bounds the stolen-token replay + // window to MaxTokenExpires + ClockSkew, independent of our cache TTL. + if err := parsed.CheckFreshness(p.now(), p.cfg.MaxTokenExpires, p.cfg.ClockSkew); err != nil { + return nil, fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) + } + cacheKey := dateBucketKey(parsed.AKID, parsed.Date) if id, ok := p.cache.Get(cacheKey); ok { return id, nil diff --git a/forge-core/auth/providers/aws_sigv4/provider_test.go b/forge-core/auth/providers/aws_sigv4/provider_test.go index fb17dc9..305a56c 100644 --- a/forge-core/auth/providers/aws_sigv4/provider_test.go +++ b/forge-core/auth/providers/aws_sigv4/provider_test.go @@ -17,16 +17,26 @@ import ( "github.com/initializ/forge/forge-core/auth" ) +// fixedTestTime is the wall-clock the test helpers pretend it is. +// Tokens minted by tokenFor() are signed at this instant (X-Amz-Date) +// and Provider.now is pinned to this instant via newTestProvider, so +// CheckFreshness sees an in-window token by default. Tests that want +// to exercise expiry/skew override Provider.now after construction. +var fixedTestTime = time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC) + // tokenFor builds a forge-aws-v1 token whose embedded URL points at the // test STS server. AKID + date + region + signature are placeholders; -// the fake STS doesn't validate them. +// the fake STS doesn't validate them. X-Amz-Date is pinned to +// fixedTestTime so the freshness check passes by default. func tokenFor(stsURL, akid, dateYYYYMMDD, region string) string { q := url.Values{} q.Set("Action", "GetCallerIdentity") q.Set("Version", "2011-06-15") q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") q.Set("X-Amz-Credential", fmt.Sprintf("%s/%s/%s/sts/aws4_request", akid, dateYYYYMMDD, region)) - q.Set("X-Amz-Date", dateYYYYMMDD+"T010000Z") + // Build X-Amz-Date from the (possibly day-rolled) dateYYYYMMDD plus + // fixedTestTime's HHMMSS so date-bucket rollover tests still work. + q.Set("X-Amz-Date", dateYYYYMMDD+"T"+fixedTestTime.UTC().Format("150405")+"Z") q.Set("X-Amz-Expires", "900") q.Set("X-Amz-SignedHeaders", "host") q.Set("X-Amz-Signature", "fakesig"+akid) @@ -57,6 +67,9 @@ func newTestProvider(t *testing.T, sts http.Handler, opts ...func(*Config)) (*Pr if err != nil { t.Fatalf("New: %v", err) } + // Pin the provider's clock to fixedTestTime so tokens minted by + // tokenFor() pass the M2 freshness check. + p.now = func() time.Time { return fixedTestTime } return p, srv.URL } @@ -293,10 +306,13 @@ func TestProvider_DateBucketRollover_TriggersFreshSTSCall(t *testing.T) { calls.Add(1) _, _ = io.WriteString(w, happySTSXML) })) - // Day 1 + Day 2 → different cache keys. + // Day 1: use the default fixedTestTime clock. if _, err := p.Verify(context.Background(), tokenFor(stsURL, "AKIA", "20260524", "us-east-1"), nil); err != nil { t.Fatal(err) } + // Day 2: advance clock by 24h so the day-2 token is fresh under + // CheckFreshness (post-M2 freshness gate). + p.now = func() time.Time { return fixedTestTime.Add(24 * time.Hour) } if _, err := p.Verify(context.Background(), tokenFor(stsURL, "AKIA", "20260525", "us-east-1"), nil); err != nil { t.Fatal(err) } @@ -305,6 +321,80 @@ func TestProvider_DateBucketRollover_TriggersFreshSTSCall(t *testing.T) { } } +// --- Review M2: parser-side freshness --- + +func TestProvider_RejectsExpiredToken(t *testing.T) { + // Token's X-Amz-Date + Expires window is in the past relative to + // Provider.now. STS would also reject this, but we belt-and-brace. + p, stsURL := newTestProvider(t, happySTS()) + p.now = func() time.Time { + return fixedTestTime.Add(30 * time.Minute) // far past the 15min lifetime + 5min skew + } + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } + if err != nil && !strings.Contains(err.Error(), "expired") { + t.Errorf("err should mention 'expired'; got %v", err) + } +} + +func TestProvider_RejectsTokenFromFuture(t *testing.T) { + // Token signed beyond skew tolerance in the future. + p, stsURL := newTestProvider(t, happySTS()) + p.now = func() time.Time { + return fixedTestTime.Add(-1 * time.Hour) // way before the token's signing instant + } + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } + if err != nil && !strings.Contains(err.Error(), "future") { + t.Errorf("err should mention 'future'; got %v", err) + } +} + +func TestProvider_RejectsOverlyLongExpiresClaim(t *testing.T) { + // Caller crafted a token with X-Amz-Expires > 15min. STS would + // also reject the signature, but we belt-and-brace at parse-side. + p, stsURL := newTestProvider(t, happySTS()) + + // Build the token pointing at the test STS host so we pass the + // SSRF-guard host check; the freshness check then catches the + // oversized X-Amz-Expires. + q := url.Values{} + q.Set("Action", "GetCallerIdentity") + q.Set("Version", "2011-06-15") + q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") + q.Set("X-Amz-Credential", "AKIA/20260524/us-east-1/sts/aws4_request") + q.Set("X-Amz-Date", "20260524T010000Z") + q.Set("X-Amz-Expires", "3600") // 1 hour — exceeds our 15min cap + q.Set("X-Amz-SignedHeaders", "host") + q.Set("X-Amz-Signature", "abc") + full := stsURL + "/?" + q.Encode() + tok := TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(full)) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } + if err != nil && !strings.Contains(err.Error(), "X-Amz-Expires") { + t.Errorf("err should mention X-Amz-Expires cap; got %v", err) + } +} + +func TestProvider_AcceptsTokenAtEdgeOfSkewWindow(t *testing.T) { + // Just barely fresh: signed at fixedTestTime, expires after 15min, + // and now is fixedTestTime + 15min + 4min skew (still within 5min). + p, stsURL := newTestProvider(t, happySTS()) + p.now = func() time.Time { + return fixedTestTime.Add(15*time.Minute + 4*time.Minute) + } + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("token within skew window should pass, got %v", err) + } +} + func TestProvider_RegisteredInRegistry(t *testing.T) { p, err := auth.Build("aws_sigv4", map[string]any{ "region": "us-east-1", diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go index 0521992..ebf6f31 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/url" + "strconv" "strings" + "time" ) // TokenPrefix is the magic token-type marker that distinguishes @@ -25,11 +27,13 @@ const TokenPrefix = "forge-aws-v1." // encoding of "/" in X-Amz-Credential, "+" in X-Amz-Security-Token), and // any such re-encoding invalidates the signature. type PresignedToken struct { - RawURL string // the exact URL the AWS SDK produced — preserve as-is - URL *url.URL // parsed view, for host validation and query inspection only - AKID string // for IdentityCache bucket key - Date string // YYYYMMDD scope date — for IdentityCache bucket key - Region string // from the credential scope (we cross-check against cfg.Region) + RawURL string // the exact URL the AWS SDK produced — preserve as-is + URL *url.URL // parsed view, for host validation and query inspection only + AKID string // for IdentityCache bucket key + Date string // YYYYMMDD scope date — for IdentityCache bucket key + Region string // from the credential scope (we cross-check against cfg.Region) + SigTime time.Time // parsed X-Amz-Date — used by CheckFreshness, not by ParseToken itself + Expires time.Duration // parsed X-Amz-Expires — used by CheckFreshness } // ParseToken validates a forge-aws-v1 Bearer token end-to-end and returns @@ -117,15 +121,85 @@ func ParseToken(token, expectedHost string, requireHTTPS bool) (*PresignedToken, return nil, fmt.Errorf("X-Amz-Credential: %w", err) } + sigTime, err := parseAmzDate(q.Get("X-Amz-Date")) + if err != nil { + return nil, fmt.Errorf("X-Amz-Date: %w", err) + } + expires, err := parseAmzExpires(q.Get("X-Amz-Expires")) + if err != nil { + return nil, fmt.Errorf("X-Amz-Expires: %w", err) + } + return &PresignedToken{ - RawURL: string(raw), - URL: u, - AKID: akid, - Date: date, - Region: region, + RawURL: string(raw), + URL: u, + AKID: akid, + Date: date, + Region: region, + SigTime: sigTime, + Expires: expires, }, nil } +// CheckFreshness rejects tokens whose self-declared lifetime exceeds +// maxExpires OR whose validity window has already lapsed (with skew +// for clock drift between Forge and the caller). This is defense in +// depth on top of STS's own ~15min enforcement: if STS ever accepts +// a stale token, our IdentityCache would happily serve the cached +// Identity for its full TTL. Parser-side freshness closes that gap. +// +// Caller passes `now` and the limits so this is unit-testable without +// time monkey-patching. Provider supplies them from its Config. +func (t *PresignedToken) CheckFreshness(now time.Time, maxExpires, skew time.Duration) error { + if t.Expires > maxExpires { + return fmt.Errorf("X-Amz-Expires=%s exceeds cap %s", t.Expires, maxExpires) + } + // Token's own self-declared expiry passed already (with skew tolerance). + if now.After(t.SigTime.Add(t.Expires).Add(skew)) { + return fmt.Errorf("token expired: signed at %s + %s lifetime + %s skew, now %s", + t.SigTime.UTC().Format(time.RFC3339), t.Expires, skew, now.UTC().Format(time.RFC3339)) + } + // Token from the future beyond our skew tolerance — either a wildly + // skewed client OR a malicious signer trying to extend the validity + // window. STS itself catches this; we belt-and-brace. + if t.SigTime.Sub(now) > skew { + return fmt.Errorf("token signed in the future: %s vs now %s (skew %s)", + t.SigTime.UTC().Format(time.RFC3339), now.UTC().Format(time.RFC3339), skew) + } + return nil +} + +// parseAmzDate parses an X-Amz-Date timestamp in its standard form +// "YYYYMMDDTHHMMSSZ" (e.g. "20260524T150405Z"). UTC by definition. +func parseAmzDate(s string) (time.Time, error) { + if s == "" { + return time.Time{}, errors.New("missing X-Amz-Date") + } + t, err := time.Parse("20060102T150405Z", s) + if err != nil { + return time.Time{}, fmt.Errorf("malformed %q: %v", s, err) + } + return t, nil +} + +// parseAmzExpires parses the X-Amz-Expires query value (seconds, as a +// decimal integer string). AWS SDKs constrain this to [1, 604800] +// (1s to 7 days) at signing time; we additionally cap at CheckFreshness +// time per the operator's maxExpires. +func parseAmzExpires(s string) (time.Duration, error) { + if s == "" { + return 0, errors.New("missing X-Amz-Expires") + } + n, err := strconv.Atoi(s) + if err != nil { + return 0, fmt.Errorf("not an integer: %q", s) + } + if n <= 0 { + return 0, fmt.Errorf("must be positive, got %d", n) + } + return time.Duration(n) * time.Second, nil +} + // parseCredentialScope splits X-Amz-Credential into its five segments: // // AKID/YYYYMMDD/region/service/aws4_request diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go index d324037..4ffdfb1 100644 --- a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go @@ -5,6 +5,7 @@ import ( "errors" "strings" "testing" + "time" ) // makeToken builds a forge-aws-v1 token from a complete URL. Helper for @@ -151,6 +152,102 @@ func TestParseToken_RejectsBadBase64(t *testing.T) { } } +// --- Review M2: parser surface for freshness --- + +func TestParseToken_PopulatesSigTimeAndExpires(t *testing.T) { + parsed, err := ParseToken(makeToken(validPresignedURL), validHost, true) + if err != nil { + t.Fatalf("ParseToken: %v", err) + } + wantTime := time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC) + if !parsed.SigTime.Equal(wantTime) { + t.Errorf("SigTime = %v, want %v", parsed.SigTime, wantTime) + } + if parsed.Expires != 900*time.Second { + t.Errorf("Expires = %v, want 900s", parsed.Expires) + } +} + +func TestParseToken_RejectsMissingAmzDate(t *testing.T) { + u := strings.Replace(validPresignedURL, "&X-Amz-Date=20260524T010000Z", "", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on missing X-Amz-Date") + } +} + +func TestParseToken_RejectsMalformedAmzDate(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Date=20260524T010000Z", "X-Amz-Date=not-a-date", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on malformed X-Amz-Date") + } +} + +func TestParseToken_RejectsMissingAmzExpires(t *testing.T) { + u := strings.Replace(validPresignedURL, "&X-Amz-Expires=900", "", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on missing X-Amz-Expires") + } +} + +func TestParseToken_RejectsNonNumericExpires(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Expires=900", "X-Amz-Expires=forever", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on non-numeric X-Amz-Expires") + } +} + +func TestParseToken_RejectsNonPositiveExpires(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Expires=900", "X-Amz-Expires=0", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on zero X-Amz-Expires") + } +} + +func TestCheckFreshness_Expired(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC), + Expires: 15 * time.Minute, + } + now := time.Date(2026, 5, 24, 1, 25, 0, 0, time.UTC) // 25min later, beyond 15min+5min skew + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err == nil { + t.Fatal("expected expired error") + } +} + +func TestCheckFreshness_FromTheFuture(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 2, 0, 0, 0, time.UTC), + Expires: 15 * time.Minute, + } + now := time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC) // 1h before token's sign time + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err == nil { + t.Fatal("expected future-token error") + } +} + +func TestCheckFreshness_ExceedsExpiresCap(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC), + Expires: 1 * time.Hour, // exceeds the 15min cap + } + now := tok.SigTime + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err == nil { + t.Fatal("expected cap error") + } +} + +func TestCheckFreshness_HappyPathInsideSkew(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC), + Expires: 15 * time.Minute, + } + // 19 min later: past Expires but within skew tolerance. + now := tok.SigTime.Add(19 * time.Minute) + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err != nil { + t.Errorf("token within skew window should pass, got %v", err) + } +} + func TestParseCredentialScope_HappyPath(t *testing.T) { akid, date, region, err := parseCredentialScope("AKIA123/20260524/us-east-1/sts/aws4_request") if err != nil { From 57d12a5c8a033c722ed703ae6b641d44fe6b2434 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:28:08 -0400 Subject: [PATCH 18/22] fix(auth): refine token_kind post-verify when gcp_iap matches (M4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 review M4: structural token_kind detection runs BEFORE the chain, so a request with both a Bearer JWT and an X-Goog-Iap-Jwt- Assertion records kind="jwt" — masking IAP-fronted traffic in audit dashboards even when gcp_iap was the actual verifier. The structural rule answers "what bytes were on the wire?" The audit signal we want answers "which auth path verified?" Those diverge only when a Bearer is present alongside a non-Bearer provider's payload — today, gcp_iap is the only such provider. Fix: identity, err := chain.Verify(...) if err == nil { kind = refineTokenKind(kind, identity.Source) } refineTokenKind: - Source == "gcp_iap" → return "iap_jwt" - otherwise → return structural kind unchanged Other providers (oidc, azure_ad, aws_sigv4, http_verifier, static_token) don't need refinement: their structural kind already matches the auth path that verifies them. Failure paths still record the structural kind (no chain identity to read Source from), so the existing "missing_token" / "not_for_me" audit reason-code contract is unaffected. Tests: TestMiddleware_TokenKind_RefinedToIapJwtWhenGCPIAPVerifies — Bearer JWT + IAP header both present; chain stubbed to return Source="gcp_iap"; assert kind upgrades to iap_jwt TestMiddleware_TokenKind_JWT_NotRefinedForOIDCProviders — counter-test: Source="oidc" must NOT trigger the refinement --- forge-core/auth/middleware.go | 32 +++++++++++++++ forge-core/auth/middleware_test.go | 65 ++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index 6dae9a4..13af495 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -133,6 +133,16 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { return } + // Phase 2 (Review M4): refine token_kind from the structural + // shape to the actual provider that matched. The structural + // kind says "what bytes were on the wire"; the post-verify + // kind says "which auth path succeeded." A request with both + // a Bearer JWT AND an X-Goog-Iap-Jwt-Assertion would record + // kind="jwt" under the structural rule even though gcp_iap + // was the verifier — that mis-attributes IAP-fronted traffic + // in audit dashboards. + kind = refineTokenKind(kind, identity.Source) + notifyAuth(opts.OnAuth, r, identity, nil, kind) ctx := WithIdentity(r.Context(), identity) @@ -141,6 +151,28 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { } } +// refineTokenKind upgrades the audit token_kind from the pre-verify +// structural classification to one that reflects which auth path +// actually succeeded. +// +// Today the only refinement is gcp_iap → "iap_jwt": that provider +// reads X-Goog-Iap-Jwt-Assertion (not the Bearer slot), so the +// structural rule cannot detect it when a Bearer is ALSO present. +// Refining post-verify keeps the audit signal clean even when an +// IAP-fronted Forge instance ALSO carries a Bearer JWT for app-level +// auth chaining. +// +// Other providers (oidc, azure_ad, aws_sigv4, http_verifier, +// static_token) don't need refinement: their structural kind already +// matches the auth path (aws_sigv4 has its own "forge-aws-v1." prefix +// → "sigv4"; everything else is just "jwt"/"opaque"). +func refineTokenKind(structural, providerSource string) string { + if providerSource == "gcp_iap" { + return "iap_jwt" + } + return structural +} + // notifyAuth invokes the OnAuth callback if set, swallowing the nil check // at the call sites so the main middleware body stays readable. func notifyAuth(cb func(*http.Request, *Identity, error, string), r *http.Request, id *Identity, err error, kind string) { diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index be0ce6c..238d0b0 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -575,6 +575,71 @@ func TestMiddleware_TokenKind_Sigv4FromForgeAwsToken(t *testing.T) { } } +func TestMiddleware_TokenKind_RefinedToIapJwtWhenGCPIAPVerifies(t *testing.T) { + // Review M4: structural kind is "jwt" when a Bearer is present, but + // if gcp_iap was actually the verifier (because it read the IAP + // header instead of the Bearer), the audit kind must be "iap_jwt" + // so dashboards count IAP-fronted traffic correctly. + idFromIAP := &Identity{UserID: "iap-user", Source: "gcp_iap"} + chain := NewChainProvider(&headerCapturingProvider{ + identity: idFromIAP, + }) + + var gotKind string + opts := MiddlewareOptions{ + Chain: chain, + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, id *Identity, _ error, kind string) { + if id != nil && id.Source == "gcp_iap" { + gotKind = kind + } + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + // Both a Bearer JWT (structural kind="jwt") AND an IAP header present. + // The chain (stubbed to return Source="gcp_iap") simulates gcp_iap + // being the actual verifier. + req.Header.Set("Authorization", "Bearer eyJ.eyJ.sig") + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJ.eyJ.iap-sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "iap_jwt" { + t.Errorf("token kind = %q, want iap_jwt (refined post-verify from provider Source)", gotKind) + } +} + +func TestMiddleware_TokenKind_JWT_NotRefinedForOIDCProviders(t *testing.T) { + // Counter-test: when oidc / azure_ad is the verifier (Source != + // gcp_iap), the structural "jwt" kind stays — we don't over-refine. + id := &Identity{UserID: "alice", Source: "oidc"} + chain := NewChainProvider(&headerCapturingProvider{identity: id}) + + var gotKind string + handler := Middleware(MiddlewareOptions{ + Chain: chain, + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, id *Identity, _ error, kind string) { + if id != nil { + gotKind = kind + } + }, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "Bearer eyJ.eyJ.sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "jwt" { + t.Errorf("token kind = %q, want jwt (no refinement for oidc Source)", gotKind) + } +} + func TestMiddleware_TokenKind_IapJwtOnIAPHeader(t *testing.T) { // Phase 2: when the only auth header is X-Goog-Iap-Jwt-Assertion, // audit emits token_kind="iap_jwt" (not "empty"). Pinning this so From 419f20f0fd606118d993da2fe32e5dfc1de6b2ad Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:34:34 -0400 Subject: [PATCH 19/22] feat(auth/azure_ad): allowed_tenants gate + fix misleading multi-tenant doc (M3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 review M3: 1. When allow_multi_tenant: true, both the iss check (via composed oidc.Provider's SkipIssuerCheck) AND the tid check were suppressed — any Entra tenant in the world verified. 2. resolveIssuer()'s docstring said "azure_ad enforces tenant via the tid claim instead" which was misleading: in multi-tenant mode there was no tid enforcement at all. Fix is two-part: add the missing knob, fix the doc. allowed_tenants (Config.AllowedTenants []string) Optional allowlist of Entra tenant GUIDs matched against the JWT `tid` claim. Only meaningful with allow_multi_tenant=true. Three operational modes: single-tenant (default) tid MUST equal TenantID multi-tenant + AllowedTenants set tid MUST be in list multi-tenant + AllowedTenants empty no tid check ("any tenant globally" — documented high-risk shape, opted into by deliberately omitting) Match is case-insensitive (Entra emits lowercase GUIDs; operators often paste uppercase from the portal). Factory-time validation rejects allowed_tenants in single-tenant mode (TenantID is THE gate there; the combination would silently degrade if not caught). Validator + non-interactive flag + tests - forge-core/validate/auth.go: rejects allowed_tenants in single- tenant mode at validate-time; WARNS when multi-tenant+empty so the "any tenant globally" trade-off is loud - forge-cli/cmd/init.go: --auth-azure-allowed-tenant flag (repeatable, mirrors --auth-aws-allowed-account) - forge-cli/cmd/init_auth.go: buildAuthFromFlags forwards through - forge-cli/cmd/init_auth_test.go: mock adds the new flag Five new tests (provider_test.go): AllowedTenants_AcceptsListed AllowedTenants_RejectsUnlisted AllowedTenants_CaseInsensitive AllowedTenants_MissingTidRejected SingleTenant_WithAllowedTenants_RejectedAtFactory resolveIssuer() docstring rewrite Now spells out the three operational modes explicitly, including the "any-tenant" shape with its security implication. No more "azure_ad enforces tenant via the tid claim" line that's only half-true. CHANGELOG gains an "allowed_tenants" section with the canonical recipe. 42 packages pass, lint + gofmt clean. --- CHANGELOG.md | 21 ++++ forge-cli/cmd/init.go | 1 + forge-cli/cmd/init_auth.go | 7 ++ forge-cli/cmd/init_auth_test.go | 1 + .../auth/providers/azure_ad/provider.go | 86 ++++++++++++-- .../auth/providers/azure_ad/provider_test.go | 112 ++++++++++++++++++ forge-core/validate/auth.go | 21 ++++ 7 files changed, 241 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e502ae9..d657540 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,27 @@ Identity Center (SSO) — SSO permission sets gate Org membership at sign-in, and you can match Identity Center-assumed roles with the existing `allowed_principals` globs. +### `azure_ad.allowed_tenants` — explicit allowlist for multi-tenant mode + +```yaml +auth: + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true + allowed_tenants: + - "00000000-1111-2222-3333-444444444444" # partner A + - "55555555-6666-7777-8888-999999999999" # partner B +``` + +When `allow_multi_tenant: true`, the `tid` claim must be in +`allowed_tenants` (case-insensitive GUID match). Empty list + +multi-tenant remains the documented "any tenant globally" mode for +back-compat, but `forge validate` now emits a warning when the list +is empty to make the trade-off explicit. Non-interactive flag: +`--auth-azure-allowed-tenant` (repeatable). + ### TUI wizard supports Phase 2 providers `forge init`'s TUI picker now includes `AWS Sigv4 (IAM)`, diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index 6c5c70a..ba1e71b 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -154,6 +154,7 @@ func init() { initCmd.Flags().String("auth-azure-tenant", "", "Entra tenant GUID (required unless --auth-azure-multi-tenant)") initCmd.Flags().String("auth-azure-audience", "", "Entra application ID URI / audience") initCmd.Flags().Bool("auth-azure-multi-tenant", false, "accept tokens from any Entra tenant (default false)") + initCmd.Flags().StringSlice("auth-azure-allowed-tenant", nil, "allowed Entra tenant GUID for multi-tenant mode (repeatable; empty list = any tenant globally)") initCmd.Flags().String("auth-azure-groups-mode", "", "azure_ad groups mode: claim (default) or graph") } diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index dedc203..f227069 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -86,6 +86,7 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an tenant, _ := cmd.Flags().GetString("auth-azure-tenant") audience, _ := cmd.Flags().GetString("auth-azure-audience") multiTenant, _ := cmd.Flags().GetBool("auth-azure-multi-tenant") + allowedTenants, _ := cmd.Flags().GetStringSlice("auth-azure-allowed-tenant") groupsMode, _ := cmd.Flags().GetString("auth-azure-groups-mode") if audience == "" { return nil, nil, fmt.Errorf("--auth=azure_ad requires --auth-azure-audience") @@ -93,6 +94,9 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an if !multiTenant && tenant == "" { return nil, nil, fmt.Errorf("--auth=azure_ad requires --auth-azure-tenant unless --auth-azure-multi-tenant=true") } + if !multiTenant && len(allowedTenants) > 0 { + return nil, nil, fmt.Errorf("--auth-azure-allowed-tenant is only meaningful with --auth-azure-multi-tenant=true") + } settings = map[string]any{"audience": audience} if tenant != "" { settings["tenant_id"] = tenant @@ -100,6 +104,9 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an if multiTenant { settings["allow_multi_tenant"] = true } + if len(allowedTenants) > 0 { + settings["allowed_tenants"] = allowedTenants + } if groupsMode != "" { settings["groups_mode"] = groupsMode } diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go index 3d588ad..88288a3 100644 --- a/forge-cli/cmd/init_auth_test.go +++ b/forge-cli/cmd/init_auth_test.go @@ -225,6 +225,7 @@ func mockInitCmdWithFlags() *cobra.Command { c.Flags().String("auth-azure-tenant", "", "") c.Flags().String("auth-azure-audience", "", "") c.Flags().Bool("auth-azure-multi-tenant", false, "") + c.Flags().StringSlice("auth-azure-allowed-tenant", nil, "") c.Flags().String("auth-azure-groups-mode", "", "") return c } diff --git a/forge-core/auth/providers/azure_ad/provider.go b/forge-core/auth/providers/azure_ad/provider.go index 14539e9..95108fa 100644 --- a/forge-core/auth/providers/azure_ad/provider.go +++ b/forge-core/auth/providers/azure_ad/provider.go @@ -23,6 +23,7 @@ package azure_ad import ( "context" "fmt" + "strings" "time" "github.com/initializ/forge/forge-core/auth" @@ -50,11 +51,29 @@ type Config struct { // app registration (e.g. "api://forge"). Audience string `yaml:"audience"` - // AllowMultiTenant enables accepting tokens from ANY Entra tenant. - // Defaults to false (single-tenant — safe choice). PR6 docs flag the - // security implications of flipping this on in production. + // AllowMultiTenant enables accepting tokens from Entra tenants other + // than the one in TenantID. Defaults to false (single-tenant — safe + // choice). When true: + // - the composed oidc.Provider's issuer-equality check is + // suppressed (the "common" issuer template has a {tenantid} + // placeholder that string-equality can't satisfy) + // - tenancy enforcement moves to AllowedTenants (below); see + // CHANGELOG for the security implications AllowMultiTenant bool `yaml:"allow_multi_tenant,omitempty"` + // AllowedTenants is an optional allowlist of Entra tenant GUIDs, + // matched against the JWT's `tid` claim. Only meaningful when + // AllowMultiTenant=true; ignored in single-tenant mode (TenantID + // is the gate there). + // + // Empty list + AllowMultiTenant=true = "any tenant globally" — + // the documented but high-risk shape. Set this list for the safer + // "these specific tenants only" semantic. + // + // Effort to set: customers know their partner tenants; operators + // just copy GUIDs in. There is no API to enumerate them. + AllowedTenants []string `yaml:"allowed_tenants,omitempty"` + // GroupsMode is "claim" (default — uses the in-JWT groups/roles // claim) or "graph" (queries Microsoft Graph when groups are missing, // i.e. AAD overage). @@ -83,6 +102,12 @@ func (c Config) Validate() error { if c.GroupsMode != "" && c.GroupsMode != "claim" && c.GroupsMode != "graph" { return fmt.Errorf("%w: groups_mode must be 'claim' or 'graph', got %q", auth.ErrProviderNotConfigured, c.GroupsMode) } + // allowed_tenants only makes sense with multi-tenant. Reject the + // combination at factory time so a typo'd config doesn't silently + // degrade to single-tenant behavior the operator didn't intend. + if !c.AllowMultiTenant && len(c.AllowedTenants) > 0 { + return fmt.Errorf("%w: allowed_tenants is only meaningful when allow_multi_tenant=true (single-tenant mode uses tenant_id directly)", auth.ErrProviderNotConfigured) + } return nil } @@ -149,7 +174,15 @@ func (p *Provider) Verify(ctx context.Context, token string, headers auth.Header return nil, err } - // Tenant gate (single-tenant default). + // Tenant gate. Three modes: + // + // - single-tenant (AllowMultiTenant=false): tid MUST equal TenantID + // - multi-tenant + AllowedTenants set: tid MUST be in the list + // - multi-tenant + AllowedTenants empty: no tid check at all + // ("any tenant globally" + // — the documented high-risk + // shape, explicitly opted into + // by leaving the list empty) if !p.cfg.AllowMultiTenant { tid := ExtractTenantID(id.Claims) if tid == "" { @@ -158,6 +191,18 @@ func (p *Provider) Verify(ctx context.Context, token string, headers auth.Header if tid != p.cfg.TenantID { return nil, fmt.Errorf("%w: tid mismatch", auth.ErrTokenRejected) } + } else if len(p.cfg.AllowedTenants) > 0 { + // Multi-tenant with an explicit allowlist (Review M3): the + // composed oidc.Provider's iss check is suppressed and the + // single-tenant arm above is skipped, but operators who set + // AllowedTenants want explicit per-tenant trust — enforce it. + tid := ExtractTenantID(id.Claims) + if tid == "" { + return nil, fmt.Errorf("%w: AAD token missing tid claim", auth.ErrInvalidToken) + } + if !tenantInAllowlist(tid, p.cfg.AllowedTenants) { + return nil, fmt.Errorf("%w: tid %q not in allowed_tenants", auth.ErrTokenRejected, tid) + } } // Optional Graph enrichment. @@ -175,10 +220,22 @@ func (p *Provider) Verify(ctx context.Context, token string, headers auth.Header } // resolveIssuer picks the issuer URL passed to the composed OIDC -// provider. For single-tenant, that's the full per-tenant authority. -// For multi-tenant, the "common" endpoint serves JWKS for all tenants -// — we point OIDC there and disable iss equality (azure_ad enforces -// tenant via the tid claim instead). +// provider. +// +// For SINGLE-TENANT (AllowMultiTenant=false): the full per-tenant +// authority URL. oidc.Provider's iss-equality check is in force, AND +// Verify() additionally enforces tid == TenantID. Double-gate. +// +// For MULTI-TENANT (AllowMultiTenant=true): the "common" endpoint, +// which serves JWKS for all Entra tenants. oidc.Provider's iss check +// is suppressed via SkipIssuerCheck because "common"'s issuer template +// (`https://login.microsoftonline.com/{tenantid}/v2.0`) cannot be +// satisfied by string equality. Tenancy gating then depends on +// AllowedTenants: +// - non-empty list: Verify() enforces tid ∈ AllowedTenants +// - empty list: no tid check anywhere — ANY Entra tenant in the +// world is accepted ("any-tenant" mode, opted into +// by deliberately omitting the list) func resolveIssuer(cfg Config) string { if cfg.AllowMultiTenant { return aadAuthorityBase + "/common/v2.0" @@ -186,6 +243,19 @@ func resolveIssuer(cfg Config) string { return fmt.Sprintf("%s/%s/v2.0", aadAuthorityBase, cfg.TenantID) } +// tenantInAllowlist reports whether tid is one of the configured +// AllowedTenants entries. Match is case-insensitive because Entra +// emits GUIDs in lowercase but operators commonly paste them in +// either case from the Azure portal. +func tenantInAllowlist(tid string, allowed []string) bool { + for _, a := range allowed { + if strings.EqualFold(tid, a) { + return true + } + } + return false +} + // needsEnrichment returns true when Graph should be consulted. AAD // emits no `groups` claim (or a `_claim_names` indicator) when a user // is in too many groups. Phase 1 OIDC surfaces empty Groups in that diff --git a/forge-core/auth/providers/azure_ad/provider_test.go b/forge-core/auth/providers/azure_ad/provider_test.go index 587d790..fcdf7a1 100644 --- a/forge-core/auth/providers/azure_ad/provider_test.go +++ b/forge-core/auth/providers/azure_ad/provider_test.go @@ -12,6 +12,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -250,6 +251,8 @@ func TestProvider_WrongTenant_Rejected(t *testing.T) { } func TestProvider_MultiTenant_AcceptsArbitraryTenant(t *testing.T) { + // Documented high-risk shape: AllowMultiTenant=true with NO + // AllowedTenants list means "accept any tenant globally." f := newFakeAAD(t) p := newTestProviderAADMultiTenant(t, f) @@ -261,6 +264,115 @@ func TestProvider_MultiTenant_AcceptsArbitraryTenant(t *testing.T) { } } +// --- Review M3: AllowedTenants gate in multi-tenant mode --- + +func TestProvider_MultiTenant_AllowedTenants_AcceptsListed(t *testing.T) { + f := newFakeAAD(t) + wantTID := "33333333-3333-3333-3333-333333333333" + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{ + "22222222-2222-2222-2222-222222222222", + wantTID, + }, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), wantTID, "api://forge") + tok := f.sign(claims) + + if _, err := p.Verify(context.Background(), tok, auth.Headers{}); err != nil { + t.Errorf("expected success for listed tenant, got %v", err) + } +} + +func TestProvider_MultiTenant_AllowedTenants_RejectsUnlisted(t *testing.T) { + f := newFakeAAD(t) + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{"22222222-2222-2222-2222-222222222222"}, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), "deadbeef-dead-beef-dead-beefdeadbeef", "api://forge") + tok := f.sign(claims) + + _, vErr := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(vErr, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (tid not in allowlist)", vErr) + } + if vErr != nil && !strings.Contains(vErr.Error(), "allowed_tenants") { + t.Errorf("err should mention allowed_tenants; got %v", vErr) + } +} + +func TestProvider_MultiTenant_AllowedTenants_CaseInsensitive(t *testing.T) { + // Entra emits lowercase tids; operators sometimes paste uppercase + // from the portal. Match should tolerate either. + f := newFakeAAD(t) + wantTID := "33333333-3333-3333-3333-333333333333" + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{"33333333-3333-3333-3333-333333333333"}, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), strings.ToUpper(wantTID), "api://forge") + tok := f.sign(claims) + + if _, err := p.Verify(context.Background(), tok, auth.Headers{}); err != nil { + t.Errorf("expected case-insensitive match, got %v", err) + } +} + +func TestProvider_MultiTenant_AllowedTenants_MissingTidRejected(t *testing.T) { + // Multi-tenant + AllowedTenants set + token has no tid → reject. + // In "any-tenant" mode (empty AllowedTenants) a missing tid would + // be allowed (we wouldn't be checking), but the moment the operator + // sets a list they want the gate enforced. + f := newFakeAAD(t) + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{"22222222-2222-2222-2222-222222222222"}, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), "ignored", "api://forge") + delete(claims, "tid") + tok := f.sign(claims) + + _, vErr := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(vErr, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (missing tid)", vErr) + } +} + +func TestProvider_SingleTenant_WithAllowedTenants_RejectedAtFactory(t *testing.T) { + // Config-level guard: AllowedTenants is meaningless in single-tenant + // mode (TenantID is THE gate). Reject the combo so a typo doesn't + // silently degrade. + _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "tenant_id": "11111111-1111-1111-1111-111111111111", + "allowed_tenants": []any{"22222222-2222-2222-2222-222222222222"}, + }) + if err == nil { + t.Fatal("expected error: allowed_tenants requires allow_multi_tenant=true") + } +} + func TestProvider_MissingTid_Invalid(t *testing.T) { f := newFakeAAD(t) p := newTestProviderAADSingleTenant(t, f, "11111111-1111-1111-1111-111111111111") diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go index 654ac4e..c704857 100644 --- a/forge-core/validate/auth.go +++ b/forge-core/validate/auth.go @@ -110,6 +110,27 @@ func validateProviderSettings(prefix string, p types.AuthProvider, r *Validation if !multi && asString(p.Settings, "tenant_id") == "" { r.Errors = append(r.Errors, prefix+" (azure_ad): settings.tenant_id is required unless allow_multi_tenant=true") } + // allowed_tenants only makes sense with multi-tenant. + hasAllowed := false + switch v := p.Settings["allowed_tenants"].(type) { + case []any: + hasAllowed = len(v) > 0 + case []string: + hasAllowed = len(v) > 0 + } + if !multi && hasAllowed { + r.Errors = append(r.Errors, + prefix+" (azure_ad): allowed_tenants is only meaningful when allow_multi_tenant=true") + } + // "Any-tenant mode" warning: multi-tenant + empty allowed_tenants + // admits any Entra tenant globally. Documented trade-off, but + // warn so operators don't ship it by accident. + if multi && !hasAllowed { + r.Warnings = append(r.Warnings, + prefix+" (azure_ad): allow_multi_tenant=true with no allowed_tenants list "+ + "admits any Entra tenant globally — set allowed_tenants if you want to "+ + "restrict to specific partner tenants") + } if mode := asString(p.Settings, "groups_mode"); mode != "" && mode != "claim" && mode != "graph" { r.Errors = append(r.Errors, fmt.Sprintf("%s (azure_ad): groups_mode must be 'claim' or 'graph', got %q", prefix, mode)) } From ca92921a69bfcd6e932a6b0747731588393080bc Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:39:09 -0400 Subject: [PATCH 20/22] fix(auth): closed-key whitelist for provider settings (M5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 review M5: validateProviderSettings was permissive on unknown keys. A typo like `aud:` instead of `audience:` would slip past the required-key check silently (asString returns ""), and — more critically — the forge-ui handler forwarded a.Settings unfiltered to the on-disk scaffold. A POST like {"settings": {"audience": "x", "evil_key": "y"}} would write evil_key into forge.yaml verbatim. Today's provider Config structs ignore unknown YAML fields via yaml.v3's default, so it's harmless — but one missed `yaml:"-"` tag on a future field would suddenly make it reachable from untrusted POST input. Closed at the boundary now. Two-part fix: 1. forge-core/validate/auth.go - New exported KnownAuthProviderSettings map[string]map[string]bool (closed whitelist per provider type — mirrors the yaml: tags on each provider's Config struct, kept in sync by convention). Internal yaml:"-" fields are intentionally absent. - validateProviderSettings emits a Warning per unknown settings key during `forge validate`. Loose-not-strict because some operators stash custom annotations from pre-Phase-2 days; warning surfaces the typo without breaking those configs. - New exported FilterKnownSettings(providerType, settings) returns a copy with unknown keys dropped. Defense-in-depth filter for use at trust boundaries. 2. forge-ui/handlers_create.go - validateAuthPayload calls FilterKnownSettings on the incoming a.Settings BEFORE validation OR scaffolding. Closes the specific exploit chain (Web UI POST → forge.yaml) regardless of yaml tag discipline drift on Config structs. Tests: validate package: TestValidateAuthConfig_WarnsOnUnknownSettingsKey TestValidateAuthConfig_NoWarningForKnownKeys TestFilterKnownSettings_DropsUnknownKeys TestFilterKnownSettings_UnknownProviderTypePassthrough TestFilterKnownSettings_AllPhase2Providers (subtests per provider/key) forge-ui: TestHandleCreateAgent_FiltersUnknownAuthSettings — end-to-end POST with evil_key in settings; assert it does NOT survive into the scaffold's captured AuthCreateOptions.Settings 42 packages still pass; lint + gofmt clean. --- forge-core/validate/auth.go | 101 ++++++++++++++++++++++++++++ forge-core/validate/auth_test.go | 111 +++++++++++++++++++++++++++++++ forge-ui/handlers_create.go | 9 +++ forge-ui/handlers_create_test.go | 45 +++++++++++++ 4 files changed, 266 insertions(+) diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go index c704857..4e3c4e4 100644 --- a/forge-core/validate/auth.go +++ b/forge-core/validate/auth.go @@ -23,6 +23,93 @@ var knownAuthProviderTypes = map[string]bool{ "azure_ad": true, } +// KnownAuthProviderSettings is the closed set of YAML keys each provider +// type accepts. Mirrors the `yaml:` tags on each provider's Config struct; +// must be kept in sync when new fields are added. Internal-only struct +// fields (those carrying `yaml:"-"`) are intentionally absent — they +// can only be set by another Go package, not via forge.yaml or the +// Web UI's create-payload. +// +// Two callers consume this map: +// +// 1. ValidateAuthConfig emits a *warning* per unknown key during +// `forge validate`, so a typo like `aud:` instead of `audience:` +// gets surfaced loudly. +// +// 2. The Web UI handler (forge-ui handlers_create.go) filters its +// incoming Settings map through FilterKnownSettings before +// forwarding to scaffold — closing the exploit chain where a +// malicious POST `{"settings": {"audience": "x", "evil": "y"}}` +// would otherwise drop `evil:` into forge.yaml verbatim. +var KnownAuthProviderSettings = map[string]map[string]bool{ + "http_verifier": { + "url": true, + "default_org": true, + "timeout": true, + }, + "static_token": { + "token": true, + "token_env": true, + }, + "oidc": { + "issuer": true, + "audience": true, + "client_id": true, + "jwks_url": true, + "jwks_cache_ttl": true, + "clock_skew": true, + "claim_map": true, + }, + "aws_sigv4": { + "region": true, + "audience": true, + "allowed_principals": true, + "allowed_accounts": true, + "identity_cache_ttl": true, + "sts_endpoint": true, // documented test override, intentionally YAML-reachable + "http_timeout": true, + "max_token_expires": true, + "clock_skew": true, + }, + "gcp_iap": { + "audience": true, + "jwks_refresh_ttl": true, + "http_timeout": true, + }, + "azure_ad": { + "tenant_id": true, + "audience": true, + "allow_multi_tenant": true, + "allowed_tenants": true, + "groups_mode": true, + "graph_timeout": true, + "jwks_cache_ttl": true, + // graph_endpoint intentionally omitted — yaml:"-" on the Config field + }, +} + +// FilterKnownSettings returns a copy of settings with any keys not in +// the whitelist for providerType dropped. Use this at the boundary +// between untrusted input (Web UI POST) and persistence (forge.yaml +// scaffold) so unknown keys never reach disk. +// +// For unknown providerType (returns nil from the whitelist lookup), the +// input is passed through unchanged — let the ValidateAuthConfig +// "unknown type" error catch that case instead. +func FilterKnownSettings(providerType string, settings map[string]any) map[string]any { + known, ok := KnownAuthProviderSettings[providerType] + if !ok { + return settings + } + out := make(map[string]any, len(settings)) + for k, v := range settings { + if known[k] { + out[k] = v + } + } + return out +} + // ValidateAuthConfig adds errors and warnings for a forge.yaml auth: block. // Empty AuthConfig is valid (legacy --auth-url path remains the fallback). func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { @@ -62,6 +149,20 @@ func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { // runs at runtime construction) so `forge validate` catches errors before // `forge run`. func validateProviderSettings(prefix string, p types.AuthProvider, r *ValidationResult) { + // Warn on any keys the provider doesn't recognize. Loose vs. error + // because some operators stash custom annotations (legacy practice + // from pre-Phase-2 configs); the Web UI handler additionally filters + // at write-time so the actual scaffold-poisoning chain is closed + // there (see forge-ui/handlers_create.go). + if known, ok := KnownAuthProviderSettings[p.Type]; ok { + for k := range p.Settings { + if !known[k] { + r.Warnings = append(r.Warnings, + fmt.Sprintf("%s (%s): unknown settings key %q — typo, or a key from a future provider version?", prefix, p.Type, k)) + } + } + } + switch p.Type { case "http_verifier": if asString(p.Settings, "url") == "" { diff --git a/forge-core/validate/auth_test.go b/forge-core/validate/auth_test.go index a4813fb..ed02b29 100644 --- a/forge-core/validate/auth_test.go +++ b/forge-core/validate/auth_test.go @@ -216,3 +216,114 @@ func containsSubstr(ss []string, substr string) bool { } return false } + +// --- Review M5: unknown-key warning + filter helper --- + +func TestValidateAuthConfig_WarnsOnUnknownSettingsKey(t *testing.T) { + // Typo: `aud` instead of `audience`. The required-keys check would + // fire (audience missing), but we ALSO want a loud warning about + // the unknown key so operators can spot the actual typo, not just + // the symptom. + cfg := types.AuthConfig{ + Required: true, + Providers: []types.AuthProvider{{ + Type: "oidc", + Settings: map[string]any{ + "issuer": "https://login.example.com", + "aud": "api://forge", // typo: should be 'audience' + "audience": "api://forge", // keep the required-check passing + }, + }}, + } + result := &ValidationResult{} + ValidateAuthConfig(cfg, result) + if !containsSubstr(result.Warnings, `unknown settings key "aud"`) { + t.Errorf("expected warning about unknown 'aud' key, got %v", result.Warnings) + } +} + +func TestValidateAuthConfig_NoWarningForKnownKeys(t *testing.T) { + // Every documented oidc key should be silent. + cfg := types.AuthConfig{ + Required: true, + Providers: []types.AuthProvider{{ + Type: "oidc", + Settings: map[string]any{ + "issuer": "https://x", + "audience": "y", + "client_id": "c", + "jwks_url": "https://x/jwks", + "jwks_cache_ttl": "1h", + "clock_skew": "30s", + "claim_map": map[string]any{"groups": "roles"}, + }, + }}, + } + result := &ValidationResult{} + ValidateAuthConfig(cfg, result) + for _, w := range result.Warnings { + if strings.Contains(w, "unknown settings key") { + t.Errorf("unexpected unknown-key warning for known oidc field: %q", w) + } + } +} + +func TestFilterKnownSettings_DropsUnknownKeys(t *testing.T) { + // Defense-in-depth filter that forge-ui's handler runs on + // untrusted Web UI input. Unknown keys must NOT survive. + in := map[string]any{ + "audience": "api://forge", + "issuer": "https://x", + "evil_key": "attacker-value", // unknown for oidc + } + out := FilterKnownSettings("oidc", in) + if out["audience"] != "api://forge" { + t.Errorf("audience dropped: %v", out) + } + if out["issuer"] != "https://x" { + t.Errorf("issuer dropped: %v", out) + } + if _, exists := out["evil_key"]; exists { + t.Error("evil_key must be filtered out for oidc settings") + } +} + +func TestFilterKnownSettings_UnknownProviderTypePassthrough(t *testing.T) { + // If the provider type isn't in the whitelist, pass through — + // validateProviderSettings' "unknown type" error catches that case + // separately. + in := map[string]any{"x": "y"} + out := FilterKnownSettings("future_provider", in) + if out["x"] != "y" { + t.Errorf("unknown provider type should passthrough, got %v", out) + } +} + +func TestFilterKnownSettings_AllPhase2Providers(t *testing.T) { + cases := []struct { + provider string + good string // a key that SHOULD survive + }{ + {"aws_sigv4", "region"}, + {"aws_sigv4", "allowed_accounts"}, + {"aws_sigv4", "sts_endpoint"}, // test-only override, but YAML-reachable + {"gcp_iap", "audience"}, + {"azure_ad", "tenant_id"}, + {"azure_ad", "allowed_tenants"}, + {"azure_ad", "allow_multi_tenant"}, + } + for _, tc := range cases { + t.Run(tc.provider+"/"+tc.good, func(t *testing.T) { + out := FilterKnownSettings(tc.provider, map[string]any{ + tc.good: "x", + "evil_X": "bad", + }) + if _, exists := out[tc.good]; !exists { + t.Errorf("%s: known key %q was dropped", tc.provider, tc.good) + } + if _, exists := out["evil_X"]; exists { + t.Errorf("%s: unknown key 'evil_X' survived filter", tc.provider) + } + }) + } +} diff --git a/forge-ui/handlers_create.go b/forge-ui/handlers_create.go index 21f192b..2a21d3c 100644 --- a/forge-ui/handlers_create.go +++ b/forge-ui/handlers_create.go @@ -207,6 +207,15 @@ func validateAuthPayload(a *AuthCreateOptions) error { return nil } + // Filter the incoming Settings to the known-keys whitelist BEFORE + // validation OR scaffolding. Closes the exploit chain (review M5): + // without this, a POST with `{"settings": {"audience": "x", + // "evil_key": "y"}}` would drop evil_key into forge.yaml verbatim. + // Today provider Config structs ignore unknown YAML fields, but a + // future field added without a `yaml:"-"` tag would suddenly become + // reachable via untrusted POST. Filtering here gives defense-in-depth. + a.Settings = validate.FilterKnownSettings(a.Mode, a.Settings) + authYAML := types.AuthConfig{ // Required is set true here because the wizard's renderAuthBlock // always emits required:true when a provider is chosen. Keeping diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index dcc40e0..8491d58 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -735,3 +735,48 @@ func TestHandleCreateAgent_WithoutAuthPayload(t *testing.T) { t.Errorf("Auth = %v, want nil for omitted field", captured.Auth) } } + +// --- Review M5: unknown-key filter at the Web UI boundary --- + +func TestHandleCreateAgent_FiltersUnknownAuthSettings(t *testing.T) { + // Defense-in-depth: a POST that carries an unknown settings key + // (typo OR malicious) must NOT survive into what the scaffold writes. + // Today the providers' Config structs ignore unknown YAML fields, so + // even an unfiltered key is harmless; this test pins the filter so + // a future config-struct field can't suddenly become reachable via + // untrusted POST. + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "filter-test", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge", + "evil_key": "attacker-supplied-value" + } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201", w.Code) + } + if captured.Auth == nil { + t.Fatal("expected Auth payload to reach scaffold") + } + if _, leaked := captured.Auth.Settings["evil_key"]; leaked { + t.Errorf("evil_key leaked to scaffold: %v", captured.Auth.Settings) + } + // Known keys must still be there. + if captured.Auth.Settings["issuer"] != "https://login.example.com" { + t.Errorf("issuer dropped or wrong: %v", captured.Auth.Settings) + } + if captured.Auth.Settings["audience"] != "api://forge" { + t.Errorf("audience dropped or wrong: %v", captured.Auth.Settings) + } +} From 745c0244846033bf9150d464d0186e6a3cfd5745 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 15:46:59 -0400 Subject: [PATCH 21/22] chore(auth): clear seven Phase 2 review NITs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Batch-clearing the "don't block merge" follow-ups from review of PR #79. #1 gcp_iap classifyJWTErr — use jwt v5 sentinels via errors.Is rather than substring matching (library wording shifts across patches; sentinels are public API). Special-case ErrTokenSignatureInvalid to split alg-confusion (→ ErrInvalidToken) from real bad-signature (→ ErrTokenRejected) because golang-jwt wraps both under that one sentinel. Three internal keyFunc message-matches retained — those are strings WE control, not the library's. #2 gcp_iap JWKS merge-on-success — switched j.keys = newKeys to a per-kid merge. A partial-but-valid JWKS response (e.g. one kid accidentally omitted by GCP during rotation) no longer drops kids the stale-grace contract assumes we still have. Worst case is keeping a retired kid in cache; verification still fails naturally for any token signed with the retired private key. #3 azure_ad GraphCache defensive copies — Get returns append([]string(nil), e.groups...) and Put stores a copy of its input. Caller mutating Identity.Groups (the auth.Identity layer treats it as freely mutable) can't poison the cache. #4 forge-cli needsYAMLQuoting numeric edge cases — quote anything that resembles a YAML number (hex 0x, octal 0o, binary 0b, leading-zero "010", scientific 1e10, decimal float 3.14, signed ±N, .inf / .nan in either case). Auth-setting values rarely hit these shapes but the docstring promised "false negatives are bugs" and the Web UI POST path can supply arbitrary strings. Added looksNumeric() helper with separate allHexDigits / allOctalDigits / allBinaryDigits gates. #5 aws_sigv4 identity_cache_test — replaced string(rune(i)) with strconv.Itoa(i). Surrogate code points (0xD800..0xDFFF) all map to U+FFFD, so the eviction-threshold test was silently building ~10 distinct keys instead of 10_001 and the sweep never ran. #6 http.NewRequestWithContext error handling — fixed the two `req, _ := ...` antipatterns in gcp_iap/iap_jwks.go and azure_ad/graph_client.go. Hardcoded URLs make the failure currently unreachable, but errcheck-clean is the discipline. #7 gcp_iap HS256-with-EC-public-key alg-confusion test — pinned the most dangerous attack shape: attacker fetches the verifier's public key from JWKS (open by design), uses raw X/Y bytes as the HMAC "secret", signs an HS256 token. A non-whitelisting verifier would HMAC-verify it. Our keyFunc rejects on alg != "ES256" BEFORE key lookup; this test confirms. Tests added: TestGraphCache_GetReturnsDefensiveCopy, TestGraphCache_PutStoresDefensiveCopy, TestProvider_HS256WithECPublicKeyAsSecret_Rejected. Existing TestProvider_RS256Token_Rejected still passes (alg-confusion still classified as ErrInvalidToken under the new sentinel-based path). 42 packages green, lint + gofmt clean. --- forge-cli/cmd/init_auth.go | 104 +++++++++++++++++- .../aws_sigv4/identity_cache_test.go | 8 +- .../auth/providers/azure_ad/graph_cache.go | 11 +- .../providers/azure_ad/graph_cache_test.go | 39 +++++++ .../auth/providers/azure_ad/graph_client.go | 5 +- forge-core/auth/providers/gcp_iap/iap_jwks.go | 104 ++++++++++++++---- .../auth/providers/gcp_iap/provider_test.go | 28 +++++ 7 files changed, 269 insertions(+), 30 deletions(-) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index f227069..e38e41e 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -4,6 +4,7 @@ import ( "fmt" "net/url" "sort" + "strconv" "strings" "github.com/spf13/cobra" @@ -368,15 +369,110 @@ func needsYAMLQuoting(s string) bool { case "true", "false", "yes", "no", "on", "off", "null", "~": return true } - // All-digit strings would otherwise decode as integers — quote them - // to preserve string semantics (e.g. AWS account IDs are 12-digit - // strings, NOT numbers — and the provider expects []string). - if isAllDigits(s) { + // Anything that resembles a YAML 1.1 / 1.2 number must be quoted + // to preserve string semantics. The auth-settings schema rarely + // produces these shapes (account IDs ARE all-digit; others are + // theoretical), but the docstring says "false negatives are bugs" + // and the Web UI POST path can supply arbitrary strings. + // + // Covers: + // - All-digit: "412664885516" (AWS account) + // - Leading-zero: "010" (YAML 1.1 octal) + // - Hex: "0x1A" (YAML 1.1 hex) + // - Octal: "0o17" (YAML 1.2 octal) + // - Binary: "0b10" (YAML 1.2 binary) + // - Scientific notation: "1e10", "1.5E-3" + // - Decimal float: "3.14" + // - Signed: "-5", "+7" + // - Special floats: ".inf", ".nan", "-.inf" + if looksNumeric(s) { return true } return false } +// looksNumeric reports whether s would parse as a YAML number under any +// of YAML 1.1 / 1.2 / yaml.v3's relaxed coercion rules. Conservative — +// false positives just produce extra quotes. +func looksNumeric(s string) bool { + if s == "" { + return false + } + // YAML special floats: case-insensitive. + switch strings.ToLower(s) { + case ".inf", ".nan", "-.inf", "+.inf": + return true + } + // Strip a leading sign once for the remaining shape tests. + body := s + if body[0] == '+' || body[0] == '-' { + body = body[1:] + if body == "" { + return false + } + } + // Hex / Octal / Binary prefixes. + if len(body) >= 2 && body[0] == '0' { + switch body[1] { + case 'x', 'X': + return allHexDigits(body[2:]) + case 'o', 'O': + return allOctalDigits(body[2:]) + case 'b', 'B': + return allBinaryDigits(body[2:]) + } + } + // Plain decimal integer (incl. leading-zero like "010", which + // YAML 1.1 would parse as octal). + if isAllDigits(body) { + return true + } + // Decimal float / scientific notation: rely on Go's strconv.ParseFloat + // — if it accepts the bytes as a float, YAML will too. + if _, err := strconv.ParseFloat(body, 64); err == nil { + return true + } + return false +} + +func allHexDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + c := s[i] + ok := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') + if !ok { + return false + } + } + return true +} + +func allOctalDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] < '0' || s[i] > '7' { + return false + } + } + return true +} + +func allBinaryDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] != '0' && s[i] != '1' { + return false + } + } + return true +} + // isAllDigits returns true if s is non-empty and consists only of ASCII digits. func isAllDigits(s string) bool { if s == "" { diff --git a/forge-core/auth/providers/aws_sigv4/identity_cache_test.go b/forge-core/auth/providers/aws_sigv4/identity_cache_test.go index f3cedcf..7e05e2f 100644 --- a/forge-core/auth/providers/aws_sigv4/identity_cache_test.go +++ b/forge-core/auth/providers/aws_sigv4/identity_cache_test.go @@ -1,6 +1,7 @@ package aws_sigv4 import ( + "strconv" "testing" "time" @@ -67,7 +68,12 @@ func TestIdentityCache_OpportunisticEviction(t *testing.T) { id := &auth.Identity{UserID: "arn"} for i := range 10_001 { - c.Put(string(rune(i)), id) + // strconv.Itoa(i), NOT string(rune(i)): the rune form maps + // surrogate code points to U+FFFD (the replacement char), so + // all of {0xD800..0xDFFF} would collide on one cache key and + // the map never actually reaches 10_001 entries — the + // threshold-eviction test would silently no-op. (Review NIT.) + c.Put(strconv.Itoa(i), id) } now = now.Add(2 * time.Second) // expire everything diff --git a/forge-core/auth/providers/azure_ad/graph_cache.go b/forge-core/auth/providers/azure_ad/graph_cache.go index f102cb5..de67224 100644 --- a/forge-core/auth/providers/azure_ad/graph_cache.go +++ b/forge-core/auth/providers/azure_ad/graph_cache.go @@ -30,6 +30,10 @@ func NewGraphCache(ttl time.Duration) *GraphCache { } // Get returns the cached groups for userID, or (nil, false) on miss/expiry. +// +// The returned slice is a defensive copy — callers that subsequently mutate +// their Identity.Groups (the auth.Identity layer treats Groups as a freely- +// mutable field) MUST NOT corrupt the cache. (Review NIT.) func (c *GraphCache) Get(userID string) ([]string, bool) { c.mu.RLock() e, ok := c.data[userID] @@ -37,15 +41,18 @@ func (c *GraphCache) Get(userID string) ([]string, bool) { if !ok || c.now().After(e.expireAt) { return nil, false } - return e.groups, true + return append([]string(nil), e.groups...), true } // Put stores the groups under userID with a fresh TTL. Overwrites any // prior entry (does not extend). +// +// Stores a defensive copy so subsequent caller mutations of the input +// slice don't reach back through cache hits. func (c *GraphCache) Put(userID string, groups []string) { c.mu.Lock() c.data[userID] = graphEntry{ - groups: groups, + groups: append([]string(nil), groups...), expireAt: c.now().Add(c.ttl), } c.mu.Unlock() diff --git a/forge-core/auth/providers/azure_ad/graph_cache_test.go b/forge-core/auth/providers/azure_ad/graph_cache_test.go index 15e030e..847d00a 100644 --- a/forge-core/auth/providers/azure_ad/graph_cache_test.go +++ b/forge-core/auth/providers/azure_ad/graph_cache_test.go @@ -17,6 +17,45 @@ func TestGraphCache_HitMiss(t *testing.T) { } } +func TestGraphCache_GetReturnsDefensiveCopy(t *testing.T) { + // Caller mutating Identity.Groups must NOT corrupt the cache + // (Review NIT). + c := NewGraphCache(time.Minute) + c.Put("u", []string{"g1", "g2", "g3"}) + + got, ok := c.Get("u") + if !ok { + t.Fatal("expected hit") + } + got[0] = "tampered" + //nolint:staticcheck // intentional caller-side append; the result is + // discarded because we're testing that mutations of `got` don't + // reach back into the cache. + _ = append(got, "extra") + + again, _ := c.Get("u") + if again[0] != "g1" { + t.Errorf("cache was mutated by caller: %v", again) + } + if len(again) != 3 { + t.Errorf("cache slice length changed: %d", len(again)) + } +} + +func TestGraphCache_PutStoresDefensiveCopy(t *testing.T) { + // Caller mutating the input slice after Put must NOT bleed + // through into future Gets. + c := NewGraphCache(time.Minute) + src := []string{"g1", "g2"} + c.Put("u", src) + src[0] = "tampered" + + got, _ := c.Get("u") + if got[0] != "g1" { + t.Errorf("Put didn't copy: %v", got) + } +} + func TestGraphCache_Expiry(t *testing.T) { now := time.Unix(1_700_000_000, 0) c := NewGraphCache(60 * time.Second) diff --git a/forge-core/auth/providers/azure_ad/graph_client.go b/forge-core/auth/providers/azure_ad/graph_client.go index 2d48bd1..c4f4a7e 100644 --- a/forge-core/auth/providers/azure_ad/graph_client.go +++ b/forge-core/auth/providers/azure_ad/graph_client.go @@ -103,7 +103,10 @@ func (c *GraphClient) TransitiveMemberOf(ctx context.Context, _ string, authHead } func (c *GraphClient) fetchPage(ctx context.Context, u, authHeader string) (ids []string, next string, err error) { - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, "", fmt.Errorf("%w: graph build request: %v", auth.ErrProviderUnavailable, err) + } req.Header.Set("Authorization", authHeader) req.Header.Set("Accept", "application/json") diff --git a/forge-core/auth/providers/gcp_iap/iap_jwks.go b/forge-core/auth/providers/gcp_iap/iap_jwks.go index c525575..0038845 100644 --- a/forge-core/auth/providers/gcp_iap/iap_jwks.go +++ b/forge-core/auth/providers/gcp_iap/iap_jwks.go @@ -196,7 +196,11 @@ func (j *IAPJWKSCache) refresh(ctx context.Context) error { j.lastAttempt = time.Now() j.mu.Unlock() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, j.url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.url, nil) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS build request: %v", auth.ErrProviderUnavailable, err) + } resp, err := j.http.Do(req) if err != nil { j.bumpBackoff() @@ -221,8 +225,24 @@ func (j *IAPJWKSCache) refresh(ctx context.Context) error { return fmt.Errorf("%w: IAP JWKS parse: %v", auth.ErrProviderUnavailable, err) } + // Merge-on-success rather than replace (Review NIT). A partial- + // but-valid JWKS (e.g. one freshly-rotated kid omitted by mistake) + // must not drop kids the stale-grace contract assumes we still + // have. New keys take precedence over old of the same kid; old + // keys that aren't in the new response are kept. + // + // Worst case: a stale key for a kid GCP has actually retired stays + // in our cache. JWT signature verification will fail naturally for + // any token signed with the retired private key, so this can't + // admit forged tokens — it just keeps verification working through + // JWKS-API hiccups. j.mu.Lock() - j.keys = keys + if j.keys == nil { + j.keys = map[string]*ecdsa.PublicKey{} + } + for kid, k := range keys { + j.keys[kid] = k + } j.lastSuccessful = time.Now() j.backoffDuration = 0 j.mu.Unlock() @@ -285,36 +305,76 @@ func parseECJWKSet(raw []byte) (map[string]*ecdsa.PublicKey, error) { return out, nil } -// classifyJWTErr maps golang-jwt errors to auth sentinels. The library -// returns wrapped errors with stable joinings — we string-match on the -// most reliable signals. +// classifyJWTErr maps golang-jwt errors to auth sentinels using v5's +// named sentinels via errors.Is rather than substring matching. The +// library's error wording can shift across patch releases; the sentinels +// are part of its public API and stable. (Review NIT.) +// +// Two-tier mapping: // -// Algorithm-confusion ("signing method X is invalid") and kid-related -// errors are classified as ErrInvalidToken (token shape is wrong), -// not ErrTokenRejected (which is policy denial). These checks run -// BEFORE the generic "signature is invalid" arm because the alg error -// is wrapped in a signature-shaped outer error by golang-jwt. +// ErrInvalidToken (malformed / structurally wrong): +// - jwt.ErrTokenMalformed (bad base64, dot-count, etc.) +// - jwt.ErrTokenUnverifiable (no key found, alg mismatch +// detected in keyFunc) +// - jwt.ErrInvalidKey, jwt.ErrInvalidKeyType +// - keyFunc messages we emit ourselves ("unexpected alg", +// "missing kid", "not found in IAP JWKS") +// +// ErrTokenRejected (well-formed but cryptographically/temporally +// invalid — policy-denial shape): +// - jwt.ErrTokenSignatureInvalid +// - jwt.ErrTokenExpired +// - jwt.ErrTokenNotValidYet (nbf in future) +// - jwt.ErrTokenUsedBeforeIssued (iat in future) +// - jwt.ErrTokenInvalidClaims +// +// Default: ErrInvalidToken (conservative — unknown errors are +// classified as malformed, not as policy rejections). func classifyJWTErr(err error) error { if errors.Is(err, auth.ErrProviderUnavailable) || errors.Is(err, auth.ErrInvalidToken) || errors.Is(err, auth.ErrTokenRejected) { return err } - s := err.Error() + + // Special case: ErrTokenSignatureInvalid wraps BOTH (a) actual + // bad-signature failures (rejected) AND (b) alg-confusion errors + // where golang-jwt's WithValidMethods refused the token's alg + // before signing was even attempted. The latter is a malformed- + // shape failure (alg whitelist tripped), not a policy denial, so + // inspect the wrapped message to distinguish. + if errors.Is(err, jwt.ErrTokenSignatureInvalid) { + s := err.Error() + if strings.Contains(s, "signing method") { + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + return fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) + } + switch { - case strings.Contains(s, "signing method"), - strings.Contains(s, "unexpected alg"), - strings.Contains(s, "missing kid"), - strings.Contains(s, "kid "), // "kid X not found" - strings.Contains(s, "not found"): // covers JWKS-resolution failures + case errors.Is(err, jwt.ErrTokenMalformed), + errors.Is(err, jwt.ErrTokenUnverifiable), + errors.Is(err, jwt.ErrInvalidKey), + errors.Is(err, jwt.ErrInvalidKeyType): return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) - case strings.Contains(s, "expired"), - strings.Contains(s, "iat"), - strings.Contains(s, "nbf"), - strings.Contains(s, "signature is invalid"), - strings.Contains(s, "verification error"): + case errors.Is(err, jwt.ErrTokenExpired), + errors.Is(err, jwt.ErrTokenNotValidYet), + errors.Is(err, jwt.ErrTokenUsedBeforeIssued), + errors.Is(err, jwt.ErrTokenInvalidClaims): return fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) - default: + } + + // keyFunc errors we emit ourselves get wrapped by jwt.Parse, but + // the unwrap chain doesn't preserve them as distinct sentinels. + // Fall back to substring matching for THESE messages only — they're + // strings WE control, not the library's. + s := err.Error() + switch { + case strings.Contains(s, "unexpected alg"), + strings.Contains(s, "missing kid"), + strings.Contains(s, "not found in IAP JWKS"): return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) } + + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) } diff --git a/forge-core/auth/providers/gcp_iap/provider_test.go b/forge-core/auth/providers/gcp_iap/provider_test.go index f7d9362..eb74841 100644 --- a/forge-core/auth/providers/gcp_iap/provider_test.go +++ b/forge-core/auth/providers/gcp_iap/provider_test.go @@ -247,6 +247,34 @@ func TestProvider_RS256Token_Rejected(t *testing.T) { } } +func TestProvider_HS256WithECPublicKeyAsSecret_Rejected(t *testing.T) { + // The most dangerous algorithm-confusion shape (Review NIT): attacker + // takes the verifier's PUBLIC key (which the JWKS endpoint publishes + // openly), uses its raw bytes as the HMAC secret to sign an HS256 + // token, and submits it. A verifier that doesn't whitelist the alg + // would happily HMAC-verify the token against the same "secret" = + // the public key bytes. Our keyFunc rejects on alg!="ES256" BEFORE + // even looking up the key — pin that explicitly. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + // Use the signer's public EC point coordinates as the HMAC "secret." + // The attacker only needs the public key, which they can fetch from + // the JWKS endpoint anonymously. + pubBytes := append(signer.pub.X.Bytes(), signer.pub.Y.Bytes()...) + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(validClaims("test-aud"))) + tok.Header["kid"] = signer.kid + str, err := tok.SignedString(pubBytes) + if err != nil { + t.Fatalf("sign HS256: %v", err) + } + + _, err = p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": str}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (HS256-with-public-key alg confusion)", err) + } +} + func TestFactory_RejectsMissingAudience(t *testing.T) { if _, err := auth.Build("gcp_iap", map[string]any{}); err == nil { t.Fatal("expected error when audience is missing") From 4dc8ce01d1c4b92cdb735ee6a39a93c26dde5049 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sun, 24 May 2026 16:03:52 -0400 Subject: [PATCH 22/22] docs: sync Phase 2 auth provider work across affected pages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ran the /sync-docs recipe over PR #79's surface (commit 745c024). Two pre-flight fixes before the sync: - .claude/commands/sync-docs.md mapping table was outdated: pointed at flat files (`docs/commands.md`, `docs/runtime.md`, …) that don't exist; real layout is nested under core-concepts/, reference/, security/. Rewrote the table to match the actual tree and added a `forge-core/auth/` → authentication.md + audit-logging.md row. - docs/auth/ is gitignored (per earlier review feedback). Did NOT move that content into git; the new canonical home is docs/security/authentication.md inside the existing tracked docs/security/ folder. New doc: docs/security/authentication.md (423 lines) — single home for the full auth provider chain across Phase 1 + Phase 2. Covers: - Provider matrix and chain semantics (first-match-wins, fail-closed on reject, loopback static_token auto-prepend) - Per-provider sections for static_token, oidc, aws_sigv4, gcp_iap, azure_ad, http_verifier — including yaml shape, wire format, client-side recipe, security model - aws_sigv4 client-side gotcha (boto3.generate_presigned_url doesn't work for STS; use SigV4QueryAuth) - AWS Org-wide trust recipes (Identity Center, entry role + aws:PrincipalOrgID condition) - Egress allowlist auto-extension table per provider - Wizard / CLI examples - Mesh patterns (single-account fleet, per-pair allowlist) Updates to satellite docs: docs/security/overview.md - Added Authentication layer to the security-architecture ASCII diagram (between Guardrails and Egress) - New "Authentication" section with provider table + chain rules - Authentication row in Related Documentation footer docs/security/audit-logging.md - auth_verify and auth_fail event entries in the Event Types table - Full example shape for both events - Reason-code table (missing_token, not_for_me, rejected, invalid, provider_unavailable) with operator actions - token_kind value table (empty, opaque, jwt, sigv4, iap_jwt) - Audit pipeline jq grep recipes for common queries docs/security/egress-control.md - New "Auth-provider domain auto-extension" subsection under Allowlist Resolution with the per-provider host table - Notes the wizard's Auth-before-Egress ordering docs/reference/forge-yaml-schema.md - Full auth: block inserted in the schema example with all 6 provider types and every documented settings key docs/reference/cli-reference.md - All --auth-* flags added to the forge init flag table - Two new example invocations (aws_sigv4 with allowed-account, azure_ad multi-tenant with allowlist) - Link to docs/security/authentication.md docs/reference/web-dashboard.md - Auth step added to the wizard steps table (between Fallback and Env & Security) - New "Auth step" subsection listing all 7 picker options + the Web UI's settings-key filter behavior README.md - Authentication row added to the Security documentation table Broken-link scan: zero broken links in tracked docs/. 42 packages still pass go test -race; lint + gofmt clean. No code changes in this commit — docs only. --- .claude/commands/sync-docs.md | 43 +-- README.md | 1 + docs/reference/cli-reference.md | 37 +++ docs/reference/forge-yaml-schema.md | 45 +++ docs/reference/web-dashboard.md | 31 ++ docs/security/audit-logging.md | 80 ++++++ docs/security/authentication.md | 423 ++++++++++++++++++++++++++++ docs/security/egress-control.md | 20 ++ docs/security/overview.md | 32 +++ 9 files changed, 693 insertions(+), 19 deletions(-) create mode 100644 docs/security/authentication.md diff --git a/.claude/commands/sync-docs.md b/.claude/commands/sync-docs.md index f7f050d..6c4538f 100644 --- a/.claude/commands/sync-docs.md +++ b/.claude/commands/sync-docs.md @@ -6,28 +6,33 @@ After feature work, update the affected documentation to reflect code changes. 1. **Identify changed files** — Run `git diff main --name-only` to find modified Go files. -2. **Map files to docs** — Use this mapping to determine which docs need updates: +2. **Map files to docs** — Use this mapping to determine which docs need updates. + Paths match the actual `docs/` tree (nested under `core-concepts/`, `reference/`, + `security/`, `deployment/`, `skills/`). | Changed path pattern | Affected docs | |---------------------|---------------| - | `forge-core/runtime/` | `docs/runtime.md`, `docs/hooks.md` | - | `forge-core/security/` | `docs/security/overview.md`, `docs/security/egress.md` | - | `forge-core/tools/` | `docs/tools.md` | - | `forge-core/llm/` | `docs/runtime.md` | - | `forge-core/memory/` | `docs/memory.md` | - | `forge-core/scheduler/` | `docs/scheduling.md` | - | `forge-core/secrets/` | `docs/security/secrets.md` | - | `forge-core/skills/` | `docs/skills.md` | - | `forge-core/channels/` | `docs/channels.md` | - | `forge-cli/cmd/` | `docs/commands.md` | - | `forge-cli/runtime/` | `docs/runtime.md` | - | `forge-cli/server/` | `docs/architecture.md` | - | `forge-cli/channels/` | `docs/channels.md` | - | `forge-cli/tools/` | `docs/tools.md` | - | `forge-plugins/` | `docs/channels.md`, `docs/plugins.md` | - | `forge-ui/` | `docs/dashboard.md` | - | `forge-skills/` | `docs/skills.md` | - | `forge.yaml` / `types/` | `docs/configuration.md` | + | `forge-core/auth/` | `docs/security/authentication.md`, `docs/security/audit-logging.md` | + | `forge-core/runtime/` | `docs/core-concepts/runtime-engine.md`, `docs/core-concepts/hooks.md` | + | `forge-core/security/` | `docs/security/overview.md`, `docs/security/egress-control.md` | + | `forge-core/tools/` | `docs/core-concepts/tools-and-builtins.md` | + | `forge-core/llm/` | `docs/core-concepts/runtime-engine.md` | + | `forge-core/memory/` | `docs/core-concepts/memory-system.md` | + | `forge-core/scheduler/` | `docs/core-concepts/scheduling.md` | + | `forge-core/secrets/` | `docs/security/secret-management.md` | + | `forge-core/channels/` | `docs/core-concepts/channels.md` | + | `forge-core/validate/` | `docs/reference/forge-yaml-schema.md` | + | `forge-cli/cmd/` | `docs/reference/cli-reference.md` | + | `forge-cli/runtime/` | `docs/core-concepts/runtime-engine.md` | + | `forge-cli/server/` | `docs/core-concepts/how-forge-works.md` | + | `forge-cli/channels/` | `docs/core-concepts/channels.md` | + | `forge-cli/tools/` | `docs/core-concepts/tools-and-builtins.md` | + | `forge-cli/internal/tui/` | `docs/reference/cli-reference.md` (wizard flow) | + | `forge-plugins/` | `docs/core-concepts/channels.md`, `docs/reference/framework-plugins.md` | + | `forge-ui/` | `docs/reference/web-dashboard.md` | + | `forge-skills/` | `docs/skills/writing-custom-skills.md`, `docs/skills/contributing-a-skill.md` | + | `forge-core/types/` / `forge.yaml` | `docs/reference/forge-yaml-schema.md` | + | `CHANGELOG.md` | (rendered into release notes; no per-doc sync needed) | 3. **Read the diff** — For each mapped doc, read the relevant `git diff main` output to understand what changed. diff --git a/README.md b/README.md index 35c05ff..dbd8bce 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ You write a `SKILL.md`. Forge compiles it into a secure, runnable agent with egr | Document | Description | |----------|-------------| | [Security Overview](docs/security/overview.md) | Complete security architecture | +| [Authentication](docs/security/authentication.md) | Pluggable auth providers — OIDC, AWS Sigv4, GCP IAP, Azure AD | | [Egress Security](docs/security/egress-control.md) | Egress enforcement deep dive | | [Secrets](docs/security/secret-management.md) | Encrypted secret management | | [Build Signing](docs/security/build-signing.md) | Ed25519 signing and verification | diff --git a/docs/reference/cli-reference.md b/docs/reference/cli-reference.md index 3120608..4b6a077 100644 --- a/docs/reference/cli-reference.md +++ b/docs/reference/cli-reference.md @@ -39,6 +39,23 @@ forge init [name] [flags] | `--org-id` | | | OpenAI Organization ID (enterprise) | | `--from-skills` | | | Path to a SKILL.md file for auto-configuration | | `--non-interactive` | | `false` | Skip interactive prompts | +| `--auth` | | | Auth mode: `none`, `oidc`, `http_verifier`, `aws_sigv4`, `gcp_iap`, `azure_ad`, `custom` | +| `--auth-issuer` | | | OIDC issuer URL (required with `--auth=oidc`) | +| `--auth-audience` | | | OIDC audience (required with `--auth=oidc`) | +| `--auth-url` | | | Verifier URL (required with `--auth=http_verifier`) | +| `--auth-default-org` | | | Default `org_id` for http_verifier | +| `--auth-groups-claim` | | | Custom JWT claim name for groups (oidc, default `groups`) | +| `--auth-aws-region` | | | AWS region for `aws_sigv4` (e.g. `us-east-1`) | +| `--auth-aws-audience` | | | Informational audience for `aws_sigv4` | +| `--auth-aws-allowed-principal` | | | Allowed principal glob for `aws_sigv4` (repeatable) | +| `--auth-aws-allowed-account` | | | Allowed AWS account ID for `aws_sigv4` (repeatable, 12-digit) | +| `--auth-aws-cache-ttl` | | `60s` | `aws_sigv4` identity cache TTL | +| `--auth-gcp-iap-audience` | | | Backend service ID for `gcp_iap` | +| `--auth-azure-tenant` | | | Entra tenant GUID for `azure_ad` | +| `--auth-azure-audience` | | | Audience (Application ID URI) for `azure_ad` | +| `--auth-azure-multi-tenant` | | `false` | Accept tokens from any Entra tenant | +| `--auth-azure-allowed-tenant` | | | Allowed Entra tenant GUID for multi-tenant `azure_ad` (repeatable) | +| `--auth-azure-groups-mode` | | `claim` | `azure_ad` groups mode: `claim` or `graph` | ### Generated Files @@ -84,8 +101,28 @@ forge init my-agent \ --api-key sk-... \ --org-id org-xxxxxxxxxxxxxxxxxxxxxxxx \ --non-interactive + +# AWS IAM auth (any caller in account 412664885516) +forge init my-agent \ + --model-provider ollama \ + --auth=aws_sigv4 \ + --auth-aws-region=us-east-1 \ + --auth-aws-allowed-account=412664885516 \ + --non-interactive + +# Azure AD multi-tenant with explicit partner allowlist +forge init my-agent \ + --model-provider ollama \ + --auth=azure_ad \ + --auth-azure-audience=api://forge \ + --auth-azure-multi-tenant \ + --auth-azure-allowed-tenant=00000000-1111-... \ + --auth-azure-allowed-tenant=55555555-6666-... \ + --non-interactive ``` +See [Authentication](../security/authentication.md) for the full auth provider reference. + --- ## `forge build` diff --git a/docs/reference/forge-yaml-schema.md b/docs/reference/forge-yaml-schema.md index 0407ce1..64a91bb 100644 --- a/docs/reference/forge-yaml-schema.md +++ b/docs/reference/forge-yaml-schema.md @@ -61,6 +61,51 @@ package: dest: "/usr/local/bin/custom-tool" # Install destination chmod: "0755" # File permissions +auth: # a2a HTTP-server auth chain (optional) + required: true # 401 every unauthenticated request + providers: # ordered; first match wins (fail-closed on rejection) + - type: "static_token" # local dev / shared-secret + settings: + token_env: "FORGE_AUTH_TOKEN" # env var name (preferred over literal `token:`) + - type: "oidc" # any IdP with OIDC discovery + settings: + issuer: "https://login.example.com/auth/realms/forge" + audience: "api://forge" + client_id: "" # optional azp fallback + jwks_url: "" # overrides discovery + jwks_cache_ttl: "1h" + clock_skew: "30s" + claim_map: {groups: "roles"} + - type: "http_verifier" # legacy external /verify endpoint + settings: + url: "https://auth.example.com/verify" + default_org: "acme" + timeout: "10s" + - type: "aws_sigv4" # Phase 2: AWS IAM via pre-signed STS URL + settings: + region: "us-east-1" # required + audience: "api://forge" # informational, emitted in audit Claims + allowed_accounts: # ergonomic: "anyone in these AWS accounts" + - "412664885516" + allowed_principals: # explicit globs (path.Match) + - "arn:aws:sts::412664885516:assumed-role/ci-deploy/*" + identity_cache_ttl: "60s" + max_token_expires: "15m" # caps caller's X-Amz-Expires claim + clock_skew: "5m" + - type: "gcp_iap" # Phase 2: GCP IAP-fronted Forge + settings: + audience: "/projects/PNUM/global/backendServices/BACKEND_ID" + jwks_refresh_ttl: "1h" + - type: "azure_ad" # Phase 2: Microsoft Entra ID + settings: + tenant_id: "00000000-1111-..." # required unless allow_multi_tenant + audience: "api://forge" + allow_multi_tenant: false + allowed_tenants: # required when multi-tenant + want allowlist + - "55555555-6666-..." + groups_mode: "claim" # "claim" | "graph" + graph_timeout: "5s" + secrets: providers: # Secret providers (order matters) - "encrypted-file" # AES-256-GCM encrypted file diff --git a/docs/reference/web-dashboard.md b/docs/reference/web-dashboard.md index afca134..a92f9e9 100644 --- a/docs/reference/web-dashboard.md +++ b/docs/reference/web-dashboard.md @@ -69,11 +69,42 @@ A multi-step wizard (web equivalent of `forge init`) that walks through the full | Tools | Select builtin tools; web_search shows Tavily vs Perplexity provider choice with API key input | | Skills | Browse registry skills by category with inline required/optional env var collection | | Fallback | Select backup LLM providers with API keys for automatic failover | +| Auth | Select a2a auth provider (see below) | +| Egress | Review the outbound allowlist (including auth-provider-derived hosts) | | Env & Security | Add extra env vars; set passphrase for AES-256-GCM secret encryption | | Review | Summary of all selections before creation | The wizard collects credentials inline at each step (matching the CLI TUI behavior) and supports all the same options: model selection, OAuth, web search providers, fallback chains, and encrypted secret storage. +### Auth step + +The Auth step exposes the same provider chain as `forge.yaml`'s +`auth.providers[]` block. Picker options: + +| Option | Effect | +|---|---| +| **None** | Anonymous access — no `auth:` block written | +| **OIDC (JWT)** | Generic OIDC IdP (Keycloak, Auth0, Okta, Google) — collects issuer + audience | +| **HTTP Verifier** | Legacy — POST tokens to your own `/verify` endpoint | +| **AWS Sigv4 (IAM)** | AWS-IAM-based callers — collects region + optional audience + comma-separated allowed accounts | +| **GCP Identity-Aware Proxy** | Forge behind GCP LB+IAP — collects backend service ID | +| **Azure AD / Entra ID** | Microsoft Entra ID (single-tenant in the wizard; multi-tenant requires editing `forge.yaml` as a deliberate security trade-off) | +| **Custom** | Comment stub — edit `forge.yaml` yourself | + +The Auth step runs **before** Egress, so the Egress step's review screen +displays auth-provider-derived hosts (e.g. `sts.us-east-1.amazonaws.com` +when AWS Sigv4 is selected) alongside the model / channel / tool / skill +hosts. The operator sees the full outbound surface for approval in one +place. + +The Web UI's `POST /api/create` endpoint additionally filters incoming +auth `settings` through a closed-key whitelist (defined in +`forge-core/validate/auth.go`) before scaffolding — unknown keys are +silently dropped rather than written to disk. + +See [Authentication](../security/authentication.md) for the full +provider reference. + ## Config Editor Edit `forge.yaml` for any agent with a Monaco-based YAML editor: diff --git a/docs/security/audit-logging.md b/docs/security/audit-logging.md index a850f6d..e376751 100644 --- a/docs/security/audit-logging.md +++ b/docs/security/audit-logging.md @@ -19,6 +19,8 @@ All runtime security events are emitted as structured NDJSON to stderr with corr | `egress_blocked` | Outbound request blocked (with domain, mode) | | `llm_call` | LLM API call completed (with token count) | | `guardrail_check` | Guardrail evaluation result | +| `auth_verify` | Inbound request authenticated successfully (with `provider`, `user_id`, `org_id`, `token_kind`) | +| `auth_fail` | Inbound request rejected (with `reason`, `token_kind`) | ### Example @@ -31,3 +33,81 @@ All runtime security events are emitted as structured NDJSON to stderr with corr ``` The `source` field distinguishes in-process enforcer events from subprocess proxy events. + +### Authentication events + +Every inbound request to `/tasks` emits exactly one of `auth_verify` or `auth_fail`. + +**Successful authentication:** + +```json +{ + "ts":"2026-05-24T00:50:01Z", + "event":"auth_verify", + "fields":{ + "method":"POST", + "path":"/tasks/send", + "provider":"aws_sigv4", + "user_id":"arn:aws:sts::412664885516:assumed-role/AWSReservedSSO_PowerUserAccess_.../Naveen", + "org_id":"412664885516", + "token_kind":"sigv4", + "groups_count":0, + "remote_addr":"[::1]:62297" + } +} +``` + +`user_id` is the canonical identifier the verifier returned (ARN for AWS, JWT +`sub` for OIDC/IAP/AAD). `org_id` is the AWS account, Entra tenant GUID, or +OIDC `tid`/`org_id`-mapped claim depending on the provider. + +**Failed authentication:** + +```json +{"ts":"...","event":"auth_fail","fields":{"reason":"rejected","token_kind":"sigv4","method":"POST","path":"/tasks/send","remote_addr":"[::1]:62200"}} +``` + +### Reason codes (`auth_fail.fields.reason`) + +| Reason | What it means | Operator action | +|---|---|---| +| `missing_token` | No auth-shaped headers at all | Caller forgot to authenticate | +| `not_for_me` | Bearer present but no provider claimed it | Wrong token format for the configured providers | +| `rejected` | Provider recognized + denied (allowlist miss, expired, bad sig, scope mismatch) | Check `allowed_principals` / `tenant_id` / token freshness | +| `invalid` | Token malformed (bad base64, unsupported alg, missing required field) | Token construction bug on the caller side | +| `provider_unavailable` | Verifier endpoint down (STS / JWKS / Graph 5xx, network error) | Provider-side incident; not a token issue | + +### Token kind values (`fields.token_kind`) + +Structural classification of what bytes were on the wire — safe to log: + +| Value | Shape | +|---|---| +| `empty` | No token / no auth-shaped headers | +| `opaque` | Bearer with non-JWT, non-sigv4 shape (channel adapter loopback, custom verifier tokens) | +| `jwt` | Bearer with three base64url segments (`oidc`, `azure_ad`) | +| `sigv4` | Bearer with `forge-aws-v1.` prefix (`aws_sigv4` pre-signed URL token) | +| `iap_jwt` | `X-Goog-Iap-Jwt-Assertion` header present (`gcp_iap`) — also stamped on successful verify even if Bearer was simultaneously present | + +### Audit pipeline grep recipes + +Who called my agent in the last hour, by ARN/email? + +```bash +jq -r 'select(.event=="auth_verify") | .fields.user_id' forge.log | sort | uniq -c +``` + +Why are requests failing? + +```bash +jq -r 'select(.event=="auth_fail") | .fields.reason' forge.log | sort | uniq -c +``` + +Which agents called this one (in a mesh)? + +```bash +jq -r 'select(.event=="auth_verify") | "\(.fields.user_id)"' forge.log | sort -u +``` + +See [Authentication](authentication.md) for the full provider chain and how +each provider populates these fields. diff --git a/docs/security/authentication.md b/docs/security/authentication.md new file mode 100644 index 0000000..3fd9a13 --- /dev/null +++ b/docs/security/authentication.md @@ -0,0 +1,423 @@ +--- +title: "Authentication Providers" +description: "Pluggable auth provider chain that gates Forge's /tasks endpoint — OIDC, AWS Sigv4, GCP IAP, Azure AD, and local-only static_token." +order: 6 +--- + +Forge's `a2a` HTTP server (the `/tasks` endpoint and friends) requires every +caller to authenticate through a pluggable provider chain configured in +`forge.yaml`. Each provider recognizes one token shape; the chain tries them +in order, first match wins, and the result lands in `Identity` for the +audit log and any downstream authz hook. + +## Provider matrix + +| Provider | Use case | Token format | Verifies against | Phase | +|---|---|---|---|---| +| `static_token` | Local dev, channel-adapter loopback | Shared secret | constant-time SHA-256 compare | 1 | +| `oidc` | Any IdP with OIDC discovery (Keycloak, Auth0, Okta, Google) | `Authorization: Bearer ` | Issuer's JWKS (TTL-cached, with backoff + stale-grace) | 1 | +| `http_verifier` | Custom verifier endpoint you operate | Opaque token | Your own `/verify` HTTP service | 1 | +| `aws_sigv4` | AWS-IAM-based callers (Lambda, EC2, EKS, IAM users) | `Authorization: Bearer forge-aws-v1.` | AWS STS `GetCallerIdentity` (pre-signed URL pattern) | 2 (v0.11.0) | +| `gcp_iap` | Forge behind GCP HTTPS LB + IAP | `X-Goog-Iap-Jwt-Assertion: ` | IAP's hardcoded JWKS at `www.gstatic.com` | 2 (v0.11.0) | +| `azure_ad` | Microsoft Entra ID tokens | `Authorization: Bearer ` | AAD JWKS via composed `oidc` provider + tenant gate | 2 (v0.11.0) | + +Forge holds **no IdP secrets**. All providers verify a caller-minted +credential against a third party (STS / GCP JWKS / AAD JWKS / your own +`/verify`), then stamp an `Identity` from what the verifier returned. + +## Chain semantics + +Each `Verify` returns one of: + +| Return | Meaning | Chain behavior | +|---|---|---| +| `Identity, nil` | Token accepted | Stops; chain returns this Identity | +| `nil, ErrTokenNotForMe` | "Not my format" | Continues to next provider | +| `nil, ErrTokenRejected` | "My format, but denied" | **Stops; 401** | +| `nil, ErrInvalidToken` | "Malformed" | **Stops; 401** | +| `nil, ErrProviderUnavailable` | "Can't reach my IdP" | **Stops; 401** (fail-closed) | + +The critical rule is **no fall-through on rejection**: if provider A +returns `ErrTokenRejected`, the chain does NOT try provider B. Otherwise +an attacker could downgrade by presenting a malformed token of type A and +hoping to be authenticated as type B. + +### Loopback `static_token` is auto-prepended + +Forge writes a random token to `.forge/runtime.token` (mode `0600`) on +startup and auto-prepends a `static_token` provider for it to the chain. +This is how channel adapters (Slack, Telegram, MS Teams) and the local +Web UI authenticate without you configuring anything. Anyone with read +access to `.forge/runtime.token` can call the a2a server. Treat that +file like an SSH key. + +### Non-Bearer auth headers (Phase 2) + +The middleware consults the chain **even when no `Authorization: Bearer` +was extracted**, provided a non-Bearer auth header is present +(`X-Goog-Iap-Jwt-Assertion`). When there are no auth-shaped headers at +all, the audit reason stays `missing_token` rather than widening to +`not_for_me` — operators can still distinguish "client didn't auth" from +"client tried a format we don't speak." + +## `forge.yaml` schema + +```yaml +auth: + required: true # 401 every unauthenticated request + providers: + - type: oidc | aws_sigv4 | gcp_iap | azure_ad | http_verifier | static_token + settings: + # provider-specific keys (see per-provider sections below) +``` + +Per-provider settings are validated by `forge validate`. Unknown keys +produce a warning (typo detection); the Web UI's `/api/create` endpoint +additionally filters to a closed-key whitelist before scaffolding so +malicious POST payloads can't drop arbitrary keys into `forge.yaml`. + +--- + +## `oidc` — Generic OIDC issuer + +The workhorse provider — any IdP with an OIDC discovery doc and JWKS. + +```yaml +auth: + required: true + providers: + - type: oidc + settings: + issuer: https://login.example.com/auth/realms/forge # required + audience: api://forge # required + client_id: my-spa # optional azp fallback + jwks_url: https://... # optional — overrides discovery + jwks_cache_ttl: 1h + clock_skew: 30s + claim_map: # remap claim names + groups: roles +``` + +- Algorithm whitelist: `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. `none` and HMAC are rejected before key lookup. +- JWKS is TTL-cached with backoff + stale-grace — token verification keeps working through brief JWKS outages. +- Issuer trailing-slash normalization handles the Auth0/Okta disagreement (`https://x/` vs `https://x`). + +--- + +## `aws_sigv4` — AWS IAM via pre-signed STS URL + +Authenticates callers by their AWS-IAM identity. Same pattern as +[`aws-iam-authenticator`](https://github.com/kubernetes-sigs/aws-iam-authenticator) +for EKS: caller pre-signs a `GetCallerIdentity` URL with their AWS SDK +and sends it as a Bearer token; Forge invokes that URL, STS validates +the signature against its own host and returns the canonical ARN. + +```yaml +auth: + required: true + providers: + - type: aws_sigv4 + settings: + region: us-east-1 # required + audience: api://forge # informational; in audit Claims + allowed_accounts: ["412664885516"] # ergonomic: "anyone in these accounts" + allowed_principals: # explicit globs (path.Match syntax) + - "arn:aws:sts::412664885516:assumed-role/ci-deploy/*" + identity_cache_ttl: 60s + max_token_expires: 15m # caps caller's X-Amz-Expires claim + clock_skew: 5m +``` + +### Wire format + +``` +Authorization: Bearer forge-aws-v1. +``` + +The base64-decoded payload is a complete pre-signed URL of the form: + +``` +https://sts..amazonaws.com/ + ?Action=GetCallerIdentity + &Version=2011-06-15 + &X-Amz-Algorithm=AWS4-HMAC-SHA256 + &X-Amz-Credential=///sts/aws4_request + &X-Amz-Date= + &X-Amz-Expires= + &X-Amz-SignedHeaders=host + &X-Amz-Signature= +``` + +### Client side (3 lines) + +```python +import boto3, base64, requests +from botocore.auth import SigV4QueryAuth +from botocore.awsrequest import AWSRequest + +creds = boto3.Session().get_credentials().get_frozen_credentials() +req = AWSRequest(method="GET", + url="https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15") +SigV4QueryAuth(creds, "sts", "us-east-1", expires=900).add_auth(req) +token = "forge-aws-v1." + base64.urlsafe_b64encode(req.url.encode()).rstrip(b"=").decode() + +requests.post(forge_url, headers={"Authorization": f"Bearer {token}"}, data=msg) +``` + +> `boto3.client('sts').generate_presigned_url('get_caller_identity', ...)` +> does **not** work — it signs as if the request were a POST, STS rejects +> the GET. Use the lower-level `SigV4QueryAuth` shown above. Same quirk +> `aws-iam-authenticator` works around internally. + +Reference client ships in [`scripts/forge-aws-sign.py`](../../scripts/forge-aws-sign.py). + +### `allowed_accounts` — "anyone in this account" + +The ergonomic shortcut for whole-account trust. Each 12-digit account ID +expands internally to the canonical glob set covering every STS identity +shape (IAM users, IAM roles, STS assumed-roles incl. SSO, federated +users). Composable with `allowed_principals`. + +### Org-wide trust without enumerating accounts + +There's no STS API to ask "is account X in Org Y?" — AWS deliberately +doesn't expose that. Two production paths: + +1. **AWS IAM Identity Center (SSO).** Every user's session is already an + assumed-role under `AWSReservedSSO_*`. Use a glob: + ```yaml + allowed_principals: + - "arn:aws:sts::ACCT:assumed-role/AWSReservedSSO_*/*" + ``` + Org membership is enforced by Identity Center at sign-in time. + +2. **Entry role with `aws:PrincipalOrgID` condition.** Customer creates + one IAM role in one account with a trust policy that allows anyone in + their Org to assume it. Forge's allowlist contains just that one + assumed-role ARN. The Org-membership check happens at AWS IAM, not in + Forge. + +### Security model + +- **No secret keys on Forge.** STS validates signatures. +- **SSRF guard.** Pre-signed URL host must be `sts..amazonaws.com` exactly; userinfo (`user:pass@`) and foreign hosts are rejected at parse time. +- **No HTTP redirects.** `CheckRedirect` is pinned to `ErrUseLastResponse` so a redirect off `sts.…` (e.g. MITM, TLS-inspecting proxy) can't substitute attacker bytes for the STS response. +- **Freshness gate.** Tokens claiming `X-Amz-Expires > 15min` are rejected; tokens whose `X-Amz-Date + Expires` window has lapsed (with 5min clock skew) are rejected. Bounds stolen-token replay independent of STS's own enforcement. +- **Cache bucketing on `hash(AKID, YYYYMMDD)`** — bounds stolen-key replay to one day worst-case. +- **No `aws-sdk-go-v2` dependency.** STS RPC is ~80 LOC of hand-rolled HTTP + XML. + +### Audit log shape + +```json +{ "event": "auth_verify", + "fields": { + "provider": "aws_sigv4", + "user_id": "arn:aws:sts::123456789012:assumed-role/ci-deploy/i-0abc", + "org_id": "123456789012", + "token_kind": "sigv4" + } +} +``` + +--- + +## `gcp_iap` — GCP Identity-Aware Proxy + +Verifies the JWT IAP forwards as `X-Goog-Iap-Jwt-Assertion` when Forge +sits behind a GCP HTTPS Load Balancer with IAP enabled. + +```yaml +auth: + required: true + providers: + - type: gcp_iap + settings: + audience: /projects/12345678/global/backendServices/9876543210 +``` + +`audience` is the backend service ID — find it in +**GCP Console → Security → IAP → Backend Services → Signed Header JWT Audience**. + +### Security model + +- **Hardcoded JWKS host** (`www.gstatic.com/iap/verify/public_key-jwk`). Operators cannot override — eliminates the "trust attacker's JWKS" failure mode. +- **ES256-only.** Any other alg rejected before key lookup. +- **JWKS merge-on-success.** A partial-but-valid JWKS response can't drop kids the stale-grace contract assumes are kept. +- **No HTTP redirects.** Same `ErrUseLastResponse` pin as `aws_sigv4`. +- **No GCP SDK dependency.** + +Sub `email` / `hd` (Workspace domain) flow through to `Identity.Claims` +for downstream policy. + +--- + +## `azure_ad` — Microsoft Entra ID + +Composes the Phase 1 `oidc` provider for signature verification; layers +AAD-specific concerns on top. + +### Single-tenant (the safe default) + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + tenant_id: 00000000-1111-2222-3333-444444444444 + audience: api://forge + groups_mode: claim # or "graph" +``` + +`tid` claim must equal `tenant_id`; iss is double-checked via OIDC. + +### Multi-tenant with explicit allowlist + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true + allowed_tenants: # case-insensitive GUID match + - "00000000-1111-2222-3333-444444444444" + - "55555555-6666-7777-8888-999999999999" +``` + +### Multi-tenant "any tenant globally" (high-risk) + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true + # allowed_tenants intentionally omitted +``` + +`forge validate` emits a warning so this trade-off is loud, not silent. + +### Groups overage (graph mode) + +When `groups_mode: graph` and the JWT's `groups` claim is empty (AAD +truncates at ~200 groups), Forge calls Microsoft Graph +`/me/transitiveMemberOf` using the **caller's** Bearer to fetch the +full list. Forge holds no Graph credentials of its own. Soft-fails on +Graph 5xx (returns Identity with empty Groups rather than blocking +prod traffic). + +### Security model + +- **Composition over inheritance.** No JWT verify or JWKS code in `azure_ad/` — all crypto lives in `oidc`. +- **Tenant gate.** Single-tenant: `tid == tenant_id`. Multi-tenant + allowlist: `tid ∈ allowed_tenants`. Multi-tenant + empty: no tid check (high-risk, warned). +- **Internal `skip_issuer_check` flag** carries `yaml:"-"` — unreachable from `forge.yaml`, only set by this package when `allow_multi_tenant=true`. +- **No HTTP redirects** on Graph client. Graph nextLink scheme + host both validated to prevent Bearer-downgrade via `https→http` same-host redirects. + +--- + +## `static_token` — Shared secret + +Loopback / dev use. Provider does constant-time SHA-256 comparison so +length-leak / timing attacks are blocked. + +```yaml +auth: + required: true + providers: + - type: static_token + settings: + token_env: FORGE_AUTH_TOKEN # prefer env over literal +``` + +`token:` (literal value in YAML) is also accepted but produces a warning. + +--- + +## `http_verifier` — External `/verify` endpoint + +Legacy / custom — you operate the verifier; Forge POSTs the token to it. + +```yaml +auth: + required: true + providers: + - type: http_verifier + settings: + url: https://auth.example.com/verify + default_org: acme + timeout: 10s +``` + +Same wire format as the pre-Phase-1 `--auth-url` flag. + +--- + +## Egress allowlist auto-extension + +Configuring an auth provider automatically adds the hosts it needs to +the egress allowlist: + +| Provider | Host(s) auto-added | +|---|---| +| `oidc` | ``, `` if explicit | +| `http_verifier` | `` | +| `aws_sigv4` | `sts..amazonaws.com` | +| `gcp_iap` | `www.gstatic.com` | +| `azure_ad` | `login.microsoftonline.com` (+ `graph.microsoft.com` when `groups_mode: graph`) | + +`forge init`'s wizard runs the Auth step before the Egress step, so +operators see the full outbound surface for review in one screen. + +## Wizard / CLI + +`forge init` interactive TUI: pick auth type → enter region / audience / +tenant / etc. → done. Non-interactive equivalent via flags: + +```bash +forge init --non-interactive \ + --name my-agent \ + --model-provider ollama \ + --auth=aws_sigv4 \ + --auth-aws-region=us-east-1 \ + --auth-aws-audience=api://forge \ + --auth-aws-allowed-account=412664885516 +``` + +See [CLI Reference](../reference/cli-reference.md) for the full flag set. + +--- + +## Mesh patterns (agent-to-agent) + +When an agent calls another agent, the receiver's auth provider gates +the call the same way it would for a human or CI. Two common patterns: + +**Single-account "fleet" model.** Every agent runs as a workload in +one dedicated AWS account with its own IAM role; every agent's +`forge.yaml` has `allowed_accounts: []`. Trust boundary = +the account. Onboarding a new agent = create one IAM role; no other +agent's config changes. + +**Per-pair allowlist.** Sensitive agents (touching money, PII, customer +data) override the broad account allowlist with explicit +`allowed_principals` patterns for the specific calling agents allowed. + +See [Audit Logging](audit-logging.md) for how to grep `user_id` across +audit events to map the actual call graph. + +--- + +## Related Documentation + +| Document | Description | +|----------|-------------| +| [Audit Logging](audit-logging.md) | `auth_verify` / `auth_fail` event shape, reason codes, `token_kind` values | +| [Egress Security](egress-control.md) | Auth-host auto-allowlist and how it composes with operator-set domains | +| [Trust Model](trust-model.md) | Caller → Forge trust boundary; what Forge does and doesn't trust | +| [forge.yaml Schema](../reference/forge-yaml-schema.md) | Full YAML reference including `auth:` block | +| [CLI Reference](../reference/cli-reference.md) | `forge init` auth flags | +| [Web Dashboard](../reference/web-dashboard.md) | Auth provider options in the create flow | diff --git a/docs/security/egress-control.md b/docs/security/egress-control.md index afc58b5..b1bc00b 100644 --- a/docs/security/egress-control.md +++ b/docs/security/egress-control.md @@ -212,8 +212,28 @@ The resolver (`forge-core/security/resolver.go`) combines all domain sources: - Start with explicit domains from `forge.yaml` - Add tool-inferred domains - Add capability bundle domains + - **Add auth-provider-derived domains** (see below) - Deduplicate and sort +### Auth-provider domain auto-extension + +Configuring an `auth.providers[]` entry adds the host(s) the provider needs +to reach to the allowlist automatically — operators don't have to remember +to add `sts.us-east-1.amazonaws.com` themselves when they configure +`aws_sigv4`. The `security.AuthDomains` helper centralizes the mapping: + +| Provider | Host(s) added | +|---|---| +| `oidc` | host of `issuer` URL, host of explicit `jwks_url` if set | +| `http_verifier` | host of `url` | +| `aws_sigv4` | `sts..amazonaws.com` (+ test-mode `sts_endpoint` override host) | +| `gcp_iap` | `www.gstatic.com` (hardcoded, IAP JWKS lives there) | +| `azure_ad` | `login.microsoftonline.com` (+ `graph.microsoft.com` when `groups_mode: graph`) | + +`forge init`'s wizard runs the Auth step **before** Egress so the operator +sees the full outbound surface for review in a single screen. See +[Authentication](authentication.md) for the per-provider auth model. + ## Build Artifacts The `EgressStage` generates: diff --git a/docs/security/overview.md b/docs/security/overview.md index 897f604..a2e4515 100644 --- a/docs/security/overview.md +++ b/docs/security/overview.md @@ -18,6 +18,10 @@ Forge's security is organized in layers, each addressing a different threat surf │ Global Guardrails │ │ (content filtering, PII, jailbreak) │ ├──────────────────────────────────────────────────────────────┤ +│ Authentication (a2a) │ +│ (Pluggable provider chain: OIDC, AWS Sigv4, GCP IAP, │ +│ Azure AD, http_verifier, static_token loopback) │ +├──────────────────────────────────────────────────────────────┤ │ Egress Enforcement │ │ (EgressEnforcer + EgressProxy + SafeDialer + NetworkPolicy) │ ├──────────────────────────────────────────────────────────────┤ @@ -39,6 +43,7 @@ Forge's security is organized in layers, each addressing a different threat surf ## Table of Contents - [Network Posture](#network-posture) +- [Authentication](#authentication) - [Egress Enforcement](#egress-enforcement) - [Execution Sandboxing](#execution-sandboxing) - [Secrets Management](#secrets-management) @@ -48,6 +53,32 @@ Forge's security is organized in layers, each addressing a different threat surf - [Container Security](#container-security) - [Related Documentation](#related-documentation) +## Authentication + +The `/tasks` HTTP endpoint requires every caller to authenticate through a +pluggable provider chain configured in `forge.yaml`. Forge ships six provider +types and holds no IdP secrets — every provider verifies against a third +party (STS, GCP JWKS, AAD JWKS, your custom verifier) or a local file. + +| Provider | Use case | Wire format | +|---|---|---| +| `static_token` | Loopback (channel adapters, local Web UI dashboard) | Bearer literal | +| `oidc` | Generic OIDC IdP (Keycloak, Auth0, Okta, Google) | Bearer JWT | +| `aws_sigv4` | AWS IAM identities (Lambda, EC2, EKS, SSO users) | `Bearer forge-aws-v1.` | +| `gcp_iap` | Behind GCP HTTPS LB + IAP | `X-Goog-Iap-Jwt-Assertion: ` | +| `azure_ad` | Microsoft Entra ID | Bearer JWT | +| `http_verifier` | Custom `/verify` endpoint you operate | Opaque token | + +The chain is first-match-wins with **fail-closed on rejection** — a malformed +token of type A doesn't fall through to type B. The local Web UI dashboard +and channel adapters use an auto-prepended `static_token` (the `runtime.token` +file under `.forge/`) so they keep working regardless of how external auth is +configured. + +See [Authentication Providers](authentication.md) for the complete reference, +including per-provider security model, client-side recipes, and mesh +(agent-to-agent) patterns. + --- ## Network Posture @@ -267,6 +298,7 @@ Production builds enforce: | Document | Description | |----------|-------------| +| [Authentication](authentication.md) | Pluggable auth providers (OIDC, AWS Sigv4, GCP IAP, Azure AD, etc.) gating the a2a HTTP server | | [Egress Security](egress-control.md) | Deep dive into egress enforcement: IP validation, SafeDialer, profiles, modes, domain matching, proxy architecture, NetworkPolicy | | [Secrets Management](secret-management.md) | Encrypted storage, per-agent secrets, passphrase handling | | [Build Signing & Verification](build-signing.md) | Key management, build signing, runtime verification |