From f863f9f7bf4e90a8e9ddf9114d0343b5f1274c7a Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 15:34:47 -0400 Subject: [PATCH 01/20] refactor(auth): introduce pluggable Provider chain (PR1 foundation) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the monolithic auth middleware with a Provider interface, ChainProvider composition, and a database/sql-style factory registry. Existing --auth-url logic is migrated into the new http_verifier provider; the implicit channel-loopback token becomes a static_token provider injected at the chain head. No externally observable behavior change. Backward compatible with all existing --auth-url / FORGE_AUTH_URL / FORGE_AUTH_ORG_ID flags and env vars; channel-adapter loopback continues to short-circuit before the external verifier is consulted. Why: PR2 (OIDC), PR3 (forge.yaml auth: block), and Phase 3 (Okta) plug in via Register() — no further changes to middleware or runner. Tests: - 94.3% combined coverage across forge-core/auth/... - 11 e2e tests in forge-cli/runtime/auth_chain_test.go verifying legacy --auth-url, loopback short-circuit, skip-paths, wire format - -race clean across all touched packages --- forge-cli/runtime/auth_chain_test.go | 318 ++++++++++++ forge-cli/runtime/runner.go | 77 ++- forge-core/auth/chain_provider.go | 80 +++ forge-core/auth/context.go | 26 + forge-core/auth/middleware.go | 213 +++----- forge-core/auth/middleware_test.go | 201 +++++--- forge-core/auth/provider.go | 101 ++++ forge-core/auth/provider_test.go | 461 ++++++++++++++++++ .../auth/providers/httpverifier/provider.go | 183 +++++++ .../providers/httpverifier/provider_test.go | 263 ++++++++++ .../auth/providers/statictoken/provider.go | 122 +++++ .../providers/statictoken/provider_test.go | 186 +++++++ forge-core/auth/registry.go | 81 +++ forge-core/auth/settings.go | 27 + go.work.sum | 6 +- 15 files changed, 2120 insertions(+), 225 deletions(-) create mode 100644 forge-cli/runtime/auth_chain_test.go create mode 100644 forge-core/auth/chain_provider.go create mode 100644 forge-core/auth/context.go create mode 100644 forge-core/auth/provider.go create mode 100644 forge-core/auth/provider_test.go create mode 100644 forge-core/auth/providers/httpverifier/provider.go create mode 100644 forge-core/auth/providers/httpverifier/provider_test.go create mode 100644 forge-core/auth/providers/statictoken/provider.go create mode 100644 forge-core/auth/providers/statictoken/provider_test.go create mode 100644 forge-core/auth/registry.go create mode 100644 forge-core/auth/settings.go diff --git a/forge-cli/runtime/auth_chain_test.go b/forge-cli/runtime/auth_chain_test.go new file mode 100644 index 0000000..53a5f5f --- /dev/null +++ b/forge-cli/runtime/auth_chain_test.go @@ -0,0 +1,318 @@ +package runtime + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/initializ/forge/forge-core/auth" +) + +// E2E tests covering the auth chain that the runner builds from legacy +// CLI flags (--auth-token, --auth-url, --auth-org-id). These exist to +// guarantee that the PR1 refactor preserves the pre-refactor behavior +// byte-for-byte. If any of these fail, the refactor changed externally +// observable behavior — investigate before merging. + +// fakeVerifier accepts a handler and returns a httptest.Server that +// implements the http_verifier contract. +func newFakeVerifier(t *testing.T, handler func(req map[string]any) (status int, body any)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + _ = json.Unmarshal(body, &req) + status, respBody := handler(req) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if respBody != nil { + _ = json.NewEncoder(w).Encode(respBody) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func TestBuildLegacyAuthChain_NoAuth(t *testing.T) { + chain, err := buildLegacyAuthChain("", "", "") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + if chain != nil { + t.Errorf("chain = %v, want nil (no providers configured)", chain) + } +} + +func TestBuildLegacyAuthChain_TokenOnly(t *testing.T) { + chain, err := buildLegacyAuthChain("local-token", "", "") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + if chain == nil { + t.Fatal("chain is nil, want a ChainProvider with static_token") + } + + // Valid token accepted. + id, err := chain.Verify(context.Background(), "local-token", nil) + if err != nil { + t.Errorf("valid token rejected: %v", err) + } + if id == nil || id.Source != "internal" { + t.Errorf("identity = %+v, want Source=internal", id) + } + + // Wrong token yields with ErrTokenNotForMe (and at chain end, that's + // what we see — no http_verifier behind to claim it). + _, err = chain.Verify(context.Background(), "wrong", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("wrong token err = %v, want ErrTokenNotForMe", err) + } +} + +func TestBuildLegacyAuthChain_AuthURLOnly(t *testing.T) { + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + if req["token"] == "external-token" { + return http.StatusOK, map[string]any{ + "valid": true, + "user_id": "external-user", + "org_id": "external-org", + } + } + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, err := buildLegacyAuthChain("", srv.URL, "default-org") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + if chain == nil { + t.Fatal("chain is nil, want http_verifier chain") + } + + id, err := chain.Verify(context.Background(), "external-token", nil) + if err != nil { + t.Errorf("external token rejected: %v", err) + } + if id == nil || id.UserID != "external-user" || id.Source != "http_verifier" { + t.Errorf("identity = %+v", id) + } + + _, err = chain.Verify(context.Background(), "wrong", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("wrong token err = %v, want ErrTokenRejected", err) + } +} + +func TestBuildLegacyAuthChain_TokenAndAuthURL_LoopbackShortCircuits(t *testing.T) { + // This is the pre-refactor behavior: when both an internal Token AND + // AuthURL are set, the internal token is checked FIRST so loopback + // calls from channel adapters don't hit the external verifier. + var verifierCalls atomic.Int32 + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + verifierCalls.Add(1) + return http.StatusOK, map[string]any{"valid": true, "user_id": "external"} + }) + + chain, err := buildLegacyAuthChain("loopback-token", srv.URL, "") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + + // Loopback token: external verifier should NOT be called. + id, err := chain.Verify(context.Background(), "loopback-token", nil) + if err != nil { + t.Fatalf("loopback verify: %v", err) + } + if id.UserID != "forge-internal" || id.Source != "internal" { + t.Errorf("loopback identity = %+v, want forge-internal/internal", id) + } + if got := verifierCalls.Load(); got != 0 { + t.Errorf("external verifier called %d times during loopback (must be 0)", got) + } + + // Non-loopback token: falls through to external verifier. + id2, err := chain.Verify(context.Background(), "other-token", nil) + if err != nil { + t.Fatalf("external verify: %v", err) + } + if id2.Source != "http_verifier" { + t.Errorf("external identity Source = %q, want http_verifier", id2.Source) + } + if got := verifierCalls.Load(); got != 1 { + t.Errorf("external verifier called %d times, want 1", got) + } +} + +func TestBuildLegacyAuthChain_OrgIDPropagation(t *testing.T) { + var capturedOrgID atomic.Value + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + capturedOrgID.Store(req["org_id"]) + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "default-org") + + // No headers → config default. + if _, err := chain.Verify(context.Background(), "tok", nil); err != nil { + t.Fatalf("Verify: %v", err) + } + if got := capturedOrgID.Load(); got != "default-org" { + t.Errorf("verifier got org_id %v, want default-org", got) + } + + // X-Org-ID header overrides default. + if _, err := chain.Verify(context.Background(), "tok", auth.Headers{"X-Org-ID": "header-org"}); err != nil { + t.Fatalf("Verify: %v", err) + } + if got := capturedOrgID.Load(); got != "header-org" { + t.Errorf("verifier got org_id %v, want header-org", got) + } +} + +// --- middleware-level end-to-end --- + +// runMiddleware is a small driver that runs an http.Handler chain through +// the auth.Middleware with a given chain and returns the recorded response. +func runMiddleware(t *testing.T, chain auth.Provider, method, path, authHeader string) *httptest.ResponseRecorder { + t.Helper() + handler := auth.Middleware(auth.MiddlewareOptions{ + Chain: chain, + SkipPaths: auth.DefaultSkipPaths(), + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + + req := httptest.NewRequest(method, path, nil) + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + return rr +} + +func TestE2E_NoAuthConfigured_AnonymousAccess(t *testing.T) { + chain, _ := buildLegacyAuthChain("", "", "") + // nil chain → middleware passthrough + rr := runMiddleware(t, chain, "POST", "/tasks/send", "") + if rr.Code != http.StatusOK { + t.Errorf("anonymous access blocked: status = %d", rr.Code) + } +} + +func TestE2E_LegacyAuthURL_HappyPath(t *testing.T) { + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + if req["token"] == "good" { + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + } + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer good") + if rr.Code != http.StatusOK { + t.Errorf("good token rejected: status = %d, body = %s", rr.Code, rr.Body.String()) + } +} + +func TestE2E_LegacyAuthURL_BadTokenReturns401(t *testing.T) { + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer bad") + if rr.Code != http.StatusUnauthorized { + t.Errorf("bad token accepted: status = %d", rr.Code) + } +} + +func TestE2E_SkipPathsBypassAuth(t *testing.T) { + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + t.Error("verifier should not be called for skip paths") + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + + skipPaths := []struct { + method, path string + }{ + {"GET", "/"}, + {"GET", "/.well-known/agent.json"}, + {"GET", "/healthz"}, + {"GET", "/health"}, + } + + for _, sp := range skipPaths { + t.Run(sp.method+" "+sp.path, func(t *testing.T) { + rr := runMiddleware(t, chain, sp.method, sp.path, "") + if rr.Code != http.StatusOK { + t.Errorf("skip path %s %s blocked: status = %d", sp.method, sp.path, rr.Code) + } + }) + } +} + +func TestE2E_ChannelLoopback_DoesNotCallVerifier(t *testing.T) { + // Simulates a Slack/Telegram adapter calling the local A2A server + // with the internal loopback token. The external verifier must NOT + // be consulted — that would be wasteful and would break if the + // verifier is misconfigured/unreachable. + var verifierCalls atomic.Int32 + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + verifierCalls.Add(1) + return http.StatusOK, map[string]any{"valid": true, "user_id": "external"} + }) + + chain, _ := buildLegacyAuthChain("loopback-secret", srv.URL, "") + + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer loopback-secret") + if rr.Code != http.StatusOK { + t.Errorf("loopback token rejected: status = %d", rr.Code) + } + if got := verifierCalls.Load(); got != 0 { + t.Errorf("external verifier called %d times for loopback (must be 0)", got) + } +} + +func TestE2E_VerifierUnreachable_ReturnsInvalidNot500(t *testing.T) { + // When the external verifier is down (port closed), the middleware + // should return 401, not 500 — preserving the pre-refactor "auth + // provider error" behavior. + chain, _ := buildLegacyAuthChain("", "http://127.0.0.1:1/never", "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer x") + if rr.Code != http.StatusUnauthorized { + t.Errorf("verifier-unreachable status = %d, want 401", rr.Code) + } +} + +func TestE2E_WireFormat_TokenAndOrgID(t *testing.T) { + // Black-box: confirm the verifier sees exactly `token` and `org_id` + // JSON fields. Guards against future schema drift. + var sawToken, sawOrgID atomic.Value + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + sawToken.Store(req["token"]) + sawOrgID.Store(req["org_id"]) + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer the-token") + if rr.Code != http.StatusOK { + t.Fatalf("status = %d", rr.Code) + } + if got := sawToken.Load(); got != "the-token" { + t.Errorf("verifier got token = %v, want the-token", got) + } + // No org headers sent → verifier sees empty org_id. + if got := sawOrgID.Load(); got != "" { + t.Errorf("verifier got org_id = %v, want empty", got) + } +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index de974df..68600e9 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -16,6 +16,8 @@ import ( "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/agentspec" "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/httpverifier" + "github.com/initializ/forge/forge-core/auth/providers/statictoken" "github.com/initializ/forge/forge-core/llm" "github.com/initializ/forge/forge-core/llm/oauth" "github.com/initializ/forge/forge-core/llm/providers" @@ -1719,24 +1721,33 @@ func (r *Runner) printBanner(proxyURL string) { fmt.Fprintf(os.Stderr, " Press Ctrl+C to stop\n\n") } -// resolveAuth builds the auth middleware config. Token resolution is done by -// ResolveAuth() (called early so channel adapters can use it); this method -// just wires the already-resolved token into a middleware Config with the audit callback. -func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Config, error) { +// resolveAuth builds the auth middleware options for the A2A server. +// +// Behavior preserved from the pre-chain implementation: +// - --no-auth (NoAuth=true) → no chain, anonymous access +// - --auth-url set → chain: [static_token (loopback), http_verifier] +// - neither set → chain: [static_token (local)] +// +// Token resolution is performed by ResolveAuth() (called earlier so channel +// adapters can use it); this method translates the resolved token + URL into +// a Provider chain. +func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.MiddlewareOptions, error) { // Ensure token is resolved (no-op if already done by ResolveAuth). if err := r.ResolveAuth(); err != nil { - return auth.Config{}, err + return auth.MiddlewareOptions{}, err } if r.cfg.NoAuth { - return auth.Config{Enabled: false}, nil + return auth.MiddlewareOptions{}, nil // nil chain → passthrough + } + + chain, err := buildLegacyAuthChain(r.authToken, r.cfg.AuthURL, r.cfg.AuthOrgID) + if err != nil { + return auth.MiddlewareOptions{}, fmt.Errorf("building auth chain: %w", err) } - cfg := auth.Config{ - Enabled: true, - Token: r.authToken, - AuthURL: r.cfg.AuthURL, - AuthOrgID: r.cfg.AuthOrgID, + return auth.MiddlewareOptions{ + Chain: chain, SkipPaths: auth.DefaultSkipPaths(), OnAuth: func(req *http.Request, success bool) { if auditLogger == nil { @@ -1755,8 +1766,50 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Config, }, }) }, + }, nil +} + +// buildLegacyAuthChain constructs the Provider chain from the legacy +// CLI-flag inputs (--auth-token / --auth-url / --auth-org-id). +// +// When both an internal token and an external --auth-url are present, the +// static_token provider runs first so channel-adapter loopback calls +// short-circuit before reaching the external verifier. This mirrors the +// pre-refactor "internal token then external" check in middleware.go. +// +// Returns nil chain (passthrough) when no providers are configured. +func buildLegacyAuthChain(internalToken, authURL, authOrgID string) (auth.Provider, error) { + var providers []auth.Provider + + if internalToken != "" { + p, err := statictoken.New(statictoken.Config{ + Token: internalToken, + Identity: auth.Identity{ + UserID: "forge-internal", + Source: "internal", + }, + }) + if err != nil { + return nil, fmt.Errorf("static_token provider: %w", err) + } + providers = append(providers, p) + } + + if authURL != "" { + p, err := httpverifier.New(httpverifier.Config{ + URL: authURL, + DefaultOrg: authOrgID, + }) + if err != nil { + return nil, fmt.Errorf("http_verifier provider: %w", err) + } + providers = append(providers, p) + } + + if len(providers) == 0 { + return nil, nil } - return cfg, nil + return auth.NewChainProvider(providers...), nil } // ensureGitignore makes sure .forge/ is listed in the project's .gitignore. diff --git a/forge-core/auth/chain_provider.go b/forge-core/auth/chain_provider.go new file mode 100644 index 0000000..31d3080 --- /dev/null +++ b/forge-core/auth/chain_provider.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "errors" +) + +// ChainProvider composes multiple Providers into a single Provider, calling +// them in order and returning the first non-yielding result. +// +// Behavior: +// - A Provider returning (id, nil) wins; later providers are not consulted. +// - A Provider returning ErrTokenNotForMe is skipped; the chain advances. +// - Any other error stops the chain immediately (fail-closed). The middleware +// surfaces this as 401. Falling through on non-yield errors would let an +// attacker bypass a temporarily-misbehaving provider. +// +// A ChainProvider is immutable after construction and safe for concurrent use. +type ChainProvider struct { + providers []Provider +} + +// NewChainProvider returns a chain that calls the given providers in order. +// A chain with zero providers verifies nothing — Verify returns ErrTokenNotForMe. +func NewChainProvider(providers ...Provider) *ChainProvider { + // Defensive copy so callers can't mutate the slice after construction. + cp := make([]Provider, len(providers)) + copy(cp, providers) + return &ChainProvider{providers: cp} +} + +// Name implements Provider. +func (c *ChainProvider) Name() string { return "chain" } + +// Providers returns a defensive copy of the configured providers, in order. +func (c *ChainProvider) Providers() []Provider { + out := make([]Provider, len(c.providers)) + copy(out, c.providers) + return out +} + +// Verify implements Provider with first-match-wins semantics. +func (c *ChainProvider) Verify(ctx context.Context, token string, headers Headers) (*Identity, error) { + for _, p := range c.providers { + id, err := p.Verify(ctx, token, headers) + if err == nil { + return id, nil + } + if errors.Is(err, ErrTokenNotForMe) { + continue + } + // Fail closed: non-yield error stops the chain. + return nil, err + } + return nil, ErrTokenNotForMe +} + +// PrependChain returns a new ChainProvider whose providers are `prepend` +// followed by the providers of `chain`. If chain is a *ChainProvider, its +// providers are flattened (no nested chains). If chain is a non-chain +// Provider, it is appended as a single element. If chain is nil, the result +// contains only the prepended providers. +// +// Useful for the runner to inject a loopback static_token at the chain head +// without callers having to know whether their chain already exists. +func PrependChain(chain Provider, prepend ...Provider) *ChainProvider { + var existing []Provider + switch v := chain.(type) { + case nil: + // nothing to extend + case *ChainProvider: + existing = v.providers + default: + existing = []Provider{v} + } + combined := make([]Provider, 0, len(prepend)+len(existing)) + combined = append(combined, prepend...) + combined = append(combined, existing...) + return &ChainProvider{providers: combined} +} diff --git a/forge-core/auth/context.go b/forge-core/auth/context.go new file mode 100644 index 0000000..045e581 --- /dev/null +++ b/forge-core/auth/context.go @@ -0,0 +1,26 @@ +package auth + +import "context" + +// identityKey is an unexported type used as the context key, so that no +// other package can collide with our value. +type identityKey struct{} + +// WithIdentity returns a copy of ctx that carries the given Identity. +// Storing nil is a no-op (returns ctx unchanged). +func WithIdentity(ctx context.Context, id *Identity) context.Context { + if id == nil { + return ctx + } + return context.WithValue(ctx, identityKey{}, id) +} + +// IdentityFromContext returns the Identity stored on ctx by WithIdentity, +// or nil if no Identity is present. +func IdentityFromContext(ctx context.Context) *Identity { + if ctx == nil { + return nil + } + id, _ := ctx.Value(identityKey{}).(*Identity) + return id +} diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index e94bbb3..cf7eb51 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -1,86 +1,16 @@ package auth import ( - "bytes" "encoding/json" - "fmt" + "errors" "net/http" "strings" - "time" ) -// Config controls bearer-token authentication for the A2A server. -type Config struct { - // Enabled controls whether authentication is enforced. - Enabled bool - - // Token is the expected bearer token value (local validation). - Token string - - // AuthURL is an external auth provider endpoint. When set, the middleware - // forwards the bearer token to this URL via POST for validation instead - // of comparing against the local Token value. - AuthURL string - - // AuthOrgID is the default org_id sent to the external auth provider. - // Overridden per-request by the X-Org-ID header when present. - AuthOrgID string - - // SkipPaths maps "METHOD /path" keys that bypass authentication. - // Example: "GET /" → true allows unauthenticated GET on root. - SkipPaths map[string]bool - - // OnAuth is an optional callback invoked on every auth decision. - // success indicates whether the request was authenticated. - OnAuth func(r *http.Request, success bool) -} - -// authHTTPClient is a shared client with reasonable timeouts for auth requests. -var authHTTPClient = &http.Client{Timeout: 10 * time.Second} - -// externalVerifyRequest is the request body sent to the external auth URL. -type externalVerifyRequest struct { - Token string `json:"token"` - OrgID string `json:"org_id"` -} - -// externalVerifyResponse is the response from the external auth URL. -type externalVerifyResponse struct { - Valid bool `json:"valid"` - Error string `json:"error,omitempty"` - UserID string `json:"user_id,omitempty"` - OrgID string `json:"org_id,omitempty"` - Email string `json:"email,omitempty"` - WorkspaceID string `json:"workspace_id,omitempty"` -} - -// validateExternal sends the bearer token to the external auth URL and returns -// whether the token is valid. -func validateExternal(authURL, token, orgID string) (bool, error) { - body, err := json.Marshal(externalVerifyRequest{Token: token, OrgID: orgID}) - if err != nil { - return false, fmt.Errorf("marshalling auth request: %w", err) - } - - resp, err := authHTTPClient.Post(authURL, "application/json", bytes.NewReader(body)) - if err != nil { - return false, fmt.Errorf("calling auth URL: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode == http.StatusUnauthorized { - return false, nil - } - if resp.StatusCode != http.StatusOK { - return false, fmt.Errorf("auth URL returned status %d", resp.StatusCode) - } - - var result externalVerifyResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return false, fmt.Errorf("decoding auth response: %w", err) - } - - return result.Valid, nil +// errorResponse is the JSON body returned for auth failures. +type errorResponse struct { + Error string `json:"error"` + Message string `json:"message"` } // DefaultSkipPaths returns the default set of public endpoints @@ -98,115 +28,100 @@ func DefaultSkipPaths() map[string]bool { } } -// errorResponse is the JSON body returned for auth failures. -type errorResponse struct { - Error string `json:"error"` - Message string `json:"message"` +// MiddlewareOptions configures Middleware. +type MiddlewareOptions struct { + // Chain is the provider chain that verifies bearer tokens. If nil, + // the middleware behaves as a passthrough (anonymous access). + Chain Provider + + // SkipPaths maps "METHOD /path" keys that bypass authentication. + // If nil, DefaultSkipPaths() is used. + SkipPaths map[string]bool + + // OnAuth is an optional callback invoked on every auth decision. + // success is true when the request authenticated; false otherwise. + OnAuth func(r *http.Request, success bool) } -// Middleware returns an http.Handler that enforces bearer token authentication. -// If cfg.Enabled is false, requests pass through without checks. -// When cfg.AuthURL is set, tokens are validated against the external provider. -func Middleware(cfg Config) func(http.Handler) http.Handler { +// Middleware returns an http.Handler that enforces bearer token authentication +// via the provided Provider chain. If opts.Chain is nil, requests pass through +// without checks (anonymous access). +func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { + if opts.Chain == nil { + return func(next http.Handler) http.Handler { return next } + } + skip := opts.SkipPaths + if skip == nil { + skip = DefaultSkipPaths() + } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !cfg.Enabled { + if skip[r.Method+" "+r.URL.Path] { next.ServeHTTP(w, r) return } - // Check if this method+path combination is public. - key := r.Method + " " + r.URL.Path - if cfg.SkipPaths[key] { - next.ServeHTTP(w, r) - return - } - - // Extract bearer token from Authorization header. token := extractBearerToken(r) if token == "" { - authFail(w, r, cfg, "valid bearer token required") + writeAuthFail(w, r, opts.OnAuth, "valid bearer token required") return } - // When external auth is configured, first check the internal - // token (used by channel adapter loopback calls), then fall - // through to the external auth provider for other tokens. - if cfg.AuthURL != "" { - // Accept internal token for loopback (channel adapters) - if cfg.Token != "" && ValidateToken(token, cfg.Token) { - if cfg.OnAuth != nil { - cfg.OnAuth(r, true) - } - next.ServeHTTP(w, r) - return - } - // Use org_id from request header (multiple header names), fall back to config. - orgID := extractOrgID(r, cfg.AuthOrgID) - valid, err := validateExternal(cfg.AuthURL, token, orgID) - if err != nil { - authFail(w, r, cfg, "auth provider error") - return - } - if valid { - if cfg.OnAuth != nil { - cfg.OnAuth(r, true) - } - next.ServeHTTP(w, r) - return - } - authFail(w, r, cfg, "token rejected by auth provider") + identity, err := opts.Chain.Verify(r.Context(), token, HeadersFromRequest(r)) + if err != nil || identity == nil { + writeAuthFail(w, r, opts.OnAuth, classifyAuthFailure(err)) return } - // Local token validation. - if ValidateToken(token, cfg.Token) { - if cfg.OnAuth != nil { - cfg.OnAuth(r, true) - } - next.ServeHTTP(w, r) - return + if opts.OnAuth != nil { + opts.OnAuth(r, true) } - authFail(w, r, cfg, "valid bearer token required") + ctx := WithIdentity(r.Context(), identity) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } -// authFail sends a 401 response and fires the OnAuth callback. -func authFail(w http.ResponseWriter, r *http.Request, cfg Config, msg string) { - if cfg.OnAuth != nil { - cfg.OnAuth(r, false) +// classifyAuthFailure maps a chain error into a user-visible message. +// Keep messages generic to avoid leaking information about which provider +// rejected the token or why. +func classifyAuthFailure(err error) string { + switch { + case err == nil, errors.Is(err, ErrTokenNotForMe): + return "valid bearer token required" + case errors.Is(err, ErrTokenRejected): + return "token rejected by auth provider" + case errors.Is(err, ErrInvalidToken): + return "invalid token" + default: + return "auth provider error" + } +} + +// writeAuthFail sends a 401 response and fires the OnAuth callback if set. +func writeAuthFail(w http.ResponseWriter, r *http.Request, onAuth func(*http.Request, bool), msg string) { + if onAuth != nil { + onAuth(r, false) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(errorResponse{ //nolint:errcheck + _ = json.NewEncoder(w).Encode(errorResponse{ Error: "unauthorized", Message: msg, }) } -// extractOrgID reads the org ID from the request headers, checking multiple -// common header names: "X-Org-ID", "org-id", "org_id". Falls back to the -// provided default if none are set. -func extractOrgID(r *http.Request, fallback string) string { - for _, h := range []string{"X-Org-ID", "org-id", "org_id"} { - if v := r.Header.Get(h); v != "" { - return v - } - } - return fallback -} - // extractBearerToken extracts the token from "Authorization: Bearer ". +// Case-insensitive on the "Bearer" prefix to preserve historical behavior. func extractBearerToken(r *http.Request) string { - auth := r.Header.Get("Authorization") - if auth == "" { + header := r.Header.Get("Authorization") + if header == "" { return "" } const prefix = "Bearer " - if len(auth) > len(prefix) && strings.EqualFold(auth[:len(prefix)], prefix) { - return auth[len(prefix):] + if len(header) > len(prefix) && strings.EqualFold(header[:len(prefix)], prefix) { + return header[len(prefix):] } return "" } diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index 7107698..b18e4b5 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -8,125 +9,157 @@ import ( "testing" ) +// tokenProvider is a minimal test-only Provider that compares against a fixed +// token. We define it here so middleware tests don't need to import the +// statictoken provider package (which would create a cycle since it imports +// this package). +type tokenProvider struct { + expected string + identity Identity +} + +func (t *tokenProvider) Name() string { return "test_token" } + +func (t *tokenProvider) Verify(_ context.Context, token string, _ Headers) (*Identity, error) { + if token != t.expected { + return nil, ErrTokenNotForMe + } + id := t.identity + return &id, nil +} + +// rejectingProvider returns ErrTokenRejected for any token, simulating an +// external verifier that recognized the token but denied it. +type rejectingProvider struct{} + +func (rejectingProvider) Name() string { return "test_reject" } +func (rejectingProvider) Verify(_ context.Context, _ string, _ Headers) (*Identity, error) { + return nil, ErrTokenRejected +} + +// brokenProvider returns ErrInvalidToken (e.g., signature failure). +type brokenProvider struct{} + +func (brokenProvider) Name() string { return "test_broken" } +func (brokenProvider) Verify(_ context.Context, _ string, _ Headers) (*Identity, error) { + return nil, ErrInvalidToken +} + func TestMiddleware(t *testing.T) { const validToken = "test-secret-token" okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("ok")) //nolint:errcheck + _, _ = w.Write([]byte("ok")) + }) + + chain := NewChainProvider(&tokenProvider{ + expected: validToken, + identity: Identity{UserID: "u1", Source: "test_token"}, }) tests := []struct { name string - cfg Config + opts MiddlewareOptions method string path string authHeader string wantStatus int }{ { - name: "disabled passes through", - cfg: Config{Enabled: false}, + name: "nil chain passes through", + opts: MiddlewareOptions{Chain: nil}, method: "POST", path: "/", wantStatus: http.StatusOK, }, { - name: "valid token accepted", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "valid token accepted", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", authHeader: "Bearer " + validToken, wantStatus: http.StatusOK, }, { - name: "missing token rejected", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "missing token rejected", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", wantStatus: http.StatusUnauthorized, }, { - name: "wrong token rejected", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "wrong token rejected", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", authHeader: "Bearer wrong-token", wantStatus: http.StatusUnauthorized, }, { - name: "GET / is public", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "GET / is public", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "GET", path: "/", wantStatus: http.StatusOK, }, { - name: "GET /.well-known/agent.json is public", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "GET /.well-known/agent.json is public", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "GET", path: "/.well-known/agent.json", wantStatus: http.StatusOK, }, { - name: "GET /healthz is public", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "GET /healthz is public", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "GET", path: "/healthz", wantStatus: http.StatusOK, }, { - name: "POST /tasks/send requires auth", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "POST /tasks/send requires auth", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/tasks/send", wantStatus: http.StatusUnauthorized, }, { - name: "case insensitive Bearer prefix", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "case insensitive Bearer prefix", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", authHeader: "bearer " + validToken, wantStatus: http.StatusOK, }, + { + name: "rejected by provider returns 401", + opts: MiddlewareOptions{Chain: NewChainProvider(rejectingProvider{}), SkipPaths: DefaultSkipPaths()}, + method: "POST", + path: "/", + authHeader: "Bearer anything", + wantStatus: http.StatusUnauthorized, + }, + { + name: "invalid token from provider returns 401", + opts: MiddlewareOptions{Chain: NewChainProvider(brokenProvider{}), SkipPaths: DefaultSkipPaths()}, + method: "POST", + path: "/", + authHeader: "Bearer malformed", + wantStatus: http.StatusUnauthorized, + }, + { + name: "nil SkipPaths uses defaults", + opts: MiddlewareOptions{Chain: chain}, + method: "GET", + path: "/healthz", + wantStatus: http.StatusOK, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mw := Middleware(tt.cfg) + mw := Middleware(tt.opts) handler := mw(okHandler) req := httptest.NewRequest(tt.method, tt.path, nil) @@ -155,16 +188,18 @@ func TestMiddleware(t *testing.T) { } } -func TestMiddlewareOnAuthCallback(t *testing.T) { +func TestMiddleware_OnAuthCallback(t *testing.T) { const token = "callback-token" var successCount, failCount atomic.Int32 - cfg := Config{ - Enabled: true, - Token: token, + opts := MiddlewareOptions{ + Chain: NewChainProvider(&tokenProvider{ + expected: token, + identity: Identity{UserID: "u", Source: "test"}, + }), SkipPaths: DefaultSkipPaths(), - OnAuth: func(r *http.Request, success bool) { + OnAuth: func(_ *http.Request, success bool) { if success { successCount.Add(1) } else { @@ -173,7 +208,7 @@ func TestMiddlewareOnAuthCallback(t *testing.T) { }, } - handler := Middleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) @@ -193,3 +228,51 @@ func TestMiddlewareOnAuthCallback(t *testing.T) { t.Errorf("failure callbacks = %d, want 1", got) } } + +func TestMiddleware_IdentityIsAttachedToContext(t *testing.T) { + const token = "ctx-token" + + wantID := Identity{UserID: "ctx-user", Email: "ctx@example.com", Source: "test_token"} + + opts := MiddlewareOptions{ + Chain: NewChainProvider(&tokenProvider{expected: token, identity: wantID}), + SkipPaths: DefaultSkipPaths(), + } + + var gotID *Identity + handler := Middleware(opts)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotID = IdentityFromContext(r.Context()) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer "+token) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotID == nil { + t.Fatal("identity not attached to context") + } + if gotID.UserID != wantID.UserID || gotID.Email != wantID.Email { + t.Errorf("identity = %+v, want %+v", gotID, wantID) + } +} + +func TestClassifyAuthFailure(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + {"nil", nil, "valid bearer token required"}, + {"not for me", ErrTokenNotForMe, "valid bearer token required"}, + {"rejected", ErrTokenRejected, "token rejected by auth provider"}, + {"invalid", ErrInvalidToken, "invalid token"}, + {"unexpected", context.Canceled, "auth provider error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := classifyAuthFailure(tt.err); got != tt.want { + t.Errorf("classifyAuthFailure(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } +} diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go new file mode 100644 index 0000000..a7b3322 --- /dev/null +++ b/forge-core/auth/provider.go @@ -0,0 +1,101 @@ +// Package auth provides bearer-token authentication for the A2A server, +// built around a pluggable Provider chain. +// +// Each Provider claims tokens it recognizes and rejects the rest with +// ErrTokenNotForMe, allowing a ChainProvider to compose multiple +// providers in a first-match-wins fashion. The error return is the only +// signal for the verification outcome — there is intentionally no +// Identity.Valid field, because a nil-error return is the contract for +// "this token is valid." +// +// New providers live under forge-core/auth/providers// and register +// themselves via init() against the package-level registry, mirroring the +// database/sql driver pattern. +package auth + +import ( + "context" + "errors" + "net/http" + "strings" +) + +// Provider verifies a bearer token and returns the caller's Identity. +// +// Implementations must: +// - Return (id, nil) on a verified token. +// - Return (nil, ErrTokenNotForMe) when the token is not for this provider +// (so the ChainProvider can try the next provider). +// - Return (nil, ErrTokenRejected) when the token is recognized but denied +// (e.g., revoked, expired, untrusted issuer). +// - Return (nil, ErrInvalidToken) when the token is malformed or +// cryptographically invalid. +// - Return (nil, other-error) for transient failures (network, etc.). +// The ChainProvider treats these as fatal (fail-closed) — it does NOT +// fall through to the next provider on infrastructure errors, because +// doing so would allow attackers to evade a temporarily-down provider. +type Provider interface { + Name() string + Verify(ctx context.Context, token string, headers Headers) (*Identity, error) +} + +// Identity is the authenticated principal extracted by a Provider. +// +// There is intentionally no Valid field — a non-nil *Identity returned +// alongside a nil error is the only "valid" signal. See package comment. +type Identity struct { + UserID string `json:"user_id,omitempty"` + Email string `json:"email,omitempty"` + OrgID string `json:"org_id,omitempty"` + WorkspaceID string `json:"workspace_id,omitempty"` + Groups []string `json:"groups,omitempty"` + Claims map[string]any `json:"claims,omitempty"` + // Source records which provider verified the identity (e.g., "oidc", + // "http_verifier", "static_token"). Useful for audit logs and debugging. + Source string `json:"source,omitempty"` +} + +// Headers is a case-insensitive view over selected request headers passed +// to providers. Providers should not assume any particular casing. +type Headers map[string]string + +// Get returns the value for the given header, matched case-insensitively. +func (h Headers) Get(key string) string { + if v, ok := h[key]; ok { + return v + } + lower := strings.ToLower(key) + for k, v := range h { + if strings.ToLower(k) == lower { + return v + } + } + return "" +} + +// HeadersFromRequest extracts the well-known headers providers may use. +// Keep this list narrow — providers should be explicit about the contract. +func HeadersFromRequest(r *http.Request) Headers { + return Headers{ + "X-Org-ID": r.Header.Get("X-Org-ID"), + "X-Request-ID": r.Header.Get("X-Request-ID"), + "org-id": r.Header.Get("org-id"), + "org_id": r.Header.Get("org_id"), + } +} + +// Sentinel errors that Providers and the ChainProvider use to signal outcomes. +// +// - ErrTokenNotForMe → provider does not recognize this token shape; +// ChainProvider should try the next provider. +// - ErrTokenRejected → provider recognized the token and denied it; +// ChainProvider stops and the middleware writes 401. +// - ErrInvalidToken → token is malformed or cryptographically invalid; +// ChainProvider stops and the middleware writes 401. +// - ErrProviderNotConfigured → returned by New(); never by Verify(). +var ( + ErrTokenNotForMe = errors.New("auth: token not for this provider") + ErrTokenRejected = errors.New("auth: token rejected") + ErrInvalidToken = errors.New("auth: invalid token") + ErrProviderNotConfigured = errors.New("auth: provider not configured") +) diff --git a/forge-core/auth/provider_test.go b/forge-core/auth/provider_test.go new file mode 100644 index 0000000..cadcaa1 --- /dev/null +++ b/forge-core/auth/provider_test.go @@ -0,0 +1,461 @@ +package auth + +import ( + "context" + "errors" + "maps" + "reflect" + "sort" + "sync" + "testing" +) + +// stubProvider is an in-package test provider that lets us script the +// outcome of Verify deterministically. +type stubProvider struct { + name string + identity *Identity + err error + calls int +} + +func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) Verify(_ context.Context, _ string, _ Headers) (*Identity, error) { + s.calls++ + return s.identity, s.err +} + +// --- ChainProvider --- + +func TestChainProvider_FirstMatchWins(t *testing.T) { + idA := &Identity{UserID: "a", Source: "first"} + idB := &Identity{UserID: "b", Source: "second"} + a := &stubProvider{name: "first", identity: idA} + b := &stubProvider{name: "second", identity: idB} + + chain := NewChainProvider(a, b) + got, err := chain.Verify(context.Background(), "tok", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got != idA { + t.Errorf("identity = %+v, want %+v", got, idA) + } + if a.calls != 1 { + t.Errorf("first provider calls = %d, want 1", a.calls) + } + if b.calls != 0 { + t.Errorf("second provider should not be called, got %d calls", b.calls) + } +} + +func TestChainProvider_NotForMeYieldsToNext(t *testing.T) { + idB := &Identity{UserID: "b"} + a := &stubProvider{name: "first", err: ErrTokenNotForMe} + b := &stubProvider{name: "second", identity: idB} + + chain := NewChainProvider(a, b) + got, err := chain.Verify(context.Background(), "tok", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got != idB { + t.Errorf("identity = %+v, want %+v", got, idB) + } + if a.calls != 1 || b.calls != 1 { + t.Errorf("call counts: a=%d b=%d, want 1,1", a.calls, b.calls) + } +} + +func TestChainProvider_RejectedStopsChain(t *testing.T) { + a := &stubProvider{name: "first", err: ErrTokenRejected} + b := &stubProvider{name: "second", identity: &Identity{UserID: "b"}} + + chain := NewChainProvider(a, b) + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected", err) + } + if b.calls != 0 { + t.Errorf("second provider should not be called after rejection, got %d calls", b.calls) + } +} + +func TestChainProvider_InvalidStopsChain(t *testing.T) { + a := &stubProvider{name: "first", err: ErrInvalidToken} + b := &stubProvider{name: "second", identity: &Identity{UserID: "b"}} + + chain := NewChainProvider(a, b) + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken", err) + } + if b.calls != 0 { + t.Errorf("second provider should not be called after invalid, got %d calls", b.calls) + } +} + +func TestChainProvider_GenericErrorFailsClosed(t *testing.T) { + // Infrastructure errors (network, etc.) must NOT fall through to the + // next provider — that would let attackers evade a temporarily-down + // upstream by waiting for it to fail. + netErr := errors.New("connection refused") + a := &stubProvider{name: "first", err: netErr} + b := &stubProvider{name: "second", identity: &Identity{UserID: "b"}} + + chain := NewChainProvider(a, b) + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, netErr) { + t.Fatalf("err = %v, want network error", err) + } + if b.calls != 0 { + t.Errorf("fail-closed violated: second provider was called after generic error") + } +} + +func TestChainProvider_EmptyChainYields(t *testing.T) { + chain := NewChainProvider() + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrTokenNotForMe) { + t.Errorf("empty chain err = %v, want ErrTokenNotForMe", err) + } +} + +func TestChainProvider_AllYieldFinalErrorIsNotForMe(t *testing.T) { + a := &stubProvider{name: "a", err: ErrTokenNotForMe} + b := &stubProvider{name: "b", err: ErrTokenNotForMe} + chain := NewChainProvider(a, b) + + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestChainProvider_DefensiveSliceCopy(t *testing.T) { + a := &stubProvider{name: "a", identity: &Identity{UserID: "a"}} + b := &stubProvider{name: "b", identity: &Identity{UserID: "b"}} + srcSlice := []Provider{a, b} + + chain := NewChainProvider(srcSlice...) + // Mutate the source — chain must be unaffected. + srcSlice[0] = &stubProvider{name: "tamper", err: ErrTokenRejected} + + got, err := chain.Verify(context.Background(), "tok", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got.UserID != "a" { + t.Errorf("chain saw mutated source; identity = %+v", got) + } +} + +func TestChainProvider_ProvidersReturnsCopy(t *testing.T) { + a := &stubProvider{name: "a"} + chain := NewChainProvider(a) + + listed := chain.Providers() + // Mutate the returned slice — internal state must be unchanged. + listed[0] = &stubProvider{name: "tamper"} + + listed2 := chain.Providers() + if listed2[0].Name() != "a" { + t.Errorf("Providers() did not return a defensive copy; got %q", listed2[0].Name()) + } +} + +func TestChainProvider_NameIsChain(t *testing.T) { + if got := (&ChainProvider{}).Name(); got != "chain" { + t.Errorf("Name = %q, want chain", got) + } +} + +// --- PrependChain --- + +func TestPrependChain_FlattenedChain(t *testing.T) { + inner := NewChainProvider( + &stubProvider{name: "b"}, + &stubProvider{name: "c"}, + ) + combined := PrependChain(inner, &stubProvider{name: "a"}) + + names := providerNames(combined.Providers()) + want := []string{"a", "b", "c"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v (flatten failed)", names, want) + } +} + +func TestPrependChain_NonChainSingle(t *testing.T) { + single := &stubProvider{name: "single"} + combined := PrependChain(single, &stubProvider{name: "first"}) + names := providerNames(combined.Providers()) + want := []string{"first", "single"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v", names, want) + } +} + +func TestPrependChain_NilChain(t *testing.T) { + combined := PrependChain(nil, &stubProvider{name: "first"}) + names := providerNames(combined.Providers()) + want := []string{"first"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v", names, want) + } +} + +func TestPrependChain_NoPrependProviders(t *testing.T) { + inner := NewChainProvider(&stubProvider{name: "x"}) + combined := PrependChain(inner) + names := providerNames(combined.Providers()) + want := []string{"x"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v", names, want) + } +} + +// --- Headers --- + +func TestHeaders_GetCaseInsensitive(t *testing.T) { + h := Headers{"X-Org-ID": "acme"} + tests := []struct { + key string + want string + }{ + {"X-Org-ID", "acme"}, + {"x-org-id", "acme"}, + {"X-ORG-ID", "acme"}, + {"X-Other", ""}, + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + if got := h.Get(tt.key); got != tt.want { + t.Errorf("Get(%q) = %q, want %q", tt.key, got, tt.want) + } + }) + } +} + +func TestHeaders_EmptyMap(t *testing.T) { + var h Headers + if got := h.Get("anything"); got != "" { + t.Errorf("Get on nil Headers = %q, want empty", got) + } +} + +// --- Context --- + +func TestContext_RoundTrip(t *testing.T) { + want := &Identity{UserID: "alice", Source: "test"} + ctx := WithIdentity(context.Background(), want) + got := IdentityFromContext(ctx) + if got != want { + t.Errorf("IdentityFromContext = %+v, want %+v", got, want) + } +} + +func TestContext_NilIdentityIsNoOp(t *testing.T) { + parent := context.Background() + ctx := WithIdentity(parent, nil) + if ctx != parent { + t.Error("WithIdentity(nil) should return parent ctx unchanged") + } +} + +func TestContext_MissingReturnsNil(t *testing.T) { + if got := IdentityFromContext(context.Background()); got != nil { + t.Errorf("IdentityFromContext on empty ctx = %+v, want nil", got) + } +} + +func TestContext_NilContextReturnsNil(t *testing.T) { + // Intentional: confirm we don't panic on a defensive nil-context call. + var ctx context.Context //nolint:gocritic // intentional nil-context safety check + if got := IdentityFromContext(ctx); got != nil { + t.Errorf("IdentityFromContext(nil) = %+v, want nil", got) + } +} + +// --- Registry --- + +// registryGuard isolates registry mutations from the global state so the +// concurrently-running provider package tests (httpverifier, statictoken) +// don't interfere. +// +// Strategy: snapshot the registry, run the test against a wiped registry, +// then restore. +func registryGuard(t *testing.T) { + t.Helper() + registryMu.Lock() + snapshot := maps.Clone(registry) + registryMu.Unlock() + t.Cleanup(func() { + registryMu.Lock() + registry = snapshot + registryMu.Unlock() + }) + resetRegistryForTest() +} + +func TestRegister_RoundTrip(t *testing.T) { + registryGuard(t) + + Register("custom", func(_ map[string]any) (Provider, error) { + return &stubProvider{name: "custom"}, nil + }) + + p, err := Build("custom", nil) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "custom" { + t.Errorf("Name = %q, want custom", p.Name()) + } +} + +func TestRegister_DuplicatePanics(t *testing.T) { + registryGuard(t) + + Register("dup", func(_ map[string]any) (Provider, error) { + return &stubProvider{}, nil + }) + + defer func() { + if r := recover(); r == nil { + t.Error("duplicate Register did not panic") + } + }() + Register("dup", func(_ map[string]any) (Provider, error) { + return &stubProvider{}, nil + }) +} + +func TestRegister_EmptyNamePanics(t *testing.T) { + registryGuard(t) + + defer func() { + if r := recover(); r == nil { + t.Error("Register with empty name did not panic") + } + }() + Register("", func(_ map[string]any) (Provider, error) { + return &stubProvider{}, nil + }) +} + +func TestRegister_NilFactoryPanics(t *testing.T) { + registryGuard(t) + + defer func() { + if r := recover(); r == nil { + t.Error("Register with nil factory did not panic") + } + }() + Register("nil_fac", nil) +} + +func TestBuild_UnknownTypeReturnsError(t *testing.T) { + registryGuard(t) + + _, err := Build("does-not-exist", nil) + if err == nil { + t.Fatal("Build of unknown type returned nil error") + } +} + +func TestBuild_FactoryErrorIsWrapped(t *testing.T) { + registryGuard(t) + + custom := errors.New("custom failure") + Register("failing", func(_ map[string]any) (Provider, error) { + return nil, custom + }) + + _, err := Build("failing", nil) + if !errors.Is(err, custom) { + t.Fatalf("err = %v, want wrapped custom failure", err) + } +} + +func TestRegisteredTypes_SortedAndDeduplicated(t *testing.T) { + registryGuard(t) + + Register("zeta", func(_ map[string]any) (Provider, error) { return &stubProvider{}, nil }) + Register("alpha", func(_ map[string]any) (Provider, error) { return &stubProvider{}, nil }) + Register("mu", func(_ map[string]any) (Provider, error) { return &stubProvider{}, nil }) + + got := RegisteredTypes() + want := []string{"alpha", "mu", "zeta"} + if !sort.StringsAreSorted(got) { + t.Errorf("RegisteredTypes not sorted: %v", got) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("RegisteredTypes = %v, want %v", got, want) + } +} + +func TestRegistry_ConcurrentReads(t *testing.T) { + registryGuard(t) + + Register("a", func(_ map[string]any) (Provider, error) { return &stubProvider{name: "a"}, nil }) + Register("b", func(_ map[string]any) (Provider, error) { return &stubProvider{name: "b"}, nil }) + + var wg sync.WaitGroup + for range 50 { + wg.Go(func() { + _, _ = Build("a", nil) + _, _ = Build("b", nil) + _ = RegisteredTypes() + }) + } + wg.Wait() +} + +// --- UnmarshalSettings --- + +type sampleSettings struct { + Issuer string `yaml:"issuer"` + Audience string `yaml:"audience"` + Nested map[string]any `yaml:"nested,omitempty"` +} + +func TestUnmarshalSettings(t *testing.T) { + in := map[string]any{ + "issuer": "https://example.com", + "audience": "api://forge", + "nested": map[string]any{ + "k": "v", + }, + } + + var out sampleSettings + if err := UnmarshalSettings(in, &out); err != nil { + t.Fatalf("UnmarshalSettings: %v", err) + } + if out.Issuer != "https://example.com" { + t.Errorf("Issuer = %q, want https://example.com", out.Issuer) + } + if out.Audience != "api://forge" { + t.Errorf("Audience = %q, want api://forge", out.Audience) + } + if out.Nested["k"] != "v" { + t.Errorf("Nested[k] = %v, want v", out.Nested["k"]) + } +} + +func TestUnmarshalSettings_NilOut(t *testing.T) { + if err := UnmarshalSettings(map[string]any{"x": 1}, nil); err == nil { + t.Error("UnmarshalSettings with nil out returned nil error") + } +} + +// --- helpers --- + +func providerNames(ps []Provider) []string { + out := make([]string, len(ps)) + for i, p := range ps { + out[i] = p.Name() + } + return out +} diff --git a/forge-core/auth/providers/httpverifier/provider.go b/forge-core/auth/providers/httpverifier/provider.go new file mode 100644 index 0000000..ec0cc35 --- /dev/null +++ b/forge-core/auth/providers/httpverifier/provider.go @@ -0,0 +1,183 @@ +// Package httpverifier implements the legacy external auth provider: +// POST a JSON envelope to a verifier URL and trust its response. +// +// This is the provider that the historical --auth-url / FORGE_AUTH_URL flag +// configures. The request/response shape is preserved byte-for-byte so +// existing custom verifier services keep working unchanged. +// +// Request: POST {URL} +// Content-Type: application/json +// { "token": "", "org_id": "" } +// +// Response: 200 OK +// { "valid": bool, "error": "...", "user_id": "...", +// "org_id": "...", "email": "...", "workspace_id": "..." } +// +// The HTTP verifier claims every token presented to it — it never returns +// ErrTokenNotForMe. When placed in a chain, it is typically the terminator. +package httpverifier + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this provider. +const ProviderName = "http_verifier" + +// DefaultTimeout is the per-request timeout used when Config.Timeout is unset. +const DefaultTimeout = 10 * time.Second + +// Config controls the http_verifier provider. +type Config struct { + // URL is the verifier endpoint. Required. + URL string `yaml:"url"` + + // DefaultOrg is the org_id sent to the verifier when no X-Org-ID + // header (or org-id / org_id variant) is present on the request. + DefaultOrg string `yaml:"default_org,omitempty"` + + // Timeout caps each verify call. Defaults to DefaultTimeout. + Timeout time.Duration `yaml:"timeout,omitempty"` + + // HTTPClient overrides the default client. Injectable for tests. + HTTPClient *http.Client `yaml:"-"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.URL == "" { + return fmt.Errorf("%w: url required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider against a remote HTTP verifier. +type Provider struct { + cfg Config + client *http.Client +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.Timeout == 0 { + cfg.Timeout = DefaultTimeout + } + client := cfg.HTTPClient + if client == nil { + client = &http.Client{Timeout: cfg.Timeout} + } + return &Provider{cfg: cfg, client: client}, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// verifyRequest is the JSON envelope POSTed to the verifier URL. +// Preserved byte-for-byte from the pre-refactor implementation. +type verifyRequest struct { + Token string `json:"token"` + OrgID string `json:"org_id"` +} + +// verifyResponse is the JSON envelope returned by the verifier URL. +// Preserved byte-for-byte from the pre-refactor implementation. +type verifyResponse struct { + Valid bool `json:"valid"` + Error string `json:"error,omitempty"` + UserID string `json:"user_id,omitempty"` + OrgID string `json:"org_id,omitempty"` + Email string `json:"email,omitempty"` + WorkspaceID string `json:"workspace_id,omitempty"` +} + +// Verify implements auth.Provider. It POSTs the bearer token (with the +// caller's org_id) to the configured verifier URL and translates the +// response into an Identity or one of the sentinel errors. +// +// Mapping: +// - HTTP 200 + valid:true → (Identity, nil) +// - HTTP 200 + valid:false → ErrTokenRejected +// - HTTP 401 → ErrTokenRejected +// - HTTP other / I/O error → ErrInvalidToken (wrapped cause) +// +// This provider does not return ErrTokenNotForMe. +func (p *Provider) Verify(ctx context.Context, token string, headers auth.Headers) (*auth.Identity, error) { + orgID := resolveOrgID(headers, p.cfg.DefaultOrg) + + body, err := json.Marshal(verifyRequest{Token: token, OrgID: orgID}) + if err != nil { + return nil, fmt.Errorf("%w: marshal request: %w", auth.ErrInvalidToken, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.URL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("%w: build request: %w", auth.ErrInvalidToken, err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: call verifier: %w", auth.ErrInvalidToken, err) + } + defer func() { _ = resp.Body.Close() }() + + switch resp.StatusCode { + case http.StatusOK: + // fall through to decode body + case http.StatusUnauthorized: + return nil, auth.ErrTokenRejected + default: + // Drain a bounded amount of body for diagnostics, then fail. + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return nil, fmt.Errorf("%w: verifier returned status %d", auth.ErrInvalidToken, resp.StatusCode) + } + + var result verifyResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("%w: decode response: %w", auth.ErrInvalidToken, err) + } + if !result.Valid { + return nil, auth.ErrTokenRejected + } + + return &auth.Identity{ + UserID: result.UserID, + Email: result.Email, + OrgID: result.OrgID, + WorkspaceID: result.WorkspaceID, + Source: ProviderName, + }, nil +} + +// resolveOrgID picks the org_id sent to the verifier with this precedence: +// X-Org-ID header > "org-id" > "org_id" > config default. Matches the +// pre-refactor behavior in middleware.go. +func resolveOrgID(headers auth.Headers, fallback string) string { + for _, key := range []string{"X-Org-ID", "org-id", "org_id"} { + if v := headers.Get(key); v != "" { + return v + } + } + return fallback +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/httpverifier/provider_test.go b/forge-core/auth/providers/httpverifier/provider_test.go new file mode 100644 index 0000000..fd70703 --- /dev/null +++ b/forge-core/auth/providers/httpverifier/provider_test.go @@ -0,0 +1,263 @@ +package httpverifier_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/httpverifier" +) + +// fakeVerifier returns a test server that implements the verifier contract. +// `handler` lets each test customize the response. +func fakeVerifier(t *testing.T, handler func(req map[string]any) (status int, body any)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("verifier got method %q, want POST", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("verifier got Content-Type %q, want application/json", ct) + } + body, _ := io.ReadAll(r.Body) + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("verifier got non-JSON body: %v", err) + } + status, respBody := handler(req) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if respBody != nil { + _ = json.NewEncoder(w).Encode(respBody) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func TestNew_ValidationErrors(t *testing.T) { + t.Run("missing url", func(t *testing.T) { + _, err := httpverifier.New(httpverifier.Config{}) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } + }) +} + +func TestVerify_HappyPath(t *testing.T) { + srv := fakeVerifier(t, func(req map[string]any) (int, any) { + if req["token"] != "good-token" { + t.Errorf("verifier got token %q, want good-token", req["token"]) + } + return http.StatusOK, map[string]any{ + "valid": true, + "user_id": "user-1", + "org_id": "org-1", + "email": "user@example.com", + "workspace_id": "ws-1", + } + }) + + p, err := httpverifier.New(httpverifier.Config{URL: srv.URL}) + if err != nil { + t.Fatalf("New: %v", err) + } + + id, err := p.Verify(context.Background(), "good-token", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } + if id.UserID != "user-1" || id.Email != "user@example.com" || id.OrgID != "org-1" || id.WorkspaceID != "ws-1" { + t.Errorf("identity = %+v, fields missing", id) + } + if id.Source != httpverifier.ProviderName { + t.Errorf("identity.Source = %q, want %q", id.Source, httpverifier.ProviderName) + } +} + +func TestVerify_ValidFalse_ReturnsRejected(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusOK, map[string]any{"valid": false, "error": "bad token"} + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "bad-token", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected", err) + } +} + +func TestVerify_401_ReturnsRejected(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusUnauthorized, nil + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected", err) + } +} + +func TestVerify_500_ReturnsInvalidToken(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusInternalServerError, map[string]any{"error": "boom"} + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken", err) + } +} + +func TestVerify_NetworkError_ReturnsInvalidToken(t *testing.T) { + p, _ := httpverifier.New(httpverifier.Config{ + URL: "http://127.0.0.1:0/never", // port 0 is invalid, will fail to connect + Timeout: 100 * time.Millisecond, + }) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken", err) + } +} + +func TestVerify_OrgID_Precedence(t *testing.T) { + tests := []struct { + name string + headers auth.Headers + want string + }{ + { + name: "header X-Org-ID wins", + headers: auth.Headers{"X-Org-ID": "header-org", "org-id": "lower-org"}, + want: "header-org", + }, + { + name: "lowercase org-id when X-Org-ID absent", + headers: auth.Headers{"org-id": "lower-org"}, + want: "lower-org", + }, + { + name: "snake_case org_id when others absent", + headers: auth.Headers{"org_id": "snake-org"}, + want: "snake-org", + }, + { + name: "config default when no headers", + headers: auth.Headers{}, + want: "default-org", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotOrgID atomic.Value + srv := fakeVerifier(t, func(req map[string]any) (int, any) { + gotOrgID.Store(req["org_id"]) + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL, DefaultOrg: "default-org"}) + if _, err := p.Verify(context.Background(), "tok", tt.headers); err != nil { + t.Fatalf("Verify: %v", err) + } + if got := gotOrgID.Load(); got != tt.want { + t.Errorf("verifier received org_id = %v, want %q", got, tt.want) + } + }) + } +} + +func TestVerify_WireFormat_PreservedExactly(t *testing.T) { + // Black-box check: the JSON body sent must contain exactly "token" and + // "org_id" fields — no more, no less. This guards against accidental + // schema drift from the legacy --auth-url contract. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var got map[string]any + _ = json.Unmarshal(body, &got) + if len(got) != 2 { + t.Errorf("request body has %d fields, want 2: %v", len(got), got) + } + if _, ok := got["token"]; !ok { + t.Error("request body missing 'token' field") + } + if _, ok := got["org_id"]; !ok { + t.Error("request body missing 'org_id' field") + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"valid": true}) + })) + defer srv.Close() + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + if _, err := p.Verify(context.Background(), "tok", nil); err != nil { + t.Fatalf("Verify: %v", err) + } +} + +func TestVerify_RespectsContextCancellation(t *testing.T) { + // Server blocks until either the request context is cancelled OR the + // test ends. The serverDone channel guarantees srv.Close() doesn't + // hang waiting for the handler to return under -race scheduling. + serverDone := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-serverDone: + } + })) + defer func() { + close(serverDone) + srv.Close() + }() + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL, Timeout: 5 * time.Second}) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := p.Verify(ctx, "tok", nil) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("Verify returned nil error despite context timeout") + } + if elapsed > 1*time.Second { + t.Errorf("Verify took %v, expected <1s (context should have cancelled)", elapsed) + } +} + +func TestRegisteredViaFactory(t *testing.T) { + // Confirm the package registered itself with the auth package on import. + p, err := auth.Build("http_verifier", map[string]any{"url": "https://example.com/verify"}) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "http_verifier" { + t.Errorf("Name = %q, want http_verifier", p.Name()) + } +} + +func TestFactory_UnknownSettings_AreIgnored(t *testing.T) { + // Forward-compat: unknown YAML keys must not break construction. + _, err := auth.Build("http_verifier", map[string]any{ + "url": "https://example.com/verify", + "unknown_field": "future", + }) + if err != nil { + t.Fatalf("Build with unknown field: %v", err) + } +} diff --git a/forge-core/auth/providers/statictoken/provider.go b/forge-core/auth/providers/statictoken/provider.go new file mode 100644 index 0000000..ab096ad --- /dev/null +++ b/forge-core/auth/providers/statictoken/provider.go @@ -0,0 +1,122 @@ +// Package statictoken implements a Provider that matches the presented +// bearer token against a single expected value using constant-time +// comparison. +// +// Two intended uses: +// +// 1. Channel adapter loopback. The runner generates a random per-process +// token, configures a statictoken Provider with it (placed at the head +// of the chain), and shares the same token with Slack/Telegram adapters +// so their callbacks into the local A2A server authenticate cheaply +// without touching an upstream IdP. +// +// 2. Local dev / CI. A fixed token configured via env var lets developers +// hit a running agent with `curl -H "Authorization: Bearer $FORGE_DEV_TOKEN"` +// without setting up an IdP. +// +// Mismatch returns ErrTokenNotForMe (yield to next provider), not +// ErrTokenRejected — the loopback token is "not for me" from the +// perspective of an external client, and chain semantics require yielding +// in that case. +package statictoken + +import ( + "context" + "crypto/subtle" + "fmt" + "maps" + "os" + "slices" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this provider. +const ProviderName = "static_token" + +// Config controls the static_token provider. +type Config struct { + // Token is the expected bearer value (literal). Prefer TokenEnv for + // non-test use — putting a secret in YAML is a footgun. + Token string `yaml:"token,omitempty"` + + // TokenEnv names an environment variable that holds the token at + // construction time. Read once in New(); subsequent env changes do + // not affect the running provider. + TokenEnv string `yaml:"token_env,omitempty"` + + // Identity is returned on a successful match (a defensive copy is + // returned to callers). If Source is empty it defaults to "static_token". + Identity auth.Identity `yaml:"identity,omitempty"` +} + +// Validate returns ErrProviderNotConfigured when neither Token nor TokenEnv +// resolves to a non-empty value. +func (c Config) Validate() error { + if c.Token == "" && c.TokenEnv == "" { + return fmt.Errorf("%w: token or token_env required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider with a constant-time token compare. +type Provider struct { + expected []byte + identity auth.Identity +} + +// New constructs a Provider after resolving the token (TokenEnv takes +// precedence over Token literal). Returns ErrProviderNotConfigured if the +// resolved token is empty. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + token := cfg.Token + if cfg.TokenEnv != "" { + if v := os.Getenv(cfg.TokenEnv); v != "" { + token = v + } + } + if token == "" { + return nil, fmt.Errorf("%w: resolved token is empty", auth.ErrProviderNotConfigured) + } + id := cfg.Identity + if id.Source == "" { + id.Source = ProviderName + } + return &Provider{ + expected: []byte(token), + identity: id, + }, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. Constant-time compare against the +// configured token. Mismatch yields to the next provider via ErrTokenNotForMe. +func (p *Provider) Verify(_ context.Context, token string, _ auth.Headers) (*auth.Identity, error) { + if subtle.ConstantTimeCompare([]byte(token), p.expected) != 1 { + return nil, auth.ErrTokenNotForMe + } + // Return a defensive copy so callers can't mutate the configured identity. + id := p.identity + if id.Groups != nil { + id.Groups = slices.Clone(id.Groups) + } + if id.Claims != nil { + id.Claims = maps.Clone(id.Claims) + } + return &id, nil +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/statictoken/provider_test.go b/forge-core/auth/providers/statictoken/provider_test.go new file mode 100644 index 0000000..290db08 --- /dev/null +++ b/forge-core/auth/providers/statictoken/provider_test.go @@ -0,0 +1,186 @@ +package statictoken_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/statictoken" +) + +func TestNew_ValidationErrors(t *testing.T) { + t.Run("empty config", func(t *testing.T) { + _, err := statictoken.New(statictoken.Config{}) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } + }) + + t.Run("token_env points at unset variable", func(t *testing.T) { + t.Setenv("FORGE_TEST_TOKEN_DOES_NOT_EXIST", "") + _, err := statictoken.New(statictoken.Config{TokenEnv: "FORGE_TEST_TOKEN_DOES_NOT_EXIST"}) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } + }) +} + +func TestVerify_MatchReturnsIdentity(t *testing.T) { + p, err := statictoken.New(statictoken.Config{ + Token: "secret-token", + Identity: auth.Identity{ + UserID: "internal", + Email: "system@forge", + Source: "internal", + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + id, err := p.Verify(context.Background(), "secret-token", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } + if id.UserID != "internal" || id.Email != "system@forge" { + t.Errorf("identity = %+v", id) + } + if id.Source != "internal" { + t.Errorf("identity.Source = %q, want %q", id.Source, "internal") + } +} + +func TestVerify_MatchDefaultSource(t *testing.T) { + p, _ := statictoken.New(statictoken.Config{Token: "secret-token"}) + id, err := p.Verify(context.Background(), "secret-token", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "static_token" { + t.Errorf("identity.Source = %q, want %q (default)", id.Source, "static_token") + } +} + +func TestVerify_MismatchYieldsToNext(t *testing.T) { + p, _ := statictoken.New(statictoken.Config{Token: "expected"}) + + _, err := p.Verify(context.Background(), "different", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Fatalf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestVerify_EmptyTokenYields(t *testing.T) { + p, _ := statictoken.New(statictoken.Config{Token: "expected"}) + + _, err := p.Verify(context.Background(), "", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Fatalf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestTokenEnv_PrecedenceOverLiteral(t *testing.T) { + t.Setenv("FORGE_TEST_OVERRIDE_TOKEN", "from-env") + + p, err := statictoken.New(statictoken.Config{ + Token: "from-literal", + TokenEnv: "FORGE_TEST_OVERRIDE_TOKEN", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Env value should win. + if _, err := p.Verify(context.Background(), "from-env", nil); err != nil { + t.Errorf("env token rejected: %v", err) + } + if _, err := p.Verify(context.Background(), "from-literal", nil); !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("literal token accepted (should be overridden by env): %v", err) + } +} + +func TestTokenEnv_FallsBackToLiteralWhenEnvEmpty(t *testing.T) { + t.Setenv("FORGE_TEST_UNSET_TOKEN", "") + + p, err := statictoken.New(statictoken.Config{ + Token: "from-literal", + TokenEnv: "FORGE_TEST_UNSET_TOKEN", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + if _, err := p.Verify(context.Background(), "from-literal", nil); err != nil { + t.Errorf("literal token rejected when env was empty: %v", err) + } +} + +func TestIdentityIsDefensivelyCopied(t *testing.T) { + cfg := statictoken.Config{ + Token: "tok", + Identity: auth.Identity{ + UserID: "u", + Groups: []string{"a", "b"}, + Claims: map[string]any{"role": "admin"}, + }, + } + p, _ := statictoken.New(cfg) + + id1, _ := p.Verify(context.Background(), "tok", nil) + // Mutate the returned identity. + id1.Groups[0] = "MUTATED" + id1.Claims["role"] = "MUTATED" + + id2, _ := p.Verify(context.Background(), "tok", nil) + if id2.Groups[0] != "a" { + t.Errorf("Groups[0] = %q after mutation, want %q (defensive copy failed)", id2.Groups[0], "a") + } + if id2.Claims["role"] != "admin" { + t.Errorf("Claims[role] = %v after mutation, want admin (defensive copy failed)", id2.Claims["role"]) + } +} + +func TestConcurrentVerify_RaceSafe(t *testing.T) { + // Race-safety smoke. Run with `go test -race`. + p, _ := statictoken.New(statictoken.Config{Token: "tok"}) + + var wg sync.WaitGroup + for range 100 { + wg.Go(func() { + _, _ = p.Verify(context.Background(), "tok", nil) + _, _ = p.Verify(context.Background(), "wrong", nil) + }) + } + wg.Wait() +} + +func TestRegisteredViaFactory(t *testing.T) { + p, err := auth.Build("static_token", map[string]any{ + "token": "abc", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "static_token" { + t.Errorf("Name = %q, want static_token", p.Name()) + } +} + +func TestFactory_UsesTokenEnv(t *testing.T) { + t.Setenv("FORGE_TEST_FACTORY_TOKEN", "via-env") + + p, err := auth.Build("static_token", map[string]any{ + "token_env": "FORGE_TEST_FACTORY_TOKEN", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if _, err := p.Verify(context.Background(), "via-env", nil); err != nil { + t.Errorf("env-resolved token rejected: %v", err) + } +} diff --git a/forge-core/auth/registry.go b/forge-core/auth/registry.go new file mode 100644 index 0000000..4e6ff9d --- /dev/null +++ b/forge-core/auth/registry.go @@ -0,0 +1,81 @@ +package auth + +import ( + "fmt" + "sort" + "sync" +) + +// Factory constructs a Provider from a freeform settings map (typically +// the `settings` block of an `auth.providers[]` entry in forge.yaml). +// +// A Factory should: +// - Unmarshal settings into its typed Config via UnmarshalSettings. +// - Call Config.Validate() and return clear errors. +// - Return ErrProviderNotConfigured for missing required fields. +// +// A Factory must not perform network I/O — Verify() does that lazily. +type Factory func(settings map[string]any) (Provider, error) + +var ( + registryMu sync.RWMutex + registry = map[string]Factory{} +) + +// Register adds a provider factory under the given type name. Intended to be +// called from package init() in each provider subpackage. Panics on duplicate +// registration — a duplicate is a programming error that must fail loud at +// startup, not at the first request. +func Register(typeName string, factory Factory) { + if typeName == "" { + panic("auth: Register called with empty typeName") + } + if factory == nil { + panic("auth: Register called with nil factory: " + typeName) + } + registryMu.Lock() + defer registryMu.Unlock() + if _, exists := registry[typeName]; exists { + panic("auth: provider already registered: " + typeName) + } + registry[typeName] = factory +} + +// Build constructs a Provider for the given type name using the registered +// factory. Returns an error if the type is not registered. +func Build(typeName string, settings map[string]any) (Provider, error) { + registryMu.RLock() + factory, ok := registry[typeName] + registryMu.RUnlock() + if !ok { + return nil, fmt.Errorf("auth: unknown provider type %q (registered: %v)", typeName, RegisteredTypes()) + } + p, err := factory(settings) + if err != nil { + return nil, fmt.Errorf("auth: build %q: %w", typeName, err) + } + return p, nil +} + +// RegisteredTypes returns a sorted, deduplicated slice of registered provider +// type names. Used by config validation and the wizard meta endpoint to expose +// the set of available provider types. +func RegisteredTypes() []string { + registryMu.RLock() + defer registryMu.RUnlock() + out := make([]string, 0, len(registry)) + for name := range registry { + out = append(out, name) + } + sort.Strings(out) + return out +} + +// resetRegistryForTest clears all registered providers. Test-only helper — +// not exported in non-test builds. (Kept package-private here; tests in the +// same package can call it directly.) +func resetRegistryForTest() { + registryMu.Lock() + defer registryMu.Unlock() + registry = map[string]Factory{} +} diff --git a/forge-core/auth/settings.go b/forge-core/auth/settings.go new file mode 100644 index 0000000..2c61b8d --- /dev/null +++ b/forge-core/auth/settings.go @@ -0,0 +1,27 @@ +package auth + +import ( + "fmt" + + "gopkg.in/yaml.v3" +) + +// UnmarshalSettings decodes a freeform settings map into a typed Config +// struct, honoring `yaml:"..."` tags on the destination. Implemented as a +// yaml.Marshal + yaml.Unmarshal roundtrip so each provider can keep its +// Config strongly typed while the public schema stays map[string]any. +// +// Pass a pointer to your Config struct as `out`. +func UnmarshalSettings(in map[string]any, out any) error { + if out == nil { + return fmt.Errorf("auth: UnmarshalSettings: out is nil") + } + data, err := yaml.Marshal(in) + if err != nil { + return fmt.Errorf("auth: settings marshal: %w", err) + } + if err := yaml.Unmarshal(data, out); err != nil { + return fmt.Errorf("auth: settings unmarshal: %w", err) + } + return nil +} diff --git a/go.work.sum b/go.work.sum index b5df3da..903f097 100644 --- a/go.work.sum +++ b/go.work.sum @@ -3,16 +3,12 @@ github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= From 39f4a6753c2f26aaec04c8be909cfe91bb938c3a Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 19:20:19 -0400 Subject: [PATCH 02/20] feat(auth): add OIDC provider with JWKS + claim mapping (PR2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements auth.Provider for any OIDC-compliant issuer (Auth0, Keycloak, Azure AD, Google Workspace, Ping, JumpCloud, Okta in OIDC-only mode). Security posture: - Reject alg=none and HMAC algorithms (HS256/384/512) outright - Whitelist asymmetric algs: RS256/384/512, PS256/384/512, ES256/384/512 - Algorithm derived from JWKS-declared alg, NOT token header — defends against algorithm-confusion attacks - iss exact match, aud contains check with optional azp == ClientID fallback, exp/nbf with configurable ClockSkew Performance: - Lazy discovery + JWKS load (no network on New) - Discovery cached forever (per OIDC spec, issuer is stable) - JWKS cache with sync.RWMutex hot-path; refresh on unknown kid - Stampede protection: refetch-grace window suppresses repeated JWKS fetches for the same bad kid Why: PR3 (forge.yaml auth: block) can now plug `type: oidc` directly into the auth.providers list. Phase 3 (Okta) composes oidc.Provider to add introspection + group sync without rewriting JWT verification. Tests (coverage 83.3% on oidc, 88.0% combined auth): - Happy paths: RS256, ECDSA P-256 - Security: alg=none reject, HMAC reject, alg-confusion reject, missing-kid reject, HMAC-in-JWKS reject - Standard claims: wrong iss, wrong aud, azp fallback, expired exp, not-yet-valid nbf, ClockSkew tolerance, missing exp, missing aud, multi-audience array - JWKS: lazy load, rotation triggers refresh, unknown kid after refresh, stampede prevention, malformed RSA key skipped, 500 response error - Discovery: issuer-mismatch rejected, JWKSURL override bypasses discovery - ClaimMap: defaults, custom names, X-Org-ID header override, groups as single-string tolerated - Concurrency: 50 parallel Verify calls under -race - Context: respects cancellation - Factory: registers as "oidc", forward-compat with unknown YAML keys No regressions: - All forge-core packages pass (28/28) - All forge-cli packages pass (14/14) - All forge-plugins packages pass (4/4) - Race detector clean across auth/... + cli/runtime - gofmt + go vet clean - Anti-pattern greps clean (no http.DefaultClient, no plain == on tokens, no token/secret in logs, no disallowed algs accepted) --- forge-core/auth/providers/oidc/claims.go | 127 +++ forge-core/auth/providers/oidc/discovery.go | 99 +++ forge-core/auth/providers/oidc/jwks.go | 340 ++++++++ forge-core/auth/providers/oidc/provider.go | 319 ++++++++ .../auth/providers/oidc/provider_test.go | 772 ++++++++++++++++++ .../providers/oidc/testing_helper_test.go | 231 ++++++ forge-core/go.mod | 1 + forge-core/go.sum | 2 + 8 files changed, 1891 insertions(+) create mode 100644 forge-core/auth/providers/oidc/claims.go create mode 100644 forge-core/auth/providers/oidc/discovery.go create mode 100644 forge-core/auth/providers/oidc/jwks.go create mode 100644 forge-core/auth/providers/oidc/provider.go create mode 100644 forge-core/auth/providers/oidc/provider_test.go create mode 100644 forge-core/auth/providers/oidc/testing_helper_test.go diff --git a/forge-core/auth/providers/oidc/claims.go b/forge-core/auth/providers/oidc/claims.go new file mode 100644 index 0000000..ef89c0e --- /dev/null +++ b/forge-core/auth/providers/oidc/claims.go @@ -0,0 +1,127 @@ +package oidc + +import ( + "maps" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// ClaimMap configures which JWT claim names are read into each Identity +// field. Empty fields fall back to OIDC standard claim names. +type ClaimMap struct { + UserID string `yaml:"user_id,omitempty"` + Email string `yaml:"email,omitempty"` + OrgID string `yaml:"org_id,omitempty"` + WorkspaceID string `yaml:"workspace_id,omitempty"` + Groups string `yaml:"groups,omitempty"` +} + +// defaults fills empty ClaimMap fields with OIDC-standard claim names. +// Returns a copy — the input is not mutated. +func (m ClaimMap) defaults() ClaimMap { + if m.UserID == "" { + m.UserID = "sub" + } + if m.Email == "" { + m.Email = "email" + } + if m.OrgID == "" { + m.OrgID = "org_id" + } + if m.WorkspaceID == "" { + m.WorkspaceID = "workspace_id" + } + if m.Groups == "" { + m.Groups = "groups" + } + return m +} + +// mapClaims builds an Identity from JWT claims using the configured +// ClaimMap. Header overrides (X-Org-ID and variants) take precedence over +// claim-derived OrgID — this preserves the per-request tenant routing +// behavior the http_verifier provider had. +// +// The raw claim map is also copied into Identity.Claims so that downstream +// authorization layers (Phase 4+) can read provider-specific claims +// without going back through the JWT. +func mapClaims(m ClaimMap, claims jwt.MapClaims, headers auth.Headers) *auth.Identity { + cm := m.defaults() + id := &auth.Identity{ + UserID: stringClaim(claims, cm.UserID), + Email: stringClaim(claims, cm.Email), + OrgID: stringClaim(claims, cm.OrgID), + WorkspaceID: stringClaim(claims, cm.WorkspaceID), + Groups: stringSliceClaim(claims, cm.Groups), + Claims: copyClaims(claims), + Source: ProviderName, + } + + // Header override: X-Org-ID (and lowercase variants) takes precedence + // over the claim-derived OrgID. The http_verifier provider has the + // same behavior for parity. + for _, key := range []string{"X-Org-ID", "org-id", "org_id"} { + if v := headers.Get(key); v != "" { + id.OrgID = v + break + } + } + + return id +} + +// stringClaim returns the named claim as a string, or "" if missing or +// not a string. +func stringClaim(claims jwt.MapClaims, name string) string { + v, ok := claims[name] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +// stringSliceClaim returns the named claim as a string slice. Tolerates +// the common shapes IdPs use: []string, []any of strings, or a single +// string (which becomes a one-element slice). +func stringSliceClaim(claims jwt.MapClaims, name string) []string { + v, ok := claims[name] + if !ok { + return nil + } + switch val := v.(type) { + case []string: + out := make([]string, len(val)) + copy(out, val) + return out + case []any: + out := make([]string, 0, len(val)) + for _, item := range val { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return out + case string: + if val == "" { + return nil + } + return []string{val} + default: + return nil + } +} + +// copyClaims returns a shallow copy of the claim map so the caller can +// expose it on Identity without sharing the underlying map with the +// parsed-token state. +func copyClaims(claims jwt.MapClaims) map[string]any { + if len(claims) == 0 { + return nil + } + out := make(map[string]any, len(claims)) + maps.Copy(out, claims) + return out +} diff --git a/forge-core/auth/providers/oidc/discovery.go b/forge-core/auth/providers/oidc/discovery.go new file mode 100644 index 0000000..0537fb1 --- /dev/null +++ b/forge-core/auth/providers/oidc/discovery.go @@ -0,0 +1,99 @@ +package oidc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" +) + +// providerMetadata is the subset of the OIDC discovery document we read. +// Per the OIDC spec, this document is stable for the lifetime of an issuer, +// so we cache it forever after the first successful load. +type providerMetadata struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` +} + +// discovery lazily fetches and caches the OIDC discovery document. +// +// Concurrency model: a sync.Once gates the initial load so concurrent +// requests don't race. Subsequent calls return the cached result without +// further locking once `done` is true. If the first load fails, every +// caller sees the same error — but the next call re-attempts (we don't +// want a transient startup failure to permanently brick the provider). +type discovery struct { + issuer string + client *http.Client + + mu sync.Mutex + loaded bool + meta providerMetadata +} + +func newDiscovery(issuer string, client *http.Client) *discovery { + return &discovery{ + issuer: strings.TrimRight(issuer, "/"), + client: client, + } +} + +// EnsureLoaded fetches the discovery doc on first call (or after a failed +// previous attempt). Safe for concurrent use; only one HTTP call is in +// flight per discovery instance at a time. +func (d *discovery) EnsureLoaded(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + if d.loaded { + return nil + } + meta, err := d.fetch(ctx) + if err != nil { + return err + } + // Per OIDC spec, the discovery doc's `issuer` must match the issuer + // the client used to fetch it. Catches misconfiguration where someone + // points at a discovery doc for a different tenant. + if meta.Issuer != d.issuer { + return fmt.Errorf("oidc: discovery issuer %q does not match configured issuer %q", meta.Issuer, d.issuer) + } + if meta.JWKSURI == "" { + return fmt.Errorf("oidc: discovery doc for %q has empty jwks_uri", d.issuer) + } + d.meta = meta + d.loaded = true + return nil +} + +// JWKSURI returns the JWKS endpoint URL from the loaded discovery doc. +// Callers must call EnsureLoaded first. +func (d *discovery) JWKSURI() string { + d.mu.Lock() + defer d.mu.Unlock() + return d.meta.JWKSURI +} + +func (d *discovery) fetch(ctx context.Context) (providerMetadata, error) { + url := d.issuer + "/.well-known/openid-configuration" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return providerMetadata{}, fmt.Errorf("oidc: build discovery request: %w", err) + } + resp, err := d.client.Do(req) + if err != nil { + return providerMetadata{}, fmt.Errorf("oidc: discovery fetch: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return providerMetadata{}, fmt.Errorf("oidc: discovery returned status %d", resp.StatusCode) + } + var meta providerMetadata + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return providerMetadata{}, fmt.Errorf("oidc: discovery decode: %w", err) + } + return meta, nil +} diff --git a/forge-core/auth/providers/oidc/jwks.go b/forge-core/auth/providers/oidc/jwks.go new file mode 100644 index 0000000..e12d075 --- /dev/null +++ b/forge-core/auth/providers/oidc/jwks.go @@ -0,0 +1,340 @@ +package oidc + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "sync" + "time" +) + +// jwk is one entry of a JWKS document. We only model the fields we read. +// +// Reference: RFC 7517 (JWK) and RFC 7518 (JWA). +type jwk struct { + Kty string `json:"kty"` // key type ("RSA" or "EC") + Use string `json:"use,omitempty"` // "sig" expected + Alg string `json:"alg,omitempty"` // algorithm hint (advisory) + Kid string `json:"kid,omitempty"` // key ID + + // RSA fields + N string `json:"n,omitempty"` + E string `json:"e,omitempty"` + + // EC fields + Crv string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` +} + +// parsedKey is a JWK after decoding the public key material plus the +// algorithm advertised by the JWKS. The provider trusts the JWKS-declared +// algorithm — NOT the algorithm asserted in the token header. +type parsedKey struct { + Public any // *rsa.PublicKey or *ecdsa.PublicKey + Alg string // RS256, ES256, etc. +} + +// jwksCache is a TTL-bounded cache of parsed JWK keys keyed by `kid`. +// +// Refresh strategy: +// - First lookup ever: fetch JWKS. +// - kid not in cache and last fetch is older than the small "refetch +// grace window" (avoids stampedes on bad input): fetch JWKS again. +// - kid still not in cache after refresh: return ErrKeyNotFound. +// - kid in cache: return immediately. +// +// The cache does not eagerly expire entries by TTL — keys are durable; we +// refresh on miss. This matches how real-world identity providers rotate +// JWKS (add new key, then later remove old key — the new kid triggers a +// refresh that picks up both). +type jwksCache struct { + jwksURL func(ctx context.Context) (string, error) + httpClient *http.Client + ttl time.Duration + // refetchGrace prevents a malformed kid from triggering one JWKS fetch + // per request. Within this window, repeated unknown-kid lookups skip + // the refetch. + refetchGrace time.Duration + + mu sync.RWMutex + keys map[string]parsedKey + lastFetch time.Time + // lastMissRefresh is the last time we refreshed *because of* an + // unknown kid and the kid was still not present after refresh. + // Used to suppress stampedes of refetches for the same bad kid. + lastMissRefresh time.Time + lastErr error // last fetch error, for diagnostics +} + +const ( + defaultJWKSTTL = 1 * time.Hour + minJWKSTTL = 5 * time.Minute + defaultRefetchGrace = 30 * time.Second +) + +func newJWKSCache(jwksURL func(ctx context.Context) (string, error), client *http.Client, ttl time.Duration) *jwksCache { + if ttl == 0 { + ttl = defaultJWKSTTL + } + if ttl < minJWKSTTL { + ttl = minJWKSTTL + } + return &jwksCache{ + jwksURL: jwksURL, + httpClient: client, + ttl: ttl, + refetchGrace: defaultRefetchGrace, + keys: map[string]parsedKey{}, + } +} + +// ErrKeyNotFound is returned when a `kid` is not in the cache and a +// refresh did not produce it. +var ErrKeyNotFound = fmt.Errorf("oidc: signing key not found") + +// Get returns the parsed key for the given kid, refreshing the JWKS cache +// if necessary. Safe for concurrent use. +// +// Refresh semantics: +// - kid in cache → return immediately. +// - kid missing, never refreshed for a miss before → refresh once. +// - kid missing AND we already refreshed for a miss recently (within +// refetchGrace) → deny without re-fetching, to prevent a stream of +// bad kids from hammering the JWKS endpoint. +func (c *jwksCache) Get(ctx context.Context, kid string) (parsedKey, error) { + // Fast path: lock-free read. + c.mu.RLock() + if key, ok := c.keys[kid]; ok { + c.mu.RUnlock() + return key, nil + } + cooldownActive := !c.lastMissRefresh.IsZero() && time.Since(c.lastMissRefresh) < c.refetchGrace + c.mu.RUnlock() + + if cooldownActive { + return parsedKey{}, ErrKeyNotFound + } + + // Miss — refresh. Take the write lock first so concurrent misses + // coalesce into one fetch. + c.mu.Lock() + defer c.mu.Unlock() + + // Re-check after acquiring the write lock — another goroutine may + // have just refreshed and populated this kid. + if key, ok := c.keys[kid]; ok { + return key, nil + } + if !c.lastMissRefresh.IsZero() && time.Since(c.lastMissRefresh) < c.refetchGrace { + return parsedKey{}, ErrKeyNotFound + } + + if err := c.refreshLocked(ctx); err != nil { + return parsedKey{}, err + } + if key, ok := c.keys[kid]; ok { + return key, nil + } + // Refresh happened but kid is still missing — remember that so we + // don't hammer JWKS on the next unknown-kid request. + c.lastMissRefresh = time.Now() + return parsedKey{}, ErrKeyNotFound +} + +// refreshLocked fetches the JWKS endpoint and replaces the cache. Must be +// called with c.mu held for writing. +func (c *jwksCache) refreshLocked(ctx context.Context) error { + url, err := c.jwksURL(ctx) + if err != nil { + c.lastErr = err + c.lastFetch = time.Now() + return fmt.Errorf("oidc: resolve jwks_uri: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + c.lastErr = err + c.lastFetch = time.Now() + return fmt.Errorf("oidc: build jwks request: %w", err) + } + resp, err := c.httpClient.Do(req) + if err != nil { + c.lastErr = err + c.lastFetch = time.Now() + return fmt.Errorf("oidc: jwks fetch: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + err := fmt.Errorf("oidc: jwks returned status %d", resp.StatusCode) + c.lastErr = err + c.lastFetch = time.Now() + return err + } + + var body struct { + Keys []jwk `json:"keys"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + c.lastErr = err + c.lastFetch = time.Now() + return fmt.Errorf("oidc: jwks decode: %w", err) + } + + parsed := make(map[string]parsedKey, len(body.Keys)) + for _, k := range body.Keys { + // Only signature-use keys. Empty `use` is permitted (treated as + // "sig" since `key_ops` is not modeled). + if k.Use != "" && k.Use != "sig" { + continue + } + // Reject HMAC algorithms — using shared-secret HMAC with JWKS is + // a footgun: anyone who can read the JWKS can sign tokens. + if k.Alg != "" && !isAllowedAlg(k.Alg) { + continue + } + pk, alg, err := parseJWK(k) + if err != nil { + // Skip unparseable keys rather than failing the whole load — + // real-world JWKS may include keys for algorithms we don't + // support (e.g., RSA-OAEP for encryption). + continue + } + if k.Kid == "" { + // Without kid we can't look up; skip. Tokens without kid are + // also rejected at the provider level. + continue + } + parsed[k.Kid] = parsedKey{Public: pk, Alg: alg} + } + + c.keys = parsed + c.lastFetch = time.Now() + c.lastErr = nil + return nil +} + +// parseJWK extracts the public key material and derives the signing +// algorithm from the JWK. The returned alg is the value the provider +// will require the token header to match. +func parseJWK(k jwk) (any, string, error) { + switch k.Kty { + case "RSA": + pk, err := parseRSAPublicKey(k) + if err != nil { + return nil, "", err + } + alg := k.Alg + if alg == "" { + alg = "RS256" // RFC 7518 default for RSA when alg is absent + } + return pk, alg, nil + case "EC": + pk, alg, err := parseECDSAPublicKey(k) + if err != nil { + return nil, "", err + } + return pk, alg, nil + default: + return nil, "", fmt.Errorf("unsupported kty %q", k.Kty) + } +} + +func parseRSAPublicKey(k jwk) (*rsa.PublicKey, error) { + nBytes, err := base64.RawURLEncoding.DecodeString(k.N) + if err != nil { + return nil, fmt.Errorf("decode RSA n: %w", err) + } + eBytes, err := base64.RawURLEncoding.DecodeString(k.E) + if err != nil { + return nil, fmt.Errorf("decode RSA e: %w", err) + } + // e is a big-endian variable-length integer. + var e int + switch len(eBytes) { + case 0: + return nil, fmt.Errorf("empty RSA exponent") + case 1: + e = int(eBytes[0]) + case 2: + e = int(binary.BigEndian.Uint16(eBytes)) + case 3: + padded := make([]byte, 4) + copy(padded[1:], eBytes) + e = int(binary.BigEndian.Uint32(padded)) + case 4: + e = int(binary.BigEndian.Uint32(eBytes)) + default: + return nil, fmt.Errorf("RSA exponent too large: %d bytes", len(eBytes)) + } + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(nBytes), + E: e, + }, nil +} + +func parseECDSAPublicKey(k jwk) (*ecdsa.PublicKey, string, error) { + var ( + curve elliptic.Curve + alg string + ) + switch k.Crv { + case "P-256": + curve = elliptic.P256() + alg = "ES256" + case "P-384": + curve = elliptic.P384() + alg = "ES384" + case "P-521": + curve = elliptic.P521() + alg = "ES512" + default: + return nil, "", fmt.Errorf("unsupported curve %q", k.Crv) + } + xBytes, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + return nil, "", fmt.Errorf("decode EC x: %w", err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + return nil, "", fmt.Errorf("decode EC y: %w", err) + } + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(xBytes), + Y: new(big.Int).SetBytes(yBytes), + }, alg, nil +} + +// allowedAlgs lists the asymmetric signing algorithms accepted by the +// provider. Symmetric (HMAC) and "none" are NEVER accepted, even if a +// JWKS advertises them. +var allowedAlgs = map[string]struct{}{ + "RS256": {}, "RS384": {}, "RS512": {}, + "PS256": {}, "PS384": {}, "PS512": {}, + "ES256": {}, "ES384": {}, "ES512": {}, +} + +func isAllowedAlg(alg string) bool { + _, ok := allowedAlgs[alg] + return ok +} + +// allowedAlgNames returns the algorithm names in a slice — for passing to +// jwt.WithValidMethods. +func allowedAlgNames() []string { + out := make([]string, 0, len(allowedAlgs)) + for a := range allowedAlgs { + out = append(out, a) + } + return out +} diff --git a/forge-core/auth/providers/oidc/provider.go b/forge-core/auth/providers/oidc/provider.go new file mode 100644 index 0000000..924c7df --- /dev/null +++ b/forge-core/auth/providers/oidc/provider.go @@ -0,0 +1,319 @@ +// Package oidc implements an auth.Provider that verifies JWT bearer tokens +// against an OpenID Connect issuer. +// +// It works with any compliant OIDC provider — Auth0, Keycloak, Azure AD, +// Google Workspace, Ping, JumpCloud, and (in OIDC-only mode) Okta. +// +// Behavior at a glance: +// +// - On first verify, the discovery document is fetched once from +// {issuer}/.well-known/openid-configuration and cached forever (the +// issuer is stable per OIDC spec). +// +// - JWKS is fetched lazily and cached. On unknown-kid, the JWKS is +// refetched once before declaring the key unknown. +// +// - The signing algorithm is taken from the JWKS entry, NOT from the +// token header. This defends against algorithm-confusion attacks +// (e.g., a token claiming `alg: HS256` against an RSA JWKS). +// +// - alg=none and HMAC algorithms (HS256/384/512) are never accepted. +// +// - Standard claims are validated: iss exact match, aud contains +// configured Audience (with optional azp == ClientID fallback), +// exp/nbf within ClockSkew leeway. +// +// - Claim → Identity mapping is configurable via ClaimMap. The X-Org-ID +// header overrides the claim-derived OrgID for per-request tenant +// routing. +package oidc + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this +// provider in the auth registry. +const ProviderName = "oidc" + +// Default operational tunables. +const ( + DefaultClockSkew = 30 * time.Second + DefaultHTTPTimeout = 10 * time.Second +) + +// Config controls the OIDC provider. +type Config struct { + // Issuer is the full OIDC issuer URL (no trailing slash). Required. + // The token's `iss` claim must match this value exactly. + Issuer string `yaml:"issuer"` + + // Audience is the expected `aud` claim value. Required. If the token + // has multiple audiences, the configured Audience must be one of them. + Audience string `yaml:"audience"` + + // ClientID is an optional secondary audience check: if set, a token + // whose `aud` does not contain Audience is still accepted when its + // `azp` (authorized party) claim equals ClientID. + ClientID string `yaml:"client_id,omitempty"` + + // JWKSURL overrides the JWKS endpoint discovered via the OIDC + // discovery document. Most users leave this empty. + JWKSURL string `yaml:"jwks_url,omitempty"` + + // JWKSCacheTTL caps the maximum age of cached JWKS keys. Defaults to + // 1 hour. Values below 5 minutes are silently clamped up to avoid + // hammering the IdP. + JWKSCacheTTL time.Duration `yaml:"jwks_cache_ttl,omitempty"` + + // ClockSkew is the leeway applied to `exp` and `nbf` validation. + // Defaults to 30 seconds. + ClockSkew time.Duration `yaml:"clock_skew,omitempty"` + + // ClaimMap configures which JWT claim names map to Identity fields. + // Empty fields use OIDC defaults (sub, email, org_id, …). + ClaimMap ClaimMap `yaml:"claim_map,omitempty"` + + // HTTPClient overrides the default client. Injectable for tests. + HTTPClient *http.Client `yaml:"-"` +} + +// Validate returns ErrProviderNotConfigured when required fields are +// missing. +func (c Config) Validate() error { + if c.Issuer == "" { + return fmt.Errorf("%w: issuer required", auth.ErrProviderNotConfigured) + } + if c.Audience == "" { + return fmt.Errorf("%w: audience required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider for OIDC issuers. +type Provider struct { + cfg Config + client *http.Client + discovery *discovery + jwks *jwksCache + parser *jwt.Parser + clockSkew time.Duration +} + +// New constructs a Provider after validating cfg. No network I/O happens +// here — discovery and JWKS are fetched lazily on first Verify. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.ClockSkew == 0 { + cfg.ClockSkew = DefaultClockSkew + } + client := cfg.HTTPClient + if client == nil { + client = &http.Client{Timeout: DefaultHTTPTimeout} + } + + p := &Provider{ + cfg: cfg, + client: client, + clockSkew: cfg.ClockSkew, + } + + p.discovery = newDiscovery(cfg.Issuer, client) + p.jwks = newJWKSCache(p.resolveJWKSURL, client, cfg.JWKSCacheTTL) + + // Configure the JWT parser once. + parserOpts := []jwt.ParserOption{ + jwt.WithValidMethods(allowedAlgNames()), + jwt.WithIssuer(cfg.Issuer), + jwt.WithLeeway(cfg.ClockSkew), + jwt.WithExpirationRequired(), + } + // Audience validation handled manually so we can implement the azp + // fallback below. (jwt.WithAudience rejects when aud lacks the + // expected value without considering azp.) + p.parser = jwt.NewParser(parserOpts...) + + return p, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// resolveJWKSURL returns the configured override JWKSURL if set, otherwise +// loads discovery and returns the discovered jwks_uri. +func (p *Provider) resolveJWKSURL(ctx context.Context) (string, error) { + if p.cfg.JWKSURL != "" { + return p.cfg.JWKSURL, nil + } + if err := p.discovery.EnsureLoaded(ctx); err != nil { + return "", err + } + return p.discovery.JWKSURI(), nil +} + +// Verify implements auth.Provider. +// +// The verification flow: +// 1. Parse token structurally (without signature verification yet). +// 2. Extract kid from header; reject tokens without kid. +// 3. Look up kid in JWKS cache (refreshing on miss). +// 4. Cross-check token's `alg` against JWKS-declared alg (defends against +// algorithm confusion). +// 5. Re-parse the token with the resolved key, validating iss/exp/nbf. +// 6. Validate audience (with azp fallback if ClientID is configured). +// 7. Map claims to Identity, applying header overrides. +func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Headers) (*auth.Identity, error) { + // Quick structural check before doing crypto work. JWTs are three + // base64url segments separated by dots. + if !looksLikeJWT(tokenStr) { + return nil, auth.ErrTokenNotForMe + } + + // First pass: parse the header to extract kid + alg without verifying + // the signature. + parsedHeader, _, err := jwt.NewParser().ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("%w: parse header: %w", auth.ErrInvalidToken, err) + } + kid, _ := parsedHeader.Header["kid"].(string) + if kid == "" { + return nil, fmt.Errorf("%w: token header missing kid", auth.ErrInvalidToken) + } + headerAlg, _ := parsedHeader.Header["alg"].(string) + if !isAllowedAlg(headerAlg) { + // Catches alg=none, HMAC, etc. before any further work. + return nil, fmt.Errorf("%w: disallowed alg %q", auth.ErrInvalidToken, headerAlg) + } + + // Look up the key. + key, err := p.jwks.Get(ctx, kid) + if err != nil { + if errors.Is(err, ErrKeyNotFound) { + return nil, fmt.Errorf("%w: kid %q not found in JWKS", auth.ErrInvalidToken, kid) + } + return nil, err // transport / decode error — fail-closed via chain + } + + // Cross-check: the JWKS-declared alg must match the token's alg. + // This is the algorithm-confusion defense. + if key.Alg != headerAlg { + return nil, fmt.Errorf("%w: token alg %q does not match JWKS alg %q for kid %q", + auth.ErrInvalidToken, headerAlg, key.Alg, kid) + } + + // Second pass: verify signature + standard claims (iss/exp/nbf). + claims := jwt.MapClaims{} + _, err = p.parser.ParseWithClaims(tokenStr, claims, func(_ *jwt.Token) (any, error) { + return key.Public, nil + }) + if err != nil { + return nil, classifyJWTError(err) + } + + // Audience validation (with azp fallback). + if err := p.checkAudience(claims); err != nil { + return nil, err + } + + return mapClaims(p.cfg.ClaimMap, claims, headers), nil +} + +// looksLikeJWT does a cheap structural check: exactly three non-empty +// dot-separated segments. Used to yield to the next provider on non-JWT +// tokens (e.g., opaque tokens that another chain entry would handle). +func looksLikeJWT(s string) bool { + if len(s) < 5 { + return false + } + segments := 1 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + segments++ + if segments > 3 { + return false + } + } + } + return segments == 3 +} + +// checkAudience validates that the token's `aud` claim contains the +// configured Audience. If ClientID is set and aud does not contain +// Audience, accepts when `azp` claim equals ClientID. +func (p *Provider) checkAudience(claims jwt.MapClaims) error { + auds := audienceSlice(claims) + if slices.Contains(auds, p.cfg.Audience) { + return nil + } + if p.cfg.ClientID != "" { + if azp, _ := claims["azp"].(string); azp == p.cfg.ClientID { + return nil + } + } + return fmt.Errorf("%w: audience %v does not match configured %q", auth.ErrTokenRejected, auds, p.cfg.Audience) +} + +// audienceSlice normalizes the `aud` claim into []string. RFC 7519 allows +// `aud` to be either a string or a string array. +func audienceSlice(claims jwt.MapClaims) []string { + switch v := claims["aud"].(type) { + case string: + return []string{v} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return out + case []string: + return v + default: + return nil + } +} + +// classifyJWTError maps a parsing/verification error from jwt/v5 to one of +// the auth-package sentinels. +func classifyJWTError(err error) error { + switch { + case errors.Is(err, jwt.ErrTokenExpired): + return fmt.Errorf("%w: token expired", auth.ErrTokenRejected) + case errors.Is(err, jwt.ErrTokenNotValidYet): + return fmt.Errorf("%w: token not yet valid", auth.ErrTokenRejected) + case errors.Is(err, jwt.ErrTokenInvalidIssuer): + return fmt.Errorf("%w: issuer mismatch", auth.ErrTokenRejected) + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + return fmt.Errorf("%w: signature invalid", auth.ErrInvalidToken) + case errors.Is(err, jwt.ErrTokenMalformed): + return fmt.Errorf("%w: token malformed", auth.ErrInvalidToken) + case errors.Is(err, jwt.ErrTokenUnverifiable): + return fmt.Errorf("%w: token unverifiable", auth.ErrInvalidToken) + case errors.Is(err, jwt.ErrTokenRequiredClaimMissing): + return fmt.Errorf("%w: required claim missing", auth.ErrInvalidToken) + default: + return fmt.Errorf("%w: %w", auth.ErrInvalidToken, err) + } +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/oidc/provider_test.go b/forge-core/auth/providers/oidc/provider_test.go new file mode 100644 index 0000000..82fe138 --- /dev/null +++ b/forge-core/auth/providers/oidc/provider_test.go @@ -0,0 +1,772 @@ +package oidc_test + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/oidc" +) + +const testAudience = "api://forge" + +func newProvider(t *testing.T, fi *fakeIssuer, mod func(*oidc.Config)) *oidc.Provider { + t.Helper() + cfg := oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + } + if mod != nil { + mod(&cfg) + } + p, err := oidc.New(cfg) + if err != nil { + t.Fatalf("oidc.New: %v", err) + } + return p +} + +// --- Construction / Validation --- + +func TestNew_Validation(t *testing.T) { + tests := []struct { + name string + cfg oidc.Config + wantErr error + }{ + { + name: "empty config", + cfg: oidc.Config{}, + wantErr: auth.ErrProviderNotConfigured, + }, + { + name: "missing audience", + cfg: oidc.Config{Issuer: "https://example.com"}, + wantErr: auth.ErrProviderNotConfigured, + }, + { + name: "missing issuer", + cfg: oidc.Config{Audience: "api://forge"}, + wantErr: auth.ErrProviderNotConfigured, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := oidc.New(tt.cfg) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("err = %v, want %v", err, tt.wantErr) + } + }) + } +} + +// --- Happy paths --- + +func TestVerify_RSA256_HappyPath(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["email"] = "alice@example.com" + claims["groups"] = []string{"engineers", "oncall"} + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } + if id.UserID != "user-123" { + t.Errorf("UserID = %q, want user-123", id.UserID) + } + if id.Email != "alice@example.com" { + t.Errorf("Email = %q, want alice@example.com", id.Email) + } + if len(id.Groups) != 2 || id.Groups[0] != "engineers" || id.Groups[1] != "oncall" { + t.Errorf("Groups = %v, want [engineers oncall]", id.Groups) + } + if id.Source != "oidc" { + t.Errorf("Source = %q, want oidc", id.Source) + } +} + +func TestVerify_ECDSA_HappyPath(t *testing.T) { + fi := newFakeIssuer(t) + fi.addECDSAKey("ec-1") + p := newProvider(t, fi, nil) + + tok := fi.SignWith("ec-1", fi.DefaultClaims(testAudience)) + id, err := p.Verify(context.Background(), tok, nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.UserID != "user-123" { + t.Errorf("UserID = %q, want user-123", id.UserID) + } +} + +// --- Algorithm-confusion / security --- + +func TestVerify_RejectsAlgNone(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + tok := SignUnsigned(fi.DefaultClaims(testAudience)) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (alg=none must be rejected)", err) + } +} + +func TestVerify_RejectsHMAC(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + tok := SignHMAC("key-1", fi.DefaultClaims(testAudience), []byte("any-secret")) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (HMAC must be rejected)", err) + } +} + +func TestVerify_AlgInHeaderMustMatchJWKS(t *testing.T) { + // Algorithm-confusion defense: if the token claims an alg that + // differs from the JWKS-declared alg for that kid, reject. + fi := newFakeIssuer(t) + // Replace the JWKS-declared alg for key-1 with RS512 while still + // signing the token with RS256 (the key's actual method). + fi.keys["key-1"].algInJWKS = "RS512" + + p := newProvider(t, fi, nil) + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (alg mismatch)", err) + } +} + +func TestVerify_RejectsTokenWithoutKid(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Sign manually, without kid in header. + claims := fi.DefaultClaims(testAudience) + priv := fi.keys["key-1"].priv + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signed, err := tok.SignedString(priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + _, err = p.Verify(context.Background(), signed, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (missing kid)", err) + } +} + +// --- iss / aud / azp / exp / nbf --- + +func TestVerify_RejectsWrongIssuer(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["iss"] = "https://different-issuer.example.com" + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (iss mismatch)", err) + } +} + +func TestVerify_RejectsWrongAudience(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims("api://different-service") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (aud mismatch)", err) + } +} + +func TestVerify_AcceptsAudienceInArray(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["aud"] = []string{"api://other", testAudience, "api://third"} + tok := fi.SignWith("key-1", claims) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("multi-audience token rejected: %v", err) + } +} + +func TestVerify_AcceptsAZPAsFallback(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClientID = "client-abc" + }) + + claims := fi.DefaultClaims("api://different") // aud doesn't match + claims["azp"] = "client-abc" // but azp does + tok := fi.SignWith("key-1", claims) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("azp-fallback token rejected: %v", err) + } +} + +func TestVerify_RejectsExpiredToken(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClockSkew = 1 * time.Second + }) + + claims := fi.DefaultClaims(testAudience) + claims["exp"] = time.Now().Add(-10 * time.Minute).Unix() + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (expired)", err) + } +} + +func TestVerify_ClockSkewToleratesNearExpiration(t *testing.T) { + fi := newFakeIssuer(t) + // Configure 60s leeway; token expired 30s ago — should be accepted. + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClockSkew = 60 * time.Second + }) + + claims := fi.DefaultClaims(testAudience) + claims["exp"] = time.Now().Add(-30 * time.Second).Unix() + tok := fi.SignWith("key-1", claims) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("token within ClockSkew rejected: %v", err) + } +} + +func TestVerify_RejectsNotYetValid(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClockSkew = 1 * time.Second + }) + + claims := fi.DefaultClaims(testAudience) + claims["nbf"] = time.Now().Add(10 * time.Minute).Unix() + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (nbf in future)", err) + } +} + +// --- JWKS cache / rotation --- + +func TestJWKS_LazyLoad(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Provider construction must not fetch JWKS (lazy). + if fi.jwksFetches != 0 { + t.Errorf("JWKS fetched %d times during New (must be 0)", fi.jwksFetches) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("first Verify: %v", err) + } + if fi.jwksFetches != 1 { + t.Errorf("JWKS fetched %d times after first Verify, want 1", fi.jwksFetches) + } + + // Second Verify must reuse the cache. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("second Verify: %v", err) + } + if fi.jwksFetches != 1 { + t.Errorf("JWKS fetched %d times after second Verify, want 1 (cache miss)", fi.jwksFetches) + } +} + +func TestJWKS_RotationTriggersRefresh(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Initial token with key-1. + tok1 := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok1, nil); err != nil { + t.Fatalf("initial Verify: %v", err) + } + + // IdP rotates: adds key-2, then a token signed with key-2 arrives. + fi.addRSAKey("key-2") + tok2 := fi.SignWith("key-2", fi.DefaultClaims(testAudience)) + + if _, err := p.Verify(context.Background(), tok2, nil); err != nil { + t.Fatalf("post-rotation Verify: %v", err) + } + // Provider should have refreshed JWKS once on the unknown kid. + if fi.jwksFetches != 2 { + t.Errorf("JWKS fetched %d times after rotation, want 2 (lazy + refresh)", fi.jwksFetches) + } +} + +func TestJWKS_UnknownKidAfterRefresh(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Sign a token whose header advertises a kid that doesn't exist on + // the issuer. We do this by signing with key-1 then editing the + // header — easier: build the signed-string manually. + claims := fi.DefaultClaims(testAudience) + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = "ghost-kid" + signed, err := tok.SignedString(fi.keys["key-1"].priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + + _, err = p.Verify(context.Background(), signed, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (kid not in JWKS)", err) + } +} + +func TestJWKS_RefetchGracePreventsStampede(t *testing.T) { + // Two requests with the same unknown kid in quick succession should + // only trigger ONE JWKS fetch — the grace window should suppress + // the second. + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Warm cache with a valid request first. + if _, err := p.Verify(context.Background(), fi.SignWith("key-1", fi.DefaultClaims(testAudience)), nil); err != nil { + t.Fatalf("warmup: %v", err) + } + startFetches := fi.jwksFetches + + // Two back-to-back unknown-kid requests. + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, fi.DefaultClaims(testAudience)) + tok.Header["kid"] = "ghost" + signed, _ := tok.SignedString(fi.keys["key-1"].priv) + + _, _ = p.Verify(context.Background(), signed, nil) + _, _ = p.Verify(context.Background(), signed, nil) + + addedFetches := fi.jwksFetches - startFetches + if addedFetches > 1 { + t.Errorf("JWKS fetched %d times for two unknown-kid requests, want ≤1 (stampede prevention)", addedFetches) + } +} + +// --- ClaimMap --- + +func TestClaimMap_Defaults(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["email"] = "x@y" + claims["org_id"] = "tenant-1" + claims["workspace_id"] = "ws-2" + claims["groups"] = []string{"a"} + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.UserID != "user-123" || id.Email != "x@y" || id.OrgID != "tenant-1" || id.WorkspaceID != "ws-2" { + t.Errorf("default ClaimMap mapping wrong: %+v", id) + } +} + +func TestClaimMap_CustomNames(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClaimMap = oidc.ClaimMap{ + UserID: "uid", + Email: "mail", + OrgID: "tenant", + Groups: "roles", + } + }) + + claims := fi.DefaultClaims(testAudience) + claims["uid"] = "custom-user" + claims["mail"] = "x@y" + claims["tenant"] = "tenant-99" + claims["roles"] = []string{"admin"} + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.UserID != "custom-user" || id.Email != "x@y" || id.OrgID != "tenant-99" || len(id.Groups) != 1 || id.Groups[0] != "admin" { + t.Errorf("custom ClaimMap mapping wrong: %+v", id) + } +} + +func TestClaimMap_HeaderOverridesOrgID(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["org_id"] = "claim-org" + + id, err := p.Verify( + context.Background(), + fi.SignWith("key-1", claims), + auth.Headers{"X-Org-ID": "header-org"}, + ) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.OrgID != "header-org" { + t.Errorf("OrgID = %q, want header-org (header should override claim)", id.OrgID) + } +} + +func TestClaimMap_GroupsToleratesSingleString(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["groups"] = "single-group" + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if len(id.Groups) != 1 || id.Groups[0] != "single-group" { + t.Errorf("Groups = %v, want [single-group]", id.Groups) + } +} + +// --- Non-JWT tokens yield --- + +func TestVerify_NonJWTYields(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Opaque-token shapes another provider might handle. + cases := []string{"", "opaque-token", "abc.def", "abc.def.ghi.jkl", "no-dots"} + for _, tok := range cases { + t.Run(tok, func(t *testing.T) { + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("token %q: err = %v, want ErrTokenNotForMe", tok, err) + } + }) + } +} + +// --- Discovery --- + +func TestDiscovery_IssuerMismatchFails(t *testing.T) { + // Build an issuer whose discovery doc returns a DIFFERENT issuer URL + // than the one we hit. Catches misconfiguration where someone copies + // an Okta discovery URL but configures the wrong issuer string. + mismatchSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{"issuer":"https://different.example.com","jwks_uri":"https://different.example.com/jwks"}`)) + })) + defer mismatchSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: mismatchSrv.URL, + Audience: testAudience, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Use a syntactically-valid JWT — header includes kid + alg so we + // reach the JWKS-fetch step where discovery is consulted. + // Header: {"alg":"RS256","kid":"k1"} Payload: {"iss":"x"} + dummyJWT := "eyJhbGciOiJSUzI1NiIsImtpZCI6ImsxIn0.eyJpc3MiOiJ4In0.AAAA" + _, err = p.Verify(context.Background(), dummyJWT, nil) + if err == nil { + t.Fatal("expected discovery mismatch error, got nil") + } + if !strings.Contains(err.Error(), "discovery issuer") { + t.Errorf("err = %v, want discovery-issuer-mismatch message", err) + } +} + +func TestDiscovery_JWKSURLOverrideSkipsDiscovery(t *testing.T) { + fi := newFakeIssuer(t) + jwksFetchesBefore := fi.jwksFetches + + // Configure JWKSURL explicitly — discovery should never be called. + // Implicitly verified by the fact that the fakeIssuer's discovery + // handler increments jwksFetches only on /jwks (not /.well-known/), + // so we check that signing still works. + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: fi.IssuerURL() + "/jwks", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("Verify: %v", err) + } + if fi.jwksFetches != jwksFetchesBefore+1 { + t.Errorf("JWKS fetches = %d, want %d", fi.jwksFetches, jwksFetchesBefore+1) + } +} + +// --- Concurrency / race safety --- + +func TestVerify_Concurrent(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + const N = 50 + errs := make(chan error, N) + for range N { + go func() { + _, err := p.Verify(context.Background(), tok, nil) + errs <- err + }() + } + for range N { + if err := <-errs; err != nil { + t.Fatalf("concurrent Verify: %v", err) + } + } +} + +// --- HTTP client / context --- + +func TestVerify_RespectsContextCancellation(t *testing.T) { + // Slow JWKS endpoint: handler blocks until the test ends. + done := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-done: + } + })) + defer func() { close(done); srv.Close() }() + + p, err := oidc.New(oidc.Config{ + Issuer: srv.URL, + Audience: testAudience, + JWKSURL: srv.URL + "/jwks", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + dummyJWT := "eyJhbGciOiJSUzI1NiIsImtpZCI6Inh4eCJ9.eyJpc3MiOiJ4In0.signature" + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + start := time.Now() + _, _ = p.Verify(ctx, dummyJWT, nil) + elapsed := time.Since(start) + + if elapsed > 2*time.Second { + t.Errorf("Verify took %v after ctx cancellation; expected <2s", elapsed) + } +} + +// --- Factory / registration --- + +func TestRegisteredViaFactory(t *testing.T) { + p, err := auth.Build("oidc", map[string]any{ + "issuer": "https://example.com", + "audience": "api://forge", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "oidc" { + t.Errorf("Name = %q, want oidc", p.Name()) + } +} + +func TestFactory_MissingIssuerErrors(t *testing.T) { + _, err := auth.Build("oidc", map[string]any{ + "audience": "api://forge", + }) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } +} + +func TestFactory_UnknownKeysAreIgnored(t *testing.T) { + // Forward compat: future YAML keys must not break construction. + _, err := auth.Build("oidc", map[string]any{ + "issuer": "https://example.com", + "audience": "api://forge", + "future_setting": "ignored", + "another_unknown": 42, + }) + if err != nil { + t.Fatalf("Build with unknown keys: %v", err) + } +} + +// --- Smoke: identity Claims is a copy --- + +func TestIdentity_ClaimsIsCopiedFromToken(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["custom"] = "value" + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Claims["custom"] != "value" { + t.Errorf("Claims[custom] = %v, want value", id.Claims["custom"]) + } + // Mutate the returned map — provider state must not be affected. + id.Claims["custom"] = "MUTATED" + + id2, _ := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if id2.Claims["custom"] != "value" { + t.Errorf("Claims map shared across Verify calls — got %v after mutation", id2.Claims["custom"]) + } +} + +// Make atomic referenced so the import isn't dropped if we later remove +// the concurrent-fetches counter. +var _ atomic.Int32 + +// --- Error-path coverage --- + +func TestJWKS_MalformedRSAKey_IsSkipped(t *testing.T) { + // JWKS contains a valid key AND a malformed one. The malformed one + // is silently skipped — load must succeed and the valid one usable. + fi := newFakeIssuer(t) + jwksSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{ + "keys": [ + {"kty":"RSA","kid":"broken","alg":"RS256","n":"!!!not-base64!!!","e":"AQAB"}, + {"kty":"unknown","kid":"weird"}, + {"kty":"RSA","kid":"no-kid-here","alg":"RS256","n":"","e":""} + ] + }`)) + })) + defer jwksSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: jwksSrv.URL, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + _, err = p.Verify(context.Background(), tok, nil) + // Should fail because the broken JWKS has no usable keys. + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (no valid keys in JWKS)", err) + } +} + +func TestJWKS_EndpointReturns500(t *testing.T) { + fi := newFakeIssuer(t) + jwksSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer jwksSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: jwksSrv.URL, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + _, err = p.Verify(context.Background(), tok, nil) + if err == nil { + t.Fatal("expected error when JWKS endpoint returns 500") + } +} + +func TestJWKS_HMACAlgInJWKS_IsRejected(t *testing.T) { + // If the JWKS advertises an HMAC alg, we must skip that key entirely. + fi := newFakeIssuer(t) + jwksSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{ + "keys": [ + {"kty":"oct","kid":"hmac-key","alg":"HS256","k":"some-secret"} + ] + }`)) + })) + defer jwksSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: jwksSrv.URL, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Build a token claiming HS256 with kid hmac-key. + tok := SignHMAC("hmac-key", fi.DefaultClaims(testAudience), []byte("some-secret")) + _, err = p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (HMAC in JWKS must be skipped)", err) + } +} + +func TestVerify_TokenWithoutAudClaim(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + delete(claims, "aud") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (missing aud)", err) + } +} + +func TestVerify_TokenWithoutExp(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + delete(claims, "exp") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (missing exp)", err) + } +} diff --git a/forge-core/auth/providers/oidc/testing_helper_test.go b/forge-core/auth/providers/oidc/testing_helper_test.go new file mode 100644 index 0000000..49cefbe --- /dev/null +++ b/forge-core/auth/providers/oidc/testing_helper_test.go @@ -0,0 +1,231 @@ +package oidc_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// fakeIssuer is a minimal in-memory OIDC issuer: serves the discovery +// document and JWKS, and helps tests sign tokens with the matching key. +type fakeIssuer struct { + t *testing.T + server *httptest.Server + // keys keyed by kid → signer + keys map[string]*signingKey + // keyOrder remembers insertion order so JWKS responses are stable. + keyOrder []string + // jwksFetches counts how many times /jwks was hit (race-safe via the + // server's serialized handler). + jwksFetches int +} + +type signingKey struct { + kid string + method jwt.SigningMethod + priv any // *rsa.PrivateKey or *ecdsa.PrivateKey + pub any // *rsa.PublicKey or *ecdsa.PublicKey + // alg is what the JWKS will declare; mismatch with `method` is what + // algorithm-confusion tests need. + algInJWKS string +} + +// newFakeIssuer builds a default RSA-keyed issuer with kid "key-1". +func newFakeIssuer(t *testing.T) *fakeIssuer { + t.Helper() + fi := &fakeIssuer{ + t: t, + keys: map[string]*signingKey{}, + } + fi.addRSAKey("key-1") + fi.server = httptest.NewServer(http.HandlerFunc(fi.serve)) + t.Cleanup(fi.server.Close) + return fi +} + +// IssuerURL is the canonical issuer URL (no trailing slash). +func (fi *fakeIssuer) IssuerURL() string { return fi.server.URL } + +// addRSAKey generates a fresh RSA key and registers it with the given kid. +func (fi *fakeIssuer) addRSAKey(kid string) *signingKey { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + fi.t.Fatalf("generate RSA key: %v", err) + } + k := &signingKey{ + kid: kid, + method: jwt.SigningMethodRS256, + priv: priv, + pub: &priv.PublicKey, + algInJWKS: "RS256", + } + fi.keys[kid] = k + fi.keyOrder = append(fi.keyOrder, kid) + return k +} + +// addECDSAKey generates a fresh P-256 ECDSA key. +func (fi *fakeIssuer) addECDSAKey(kid string) *signingKey { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + fi.t.Fatalf("generate ECDSA key: %v", err) + } + k := &signingKey{ + kid: kid, + method: jwt.SigningMethodES256, + priv: priv, + pub: &priv.PublicKey, + algInJWKS: "ES256", + } + fi.keys[kid] = k + fi.keyOrder = append(fi.keyOrder, kid) + return k +} + +// removeKey simulates an IdP rotating out an old key. +func (fi *fakeIssuer) removeKey(kid string) { + delete(fi.keys, kid) + for i, k := range fi.keyOrder { + if k == kid { + fi.keyOrder = append(fi.keyOrder[:i], fi.keyOrder[i+1:]...) + break + } + } +} + +// SignWith builds a signed token using the named key. Claims is the full +// claim map; helpers below build common claim sets. +func (fi *fakeIssuer) SignWith(kid string, claims jwt.MapClaims) string { + fi.t.Helper() + k, ok := fi.keys[kid] + if !ok { + fi.t.Fatalf("unknown signing key %q", kid) + } + tok := jwt.NewWithClaims(k.method, claims) + tok.Header["kid"] = kid + s, err := tok.SignedString(k.priv) + if err != nil { + fi.t.Fatalf("sign: %v", err) + } + return s +} + +// SignUnsigned returns a token with `alg: none` and no signature segment. +// jwt/v5 doesn't expose a clean way to construct these; we do it manually. +func SignUnsigned(claims jwt.MapClaims) string { + header := map[string]any{"alg": "none", "typ": "JWT", "kid": "none-kid"} + hb, _ := json.Marshal(header) + cb, _ := json.Marshal(claims) + enc := base64.RawURLEncoding.EncodeToString + return enc(hb) + "." + enc(cb) + "." +} + +// SignHMAC returns a token signed with HS256 — must be rejected by the +// provider since HMAC + JWKS is a footgun (anyone with the JWKS could +// produce one). +func SignHMAC(kid string, claims jwt.MapClaims, secret []byte) string { + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tok.Header["kid"] = kid + s, _ := tok.SignedString(secret) + return s +} + +// DefaultClaims returns a claim set that satisfies the default fakeIssuer +// + a single audience. +func (fi *fakeIssuer) DefaultClaims(audience string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": fi.IssuerURL(), + "aud": audience, + "sub": "user-123", + "iat": now.Unix(), + "exp": now.Add(15 * time.Minute).Unix(), + } +} + +// serve handles both the discovery endpoint and the JWKS endpoint. +func (fi *fakeIssuer) serve(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": fi.IssuerURL(), + "jwks_uri": fi.IssuerURL() + "/jwks", + }) + case "/jwks": + fi.jwksFetches++ + _ = json.NewEncoder(w).Encode(fi.jwksDoc()) + default: + http.NotFound(w, r) + } +} + +// jwksDoc constructs the JWKS document from the current key set. +func (fi *fakeIssuer) jwksDoc() map[string]any { + keys := make([]map[string]any, 0, len(fi.keyOrder)) + for _, kid := range fi.keyOrder { + k := fi.keys[kid] + switch pub := k.pub.(type) { + case *rsa.PublicKey: + keys = append(keys, map[string]any{ + "kty": "RSA", + "kid": kid, + "use": "sig", + "alg": k.algInJWKS, + "n": base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(intToBigEndian(pub.E)), + }) + case *ecdsa.PublicKey: + keys = append(keys, map[string]any{ + "kty": "EC", + "kid": kid, + "use": "sig", + "alg": k.algInJWKS, + "crv": ecCurveName(pub.Curve), + "x": base64.RawURLEncoding.EncodeToString(pub.X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(pub.Y.Bytes()), + }) + default: + panic(fmt.Sprintf("unsupported key type: %T", pub)) + } + } + return map[string]any{"keys": keys} +} + +func intToBigEndian(e int) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(e)) + // Strip leading zero bytes — RFC 7518 demands no leading zeros. + for len(buf) > 1 && buf[0] == 0 { + buf = buf[1:] + } + return buf +} + +func ecCurveName(c elliptic.Curve) string { + switch c { + case elliptic.P256(): + return "P-256" + case elliptic.P384(): + return "P-384" + case elliptic.P521(): + return "P-521" + default: + return "" + } +} + +// Compile-time assertion: tests reference math/big to ensure the import +// graph is exercised in case future refactors lose the dep. +var _ = big.NewInt diff --git a/forge-core/go.mod b/forge-core/go.mod index 1b2f92b..bae6d93 100644 --- a/forge-core/go.mod +++ b/forge-core/go.mod @@ -9,6 +9,7 @@ require ( ) require ( + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect golang.org/x/crypto v0.48.0 // indirect diff --git a/forge-core/go.sum b/forge-core/go.sum index f2cd70a..7e485de 100644 --- a/forge-core/go.sum +++ b/forge-core/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From 3f65871971d1941ad1f8bb7eb9d2cb02c6702210 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 19:34:43 -0400 Subject: [PATCH 03/20] feat(auth): wire forge.yaml auth: block + egress allowlist (PR3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the typed AuthConfig schema, validation, and the AuthDomains egress hook. The runner now builds the provider chain via auth.Registry.BuildChain(cfg.Auth.Providers) when forge.yaml has an auth: block, falling back to the legacy --auth-url path otherwise. Schema: auth: required: true providers: - type: oidc name: corporate-sso settings: issuer: https://login.example.com audience: api://forge Precedence: --no-auth → anonymous (unchanged) forge.yaml auth.providers → Registry.BuildChain --auth-url / FORGE_AUTH_URL → legacy http_verifier nothing → loopback-only chain When both forge.yaml auth: and --auth-url are configured, the YAML block wins and a warning is logged. Loopback static_token is always prepended for channel-adapter callbacks. Egress integration: security.AuthDomains(cfg.Auth) returns issuer/verifier hosts. The runner merges these into the egress allowlist BEFORE the enforcer is constructed — without this, an OIDC issuer configured in forge.yaml would be silently blocked at JWKS-fetch time. Validation: validate.ValidateAuthConfig is called from ValidateForgeConfig. Each provider type's required settings are checked at `forge validate` time (not just at runtime construction): - oidc: issuer + audience - http_verifier: url - static_token: token or token_env (literal token warns) - unknown type: error with list of known types Side-effect import: cmd-side runner.go imports providers/oidc as `_` so auth.Build("oidc", settings) succeeds via factory registration. Tests (88.6% validate, 90.8% security, 88% combined auth): - Validation: every provider type, missing required fields, unknown type, missing type, duplicate names warning, ValidateForgeConfig integration - AuthDomains: oidc issuer, http_verifier URL, JWKS-URL override, multi-provider dedup, static_token contributes nothing, malformed URLs skipped, port stripping, unknown provider returns nothing - Runner: buildChainFromConfig (empty, single provider, ordered chain with first-match-wins, invalid type, factory error), buildLegacyHTTPVerifierChain, buildUserAuthChain precedence rules (neither / only legacy / only YAML / both prefers YAML / oidc via registry) No regressions: - All forge-core packages pass (28/28) - All forge-cli packages pass (15/15) - All forge-plugins packages pass (4/4) - Race detector clean across auth/.../validate/security/runtime - gofmt + go vet clean - Anti-pattern greps clean (no DefaultClient, no token-in-logs, no plain == on tokens, oidc uses injected client only) --- forge-cli/runtime/auth_config_test.go | 274 +++++++++++++++++++++++ forge-cli/runtime/runner.go | 147 +++++++++--- forge-core/security/auth_domains.go | 89 ++++++++ forge-core/security/auth_domains_test.go | 136 +++++++++++ forge-core/types/config.go | 30 +++ forge-core/validate/auth.go | 93 ++++++++ forge-core/validate/auth_test.go | 218 ++++++++++++++++++ forge-core/validate/forge_config.go | 2 + 8 files changed, 955 insertions(+), 34 deletions(-) create mode 100644 forge-cli/runtime/auth_config_test.go create mode 100644 forge-core/security/auth_domains.go create mode 100644 forge-core/security/auth_domains_test.go create mode 100644 forge-core/validate/auth.go create mode 100644 forge-core/validate/auth_test.go diff --git a/forge-cli/runtime/auth_config_test.go b/forge-cli/runtime/auth_config_test.go new file mode 100644 index 0000000..607f951 --- /dev/null +++ b/forge-cli/runtime/auth_config_test.go @@ -0,0 +1,274 @@ +package runtime + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/types" +) + +// E2E coverage for PR3's forge.yaml `auth:` block → ChainProvider wiring. + +// fakeVerifierFor is a test verifier that accepts `goodToken` and rejects +// everything else. +func fakeVerifierFor(t *testing.T, goodToken string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + _ = json.Unmarshal(body, &req) + if req["token"] == goodToken { + _ = json.NewEncoder(w).Encode(map[string]any{"valid": true, "user_id": "u"}) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{"valid": false}) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestBuildChainFromConfig_Empty(t *testing.T) { + chain, err := buildChainFromConfig(types.AuthConfig{}) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if chain != nil { + t.Errorf("chain = %v, want nil for empty config", chain) + } +} + +func TestBuildChainFromConfig_SingleHTTPVerifier(t *testing.T) { + srv := fakeVerifierFor(t, "good") + chain, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srv.URL}}, + }, + }) + if err != nil { + t.Fatalf("buildChainFromConfig: %v", err) + } + id, err := chain.Verify(context.Background(), "good", nil) + if err != nil || id == nil { + t.Errorf("good token rejected: id=%v err=%v", id, err) + } + if _, err := chain.Verify(context.Background(), "bad", nil); !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("bad token err = %v, want ErrTokenRejected", err) + } +} + +func TestBuildChainFromConfig_StaticToken(t *testing.T) { + t.Setenv("FORGE_TEST_STATIC_TOKEN", "dev-secret") + + chain, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "static_token", Settings: map[string]any{"token_env": "FORGE_TEST_STATIC_TOKEN"}}, + }, + }) + if err != nil { + t.Fatalf("buildChainFromConfig: %v", err) + } + if _, err := chain.Verify(context.Background(), "dev-secret", nil); err != nil { + t.Errorf("dev-secret rejected: %v", err) + } +} + +func TestBuildChainFromConfig_OrderedChain(t *testing.T) { + // Verify first-match-wins ordering: a static_token at position 1 + // short-circuits an http_verifier at position 2. + verifierHit := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + verifierHit = true + _ = json.NewEncoder(w).Encode(map[string]any{"valid": true, "user_id": "external"}) + })) + defer srv.Close() + + t.Setenv("FORGE_TEST_LOOPBACK_TOKEN", "loopback") + + chain, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "static_token", Settings: map[string]any{"token_env": "FORGE_TEST_LOOPBACK_TOKEN"}}, + {Type: "http_verifier", Settings: map[string]any{"url": srv.URL}}, + }, + }) + if err != nil { + t.Fatalf("buildChainFromConfig: %v", err) + } + + // Loopback token matches first provider — verifier must NOT be called. + id, err := chain.Verify(context.Background(), "loopback", nil) + if err != nil { + t.Fatalf("loopback verify: %v", err) + } + if id.Source != "static_token" { + t.Errorf("identity.Source = %q, want static_token", id.Source) + } + if verifierHit { + t.Error("http_verifier was called despite static_token match (chain ordering broken)") + } + + // Different token falls through to verifier. + if _, err := chain.Verify(context.Background(), "other", nil); err != nil { + t.Fatalf("fallthrough verify: %v", err) + } + if !verifierHit { + t.Error("http_verifier was not called for non-loopback token (chain ordering broken)") + } +} + +func TestBuildChainFromConfig_InvalidProviderType(t *testing.T) { + _, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "ldap"}, // not registered + }, + }) + if err == nil { + t.Fatal("expected error for unknown provider type") + } +} + +func TestBuildChainFromConfig_FactoryError(t *testing.T) { + _, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc"}, // missing issuer + audience + }, + }) + if err == nil { + t.Fatal("expected error for oidc without issuer/audience") + } + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Errorf("err = %v, want wrapped ErrProviderNotConfigured", err) + } +} + +func TestBuildLegacyHTTPVerifierChain_HappyPath(t *testing.T) { + srv := fakeVerifierFor(t, "ok") + chain, err := buildLegacyHTTPVerifierChain(srv.URL, "default-org") + if err != nil { + t.Fatalf("buildLegacyHTTPVerifierChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "ok", nil); err != nil { + t.Errorf("verify failed: %v", err) + } +} + +func TestBuildLegacyHTTPVerifierChain_EmptyURLFails(t *testing.T) { + if _, err := buildLegacyHTTPVerifierChain("", ""); err == nil { + t.Error("expected error for empty URL") + } +} + +// --- buildUserAuthChain (precedence rules) --- + +// runnerWith returns a Runner whose Config has the given AuthConfig and +// CLI fields. Minimal — only what buildUserAuthChain needs. +func runnerWith(authYAML types.AuthConfig, authURL string) *Runner { + r := &Runner{logger: nopLogger{}} + r.cfg.Config = &types.ForgeConfig{Auth: authYAML} + r.cfg.AuthURL = authURL + return r +} + +// nopLogger is a Logger that does nothing — keeps the warning path out +// of test output for the prefer-YAML-over-flag test below. +type nopLogger struct{} + +func (nopLogger) Debug(string, map[string]any) {} +func (nopLogger) Info(string, map[string]any) {} +func (nopLogger) Warn(string, map[string]any) {} +func (nopLogger) Error(string, map[string]any) {} + +func TestBuildUserAuthChain_NeitherSourcePresent(t *testing.T) { + r := runnerWith(types.AuthConfig{}, "") + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if chain != nil { + t.Errorf("chain = %v, want nil when neither YAML nor flag is set", chain) + } +} + +func TestBuildUserAuthChain_OnlyLegacyURL(t *testing.T) { + srv := fakeVerifierFor(t, "ok") + r := runnerWith(types.AuthConfig{}, srv.URL) + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "ok", nil); err != nil { + t.Errorf("verify failed: %v", err) + } +} + +func TestBuildUserAuthChain_OnlyYAMLAuth(t *testing.T) { + srv := fakeVerifierFor(t, "yaml-token") + r := runnerWith(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srv.URL}}, + }, + }, "") + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "yaml-token", nil); err != nil { + t.Errorf("yaml-source verifier rejected: %v", err) + } +} + +func TestBuildUserAuthChain_BothSourcesPrefersYAML(t *testing.T) { + // YAML auth points at srvYAML (accepts "yaml-only"). + // Legacy --auth-url points at srvLegacy (accepts "legacy-only"). + // Expect: YAML wins, legacy verifier never consulted. + srvYAML := fakeVerifierFor(t, "yaml-only") + srvLegacy := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("legacy verifier should not be called when YAML auth: is present") + })) + defer srvLegacy.Close() + + r := runnerWith(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srvYAML.URL}}, + }, + }, srvLegacy.URL) + + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + // YAML accepts "yaml-only". + if _, err := chain.Verify(context.Background(), "yaml-only", nil); err != nil { + t.Errorf("yaml-only rejected: %v", err) + } + // Legacy token rejected because YAML verifier doesn't recognize it. + if _, err := chain.Verify(context.Background(), "legacy-only", nil); !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("legacy-only token err = %v, want ErrTokenRejected", err) + } +} + +func TestBuildUserAuthChain_OIDCFromYAML(t *testing.T) { + // Builds an oidc provider successfully — confirms the side-effect + // import wired it into the registry. Network isn't actually exercised + // because we don't call Verify with a real JWT. + r := runnerWith(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }}, + }, + }, "") + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if chain == nil { + t.Fatal("chain is nil; OIDC factory may not be registered") + } +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 68600e9..b43ada3 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -17,6 +17,10 @@ import ( "github.com/initializ/forge/forge-core/agentspec" "github.com/initializ/forge/forge-core/auth" "github.com/initializ/forge/forge-core/auth/providers/httpverifier" + // Side-effect import: registers the "oidc" provider with the auth registry + // so forge.yaml `auth: { type: oidc }` blocks construct successfully via + // auth.Build("oidc", settings). The package is not referenced directly. + _ "github.com/initializ/forge/forge-core/auth/providers/oidc" "github.com/initializ/forge/forge-core/auth/providers/statictoken" "github.com/initializ/forge/forge-core/llm" "github.com/initializ/forge/forge-core/llm/oauth" @@ -248,6 +252,10 @@ func (r *Runner) Run(ctx context.Context) error { egressDomains = append(egressDomains, expandEgressDomains(d, envVars)...) } } + // Auto-merge auth-provider issuer/verifier hosts. Without this, an + // OIDC issuer or http_verifier URL configured in forge.yaml would be + // silently blocked at runtime by the egress enforcer. + egressDomains = append(egressDomains, security.AuthDomains(r.cfg.Config.Auth)...) egressCfg, egressErr := security.Resolve( r.cfg.Config.Egress.Profile, r.cfg.Config.Egress.Mode, @@ -1723,14 +1731,18 @@ func (r *Runner) printBanner(proxyURL string) { // resolveAuth builds the auth middleware options for the A2A server. // -// Behavior preserved from the pre-chain implementation: -// - --no-auth (NoAuth=true) → no chain, anonymous access -// - --auth-url set → chain: [static_token (loopback), http_verifier] -// - neither set → chain: [static_token (local)] +// Source precedence (highest first): +// 1. --no-auth flag → nil chain, anonymous access +// 2. forge.yaml auth: → Registry.BuildChain(cfg.Auth.Providers) +// 3. --auth-url / env → legacy http_verifier chain +// 4. nothing → loopback-token-only chain (or nil if also no channels) // -// Token resolution is performed by ResolveAuth() (called earlier so channel -// adapters can use it); this method translates the resolved token + URL into -// a Provider chain. +// If BOTH forge.yaml auth: AND --auth-url are configured, the YAML block +// wins and a warning is logged — silent merging would be surprising. +// +// The internal loopback static_token is ALWAYS prepended at the chain +// head (regardless of source) so channel-adapter callbacks short-circuit +// before any external provider is invoked. func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.MiddlewareOptions, error) { // Ensure token is resolved (no-op if already done by ResolveAuth). if err := r.ResolveAuth(); err != nil { @@ -1741,11 +1753,28 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew return auth.MiddlewareOptions{}, nil // nil chain → passthrough } - chain, err := buildLegacyAuthChain(r.authToken, r.cfg.AuthURL, r.cfg.AuthOrgID) + userChain, err := r.buildUserAuthChain() if err != nil { return auth.MiddlewareOptions{}, fmt.Errorf("building auth chain: %w", err) } + // Prepend the loopback static_token so channel adapter callbacks + // short-circuit before any user-configured provider runs. + chain := userChain + if r.authToken != "" { + loopback, err := statictoken.New(statictoken.Config{ + Token: r.authToken, + Identity: auth.Identity{ + UserID: "forge-internal", + Source: "internal", + }, + }) + if err != nil { + return auth.MiddlewareOptions{}, fmt.Errorf("loopback static_token: %w", err) + } + chain = auth.PrependChain(userChain, loopback) + } + return auth.MiddlewareOptions{ Chain: chain, SkipPaths: auth.DefaultSkipPaths(), @@ -1769,20 +1798,82 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew }, nil } -// buildLegacyAuthChain constructs the Provider chain from the legacy -// CLI-flag inputs (--auth-token / --auth-url / --auth-org-id). +// buildUserAuthChain returns the user-facing portion of the auth chain +// (everything EXCEPT the loopback static_token, which the caller prepends). // -// When both an internal token and an external --auth-url are present, the -// static_token provider runs first so channel-adapter loopback calls -// short-circuit before reaching the external verifier. This mirrors the -// pre-refactor "internal token then external" check in middleware.go. +// - forge.yaml auth.providers populated → Registry.BuildChain +// - --auth-url / FORGE_AUTH_URL only → legacy http_verifier +// - neither → nil (no user chain) +func (r *Runner) buildUserAuthChain() (auth.Provider, error) { + hasYAMLAuth := len(r.cfg.Config.Auth.Providers) > 0 + hasLegacyURL := r.cfg.AuthURL != "" + + if hasYAMLAuth && hasLegacyURL { + r.logger.Warn("both --auth-url and forge.yaml 'auth:' block configured; preferring 'auth:' block (--auth-url ignored)", nil) + } + + if hasYAMLAuth { + return buildChainFromConfig(r.cfg.Config.Auth) + } + if hasLegacyURL { + return buildLegacyHTTPVerifierChain(r.cfg.AuthURL, r.cfg.AuthOrgID) + } + return nil, nil +} + +// buildChainFromConfig builds a ChainProvider from a typed AuthConfig by +// delegating to the package-level registry. Each AuthProvider entry is +// constructed via its registered factory. +func buildChainFromConfig(cfg types.AuthConfig) (auth.Provider, error) { + if len(cfg.Providers) == 0 { + return nil, nil + } + providers := make([]auth.Provider, 0, len(cfg.Providers)) + for i, spec := range cfg.Providers { + p, err := auth.Build(spec.Type, spec.Settings) + if err != nil { + return nil, fmt.Errorf("auth.providers[%d] (%s): %w", i, spec.Type, err) + } + providers = append(providers, p) + } + return auth.NewChainProvider(providers...), nil +} + +// buildLegacyHTTPVerifierChain constructs the legacy --auth-url chain +// (single http_verifier provider). Kept separate from +// buildChainFromConfig so the legacy code path is obvious. +func buildLegacyHTTPVerifierChain(authURL, authOrgID string) (auth.Provider, error) { + p, err := httpverifier.New(httpverifier.Config{ + URL: authURL, + DefaultOrg: authOrgID, + }) + if err != nil { + return nil, fmt.Errorf("legacy http_verifier: %w", err) + } + return auth.NewChainProvider(p), nil +} + +// buildLegacyAuthChain is a test-friendly helper that combines the +// loopback-token prepend and the legacy http_verifier chain. Mirrors the +// production wiring in resolveAuth so PR1's e2e tests keep working +// unchanged. // -// Returns nil chain (passthrough) when no providers are configured. +// Order: +// 1. static_token (loopback) — if internalToken non-empty +// 2. http_verifier — if authURL non-empty +// +// Returns nil chain when both inputs are empty. func buildLegacyAuthChain(internalToken, authURL, authOrgID string) (auth.Provider, error) { - var providers []auth.Provider - + var userChain auth.Provider + if authURL != "" { + c, err := buildLegacyHTTPVerifierChain(authURL, authOrgID) + if err != nil { + return nil, err + } + userChain = c + } if internalToken != "" { - p, err := statictoken.New(statictoken.Config{ + loopback, err := statictoken.New(statictoken.Config{ Token: internalToken, Identity: auth.Identity{ UserID: "forge-internal", @@ -1792,24 +1883,12 @@ func buildLegacyAuthChain(internalToken, authURL, authOrgID string) (auth.Provid if err != nil { return nil, fmt.Errorf("static_token provider: %w", err) } - providers = append(providers, p) - } - - if authURL != "" { - p, err := httpverifier.New(httpverifier.Config{ - URL: authURL, - DefaultOrg: authOrgID, - }) - if err != nil { - return nil, fmt.Errorf("http_verifier provider: %w", err) + if userChain == nil { + return auth.NewChainProvider(loopback), nil } - providers = append(providers, p) - } - - if len(providers) == 0 { - return nil, nil + return auth.PrependChain(userChain, loopback), nil } - return auth.NewChainProvider(providers...), nil + return userChain, nil } // ensureGitignore makes sure .forge/ is listed in the project's .gitignore. diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go new file mode 100644 index 0000000..b82957f --- /dev/null +++ b/forge-core/security/auth_domains.go @@ -0,0 +1,89 @@ +package security + +import ( + "net/url" + "sort" + + "github.com/initializ/forge/forge-core/types" +) + +// AuthDomains extracts the outbound hosts that auth providers must be able +// to reach: OIDC issuers, http_verifier URLs, future Okta tenants, etc. +// +// These domains are merged into the egress allowlist BEFORE the egress +// enforcer is constructed, so configuring an OIDC provider does not +// silently fail at runtime with a network-blocked JWKS fetch. +// +// Returned hosts are deduplicated and sorted for stable test output. +// Empty/malformed URLs are skipped (validation happens elsewhere). +func AuthDomains(cfg types.AuthConfig) []string { + if len(cfg.Providers) == 0 { + return nil + } + seen := map[string]struct{}{} + for _, p := range cfg.Providers { + for _, raw := range authProviderURLs(p) { + host := hostFromURL(raw) + if host == "" { + continue + } + seen[host] = struct{}{} + } + } + if len(seen) == 0 { + return nil + } + out := make([]string, 0, len(seen)) + for h := range seen { + out = append(out, h) + } + sort.Strings(out) + return out +} + +// authProviderURLs returns the URLs that a given provider needs to reach. +// Centralizes per-provider extraction so adding a new provider type in +// the registry only requires touching one map. +func authProviderURLs(p types.AuthProvider) []string { + switch p.Type { + case "oidc": + return []string{ + settingString(p.Settings, "issuer"), + settingString(p.Settings, "jwks_url"), + } + case "http_verifier": + return []string{ + settingString(p.Settings, "url"), + } + // static_token has no outbound; not listed + // okta (Phase 3) will be added here with issuer + api domain + default: + return nil + } +} + +func settingString(m map[string]any, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +// hostFromURL extracts the bare host (without port) from a URL string. +// Returns "" for unparseable values — the caller skips those silently +// since real validation happens in validate.ValidateAuthConfig. +func hostFromURL(raw string) string { + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil || u.Host == "" { + return "" + } + return u.Hostname() +} diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go new file mode 100644 index 0000000..a917e07 --- /dev/null +++ b/forge-core/security/auth_domains_test.go @@ -0,0 +1,136 @@ +package security_test + +import ( + "reflect" + "testing" + + "github.com/initializ/forge/forge-core/security" + "github.com/initializ/forge/forge-core/types" +) + +func TestAuthDomains_Empty(t *testing.T) { + if got := security.AuthDomains(types.AuthConfig{}); got != nil { + t.Errorf("AuthDomains(empty) = %v, want nil", got) + } +} + +func TestAuthDomains_OIDCIssuer(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }}, + }, + }) + want := []string{"login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_HTTPVerifier(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{ + "url": "https://verify.example.com/verify", + }}, + }, + }) + want := []string{"verify.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_OIDCWithExplicitJWKSURL(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + "jwks_url": "https://keys.example.com/.well-known/jwks.json", + }}, + }, + }) + want := []string{"keys.example.com", "login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_MultiProviderDedup(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }}, + {Type: "http_verifier", Settings: map[string]any{ + "url": "https://login.example.com/verify", + }}, + {Type: "static_token", Settings: map[string]any{"token_env": "X"}}, + }, + }) + want := []string{"login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v (must dedup across providers)", got, want) + } +} + +func TestAuthDomains_StaticTokenContributesNothing(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "static_token", Settings: map[string]any{"token_env": "X"}}, + }, + }) + if got != nil { + t.Errorf("static_token-only config should contribute no domains, got %v", got) + } +} + +func TestAuthDomains_MalformedURLsSkipped(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "://not a url", + "audience": "api://forge", + }}, + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://valid.example.com", + "audience": "api://forge", + }}, + }, + }) + // Malformed URL silently skipped (validate package handles surface errors). + want := []string{"valid.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_PortStripped(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "http://localhost:8080/realms/dev", + "audience": "api://forge", + }}, + }, + }) + want := []string{"localhost"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v (port must be stripped)", got, want) + } +} + +func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "future_provider", Settings: map[string]any{"url": "https://x.example.com"}}, + }, + }) + if got != nil { + t.Errorf("unknown provider type returned domains: %v (must be nil — extractor not registered)", got) + } +} diff --git a/forge-core/types/config.go b/forge-core/types/config.go index faf90af..2294c53 100644 --- a/forge-core/types/config.go +++ b/forge-core/types/config.go @@ -22,12 +22,42 @@ type ForgeConfig struct { Skills SkillsRef `yaml:"skills,omitempty"` Memory MemoryConfig `yaml:"memory,omitempty"` Secrets SecretsConfig `yaml:"secrets,omitempty"` + Auth AuthConfig `yaml:"auth,omitempty"` Schedules []ScheduleConfig `yaml:"schedules,omitempty"` CORSOrigins []string `yaml:"cors_origins,omitempty"` Package PackageConfig `yaml:"package,omitempty"` GuardrailsPath string `yaml:"guardrails_path,omitempty"` // path to guardrails.json (default: "guardrails.json") } +// AuthConfig declares the auth provider chain for the A2A server. Mirrors +// the secrets.providers pattern: each entry is { type, settings } and the +// runner builds them in order via auth.Registry.BuildChain. +// +// Backward compatibility: if AuthConfig.Providers is empty, the legacy +// --auth-url / FORGE_AUTH_URL / FORGE_AUTH_ORG_ID flow synthesizes a +// single-element http_verifier chain (unchanged from pre-PR3 behavior). +type AuthConfig struct { + // Required indicates whether auth is mandatory. When false (default), + // the runtime treats Providers as the source of truth — operators may + // still opt out via --no-auth on localhost. Reserved for future + // TUI/UI gating logic. + Required bool `yaml:"required,omitempty"` + + // Providers is the ordered list of auth providers that compose into + // the A2A server's auth chain. First-match wins. + Providers []AuthProvider `yaml:"providers,omitempty"` +} + +// AuthProvider is one entry in AuthConfig.Providers. The Type names a +// factory registered with the auth package (e.g., "oidc", "http_verifier", +// "static_token", and — in Phase 3 — "okta"). Settings is unmarshaled +// into the provider-specific Config struct via auth.UnmarshalSettings. +type AuthProvider struct { + Type string `yaml:"type"` + Name string `yaml:"name,omitempty"` + Settings map[string]any `yaml:"settings,omitempty"` +} + // ScheduleConfig defines a recurring scheduled task in forge.yaml. type ScheduleConfig struct { ID string `yaml:"id"` diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go new file mode 100644 index 0000000..78ead5b --- /dev/null +++ b/forge-core/validate/auth.go @@ -0,0 +1,93 @@ +package validate + +import ( + "fmt" + + "github.com/initializ/forge/forge-core/types" +) + +// knownAuthProviderTypes is the closed set of auth provider types that +// validate accepts. It is intentionally NOT derived from auth.RegisteredTypes() +// — validation must work at build time (forge validate) without importing +// the runtime provider implementations, which is what allows the forge-core +// `validate` package to stay free of side-effect imports. +// +// New provider types must be added here when they ship. +var knownAuthProviderTypes = map[string]bool{ + "http_verifier": true, + "static_token": true, + "oidc": true, + // "okta": true, // Phase 3 (v0.11.0) +} + +// ValidateAuthConfig adds errors and warnings for a forge.yaml auth: block. +// Empty AuthConfig is valid (legacy --auth-url path remains the fallback). +func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { + if len(cfg.Providers) == 0 { + if cfg.Required { + r.Errors = append(r.Errors, "auth.required is true but auth.providers is empty") + } + return + } + + seenNames := map[string]int{} + for i, p := range cfg.Providers { + prefix := fmt.Sprintf("auth.providers[%d]", i) + + if p.Type == "" { + r.Errors = append(r.Errors, prefix+": type is required") + continue + } + if !knownAuthProviderTypes[p.Type] { + r.Errors = append(r.Errors, fmt.Sprintf("%s: unknown type %q (known: http_verifier, static_token, oidc)", prefix, p.Type)) + continue + } + + if p.Name != "" { + seenNames[p.Name]++ + if seenNames[p.Name] > 1 { + r.Warnings = append(r.Warnings, fmt.Sprintf("%s: duplicate name %q across providers", prefix, p.Name)) + } + } + + validateProviderSettings(prefix, p, r) + } +} + +// validateProviderSettings checks the type-specific required keys in the +// settings block. We re-validate here (the provider's own Validate() +// runs at runtime construction) so `forge validate` catches errors before +// `forge run`. +func validateProviderSettings(prefix string, p types.AuthProvider, r *ValidationResult) { + switch p.Type { + case "http_verifier": + if asString(p.Settings, "url") == "" { + r.Errors = append(r.Errors, prefix+" (http_verifier): settings.url is required") + } + case "static_token": + if asString(p.Settings, "token") == "" && asString(p.Settings, "token_env") == "" { + r.Errors = append(r.Errors, prefix+" (static_token): settings.token or settings.token_env is required") + } + if asString(p.Settings, "token") != "" { + r.Warnings = append(r.Warnings, prefix+" (static_token): literal 'token' in YAML is a footgun — prefer 'token_env'") + } + case "oidc": + if asString(p.Settings, "issuer") == "" { + r.Errors = append(r.Errors, prefix+" (oidc): settings.issuer is required") + } + if asString(p.Settings, "audience") == "" { + r.Errors = append(r.Errors, prefix+" (oidc): settings.audience is required") + } + } +} + +// asString reads a string-valued setting, returning "" for missing or +// non-string values. +func asString(m map[string]any, key string) string { + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} diff --git a/forge-core/validate/auth_test.go b/forge-core/validate/auth_test.go new file mode 100644 index 0000000..a4813fb --- /dev/null +++ b/forge-core/validate/auth_test.go @@ -0,0 +1,218 @@ +package validate + +import ( + "strings" + "testing" + + "github.com/initializ/forge/forge-core/types" +) + +func TestValidateAuthConfig_EmptyIsValid(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{}, r) + if !r.IsValid() { + t.Errorf("empty AuthConfig should be valid, got errors: %v", r.Errors) + } +} + +func TestValidateAuthConfig_RequiredWithoutProviders(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{Required: true}, r) + if r.IsValid() { + t.Error("required=true with no providers should be an error") + } + if !containsSubstr(r.Errors, "required") { + t.Errorf("missing 'required' in error message: %v", r.Errors) + } +} + +func TestValidateAuthConfig_HTTPVerifier(t *testing.T) { + tests := []struct { + name string + provider types.AuthProvider + wantErr string + }{ + { + name: "valid", + provider: types.AuthProvider{ + Type: "http_verifier", + Settings: map[string]any{"url": "https://verify.example.com"}, + }, + }, + { + name: "missing url", + provider: types.AuthProvider{ + Type: "http_verifier", + Settings: map[string]any{}, + }, + wantErr: "settings.url is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{Providers: []types.AuthProvider{tt.provider}}, r) + if tt.wantErr == "" { + if !r.IsValid() { + t.Errorf("want valid, got errors: %v", r.Errors) + } + } else { + if !containsSubstr(r.Errors, tt.wantErr) { + t.Errorf("want error containing %q, got %v", tt.wantErr, r.Errors) + } + } + }) + } +} + +func TestValidateAuthConfig_OIDC(t *testing.T) { + tests := []struct { + name string + settings map[string]any + wantErr string + }{ + { + name: "valid", + settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }, + }, + { + name: "missing issuer", + settings: map[string]any{"audience": "api://forge"}, + wantErr: "settings.issuer is required", + }, + { + name: "missing audience", + settings: map[string]any{"issuer": "https://login.example.com"}, + wantErr: "settings.audience is required", + }, + { + name: "both missing", + settings: map[string]any{}, + wantErr: "issuer is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Type: "oidc", Settings: tt.settings}}, + }, r) + if tt.wantErr == "" { + if !r.IsValid() { + t.Errorf("want valid, got errors: %v", r.Errors) + } + } else { + if !containsSubstr(r.Errors, tt.wantErr) { + t.Errorf("want error containing %q, got %v", tt.wantErr, r.Errors) + } + } + }) + } +} + +func TestValidateAuthConfig_StaticToken(t *testing.T) { + tests := []struct { + name string + settings map[string]any + wantErr string + wantWarn string + }{ + { + name: "valid with token_env", + settings: map[string]any{"token_env": "FORGE_TOKEN"}, + }, + { + name: "valid with token (warns)", + settings: map[string]any{"token": "literal-secret"}, + wantWarn: "footgun", + }, + { + name: "neither token nor token_env", + settings: map[string]any{}, + wantErr: "settings.token or settings.token_env is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Type: "static_token", Settings: tt.settings}}, + }, r) + if tt.wantErr != "" && !containsSubstr(r.Errors, tt.wantErr) { + t.Errorf("want error containing %q, got %v", tt.wantErr, r.Errors) + } + if tt.wantWarn != "" && !containsSubstr(r.Warnings, tt.wantWarn) { + t.Errorf("want warning containing %q, got %v", tt.wantWarn, r.Warnings) + } + }) + } +} + +func TestValidateAuthConfig_UnknownType(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Type: "ldap"}}, + }, r) + if r.IsValid() { + t.Error("unknown provider type should fail validation") + } + if !containsSubstr(r.Errors, "unknown type") { + t.Errorf("want 'unknown type' in errors, got %v", r.Errors) + } +} + +func TestValidateAuthConfig_MissingType(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Settings: map[string]any{"url": "https://x"}}}, + }, r) + if r.IsValid() { + t.Error("provider without type should fail validation") + } +} + +func TestValidateAuthConfig_DuplicateNamesWarn(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Name: "sso", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + {Type: "oidc", Name: "sso", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }, r) + if !containsSubstr(r.Warnings, "duplicate name") { + t.Errorf("want duplicate-name warning, got %v", r.Warnings) + } +} + +func TestValidateForgeConfig_IncludesAuth(t *testing.T) { + // Ensure the top-level ValidateForgeConfig calls ValidateAuthConfig. + cfg := &types.ForgeConfig{ + AgentID: "test-agent", + Version: "0.1.0", + Auth: types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc"}, // missing issuer + audience + }, + }, + } + r := ValidateForgeConfig(cfg) + if r.IsValid() { + t.Fatal("config with invalid oidc settings should fail validation") + } + if !containsSubstr(r.Errors, "issuer is required") { + t.Errorf("want oidc-issuer error to surface via ValidateForgeConfig, got %v", r.Errors) + } +} + +// containsSubstr returns true if any string in ss contains substr. +func containsSubstr(ss []string, substr string) bool { + for _, s := range ss { + if strings.Contains(s, substr) { + return true + } + } + return false +} diff --git a/forge-core/validate/forge_config.go b/forge-core/validate/forge_config.go index 7ff5da5..e6fa43f 100644 --- a/forge-core/validate/forge_config.go +++ b/forge-core/validate/forge_config.go @@ -118,5 +118,7 @@ func ValidateForgeConfig(cfg *types.ForgeConfig) *ValidationResult { } } + ValidateAuthConfig(cfg.Auth, r) + return r } From 381af95dd15e8469f01d024cbfab28efe54bfcc2 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 19:41:08 -0400 Subject: [PATCH 04/20] chore(auth/oidc): drop unused removeKey test helper (lint fix) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit golangci-lint v2.10.1 flagged removeKey as unused — it was scaffolded for a key-rotation test case that was implemented differently (by adding a new key without removing the old). Dropping the dead code keeps CI lint green. --- forge-core/auth/providers/oidc/testing_helper_test.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/forge-core/auth/providers/oidc/testing_helper_test.go b/forge-core/auth/providers/oidc/testing_helper_test.go index 49cefbe..3185912 100644 --- a/forge-core/auth/providers/oidc/testing_helper_test.go +++ b/forge-core/auth/providers/oidc/testing_helper_test.go @@ -94,17 +94,6 @@ func (fi *fakeIssuer) addECDSAKey(kid string) *signingKey { return k } -// removeKey simulates an IdP rotating out an old key. -func (fi *fakeIssuer) removeKey(kid string) { - delete(fi.keys, kid) - for i, k := range fi.keyOrder { - if k == kid { - fi.keyOrder = append(fi.keyOrder[:i], fi.keyOrder[i+1:]...) - break - } - } -} - // SignWith builds a signed token using the named key. Claims is the full // claim map; helpers below build common claim sets. func (fi *fakeIssuer) SignWith(kid string, claims jwt.MapClaims) string { From 8e7747f49ea913934c7e0d71888887aad00779f1 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 19:48:01 -0400 Subject: [PATCH 05/20] feat(auth/audit): structured auth_verify + auth_fail events (PR4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the thin {method, path, remote_addr} payload of the legacy auth_success/auth_failure events with structured fields suitable for operators to dashboard and alert on, without leaking PII. Events: auth_verify (success) { provider, user_id, org_id, groups_count, token_kind, method, path, remote_addr } auth_fail (failure) { reason, token_kind, method, path, remote_addr } Reason codes (stable contract — changing is a breaking change for downstream consumers): missing_token no Authorization header rejected provider recognized + denied (revoked, expired, wrong iss/aud) invalid token malformed / cryptographically invalid not_for_me chain exhausted, no provider claimed the token infrastructure transient provider error (network, etc.) Privacy invariants enforced by test: - email, claims, token bytes NEVER appear in either event - test asserts the literal absence of email / SSN / secret values from the serialized audit line Interface changes: auth.MiddlewareOptions.OnAuth signature widened from func(*http.Request, bool) to func(*http.Request, *auth.Identity, error, string) so the caller can read the verified principal + chain error + token kind, instead of just a bool. Only caller (runner.go) updated. New: auth.ErrMissingBearer sentinel for the no-header case so callers can distinguish missing-token from chain errors. New: auth.TokenKind(token) returns "jwt" | "opaque" | "empty" — cheap structural classification suitable for logging. Backward compat: AuditAuthSuccess and AuditAuthFailure constants kept as deprecated aliases (point to the new "auth_verify"/"auth_fail" strings) so any consumer that grep'd for the old constants compiles but emits the new string. Scheduled for removal in v0.11.0. Tests (100% coverage on TokenKind, makeAuthAuditCallback, authFailReason): - TokenKind: jwt/opaque/empty/edge-cases - Middleware: OnAuth receives (Identity, err, kind) on every decision - Middleware: ErrMissingBearer surfaces on no-header requests - Audit: auth_verify carries provider/user_id/org_id/groups_count/kind - Audit: auth_fail carries stable reason codes for each sentinel - Audit: PII-leak test (email + secret claims never appear in payload) - Audit: CorrelationID propagates from request context - E2E: full middleware → callback → audit log path No regressions: - All forge-core packages pass (28/28) - All forge-cli packages pass (15/15) - All forge-plugins packages pass (4/4) - go test -race clean across affected packages - golangci-lint v2.10.1: 0 issues - gofmt + go vet clean --- forge-cli/runtime/auth_audit_test.go | 255 +++++++++++++++++++++++++++ forge-cli/runtime/runner.go | 93 ++++++++-- forge-core/auth/middleware.go | 53 ++++-- forge-core/auth/middleware_test.go | 110 +++++++++++- forge-core/auth/provider.go | 27 +++ forge-core/auth/provider_test.go | 24 +++ forge-core/runtime/audit.go | 15 +- 7 files changed, 546 insertions(+), 31 deletions(-) create mode 100644 forge-cli/runtime/auth_audit_test.go diff --git a/forge-cli/runtime/auth_audit_test.go b/forge-cli/runtime/auth_audit_test.go new file mode 100644 index 0000000..a81d0b0 --- /dev/null +++ b/forge-cli/runtime/auth_audit_test.go @@ -0,0 +1,255 @@ +package runtime + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/initializ/forge/forge-core/auth" + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// PR4 audit-event contract tests. These guarantee: +// 1. auth_verify is emitted on success with the right fields. +// 2. auth_fail is emitted on failure with a stable reason code. +// 3. Neither event ever carries PII (token bytes, email, claim payloads). + +// captureAudit decodes each NDJSON line emitted to buf into an AuditEvent. +func captureAudit(t *testing.T, buf *bytes.Buffer) []coreruntime.AuditEvent { + t.Helper() + var events []coreruntime.AuditEvent + for _, line := range strings.Split(strings.TrimSpace(buf.String()), "\n") { + if line == "" { + continue + } + var ev coreruntime.AuditEvent + if err := json.Unmarshal([]byte(line), &ev); err != nil { + t.Fatalf("decode audit line %q: %v", line, err) + } + events = append(events, ev) + } + return events +} + +func TestAuthAudit_EmitsAuthVerifyOnSuccess(t *testing.T) { + var buf bytes.Buffer + logger := coreruntime.NewAuditLogger(&buf) + + cb := makeAuthAuditCallback(logger) + if cb == nil { + t.Fatal("callback nil with non-nil logger") + } + + req := httptest.NewRequest("POST", "/tasks", nil) + id := &auth.Identity{ + UserID: "alice-123", + Email: "alice@example.com", // must NOT leak into the event + OrgID: "tenant-1", + Groups: []string{"a", "b", "c"}, + Source: "oidc", + } + cb(req, id, nil, "jwt") + + events := captureAudit(t, &buf) + if len(events) != 1 { + t.Fatalf("got %d events, want 1", len(events)) + } + ev := events[0] + + if ev.Event != coreruntime.EventAuthVerify { + t.Errorf("Event = %q, want %q", ev.Event, coreruntime.EventAuthVerify) + } + if ev.Fields["provider"] != "oidc" { + t.Errorf("provider = %v, want oidc", ev.Fields["provider"]) + } + if ev.Fields["user_id"] != "alice-123" { + t.Errorf("user_id = %v, want alice-123", ev.Fields["user_id"]) + } + if ev.Fields["org_id"] != "tenant-1" { + t.Errorf("org_id = %v, want tenant-1", ev.Fields["org_id"]) + } + // JSON unmarshal puts numbers in float64. + if g, _ := ev.Fields["groups_count"].(float64); g != 3 { + t.Errorf("groups_count = %v, want 3", ev.Fields["groups_count"]) + } + if ev.Fields["token_kind"] != "jwt" { + t.Errorf("token_kind = %v, want jwt", ev.Fields["token_kind"]) + } +} + +func TestAuthAudit_EmitsAuthFailOnError(t *testing.T) { + tests := []struct { + name string + err error + wantReason string + }{ + {"missing token", auth.ErrMissingBearer, "missing_token"}, + {"rejected", auth.ErrTokenRejected, "rejected"}, + {"invalid", auth.ErrInvalidToken, "invalid"}, + {"not for me", auth.ErrTokenNotForMe, "not_for_me"}, + {"infrastructure", errBoom, "infrastructure"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest("POST", "/tasks", nil) + cb(req, nil, tt.err, "opaque") + + events := captureAudit(t, &buf) + if len(events) != 1 { + t.Fatalf("got %d events, want 1", len(events)) + } + ev := events[0] + if ev.Event != coreruntime.EventAuthFail { + t.Errorf("Event = %q, want %q", ev.Event, coreruntime.EventAuthFail) + } + if got := ev.Fields["reason"]; got != tt.wantReason { + t.Errorf("reason = %v, want %q", got, tt.wantReason) + } + if ev.Fields["token_kind"] != "opaque" { + t.Errorf("token_kind = %v, want opaque", ev.Fields["token_kind"]) + } + }) + } +} + +func TestAuthAudit_NoPIIInPayload(t *testing.T) { + // Strong negative test: even if the Identity carries an email and + // rich Claims, the emitted audit line MUST NOT include them. + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest("POST", "/tasks", nil) + id := &auth.Identity{ + UserID: "u", + Email: "secret@hr.example.com", + Claims: map[string]any{ + "ssn": "999-99-9999", + "secret": "deadbeef", + }, + Source: "oidc", + } + cb(req, id, nil, "jwt") + + line := buf.String() + forbidden := []string{ + "secret@hr.example.com", + "999-99-9999", + "deadbeef", + "\"email\"", + "\"claims\"", + } + for _, f := range forbidden { + if strings.Contains(line, f) { + t.Errorf("audit line leaked %q:\n%s", f, line) + } + } +} + +func TestAuthAudit_NilLoggerReturnsNilCallback(t *testing.T) { + if cb := makeAuthAuditCallback(nil); cb != nil { + t.Error("expected nil callback for nil logger") + } +} + +func TestAuthAudit_CorrelationIDPropagated(t *testing.T) { + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest("POST", "/tasks", nil) + ctx := coreruntime.WithCorrelationID(req.Context(), "corr-abc-123") + req = req.WithContext(ctx) + + cb(req, &auth.Identity{UserID: "u", Source: "oidc"}, nil, "jwt") + + events := captureAudit(t, &buf) + if events[0].CorrelationID != "corr-abc-123" { + t.Errorf("CorrelationID = %q, want corr-abc-123", events[0].CorrelationID) + } +} + +func TestAuthFailReason_UnknownError(t *testing.T) { + // Unknown error types map to "infrastructure" (catch-all for things + // like network failures, JWKS-fetch errors). + if got := authFailReason(errBoom); got != "infrastructure" { + t.Errorf("authFailReason(custom err) = %q, want infrastructure", got) + } + if got := authFailReason(nil); got != "unknown" { + t.Errorf("authFailReason(nil) = %q, want unknown", got) + } +} + +// Smoke test: making sure the OnAuth callback hooks into auth.Middleware +// end-to-end and produces audit events on real HTTP traffic. +func TestAuthAudit_E2EThroughMiddleware(t *testing.T) { + var buf bytes.Buffer + logger := coreruntime.NewAuditLogger(&buf) + + chain := auth.NewChainProvider(&testProvider{ + token: "ok", + identity: auth.Identity{UserID: "u", OrgID: "o", Source: "test"}, + }) + + handler := auth.Middleware(auth.MiddlewareOptions{ + Chain: chain, + SkipPaths: auth.DefaultSkipPaths(), + OnAuth: makeAuthAuditCallback(logger), + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // 1. Successful request. + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "Bearer ok") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // 2. Missing token. + req2 := httptest.NewRequest("POST", "/tasks", nil) + handler.ServeHTTP(httptest.NewRecorder(), req2) + + events := captureAudit(t, &buf) + if len(events) != 2 { + t.Fatalf("got %d events, want 2", len(events)) + } + if events[0].Event != coreruntime.EventAuthVerify { + t.Errorf("first event = %q, want auth_verify", events[0].Event) + } + if events[1].Event != coreruntime.EventAuthFail { + t.Errorf("second event = %q, want auth_fail", events[1].Event) + } + if events[1].Fields["reason"] != "missing_token" { + t.Errorf("missing-token reason = %v, want missing_token", events[1].Fields["reason"]) + } +} + +// --- helpers --- + +// errBoom is an opaque non-sentinel error for testing the catch-all +// "infrastructure" reason mapping. +type customErr struct{ s string } + +func (e customErr) Error() string { return e.s } + +var errBoom error = customErr{s: "network is on fire"} + +// testProvider is a tiny Provider that accepts one token. +type testProvider struct { + token string + identity auth.Identity +} + +func (t *testProvider) Name() string { return "test" } +func (t *testProvider) Verify(_ context.Context, token string, _ auth.Headers) (*auth.Identity, error) { + if token != t.token { + return nil, auth.ErrTokenNotForMe + } + id := t.identity + return &id, nil +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index b43ada3..82f4dfd 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -3,6 +3,7 @@ package runtime import ( "context" "encoding/json" + "errors" "fmt" "net/http" "os" @@ -1778,24 +1779,86 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew return auth.MiddlewareOptions{ Chain: chain, SkipPaths: auth.DefaultSkipPaths(), - OnAuth: func(req *http.Request, success bool) { - if auditLogger == nil { - return - } - event := coreruntime.AuditAuthSuccess - if !success { - event = coreruntime.AuditAuthFailure + OnAuth: makeAuthAuditCallback(auditLogger), + }, nil +} + +// makeAuthAuditCallback returns the OnAuth callback that emits structured +// auth_verify / auth_fail audit events. +// +// Fields emitted (NO PII — never email, claims, token bytes, or secrets): +// +// auth_verify: { provider, user_id, org_id, groups_count, token_kind, method, path, remote_addr } +// auth_fail: { reason, token_kind, method, path, remote_addr } +// +// Reason codes for auth_fail: +// +// missing_token → no Authorization header +// rejected → provider recognized + denied (revoked, expired, wrong iss/aud) +// invalid → token malformed or cryptographically invalid +// not_for_me → chain exhausted (no provider claimed the token) +// infrastructure → transient provider error (network, etc.) +func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Request, *auth.Identity, error, string) { + if auditLogger == nil { + return nil + } + return func(req *http.Request, id *auth.Identity, err error, tokenKind string) { + correlationID := coreruntime.CorrelationIDFromContext(req.Context()) + + if err == nil && id != nil { + // Success → auth_verify. + fields := map[string]any{ + "provider": id.Source, + "user_id": id.UserID, + "org_id": id.OrgID, + "groups_count": len(id.Groups), + "token_kind": tokenKind, + "method": req.Method, + "path": req.URL.Path, + "remote_addr": req.RemoteAddr, } auditLogger.Emit(coreruntime.AuditEvent{ - Event: event, - Fields: map[string]any{ - "method": req.Method, - "path": req.URL.Path, - "remote_addr": req.RemoteAddr, - }, + Event: coreruntime.EventAuthVerify, + CorrelationID: correlationID, + Fields: fields, }) - }, - }, nil + return + } + + // Failure → auth_fail with reason code. + auditLogger.Emit(coreruntime.AuditEvent{ + Event: coreruntime.EventAuthFail, + CorrelationID: correlationID, + Fields: map[string]any{ + "reason": authFailReason(err), + "token_kind": tokenKind, + "method": req.Method, + "path": req.URL.Path, + "remote_addr": req.RemoteAddr, + }, + }) + } +} + +// authFailReason maps a chain error to a stable, low-cardinality reason +// code suitable for dashboarding and alerting. Reason strings are part +// of the audit-event contract — changing them is a breaking change for +// downstream consumers. +func authFailReason(err error) string { + switch { + case err == nil: + return "unknown" + case errors.Is(err, auth.ErrMissingBearer): + return "missing_token" + case errors.Is(err, auth.ErrTokenRejected): + return "rejected" + case errors.Is(err, auth.ErrInvalidToken): + return "invalid" + case errors.Is(err, auth.ErrTokenNotForMe): + return "not_for_me" + default: + return "infrastructure" + } } // buildUserAuthChain returns the user-facing portion of the auth chain diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index cf7eb51..4cfe7ab 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -39,10 +39,29 @@ type MiddlewareOptions struct { SkipPaths map[string]bool // OnAuth is an optional callback invoked on every auth decision. - // success is true when the request authenticated; false otherwise. - OnAuth func(r *http.Request, success bool) + // + // - identity is non-nil and err is nil on success. + // - identity is nil and err carries the chain error on failure + // (or auth.ErrMissingBearer when the header was absent). + // - tokenKind is "jwt", "opaque", or "empty" — structural metadata + // safe to log. The token itself is NOT passed; callers must not + // try to recover it from the request. + // + // Callbacks should be cheap — they run on the request hot path. + OnAuth func(r *http.Request, identity *Identity, err error, tokenKind string) } +// ErrMissingBearer is returned (via OnAuth) when the request lacked an +// Authorization: Bearer ... header. Distinct from chain-level errors so +// callers can emit a precise "missing_token" reason code without parsing +// error strings. +var ErrMissingBearer = errorString("auth: missing bearer token") + +// errorString is a sentinel-friendly error type that compares by identity. +type errorString string + +func (e errorString) Error() string { return string(e) } + // Middleware returns an http.Handler that enforces bearer token authentication // via the provided Provider chain. If opts.Chain is nil, requests pass through // without checks (anonymous access). @@ -62,20 +81,22 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { } token := extractBearerToken(r) + kind := TokenKind(token) + if token == "" { - writeAuthFail(w, r, opts.OnAuth, "valid bearer token required") + notifyAuth(opts.OnAuth, r, nil, ErrMissingBearer, kind) + writeAuthError(w, "valid bearer token required") return } identity, err := opts.Chain.Verify(r.Context(), token, HeadersFromRequest(r)) if err != nil || identity == nil { - writeAuthFail(w, r, opts.OnAuth, classifyAuthFailure(err)) + notifyAuth(opts.OnAuth, r, nil, err, kind) + writeAuthError(w, classifyAuthFailure(err)) return } - if opts.OnAuth != nil { - opts.OnAuth(r, true) - } + notifyAuth(opts.OnAuth, r, identity, nil, kind) ctx := WithIdentity(r.Context(), identity) next.ServeHTTP(w, r.WithContext(ctx)) @@ -83,6 +104,15 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { } } +// notifyAuth invokes the OnAuth callback if set, swallowing the nil check +// at the call sites so the main middleware body stays readable. +func notifyAuth(cb func(*http.Request, *Identity, error, string), r *http.Request, id *Identity, err error, kind string) { + if cb == nil { + return + } + cb(r, id, err, kind) +} + // classifyAuthFailure maps a chain error into a user-visible message. // Keep messages generic to avoid leaking information about which provider // rejected the token or why. @@ -99,11 +129,10 @@ func classifyAuthFailure(err error) string { } } -// writeAuthFail sends a 401 response and fires the OnAuth callback if set. -func writeAuthFail(w http.ResponseWriter, r *http.Request, onAuth func(*http.Request, bool), msg string) { - if onAuth != nil { - onAuth(r, false) - } +// writeAuthError sends a 401 JSON response. The OnAuth callback is fired +// separately (via notifyAuth) so the audit-emission path can run with the +// full Identity / error context, not just a bool. +func writeAuthError(w http.ResponseWriter, msg string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) _ = json.NewEncoder(w).Encode(errorResponse{ diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index b18e4b5..6954617 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "sync/atomic" @@ -199,8 +200,8 @@ func TestMiddleware_OnAuthCallback(t *testing.T) { identity: Identity{UserID: "u", Source: "test"}, }), SkipPaths: DefaultSkipPaths(), - OnAuth: func(_ *http.Request, success bool) { - if success { + OnAuth: func(_ *http.Request, id *Identity, err error, _ string) { + if err == nil && id != nil { successCount.Add(1) } else { failCount.Add(1) @@ -256,6 +257,111 @@ func TestMiddleware_IdentityIsAttachedToContext(t *testing.T) { } } +func TestMiddleware_OnAuthReceivesIdentityAndError(t *testing.T) { + const goodToken = "good" + wantID := Identity{UserID: "alice", Source: "test_token"} + + var ( + gotIDOnSuccess *Identity + gotErrOnFail error + gotKindSuccess string + gotKindFail string + ) + + opts := MiddlewareOptions{ + Chain: NewChainProvider(&tokenProvider{expected: goodToken, identity: wantID}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, id *Identity, err error, kind string) { + if err == nil && id != nil { + gotIDOnSuccess = id + gotKindSuccess = kind + } else { + gotErrOnFail = err + gotKindFail = kind + } + }, + } + + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Success path — OnAuth gets the Identity, kind = "opaque" (no dots). + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer "+goodToken) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotIDOnSuccess == nil { + t.Fatal("OnAuth did not receive Identity on success") + } + if gotIDOnSuccess.UserID != "alice" { + t.Errorf("OnAuth identity = %+v, want UserID=alice", gotIDOnSuccess) + } + if gotKindSuccess != "opaque" { + t.Errorf("token kind on success = %q, want opaque", gotKindSuccess) + } + + // Failure path — no Authorization header → ErrMissingBearer. + req2 := httptest.NewRequest("POST", "/", nil) + handler.ServeHTTP(httptest.NewRecorder(), req2) + if !errors.Is(gotErrOnFail, ErrMissingBearer) { + t.Errorf("OnAuth fail err = %v, want ErrMissingBearer", gotErrOnFail) + } + if gotKindFail != "empty" { + t.Errorf("token kind on missing-token fail = %q, want empty", gotKindFail) + } +} + +func TestMiddleware_OnAuthGetsChainError(t *testing.T) { + // When the chain rejects, OnAuth receives the precise sentinel so + // the caller can map to a stable reason code. + opts := MiddlewareOptions{ + Chain: NewChainProvider(rejectingProvider{}), + SkipPaths: DefaultSkipPaths(), + } + var gotErr error + opts.OnAuth = func(_ *http.Request, _ *Identity, err error, _ string) { + gotErr = err + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer x") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if !errors.Is(gotErr, ErrTokenRejected) { + t.Errorf("OnAuth got err = %v, want ErrTokenRejected", gotErr) + } +} + +func TestMiddleware_TokenKindDetection(t *testing.T) { + // JWT-shaped tokens are reported as "jwt" even when the chain + // doesn't recognize them — this is structural, not validity. + jwtToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ4In0.sig" + + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(rejectingProvider{}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer "+jwtToken) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "jwt" { + t.Errorf("token kind = %q, want jwt", gotKind) + } +} + func TestClassifyAuthFailure(t *testing.T) { tests := []struct { name string diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go index a7b3322..9cb5123 100644 --- a/forge-core/auth/provider.go +++ b/forge-core/auth/provider.go @@ -84,6 +84,33 @@ func HeadersFromRequest(r *http.Request) Headers { } } +// TokenKind classifies a presented bearer token structurally — useful for +// audit logging without leaking the token itself. +// +// "jwt" → three base64url segments separated by dots +// "opaque" → anything else (Okta access tokens, custom verifier tokens, dev secrets) +// +// This is a CHEAP structural check — it does not parse or validate. +// Never log the token; this helper is safe to log. +func TokenKind(token string) string { + if token == "" { + return "empty" + } + dots := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + dots++ + if dots > 2 { + return "opaque" + } + } + } + if dots == 2 { + return "jwt" + } + return "opaque" +} + // Sentinel errors that Providers and the ChainProvider use to signal outcomes. // // - ErrTokenNotForMe → provider does not recognize this token shape; diff --git a/forge-core/auth/provider_test.go b/forge-core/auth/provider_test.go index cadcaa1..c175b69 100644 --- a/forge-core/auth/provider_test.go +++ b/forge-core/auth/provider_test.go @@ -450,6 +450,30 @@ func TestUnmarshalSettings_NilOut(t *testing.T) { } } +// --- TokenKind --- + +func TestTokenKind(t *testing.T) { + tests := []struct { + name string + token string + want string + }{ + {"empty", "", "empty"}, + {"jwt three segments", "eyJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJ4In0.signature", "jwt"}, + {"opaque random", "abc123xyz", "opaque"}, + {"opaque one dot", "abc.def", "opaque"}, + {"opaque four segments", "a.b.c.d", "opaque"}, + {"jwt with empty segments still has 2 dots", "..", "jwt"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TokenKind(tt.token); got != tt.want { + t.Errorf("TokenKind(%q) = %q, want %q", tt.token, got, tt.want) + } + }) + } +} + // --- helpers --- func providerNames(ps []Provider) []string { diff --git a/forge-core/runtime/audit.go b/forge-core/runtime/audit.go index f6af274..f8e26bb 100644 --- a/forge-core/runtime/audit.go +++ b/forge-core/runtime/audit.go @@ -23,8 +23,19 @@ const ( AuditScheduleComplete = "schedule_complete" AuditScheduleSkip = "schedule_skip" AuditScheduleModify = "schedule_modify" - AuditAuthSuccess = "auth_success" - AuditAuthFailure = "auth_failure" + + // Auth events. Carry no PII (no email, no claims, no token bytes) — + // only the audit subject (UserID), tenant (OrgID), and structural + // metadata. See the auth-package middleware for the emitter. + EventAuthVerify = "auth_verify" // every successful auth decision + EventAuthFail = "auth_fail" // every failed auth decision (with reason code) + + // Deprecated: use EventAuthVerify. Kept as a string alias so any + // audit-log consumer that grep'd for "auth_success" can be migrated. + // Scheduled for removal in v0.11.0. + AuditAuthSuccess = EventAuthVerify + // Deprecated: use EventAuthFail. Same migration window as AuditAuthSuccess. + AuditAuthFailure = EventAuthFail ) // AuditEvent is a single structured audit record emitted as NDJSON. From a894f01f61ce694ab0e157cc84ffc87201f46b4c Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 20:24:38 -0400 Subject: [PATCH 06/20] feat(init/wizard): TUI + non-interactive auth: step (PR5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the Authentication step to the `forge init` wizard so users can configure the auth chain without hand-writing forge.yaml. Same surface is exposed via non-interactive --auth* flags for CI/script use. Interactive flow (between Egress and Review steps): 🔐 Authentication ○ None — anonymous (default) ○ OIDC (JWT) — Auth0, Keycloak, Azure AD, Google, Okta-OIDC, … ○ HTTP Verifier — legacy external /verify endpoint ○ Custom — write a commented stub, edit forge.yaml manually OIDC sub-flow: Issuer → Audience → Groups Claim (defaults to "groups") HTTP Verifier sub-flow: URL → Default org_id (optional) None / Custom: pick and continue, no inputs Non-interactive flags: --auth=none|oidc|http_verifier|custom --auth-issuer= (required with --auth=oidc) --auth-audience= (required with --auth=oidc) --auth-url= (required with --auth=http_verifier) --auth-default-org= (http_verifier, optional) --auth-groups-claim= (oidc, default: groups) Scaffold output: - Adds auth: block to forge.yaml (alphabetical keys for stable diffs) - Auto-merges the issuer/verifier host into egress.allowed_domains - For custom mode: writes a commented-out template stub instead - Settings are pre-rendered as YAML in Go (init_auth.go) so nested maps like claim_map: {groups: roles} emit correctly without template helpers UX details: - "None" is the first option; one Enter to skip the entire step - Backspace on empty input returns to the picker (consistent with other wizard steps) - Live validation on issuer/URL fields: must be parseable + http(s) - The wizard's existing egress preview displays the auth-host auto-add in the final review step No runtime behavior change — PR3 already wired forge.yaml auth: into the runner. This PR is pure UX: easier to author the YAML. Tests (TUI step + scaffold helpers): - All 4 mode flows: None (default), Custom (stub), OIDC (full + custom groups claim), HTTP Verifier (with/without default org) - Issuer trailing-slash trimmed; groups claim suppressed at default - AuthEgressHosts populated correctly per mode - Apply() correctness for each mode - Summary text per mode - renderAuthBlock: empty/none/custom/oidc/http_verifier shapes, nested claim_map, deterministic key ordering, name suppression when same as type - buildAuthFromFlags: happy paths, missing required fields, unknown modes, claim-map suppression at default - mergeEgressDomains: dedupe + order preservation - End-to-end: `forge init --auth=oidc` produces a forge.yaml that passes `forge validate` - E2E negative: `--auth=oidc` without --auth-issuer fails with the expected error message No regressions: - All forge-core packages pass (28/28) - All forge-cli packages pass (15/15) - All forge-plugins packages pass (4/4) - golangci-lint v2.10.1: 0 issues - gofmt + go vet clean --- forge-cli/cmd/init.go | 56 ++- forge-cli/cmd/init_auth.go | 161 +++++++ forge-cli/cmd/init_auth_test.go | 267 ++++++++++++ forge-cli/cmd/init_egress.go | 27 ++ forge-cli/internal/tui/steps/step_auth.go | 396 ++++++++++++++++++ .../internal/tui/steps/step_auth_test.go | 313 ++++++++++++++ forge-cli/internal/tui/wizard.go | 21 + forge-cli/templates/init/forge.yaml.tmpl | 4 + 8 files changed, 1242 insertions(+), 3 deletions(-) create mode 100644 forge-cli/cmd/init_auth.go create mode 100644 forge-cli/cmd/init_auth_test.go create mode 100644 forge-cli/internal/tui/steps/step_auth.go create mode 100644 forge-cli/internal/tui/steps/step_auth_test.go diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index 480c971..2a4ac52 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -45,6 +45,11 @@ type initOptions struct { Force bool // overwrite existing directory CustomModel string // custom provider model name AuthMethod string // "apikey" or "oauth" + + // A2A auth chain (from the wizard's Authentication step or CLI flags). + AuthMode string // "", "none", "oidc", "http_verifier", "custom" + AuthSettings map[string]any // provider-specific settings block + AuthEgressHosts []string // hosts to merge into egress allowlist } // toolEntry represents a tool parsed from a skills file. @@ -71,6 +76,12 @@ type templateData struct { EgressDomains []string EnvVars []envVarEntry HasSecrets bool + + // Auth chain rendering (see forge.yaml.tmpl). Pre-rendered as a YAML + // fragment because nested maps in the settings block (e.g. claim_map) + // would otherwise require non-trivial template helpers. Empty when + // the user picked "none" or skipped the auth step. + AuthBlock string } // fallbackTmplData holds template data for a fallback provider. @@ -121,6 +132,15 @@ func init() { initCmd.Flags().String("org-id", "", "OpenAI organization ID (enterprise)") initCmd.Flags().StringSlice("fallbacks", nil, "fallback LLM providers (e.g., openai,gemini)") initCmd.Flags().Bool("force", false, "overwrite existing directory") + + // Auth chain (PR5+). All optional. When --auth is unset or "none", no + // auth: block is written and the agent runs anonymously. + initCmd.Flags().String("auth", "", "auth mode: none, oidc, http_verifier, custom") + initCmd.Flags().String("auth-issuer", "", "OIDC issuer URL (required with --auth=oidc)") + initCmd.Flags().String("auth-audience", "", "OIDC audience (required with --auth=oidc)") + initCmd.Flags().String("auth-url", "", "verifier URL (required with --auth=http_verifier)") + initCmd.Flags().String("auth-default-org", "", "default org_id for http_verifier (optional)") + initCmd.Flags().String("auth-groups-claim", "", "claim name for groups (oidc, default: groups)") } func runInit(cmd *cobra.Command, args []string) error { @@ -155,6 +175,18 @@ func runInit(cmd *cobra.Command, args []string) error { opts.NonInteractive = nonInteractive opts.Force, _ = cmd.Flags().GetBool("force") + // Auth chain flags. + authMode, _ := cmd.Flags().GetString("auth") + if authMode != "" { + settings, hosts, err := buildAuthFromFlags(cmd, authMode) + if err != nil { + return err + } + opts.AuthMode = authMode + opts.AuthSettings = settings + opts.AuthEgressHosts = hosts + } + // TTY detection: require a terminal for interactive mode if !nonInteractive && !term.IsTerminal(int(os.Stdout.Fd())) { return fmt.Errorf("interactive mode requires a terminal; use --non-interactive") @@ -258,6 +290,7 @@ func collectInteractive(opts *initOptions) error { steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn), steps.NewSkillsStep(styles, skillInfos), steps.NewEgressStep(styles, deriveEgressFn), + steps.NewAuthStep(styles), steps.NewReviewStep(styles), // scaffold is handled by the caller after collectInteractive returns } @@ -327,11 +360,18 @@ func collectInteractive(opts *initOptions) error { opts.EnvVars["MODEL_API_KEY"] = ctx.CustomAPIKey } - // Store egress domains - if len(ctx.EgressDomains) > 0 { - opts.EnvVars["__egress_domains"] = strings.Join(ctx.EgressDomains, ",") + // Store egress domains, merging in any auth-provider hosts so the + // agent can reach its IdP / verifier at runtime. + mergedEgress := mergeEgressDomains(ctx.EgressDomains, ctx.AuthEgressHosts) + if len(mergedEgress) > 0 { + opts.EnvVars["__egress_domains"] = strings.Join(mergedEgress, ",") } + // Auth chain selection. + opts.AuthMode = ctx.AuthMode + opts.AuthSettings = ctx.AuthSettings + opts.AuthEgressHosts = ctx.AuthEgressHosts + // Check skill requirements checkSkillRequirements(opts) @@ -1002,6 +1042,16 @@ func buildTemplateData(opts *initOptions) templateData { data.EgressDomains = strings.Split(stored, ",") } + // Merge any auth-provider hosts (PR5 wizard / --auth flags) into the + // allowlist so the agent's runtime can reach its IdP/verifier. + if len(opts.AuthEgressHosts) > 0 { + data.EgressDomains = mergeEgressDomains(data.EgressDomains, opts.AuthEgressHosts) + } + + // Auth chain rendering: pre-render the YAML fragment in Go so nested + // maps (claim_map) emit correctly. + data.AuthBlock = renderAuthBlock(opts.AuthMode, opts.AuthMode, opts.AuthSettings) + // Build env vars data.EnvVars = buildEnvVars(opts) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go new file mode 100644 index 0000000..dd3f710 --- /dev/null +++ b/forge-cli/cmd/init_auth.go @@ -0,0 +1,161 @@ +package cmd + +import ( + "fmt" + "net/url" + "sort" + "strings" + + "github.com/spf13/cobra" +) + +// buildAuthFromFlags reads the --auth* flag set into a settings map + +// the egress hosts the runtime will need. Validates required-field +// combinations; rejects unknown modes with a clear error. +func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]any, egressHosts []string, err error) { + switch mode { + case "none", "custom": + return nil, nil, nil + case "oidc": + issuer, _ := cmd.Flags().GetString("auth-issuer") + audience, _ := cmd.Flags().GetString("auth-audience") + groupsClaim, _ := cmd.Flags().GetString("auth-groups-claim") + if issuer == "" || audience == "" { + return nil, nil, fmt.Errorf("--auth=oidc requires --auth-issuer and --auth-audience") + } + issuer = strings.TrimRight(issuer, "/") + settings = map[string]any{ + "issuer": issuer, + "audience": audience, + } + if groupsClaim != "" && groupsClaim != "groups" { + settings["claim_map"] = map[string]any{"groups": groupsClaim} + } + if host := hostFromURL(issuer); host != "" { + egressHosts = []string{host} + } + return settings, egressHosts, nil + case "http_verifier": + verifierURL, _ := cmd.Flags().GetString("auth-url") + defaultOrg, _ := cmd.Flags().GetString("auth-default-org") + if verifierURL == "" { + return nil, nil, fmt.Errorf("--auth=http_verifier requires --auth-url") + } + settings = map[string]any{"url": verifierURL} + if defaultOrg != "" { + settings["default_org"] = defaultOrg + } + if host := hostFromURL(verifierURL); host != "" { + egressHosts = []string{host} + } + return settings, egressHosts, nil + default: + return nil, nil, fmt.Errorf("unknown --auth value %q (supported: none, oidc, http_verifier, custom)", mode) + } +} + +// hostFromURL extracts the bare host (no port) from a URL string. Returns +// "" if the URL is malformed — validation happens in the wizard / +// buildAuthFromFlags above. +func hostFromURL(raw string) string { + if raw == "" { + return "" + } + u, perr := url.Parse(raw) + if perr != nil || u.Host == "" { + return "" + } + return u.Hostname() +} + +// renderAuthBlock returns a forge.yaml `auth:` fragment for the given +// provider type and settings. The output starts at column 0 with no +// trailing newline so the surrounding template controls spacing. +// +// Three shapes are produced: +// +// mode == "" → empty string (skip) +// mode == "none" → empty string (anonymous) +// mode == "custom" +// ↓ +// # Auth provider chain — configure here. +// # auth: +// # required: true +// # providers: [...] +// +// mode == "oidc" | "http_verifier" | "static_token" +// ↓ +// auth: +// required: true +// providers: +// - type: +// name: +// settings: +// key: value +// nested: +// k: v +// +// Settings keys are emitted in alphabetical order so the output is +// deterministic across runs (useful for diffing generated files). +func renderAuthBlock(mode, name string, settings map[string]any) string { + switch mode { + case "", "none": + return "" + case "custom": + return customAuthStub() + } + + var b strings.Builder + b.WriteString("auth:\n") + b.WriteString(" required: true\n") + b.WriteString(" providers:\n") + fmt.Fprintf(&b, " - type: %s\n", mode) + if name != "" && name != mode { + fmt.Fprintf(&b, " name: %s\n", name) + } + if len(settings) > 0 { + b.WriteString(" settings:\n") + writeYAMLMap(&b, settings, " ") + } + // Trim the final newline so the template controls spacing. + return strings.TrimRight(b.String(), "\n") +} + +// customAuthStub returns the commented-out template a user gets when they +// pick "Custom" in the wizard. +func customAuthStub() string { + return strings.Join([]string{ + "# Auth provider chain — configure here. See useforge.ai/docs/auth", + "# Supported types: oidc, http_verifier, static_token", + "# auth:", + "# required: true", + "# providers:", + "# - type: oidc", + "# settings:", + "# issuer: https://login.example.com", + "# audience: api://forge", + }, "\n") +} + +// writeYAMLMap renders a `map[string]any` as YAML lines, recursing into +// nested maps. Only string / number / bool / map values are supported — +// the auth-settings schema doesn't use anything else. +func writeYAMLMap(b *strings.Builder, m map[string]any, indent string) { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + v := m[k] + switch val := v.(type) { + case map[string]any: + fmt.Fprintf(b, "%s%s:\n", indent, k) + writeYAMLMap(b, val, indent+" ") + case string: + fmt.Fprintf(b, "%s%s: %s\n", indent, k, val) + default: + fmt.Fprintf(b, "%s%s: %v\n", indent, k, val) + } + } +} diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go new file mode 100644 index 0000000..094caca --- /dev/null +++ b/forge-cli/cmd/init_auth_test.go @@ -0,0 +1,267 @@ +package cmd + +import ( + "reflect" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +// --- renderAuthBlock --- + +func TestRenderAuthBlock_EmptyAndNone(t *testing.T) { + if got := renderAuthBlock("", "", nil); got != "" { + t.Errorf("mode='' → %q, want empty", got) + } + if got := renderAuthBlock("none", "", nil); got != "" { + t.Errorf("mode='none' → %q, want empty", got) + } +} + +func TestRenderAuthBlock_Custom(t *testing.T) { + got := renderAuthBlock("custom", "", nil) + if !strings.HasPrefix(got, "# Auth provider chain") { + t.Errorf("custom mode should start with comment header, got: %s", got) + } + if !strings.Contains(got, "oidc") { + t.Errorf("custom stub should mention oidc, got: %s", got) + } +} + +func TestRenderAuthBlock_OIDC(t *testing.T) { + got := renderAuthBlock("oidc", "corporate-sso", map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }) + + wantLines := []string{ + "auth:", + " required: true", + " providers:", + " - type: oidc", + " name: corporate-sso", + " settings:", + " audience: api://forge", + " issuer: https://login.example.com", + } + for _, line := range wantLines { + if !strings.Contains(got, line) { + t.Errorf("missing line %q in:\n%s", line, got) + } + } +} + +func TestRenderAuthBlock_NameSameAsModeSuppressed(t *testing.T) { + // When the name equals the mode (the default the wizard supplies), + // don't emit a redundant `name:` line. + got := renderAuthBlock("oidc", "oidc", map[string]any{ + "issuer": "https://x", + "audience": "y", + }) + if strings.Contains(got, "name: oidc") { + t.Errorf("expected name: line suppressed when same as type, got:\n%s", got) + } +} + +func TestRenderAuthBlock_NestedClaimMap(t *testing.T) { + got := renderAuthBlock("oidc", "", map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + "claim_map": map[string]any{"groups": "roles"}, + }) + + if !strings.Contains(got, "claim_map:") { + t.Errorf("missing claim_map:\n%s", got) + } + if !strings.Contains(got, " groups: roles") { + t.Errorf("missing nested groups: roles (check indentation):\n%s", got) + } +} + +func TestRenderAuthBlock_HTTPVerifier(t *testing.T) { + got := renderAuthBlock("http_verifier", "", map[string]any{ + "url": "https://verify.example.com", + "default_org": "acme", + }) + if !strings.Contains(got, "- type: http_verifier") { + t.Errorf("missing http_verifier type:\n%s", got) + } + if !strings.Contains(got, "default_org: acme") { + t.Errorf("missing default_org:\n%s", got) + } +} + +func TestRenderAuthBlock_DeterministicOrdering(t *testing.T) { + // Settings keys should always emit in alphabetical order so diffs + // of generated files are stable. + got1 := renderAuthBlock("oidc", "", map[string]any{ + "audience": "a", + "issuer": "i", + }) + got2 := renderAuthBlock("oidc", "", map[string]any{ + "issuer": "i", + "audience": "a", + }) + if got1 != got2 { + t.Errorf("output is not deterministic across map ordering") + } + // audience must come before issuer (alphabetical). + audIdx := strings.Index(got1, "audience:") + issIdx := strings.Index(got1, "issuer:") + if audIdx == -1 || issIdx == -1 || audIdx > issIdx { + t.Errorf("audience should appear before issuer; got:\n%s", got1) + } +} + +// --- buildAuthFromFlags --- + +func mockInitCmdWithFlags() *cobra.Command { + c := &cobra.Command{Use: "init"} + c.Flags().String("auth", "", "") + c.Flags().String("auth-issuer", "", "") + c.Flags().String("auth-audience", "", "") + c.Flags().String("auth-url", "", "") + c.Flags().String("auth-default-org", "", "") + c.Flags().String("auth-groups-claim", "", "") + return c +} + +func TestBuildAuthFromFlags_None(t *testing.T) { + c := mockInitCmdWithFlags() + settings, hosts, err := buildAuthFromFlags(c, "none") + if err != nil { + t.Fatal(err) + } + if settings != nil || hosts != nil { + t.Errorf("none should return nils, got %v / %v", settings, hosts) + } +} + +func TestBuildAuthFromFlags_OIDC_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://login.example.com/") + _ = c.Flags().Set("auth-audience", "api://forge") + + settings, hosts, err := buildAuthFromFlags(c, "oidc") + if err != nil { + t.Fatal(err) + } + if settings["issuer"] != "https://login.example.com" { + t.Errorf("trailing slash should be stripped from issuer; got %v", settings["issuer"]) + } + if settings["audience"] != "api://forge" { + t.Errorf("audience wrong: %v", settings["audience"]) + } + if _, ok := settings["claim_map"]; ok { + t.Errorf("claim_map should not be set without --auth-groups-claim") + } + if !reflect.DeepEqual(hosts, []string{"login.example.com"}) { + t.Errorf("hosts = %v, want [login.example.com]", hosts) + } +} + +func TestBuildAuthFromFlags_OIDC_CustomGroupsClaim(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://x") + _ = c.Flags().Set("auth-audience", "y") + _ = c.Flags().Set("auth-groups-claim", "roles") + + settings, _, err := buildAuthFromFlags(c, "oidc") + if err != nil { + t.Fatal(err) + } + cm, ok := settings["claim_map"].(map[string]any) + if !ok { + t.Fatalf("expected claim_map, got %v", settings["claim_map"]) + } + if cm["groups"] != "roles" { + t.Errorf("claim_map.groups = %v, want roles", cm["groups"]) + } +} + +func TestBuildAuthFromFlags_OIDC_DefaultGroupsClaimSuppressed(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://x") + _ = c.Flags().Set("auth-audience", "y") + _ = c.Flags().Set("auth-groups-claim", "groups") // same as default + + settings, _, err := buildAuthFromFlags(c, "oidc") + if err != nil { + t.Fatal(err) + } + if _, ok := settings["claim_map"]; ok { + t.Errorf("default groups claim should be suppressed") + } +} + +func TestBuildAuthFromFlags_OIDC_MissingFields(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://x") + // missing audience + _, _, err := buildAuthFromFlags(c, "oidc") + if err == nil { + t.Error("expected error for missing audience") + } +} + +func TestBuildAuthFromFlags_HTTPVerifier_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-url", "https://verify.example.com/verify") + _ = c.Flags().Set("auth-default-org", "acme") + + settings, hosts, err := buildAuthFromFlags(c, "http_verifier") + if err != nil { + t.Fatal(err) + } + if settings["url"] != "https://verify.example.com/verify" { + t.Errorf("url wrong: %v", settings["url"]) + } + if settings["default_org"] != "acme" { + t.Errorf("default_org wrong: %v", settings["default_org"]) + } + if !reflect.DeepEqual(hosts, []string{"verify.example.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_HTTPVerifier_MissingURL(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "http_verifier") + if err == nil { + t.Error("expected error for missing --auth-url") + } +} + +func TestBuildAuthFromFlags_UnknownMode(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "ldap") + if err == nil { + t.Error("expected error for unknown mode") + } +} + +// --- mergeEgressDomains --- + +func TestMergeEgressDomains_Dedupe(t *testing.T) { + got := mergeEgressDomains( + []string{"a", "b", "c"}, + []string{"b", "d"}, + ) + want := []string{"a", "b", "c", "d"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestMergeEgressDomains_EmptyInputs(t *testing.T) { + if got := mergeEgressDomains(nil, nil); len(got) != 0 { + t.Errorf("got %v, want empty", got) + } + if got := mergeEgressDomains([]string{"a"}, nil); !reflect.DeepEqual(got, []string{"a"}) { + t.Errorf("got %v", got) + } + if got := mergeEgressDomains(nil, []string{"b"}); !reflect.DeepEqual(got, []string{"b"}) { + t.Errorf("got %v", got) + } +} diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index 77fc237..d2d7eff 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -15,6 +15,33 @@ var providerDomains = map[string]string{ // ollama is local, no egress needed } +// mergeEgressDomains returns the union of `base` and `extra`, preserving +// the order of `base` and appending only previously-unseen items from +// `extra`. The result is deduplicated and stable. +// +// Used by the init scaffold to fold auth-provider hosts (from the wizard +// or --auth flags) into the egress allowlist without disturbing the +// existing ordering. +func mergeEgressDomains(base, extra []string) []string { + seen := make(map[string]bool, len(base)+len(extra)) + out := make([]string, 0, len(base)+len(extra)) + for _, d := range base { + if d == "" || seen[d] { + continue + } + seen[d] = true + out = append(out, d) + } + for _, d := range extra { + if d == "" || seen[d] { + continue + } + seen[d] = true + out = append(out, d) + } + return out +} + // deriveEgressDomains computes the full set of egress domains needed based on // the provider, channels, builtin tools, and selected registry skills. func deriveEgressDomains(opts *initOptions, skills []contract.SkillDescriptor) []string { diff --git a/forge-cli/internal/tui/steps/step_auth.go b/forge-cli/internal/tui/steps/step_auth.go new file mode 100644 index 0000000..1e71de0 --- /dev/null +++ b/forge-cli/internal/tui/steps/step_auth.go @@ -0,0 +1,396 @@ +package steps + +import ( + "fmt" + "net/url" + "slices" + "strings" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/initializ/forge/forge-cli/internal/tui" + "github.com/initializ/forge/forge-cli/internal/tui/components" +) + +// authPhase is the internal state machine of the auth step. +type authPhase int + +const ( + authSelectPhase authPhase = iota + authOIDCIssuerPhase + authOIDCAudiencePhase + authOIDCGroupsClaimPhase + authHTTPURLPhase + authHTTPOrgPhase + authDonePhase +) + +// Auth mode constants — stored in WizardContext.AuthMode and read by the +// init scaffold to decide what to render in forge.yaml. +const ( + AuthModeNone = "none" + AuthModeOIDC = "oidc" + AuthModeHTTPVerifier = "http_verifier" + AuthModeCustom = "custom" +) + +// AuthStep is the wizard step that configures the A2A server's auth chain. +// +// UX shape: +// +// ┌─ select provider type +// │ ○ None (default) +// │ ○ OIDC +// │ ○ HTTP Verifier +// │ ○ Custom (edit forge.yaml manually) +// └─ on choice: +// OIDC → issuer → audience → groups claim → done +// HTTP Verifier → url → default org → done +// None / Custom → done immediately +// +// Per-phase: backspace on empty input returns to the picker (Esc-like). +type AuthStep struct { + styles *tui.StyleSet + phase authPhase + + selector components.SingleSelect + input components.TextInput + + mode string // selected provider type + + // OIDC settings + issuer string + audience string + groupsClaim string + + // HTTP verifier settings + httpURL string + httpOrg string + + complete bool +} + +// NewAuthStep constructs the auth step. +func NewAuthStep(styles *tui.StyleSet) *AuthStep { + items := []components.SingleSelectItem{ + { + Label: "None", + Value: AuthModeNone, + Description: "Anonymous access — no auth: block in forge.yaml", + Icon: "🔓", + }, + { + Label: "OIDC (JWT)", + Value: AuthModeOIDC, + Description: "Auth0, Keycloak, Azure AD, Google, Okta-OIDC, …", + Icon: "🔐", + }, + { + Label: "HTTP Verifier", + Value: AuthModeHTTPVerifier, + Description: "Legacy — POST tokens to your own /verify endpoint", + Icon: "🔁", + }, + { + Label: "Custom", + Value: AuthModeCustom, + Description: "Write a commented stub — I'll edit forge.yaml myself", + Icon: "✏️", + }, + } + + selector := components.NewSingleSelect( + items, + styles.Theme.Accent, + styles.Theme.Primary, + styles.Theme.Secondary, + styles.Theme.Dim, + styles.Theme.Border, + styles.Theme.ActiveBorder, + styles.Theme.ActiveBg, + styles.KbdKey, + styles.KbdDesc, + ) + + return &AuthStep{ + styles: styles, + selector: selector, + } +} + +func (s *AuthStep) Title() string { return "Authentication" } +func (s *AuthStep) Icon() string { return "🔐" } + +func (s *AuthStep) Init() tea.Cmd { return s.selector.Init() } + +func (s *AuthStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { + if s.complete { + return s, nil + } + + // Forward window-size events to whichever component is active. + if wsm, ok := msg.(tea.WindowSizeMsg); ok { + if s.phase == authSelectPhase { + updated, cmd := s.selector.Update(wsm) + s.selector = updated + return s, cmd + } + updated, cmd := s.input.Update(wsm) + s.input = updated + return s, cmd + } + + switch s.phase { + case authSelectPhase: + return s.updateSelect(msg) + case authOIDCIssuerPhase, authOIDCAudiencePhase, authOIDCGroupsClaimPhase, + authHTTPURLPhase, authHTTPOrgPhase: + return s.updateInput(msg) + } + return s, nil +} + +func (s *AuthStep) updateSelect(msg tea.Msg) (tui.Step, tea.Cmd) { + updated, cmd := s.selector.Update(msg) + s.selector = updated + if !s.selector.Done() { + return s, cmd + } + + _, val := s.selector.Selected() + s.mode = val + + switch val { + case AuthModeNone, AuthModeCustom: + // No further input — finish. + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + case AuthModeOIDC: + s.phase = authOIDCIssuerPhase + s.input = s.newTextInput( + "Issuer URL", + "https://login.example.com", + validateHTTPSURL, + ) + return s, s.input.Init() + case AuthModeHTTPVerifier: + s.phase = authHTTPURLPhase + s.input = s.newTextInput( + "Verifier URL", + "https://auth.example.com/verify", + validateHTTPSURL, + ) + return s, s.input.Init() + } + return s, cmd +} + +func (s *AuthStep) updateInput(msg tea.Msg) (tui.Step, tea.Cmd) { + // Backspace on empty input returns to the picker. + if km, ok := msg.(tea.KeyMsg); ok && km.String() == "backspace" { + if s.input.Value() == "" { + s.phase = authSelectPhase + s.mode = "" + s.selector.Reset() + return s, s.selector.Init() + } + } + + updated, cmd := s.input.Update(msg) + s.input = updated + if !s.input.Done() { + return s, cmd + } + + v := strings.TrimSpace(s.input.Value()) + switch s.phase { + case authOIDCIssuerPhase: + s.issuer = strings.TrimRight(v, "/") + s.phase = authOIDCAudiencePhase + s.input = s.newTextInput("Audience", "api://forge", validateNonEmpty) + return s, s.input.Init() + + case authOIDCAudiencePhase: + s.audience = v + s.phase = authOIDCGroupsClaimPhase + s.input = s.newTextInput( + "Groups claim (default: groups — press Enter to accept)", + "groups", + nil, // optional + ) + return s, s.input.Init() + + case authOIDCGroupsClaimPhase: + s.groupsClaim = v // may be "" + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + + case authHTTPURLPhase: + s.httpURL = v + s.phase = authHTTPOrgPhase + s.input = s.newTextInput( + "Default org_id (optional — press Enter to skip)", + "", + nil, + ) + return s, s.input.Init() + + case authHTTPOrgPhase: + s.httpOrg = v + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + } + + return s, cmd +} + +func (s *AuthStep) newTextInput(label, placeholder string, validate func(string) error) components.TextInput { + return components.NewTextInput( + label, + placeholder, + false, // no slug hint + validate, + s.styles.Theme.Accent, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) +} + +func (s *AuthStep) View(width int) string { + switch s.phase { + case authSelectPhase: + return s.selector.View(width) + case authDonePhase: + return "" + default: + return s.input.View(width) + } +} + +func (s *AuthStep) Complete() bool { return s.complete } + +func (s *AuthStep) Summary() string { + switch s.mode { + case "", AuthModeNone: + return "Anonymous (no auth)" + case AuthModeOIDC: + host := hostnameOrEmpty(s.issuer) + if host == "" { + return "OIDC" + } + return "OIDC · " + host + case AuthModeHTTPVerifier: + host := hostnameOrEmpty(s.httpURL) + if host == "" { + return "HTTP Verifier" + } + return "HTTP Verifier · " + host + case AuthModeCustom: + return "Custom (edit forge.yaml)" + } + return s.mode +} + +// Apply writes the auth selection back to the wizard context. +// +// - AuthMode is always set (defaults to "none"). +// - AuthSettings is the typed settings block ready to copy into forge.yaml. +// - AuthEgressHosts captures the outbound hosts the runner will need +// reachable; the init scaffold merges these into egress.allowed_domains. +func (s *AuthStep) Apply(ctx *tui.WizardContext) { + mode := s.mode + if mode == "" { + mode = AuthModeNone + } + ctx.AuthMode = mode + + switch mode { + case AuthModeOIDC: + settings := map[string]any{ + "issuer": s.issuer, + "audience": s.audience, + } + // Only emit claim_map if user customized it (default name = "groups"). + if s.groupsClaim != "" && s.groupsClaim != "groups" { + settings["claim_map"] = map[string]any{"groups": s.groupsClaim} + } + ctx.AuthSettings = settings + if host := hostnameOrEmpty(s.issuer); host != "" { + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, host) + } + case AuthModeHTTPVerifier: + settings := map[string]any{"url": s.httpURL} + if s.httpOrg != "" { + settings["default_org"] = s.httpOrg + } + ctx.AuthSettings = settings + if host := hostnameOrEmpty(s.httpURL); host != "" { + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, host) + } + default: + // None / Custom: nothing to attach. + ctx.AuthSettings = nil + } +} + +// --- helpers --- + +// validateHTTPSURL ensures the input parses as a URL and uses http or https. +// Errors are surfaced inline by the text input. +func validateHTTPSURL(val string) error { + if val == "" { + return fmt.Errorf("required") + } + u, err := url.Parse(val) + if err != nil { + return fmt.Errorf("not a URL: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("must start with http:// or https://") + } + if u.Host == "" { + return fmt.Errorf("missing host") + } + return nil +} + +// validateNonEmpty enforces a non-empty input. +func validateNonEmpty(val string) error { + if strings.TrimSpace(val) == "" { + return fmt.Errorf("required") + } + return nil +} + +// hostnameOrEmpty returns the bare host (no port) from a URL, or "" on +// parse failure. +func hostnameOrEmpty(raw string) string { + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil || u.Host == "" { + return "" + } + return u.Hostname() +} + +// appendUnique appends v to slice if not already present. +func appendUnique(slice []string, v string) []string { + if slices.Contains(slice, v) { + return slice + } + return append(slice, v) +} + +// doneCmd emits the wizard "step complete" message. +func doneCmd() tea.Cmd { + return func() tea.Msg { return tui.StepCompleteMsg{} } +} diff --git a/forge-cli/internal/tui/steps/step_auth_test.go b/forge-cli/internal/tui/steps/step_auth_test.go new file mode 100644 index 0000000..f2b0ed1 --- /dev/null +++ b/forge-cli/internal/tui/steps/step_auth_test.go @@ -0,0 +1,313 @@ +package steps + +import ( + "reflect" + "testing" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/initializ/forge/forge-cli/internal/tui" +) + +// newTestAuthStep wires up an AuthStep with a default StyleSet so we can +// drive it without spinning up a real terminal. +func newTestAuthStep(t *testing.T) *AuthStep { + t.Helper() + styles := tui.NewStyleSet(tui.DarkTheme) + return NewAuthStep(styles) +} + +// press injects a key event into the step and returns the updated step. +func press(t *testing.T, s *AuthStep, key string) *AuthStep { + t.Helper() + var msg tea.KeyMsg + switch key { + case "down": + msg = tea.KeyMsg{Type: tea.KeyDown} + case "up": + msg = tea.KeyMsg{Type: tea.KeyUp} + case "enter": + msg = tea.KeyMsg{Type: tea.KeyEnter} + case "backspace": + msg = tea.KeyMsg{Type: tea.KeyBackspace} + default: + msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)} + } + updated, _ := s.Update(msg) + return updated.(*AuthStep) +} + +// typeIn injects each rune of the string as a single KeyRunes message. +func typeIn(t *testing.T, s *AuthStep, text string) *AuthStep { + t.Helper() + for _, r := range text { + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}} + updated, _ := s.Update(msg) + s = updated.(*AuthStep) + } + return s +} + +// --- Title / Icon --- + +func TestAuthStep_TitleAndIcon(t *testing.T) { + s := newTestAuthStep(t) + if s.Title() == "" { + t.Error("Title is empty") + } + if s.Icon() == "" { + t.Error("Icon is empty") + } +} + +// --- "None" path completes immediately --- + +func TestAuthStep_NoneIsDefault(t *testing.T) { + s := newTestAuthStep(t) + s = press(t, s, "enter") // Confirm "None" (first item) + if !s.Complete() { + t.Error("expected Complete after selecting None") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeNone { + t.Errorf("AuthMode = %q, want %q", ctx.AuthMode, AuthModeNone) + } + if ctx.AuthSettings != nil { + t.Errorf("AuthSettings = %v, want nil", ctx.AuthSettings) + } + if len(ctx.AuthEgressHosts) != 0 { + t.Errorf("AuthEgressHosts = %v, want empty", ctx.AuthEgressHosts) + } +} + +// --- "Custom" path completes immediately with no settings --- + +func TestAuthStep_Custom(t *testing.T) { + s := newTestAuthStep(t) + // Move down 3 times to reach "Custom" (index 3). + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after selecting Custom") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeCustom { + t.Errorf("AuthMode = %q, want %q", ctx.AuthMode, AuthModeCustom) + } + if ctx.AuthSettings != nil { + t.Errorf("AuthSettings = %v, want nil for custom mode", ctx.AuthSettings) + } +} + +// --- OIDC full path --- + +func TestAuthStep_OIDC_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + + // Pick OIDC (index 1). + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authOIDCIssuerPhase { + t.Fatalf("phase = %v, want issuer", s.phase) + } + + // Type issuer. + s = typeIn(t, s, "https://login.example.com/") + s = press(t, s, "enter") + if s.phase != authOIDCAudiencePhase { + t.Fatalf("phase = %v, want audience", s.phase) + } + + // Type audience. + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + if s.phase != authOIDCGroupsClaimPhase { + t.Fatalf("phase = %v, want groups claim", s.phase) + } + + // Accept default groups claim. + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after groups claim") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeOIDC { + t.Errorf("AuthMode = %q, want oidc", ctx.AuthMode) + } + if ctx.AuthSettings["issuer"] != "https://login.example.com" { + t.Errorf("issuer = %v, want trimmed trailing slash", ctx.AuthSettings["issuer"]) + } + if ctx.AuthSettings["audience"] != "api://forge" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + if _, ok := ctx.AuthSettings["claim_map"]; ok { + t.Errorf("claim_map should NOT be emitted when groups claim is default") + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"login.example.com"}) { + t.Errorf("AuthEgressHosts = %v, want [login.example.com]", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_OIDC_CustomGroupsClaim(t *testing.T) { + s := newTestAuthStep(t) + s = press(t, s, "down") + s = press(t, s, "enter") // OIDC + s = typeIn(t, s, "https://login.example.com") + s = press(t, s, "enter") + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + s = typeIn(t, s, "roles") // non-default groups claim + s = press(t, s, "enter") + + ctx := tui.NewWizardContext() + s.Apply(ctx) + cm, ok := ctx.AuthSettings["claim_map"].(map[string]any) + if !ok { + t.Fatalf("claim_map not present or wrong type: %v", ctx.AuthSettings) + } + if cm["groups"] != "roles" { + t.Errorf("claim_map.groups = %v, want roles", cm["groups"]) + } +} + +// --- HTTP Verifier full path --- + +func TestAuthStep_HTTPVerifier_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + + // Pick HTTP Verifier (index 2). + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authHTTPURLPhase { + t.Fatalf("phase = %v, want url", s.phase) + } + + s = typeIn(t, s, "https://verify.example.com/verify") + s = press(t, s, "enter") + if s.phase != authHTTPOrgPhase { + t.Fatalf("phase = %v, want org", s.phase) + } + + // Skip optional default org. + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeHTTPVerifier { + t.Errorf("AuthMode = %q, want http_verifier", ctx.AuthMode) + } + if ctx.AuthSettings["url"] != "https://verify.example.com/verify" { + t.Errorf("url = %v", ctx.AuthSettings["url"]) + } + if _, ok := ctx.AuthSettings["default_org"]; ok { + t.Errorf("default_org should not be emitted when blank") + } +} + +func TestAuthStep_HTTPVerifier_WithDefaultOrg(t *testing.T) { + s := newTestAuthStep(t) + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") // HTTP Verifier + s = typeIn(t, s, "https://verify.example.com/verify") + s = press(t, s, "enter") + s = typeIn(t, s, "acme") + s = press(t, s, "enter") + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthSettings["default_org"] != "acme" { + t.Errorf("default_org = %v, want acme", ctx.AuthSettings["default_org"]) + } +} + +// --- Summary text --- + +func TestAuthStep_Summary(t *testing.T) { + tests := []struct { + mode string + want string + }{ + {"", "Anonymous (no auth)"}, + {AuthModeNone, "Anonymous (no auth)"}, + {AuthModeOIDC, "OIDC"}, + {AuthModeCustom, "Custom (edit forge.yaml)"}, + } + for _, tt := range tests { + t.Run(tt.mode, func(t *testing.T) { + s := newTestAuthStep(t) + s.mode = tt.mode + if got := s.Summary(); got != tt.want { + t.Errorf("Summary = %q, want %q", got, tt.want) + } + }) + } +} + +// --- Helpers --- + +func TestValidateHTTPSURL(t *testing.T) { + tests := []struct { + val string + ok bool + }{ + {"", false}, + {"https://example.com", true}, + {"http://localhost:8080", true}, + {"login.example.com", false}, // no scheme + {"ftp://example.com", false}, // wrong scheme + {"https://", false}, // no host + } + for _, tt := range tests { + t.Run(tt.val, func(t *testing.T) { + err := validateHTTPSURL(tt.val) + if tt.ok && err != nil { + t.Errorf("expected ok, got %v", err) + } + if !tt.ok && err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} + +func TestHostnameOrEmpty(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"https://login.example.com", "login.example.com"}, + {"http://localhost:8080/realms/dev", "localhost"}, + {"not a url", ""}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + if got := hostnameOrEmpty(tt.in); got != tt.want { + t.Errorf("hostnameOrEmpty(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestAppendUnique(t *testing.T) { + out := appendUnique([]string{"a", "b"}, "a") + if !reflect.DeepEqual(out, []string{"a", "b"}) { + t.Errorf("appendUnique duplicate failed: %v", out) + } + out = appendUnique([]string{"a"}, "b") + if !reflect.DeepEqual(out, []string{"a", "b"}) { + t.Errorf("appendUnique new failed: %v", out) + } +} diff --git a/forge-cli/internal/tui/wizard.go b/forge-cli/internal/tui/wizard.go index 73dd0c1..d992496 100644 --- a/forge-cli/internal/tui/wizard.go +++ b/forge-cli/internal/tui/wizard.go @@ -30,6 +30,26 @@ type WizardContext struct { CustomModel string CustomAPIKey string EnvVars map[string]string + + // AuthMode is the user's selection from the auth step: + // "" — wizard step did not run + // "none" — anonymous access (no auth: block written) + // "oidc" — OIDC provider + // "http_verifier" — legacy external verifier + // "custom" — write a commented stub; user edits manually + AuthMode string + + // AuthSettings is the type-specific settings block for the chosen + // provider, ready to plug into `auth.providers[0].settings` in + // forge.yaml. Keys: "issuer", "audience", "url", "default_org", + // "claim_map" (a nested map), etc. + AuthSettings map[string]any + + // AuthEgressHosts is the list of bare hosts (no port) that should be + // auto-added to the egress allowlist because the auth step requires + // them — e.g. an OIDC issuer or http_verifier URL. The init scaffold + // merges these into the rendered forge.yaml egress.allowed_domains. + AuthEgressHosts []string } // NewWizardContext creates an initialized WizardContext. @@ -37,6 +57,7 @@ func NewWizardContext() *WizardContext { return &WizardContext{ ChannelTokens: make(map[string]string), EnvVars: make(map[string]string), + AuthSettings: make(map[string]any), } } diff --git a/forge-cli/templates/init/forge.yaml.tmpl b/forge-cli/templates/init/forge.yaml.tmpl index 6ebf375..f7e93e4 100644 --- a/forge-cli/templates/init/forge.yaml.tmpl +++ b/forge-cli/templates/init/forge.yaml.tmpl @@ -57,3 +57,7 @@ secrets: - encrypted-file - env {{- end}} +{{- if .AuthBlock}} + +{{.AuthBlock}} +{{- end}} From fd55b142ff80861ca4b3c033100cdbd825d8cc88 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Fri, 22 May 2026 20:50:30 -0400 Subject: [PATCH 07/20] feat(ui): web wizard Authentication step (PR6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the matching Authentication step to the Web UI wizard (`forge ui`) so users without a terminal can configure the auth chain through the browser-based create-agent flow. Same backend contract as PR5 — both surfaces feed the single scaffold path that writes forge.yaml. Backend changes: - forge-ui/types.go: new AuthCreateOptions (mode + settings) field on AgentCreateOptions; new AuthProviderTypeMeta + AuthProviderTypes on WizardMetadata. - forge-ui/handlers_create.go: GET /api/wizard/meta now returns the four founding auth provider types so the frontend renders the picker from server-driven metadata (adding Okta in Phase 3 = one line here, no frontend change). - forge-cli/cmd/ui.go: createFunc translates opts.Auth into the initOptions fields the scaffold expects (AuthMode, AuthSettings, AuthEgressHosts). - forge-cli/cmd/init_auth.go: new authEgressHostsFromSettings helper extracts issuer/verifier hosts for the Web UI path (same hosts the TUI wizard and --auth flags compute). Frontend changes (forge-ui/static/dist/app.js): - WIZARD_STEPS extended from 9 to 10 entries (Authentication slotted between Env Vars and Review). - form.auth = {mode, settings} added to initial state. - Step renders a radio group fed by meta.auth_provider_types, with conditional OIDC / HTTP Verifier sub-forms. - canNext validation: OIDC requires issuer + audience; HTTP Verifier requires url. - buildAuthPayload reshapes the wizard's flat groups_claim into the nested claim_map: {groups: ...} shape before submission; strips empty optional fields and trims trailing slashes from issuer. - Review step adds an "A2A Auth" row with a one-line summary. UX details: - None is the default — anonymous unless the user picks otherwise. - Switching modes resets settings so stale fields from a previously selected mode don't leak into the submitted payload. - Custom mode writes a commented stub the user can edit later. Backward compat: - Requests without `auth` continue to work (anonymous access). - AuthCreateOptions.Settings uses `map[string]any` so future provider types add fields server-side without breaking older payloads. Tests: - handleGetWizardMeta now asserts all 4 auth provider types are exposed with non-empty label + description. - New TestHandleCreateAgent_WithAuthPayload: AgentCreateOptions.Auth round-trips through the JSON boundary with correct settings. - New TestHandleCreateAgent_WithoutAuthPayload: omitted auth field keeps Auth nil (no spurious anonymous-mode payload). - New authEgressHostsFromSettings tests: oidc, oidc+jwks_url, http_verifier, no-network modes, nil settings, malformed URL. - End-to-end smoke (manual): forge ui POST /api/agents with auth payload → generated forge.yaml has auth: block + login.example.com auto-added to egress.allowed_domains. No regressions: - All 4 modules pass (forge-core, forge-cli, forge-plugins, forge-ui) - go test -race clean on touched packages - golangci-lint v2.10.1: 0 issues - gofmt + go vet clean --- forge-cli/cmd/init_auth.go | 38 +++++++ forge-cli/cmd/init_auth_test.go | 55 ++++++++++ forge-cli/cmd/ui.go | 9 ++ forge-ui/handlers_create.go | 10 ++ forge-ui/handlers_create_test.go | 97 +++++++++++++++++ forge-ui/static/dist/app.js | 172 ++++++++++++++++++++++++++++++- forge-ui/types.go | 22 ++++ 7 files changed, 400 insertions(+), 3 deletions(-) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index dd3f710..48197eb 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -54,6 +54,44 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an } } +// authEgressHostsFromSettings extracts the outbound hosts implied by an +// auth provider's settings block. Used by the Web UI path (cmd/ui.go) to +// derive the same egress hosts the TUI wizard / --auth flags compute. +// +// Returns nil for modes that don't require outbound (none, static_token, +// custom) or for malformed/missing URLs. +func authEgressHostsFromSettings(mode string, settings map[string]any) []string { + if settings == nil { + return nil + } + var hosts []string + switch mode { + case "oidc": + if h := hostFromURL(asStringSetting(settings, "issuer")); h != "" { + hosts = append(hosts, h) + } + if h := hostFromURL(asStringSetting(settings, "jwks_url")); h != "" { + hosts = append(hosts, h) + } + case "http_verifier": + if h := hostFromURL(asStringSetting(settings, "url")); h != "" { + hosts = append(hosts, h) + } + } + return hosts +} + +// asStringSetting reads a string-valued setting, returning "" for missing +// or non-string values. +func asStringSetting(m map[string]any, key string) string { + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + // hostFromURL extracts the bare host (no port) from a URL string. Returns // "" if the URL is malformed — validation happens in the wizard / // buildAuthFromFlags above. diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go index 094caca..b883d24 100644 --- a/forge-cli/cmd/init_auth_test.go +++ b/forge-cli/cmd/init_auth_test.go @@ -241,6 +241,61 @@ func TestBuildAuthFromFlags_UnknownMode(t *testing.T) { } } +// --- authEgressHostsFromSettings (PR6) --- + +func TestAuthEgressHostsFromSettings_OIDC(t *testing.T) { + got := authEgressHostsFromSettings("oidc", map[string]any{ + "issuer": "https://login.example.com/realms/x", + "audience": "api://forge", + }) + if !reflect.DeepEqual(got, []string{"login.example.com"}) { + t.Errorf("got %v, want [login.example.com]", got) + } +} + +func TestAuthEgressHostsFromSettings_OIDCWithJWKSURL(t *testing.T) { + got := authEgressHostsFromSettings("oidc", map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + "jwks_url": "https://keys.example.com/jwks", + }) + want := []string{"login.example.com", "keys.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestAuthEgressHostsFromSettings_HTTPVerifier(t *testing.T) { + got := authEgressHostsFromSettings("http_verifier", map[string]any{ + "url": "https://verify.example.com/verify", + }) + if !reflect.DeepEqual(got, []string{"verify.example.com"}) { + t.Errorf("got %v, want [verify.example.com]", got) + } +} + +func TestAuthEgressHostsFromSettings_NoNetworkModes(t *testing.T) { + for _, mode := range []string{"none", "custom", "static_token", "unknown"} { + if got := authEgressHostsFromSettings(mode, map[string]any{"x": "y"}); got != nil { + t.Errorf("mode %q returned %v, want nil", mode, got) + } + } +} + +func TestAuthEgressHostsFromSettings_NilSettings(t *testing.T) { + if got := authEgressHostsFromSettings("oidc", nil); got != nil { + t.Errorf("got %v, want nil for nil settings", got) + } +} + +func TestAuthEgressHostsFromSettings_MalformedURL(t *testing.T) { + if got := authEgressHostsFromSettings("oidc", map[string]any{ + "issuer": "::not a url", + }); got != nil { + t.Errorf("got %v, want nil for malformed URL", got) + } +} + // --- mergeEgressDomains --- func TestMergeEgressDomains_Dedupe(t *testing.T) { diff --git a/forge-cli/cmd/ui.go b/forge-cli/cmd/ui.go index a77aed5..4773287 100644 --- a/forge-cli/cmd/ui.go +++ b/forge-cli/cmd/ui.go @@ -112,6 +112,15 @@ func runUI(cmd *cobra.Command, args []string) error { _ = os.Setenv("FORGE_PASSPHRASE", opts.Passphrase) } + // Web UI auth chain selection (PR6). Translate the wizard payload + // into the same fields the TUI wizard / non-interactive flags use, + // so scaffold() has a single source of truth. + if opts.Auth != nil && opts.Auth.Mode != "" { + initOpts.AuthMode = opts.Auth.Mode + initOpts.AuthSettings = opts.Auth.Settings + initOpts.AuthEgressHosts = authEgressHostsFromSettings(opts.Auth.Mode, opts.Auth.Settings) + } + storeProviderEnvVar(initOpts) checkSkillRequirements(initOpts) diff --git a/forge-ui/handlers_create.go b/forge-ui/handlers_create.go index 781bc36..d668e73 100644 --- a/forge-ui/handlers_create.go +++ b/forge-ui/handlers_create.go @@ -118,6 +118,16 @@ func (s *UIServer) handleGetWizardMeta(w http.ResponseWriter, _ *http.Request) { } } + // Auth provider types — server-driven list. Adding a new provider + // (e.g., Okta in Phase 3) means appending one entry here; the + // frontend renders the picker from this metadata. + meta.AuthProviderTypes = []AuthProviderTypeMeta{ + {Type: "none", Label: "None", Description: "Anonymous access — no auth: block written"}, + {Type: "oidc", Label: "OIDC (JWT)", Description: "Auth0, Keycloak, Azure AD, Google, Okta-OIDC, …"}, + {Type: "http_verifier", Label: "HTTP Verifier", Description: "Legacy — POST tokens to your own /verify endpoint"}, + {Type: "custom", Label: "Custom", Description: "Write a commented stub, edit forge.yaml manually"}, + } + writeJSON(w, http.StatusOK, meta) } diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index 031c710..9ae8671 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -385,4 +385,101 @@ func TestHandleGetWizardMeta(t *testing.T) { if len(meta.WebSearchProviders) == 0 { t.Error("expected non-empty web_search_providers list") } + + // PR6: auth_provider_types is server-driven so the frontend doesn't + // hardcode the list. Must include the four founding types. + if len(meta.AuthProviderTypes) != 4 { + t.Errorf("auth_provider_types len = %d, want 4", len(meta.AuthProviderTypes)) + } + wantTypes := map[string]bool{"none": false, "oidc": false, "http_verifier": false, "custom": false} + for _, a := range meta.AuthProviderTypes { + if _, ok := wantTypes[a.Type]; !ok { + t.Errorf("unexpected auth type %q", a.Type) + continue + } + wantTypes[a.Type] = true + if a.Label == "" || a.Description == "" { + t.Errorf("auth type %q missing label/description", a.Type) + } + } + for typ, seen := range wantTypes { + if !seen { + t.Errorf("missing required auth type %q", typ) + } + } +} + +// setupCreateWithCapture returns a UIServer whose CreateFunc records the +// last AgentCreateOptions it received. Used to assert that opts.Auth +// round-trips through the JSON boundary. +func setupCreateWithCapture(t *testing.T) (*UIServer, *AgentCreateOptions) { + t.Helper() + root := t.TempDir() + captured := &AgentCreateOptions{} + mockCreate := func(opts AgentCreateOptions) (string, error) { + *captured = opts + return filepath.Join(root, opts.Name), nil + } + srv := NewUIServer(UIServerConfig{ + Port: 4200, + WorkDir: root, + ExePath: "/usr/bin/false", + CreateFunc: mockCreate, + AgentPort: 9100, + }) + return srv, captured +} + +func TestHandleCreateAgent_WithAuthPayload(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "auth-test-agent", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge" + } + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + } + if captured.Auth == nil { + t.Fatal("captured Auth is nil — opts.Auth did not round-trip") + } + if captured.Auth.Mode != "oidc" { + t.Errorf("Auth.Mode = %q, want oidc", captured.Auth.Mode) + } + if captured.Auth.Settings["issuer"] != "https://login.example.com" { + t.Errorf("issuer = %v, want https://login.example.com", captured.Auth.Settings["issuer"]) + } + if captured.Auth.Settings["audience"] != "api://forge" { + t.Errorf("audience = %v, want api://forge", captured.Auth.Settings["audience"]) + } +} + +func TestHandleCreateAgent_WithoutAuthPayload(t *testing.T) { + // Omitting `auth` from the request is still a valid create — the + // agent gets anonymous access by default. + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{"name": "no-auth-agent", "model_provider": "openai"}`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d", w.Code, http.StatusCreated) + } + if captured.Auth != nil { + t.Errorf("Auth = %v, want nil for omitted field", captured.Auth) + } } diff --git a/forge-ui/static/dist/app.js b/forge-ui/static/dist/app.js index 6eed95a..d4bc86b 100644 --- a/forge-ui/static/dist/app.js +++ b/forge-ui/static/dist/app.js @@ -1121,7 +1121,56 @@ function loadMonaco() { // ── Create Agent Wizard ────────────────────────────────────── -const WIZARD_STEPS = ['Name', 'Provider', 'Model & Key', 'Channels', 'Tools', 'Skills', 'Fallback', 'Env Vars', 'Review']; +const WIZARD_STEPS = ['Name', 'Provider', 'Model & Key', 'Channels', 'Tools', 'Skills', 'Fallback', 'Env Vars', 'Authentication', 'Review']; + +// buildAuthPayload reshapes the wizard's form.auth into the JSON shape +// the backend (cmd/ui.go → scaffold) expects. The wizard collects +// `groups_claim` as a flat field for usability; we translate it into the +// nested `claim_map: {groups: }` shape here. Empty optional fields +// are stripped so the generated forge.yaml stays clean. +// authReviewLabel renders the one-line summary shown on the Review step. +function authReviewLabel(auth) { + const mode = auth?.mode || 'none'; + const s = auth?.settings || {}; + if (mode === 'none') return 'Anonymous (no auth)'; + if (mode === 'custom') return 'Custom (edit forge.yaml)'; + if (mode === 'oidc') { + try { + return s.issuer ? 'OIDC · ' + new URL(s.issuer).hostname : 'OIDC'; + } catch (e) { + return 'OIDC'; + } + } + if (mode === 'http_verifier') { + try { + return s.url ? 'HTTP Verifier · ' + new URL(s.url).hostname : 'HTTP Verifier'; + } catch (e) { + return 'HTTP Verifier'; + } + } + return mode; +} + +function buildAuthPayload(auth) { + const mode = auth?.mode || 'none'; + const src = auth?.settings || {}; + if (mode === 'none' || mode === 'custom') { + return { mode, settings: {} }; + } + const settings = {}; + if (mode === 'oidc') { + const issuer = (src.issuer || '').replace(/\/+$/, ''); + if (issuer) settings.issuer = issuer; + if (src.audience) settings.audience = src.audience; + if (src.groups_claim && src.groups_claim !== 'groups') { + settings.claim_map = { groups: src.groups_claim }; + } + } else if (mode === 'http_verifier') { + if (src.url) settings.url = src.url; + if (src.default_org) settings.default_org = src.default_org; + } + return { mode, settings }; +} // Provider-specific API key labels, placeholders, and env var names const PROVIDER_KEY_INFO = { @@ -1168,6 +1217,10 @@ function CreatePage() { fallbacks: [], // [{provider, api_key}] passphrase: '', env_vars: {}, + // A2A server auth chain (PR6). The wizard's Authentication step + // sets `mode` to one of {none, oidc, http_verifier, custom} and + // populates `settings` for non-trivial modes. + auth: { mode: 'none', settings: {} }, }); const [creating, setCreating] = useState(false); const [oauthLoading, setOauthLoading] = useState(false); @@ -1250,6 +1303,15 @@ function CreatePage() { } return form.model_name.trim().length > 0; } + // Authentication step: required fields must be filled when a + // non-trivial mode is chosen. None / Custom always advance. + if (step === 8) { + const mode = form.auth?.mode || 'none'; + const s = form.auth?.settings || {}; + if (mode === 'oidc') return !!(s.issuer && s.audience); + if (mode === 'http_verifier') return !!s.url; + return true; // none / custom + } return true; }, [step, form, oauthDone]); @@ -1257,7 +1319,10 @@ function CreatePage() { setCreating(true); setError(null); try { - const result = await createAgent(form); + // Reshape the auth payload before submission so the backend gets + // the typed settings the scaffold expects (e.g., claim_map nesting). + const payload = { ...form, auth: buildAuthPayload(form.auth) }; + const result = await createAgent(payload); setSuccess(result); setTimeout(() => rescanAgents().catch(() => {}), 500); } catch (err) { @@ -1750,7 +1815,104 @@ function CreatePage() { `; } - case 8: { // Review + case 8: { // Authentication (A2A server auth chain — PR6) + const authTypes = meta?.auth_provider_types || []; + const mode = form.auth?.mode || 'none'; + const settings = form.auth?.settings || {}; + + const setAuthMode = (newMode) => { + // Switching modes always resets settings so stale fields + // (e.g., OIDC issuer when user switches to HTTP Verifier) + // don't leak into the submitted payload. + updateForm('auth', { mode: newMode, settings: {} }); + }; + const setAuthSetting = (key, val) => { + updateForm('auth', { mode, settings: { ...settings, [key]: val } }); + }; + + return html` +
+
Authentication
+
+ How should the agent's A2A server authenticate incoming requests? Default: anonymous. +
+ +
+ ${authTypes.map(opt => html` + + `)} +
+ + ${mode === 'oidc' && html` +
+
+ OIDC Configuration +
+ + setAuthSetting('issuer', e.target.value)} autocomplete="off" /> + + + setAuthSetting('audience', e.target.value)} autocomplete="off" /> + + + setAuthSetting('groups_claim', e.target.value)} autocomplete="off" /> +
+ Defaults to "groups". Set to e.g. "roles" if your IdP uses a different claim name. +
+
+ `} + + ${mode === 'http_verifier' && html` +
+
+ HTTP Verifier Configuration +
+ + setAuthSetting('url', e.target.value)} autocomplete="off" /> + + + setAuthSetting('default_org', e.target.value)} autocomplete="off" /> +
+ `} + + ${mode === 'custom' && html` +
+ A commented stub will be written to forge.yaml. Edit it after creation to plug in providers we don't expose in the wizard (e.g., static_token). +
+ `} +
+ `; + } + case 9: { // Review const configuredEnvCount = Object.values(form.env_vars).filter(v => v).length; const authLabel = form.auth_method === 'oauth' ? 'OAuth' : (form.api_key ? '\u2713 API Key provided' : 'Not set'); return html` @@ -1817,6 +1979,10 @@ function CreatePage() {
Secret Encryption
${form.passphrase ? '\u2713 Passphrase set' : 'None (plaintext .env)'}
+
+
A2A Auth
+
${authReviewLabel(form.auth)}
+
`; } diff --git a/forge-ui/types.go b/forge-ui/types.go index e2c598e..d628890 100644 --- a/forge-ui/types.go +++ b/forge-ui/types.go @@ -167,6 +167,17 @@ type AgentCreateOptions struct { Passphrase string `json:"passphrase,omitempty"` EnvVars map[string]string `json:"env_vars,omitempty"` Force bool `json:"force,omitempty"` + Auth *AuthCreateOptions `json:"auth,omitempty"` // A2A server auth chain (PR6+) +} + +// AuthCreateOptions describes the auth chain selection the web wizard +// captured. Mode is one of "none", "oidc", "http_verifier", "custom". +// Settings is the provider-type-specific settings block (issuer, audience, +// url, default_org, claim_map, …). Mirrors the TUI wizard's contract so +// both surfaces feed the same scaffold path. +type AuthCreateOptions struct { + Mode string `json:"mode"` + Settings map[string]any `json:"settings,omitempty"` } // FallbackProvider describes a fallback LLM provider with its API key. @@ -250,4 +261,15 @@ type WizardMetadata struct { Skills []SkillBrowserEntry `json:"skills"` ProviderModels map[string]ProviderModels `json:"provider_models"` WebSearchProviders []WebSearchProviderOption `json:"web_search_providers"` + AuthProviderTypes []AuthProviderTypeMeta `json:"auth_provider_types"` +} + +// AuthProviderTypeMeta describes one selectable auth provider type so the +// frontend renders its picker from server-driven metadata (no hardcoded +// provider list in JavaScript). When a new provider ships (e.g., Okta in +// Phase 3), append one entry here and the wizard picks it up. +type AuthProviderTypeMeta struct { + Type string `json:"type"` // "none", "oidc", "http_verifier", "custom" + Label string `json:"label"` // human-readable label for the picker + Description string `json:"description"` // single-line description under the label } From 379d5786f4748836dbcee32efe763620a38d0229 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:15:12 -0400 Subject: [PATCH 08/20] fix(auth/oidc): TTL-driven JWKS cache refresh (review #1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this fix the JWKS cache's ttl/lastFetch fields were set but never read — the cache only refreshed on unknown-kid lookups. Real-world IdPs revoke keys by REMOVING them from the JWKS document, which means a stolen private key remained trusted by Forge until process restart. The "refresh-on-miss" semantics handle key ROTATION (new kid arrives) but not key REVOCATION (kid silently disappears). Reviewer flagged: forge-core/auth/providers/oidc/jwks.go — high severity. Fix: - Split refresh-time state into two fields: lastSuccessful — drives TTL accounting; only advanced on success lastAttempt — drives error backoff; advanced on every attempt This stops a failed refresh during an IdP outage from silently extending trust in already-cached keys. - Get() now triggers a refresh when (now - lastSuccessful) > ttl, on the request path. A stale cache treats every kid as a miss so the refresh actually runs. - backoffActive() suppresses retry storms during JWKS outages (refetchGrace window — 30s default). - Successful refresh clears lastMissRefresh, so a previously-unknown kid can be re-attempted immediately after a successful reload. - Injectable clock (`now func() time.Time`) so tests exercise TTL expiry without time.Sleep. Tests (all -race clean): TestJWKS_TTLExpiryForcesRefresh Same kid, second call within TTL → no refetch. Advance virtual time past TTL → refetch happens. TestJWKS_RevokedKeyEventuallyRejected Sign with key A → accept. IdP removes key A from JWKS. Within TTL → still accepted (acceptable revocation lag). Past TTL → refresh picks up new JWKS → token rejected (ErrInvalidToken). TestJWKS_FailedRefreshDoesNotExtendTTL Warm cache → fail JWKS → past TTL → refresh fails (returns error). Endpoint recovers → past refetchGrace → next call retries refresh successfully (proves lastSuccessful was NOT bumped by the failure). Misleading doc comment removed; new struct comment documents the two-timestamp design and exact refresh triggers. Backward compat: no API changes. Default TTL (1 hour) and min TTL (5 min) clamp unchanged. Phase 1 OIDC behavior preserved otherwise. Verification: go test -race ./forge-core/auth/providers/oidc/ — all green go test -count=1 forge-core/... forge-cli/... forge-plugins/... forge-ui/... — 47/47 packages pass golangci-lint v2.10.1 — 0 issues --- forge-core/auth/providers/oidc/export_test.go | 13 ++ forge-core/auth/providers/oidc/jwks.go | 151 +++++++++++---- .../auth/providers/oidc/provider_test.go | 173 ++++++++++++++++++ 3 files changed, 298 insertions(+), 39 deletions(-) create mode 100644 forge-core/auth/providers/oidc/export_test.go diff --git a/forge-core/auth/providers/oidc/export_test.go b/forge-core/auth/providers/oidc/export_test.go new file mode 100644 index 0000000..8bd6a98 --- /dev/null +++ b/forge-core/auth/providers/oidc/export_test.go @@ -0,0 +1,13 @@ +package oidc + +import "time" + +// SetCacheClockForTest replaces the JWKS cache's time source. Exposed only +// in _test builds so tests can advance virtual time past TTL boundaries +// without sleeping. +// +// Not part of the public API. Test code in this package's external test +// file (oidc_test) uses this via the package-import boundary. +func SetCacheClockForTest(p *Provider, now func() time.Time) { + p.jwks.now = now +} diff --git a/forge-core/auth/providers/oidc/jwks.go b/forge-core/auth/providers/oidc/jwks.go index e12d075..5088866 100644 --- a/forge-core/auth/providers/oidc/jwks.go +++ b/forge-core/auth/providers/oidc/jwks.go @@ -46,16 +46,26 @@ type parsedKey struct { // jwksCache is a TTL-bounded cache of parsed JWK keys keyed by `kid`. // // Refresh strategy: -// - First lookup ever: fetch JWKS. -// - kid not in cache and last fetch is older than the small "refetch -// grace window" (avoids stampedes on bad input): fetch JWKS again. -// - kid still not in cache after refresh: return ErrKeyNotFound. -// - kid in cache: return immediately. // -// The cache does not eagerly expire entries by TTL — keys are durable; we -// refresh on miss. This matches how real-world identity providers rotate -// JWKS (add new key, then later remove old key — the new kid triggers a -// refresh that picks up both). +// 1. First lookup ever: fetch JWKS. +// 2. TTL expired (now - lastSuccessful > ttl): fetch JWKS on the request +// path. This is what invalidates keys that the IdP has revoked — +// without it, a compromised key removed from the issuer's JWKS would +// remain trusted by Forge until process restart. +// 3. kid not in cache (and TTL not expired): fetch JWKS to pick up new +// keys an IdP may have added since last refresh. +// 4. kid still not in cache after refresh: return ErrKeyNotFound. +// 5. kid in cache and TTL not expired: return immediately (hot path). +// +// Two separate timestamps track refresh state so failures don't extend the +// TTL window: +// - lastSuccessful: timestamp of the most recent successful refresh. +// Drives TTL accounting; only advanced on success. +// - lastAttempt: timestamp of the most recent attempt (success or fail). +// Drives failure backpressure so a dead JWKS endpoint isn't hammered. +// +// Concurrency: a single fetch is in flight at a time; concurrent misses +// coalesce behind the write lock. type jwksCache struct { jwksURL func(ctx context.Context) (string, error) httpClient *http.Client @@ -64,10 +74,14 @@ type jwksCache struct { // per request. Within this window, repeated unknown-kid lookups skip // the refetch. refetchGrace time.Duration + // now returns the current time. Injectable for tests so TTL expiry + // can be exercised without real time passing. + now func() time.Time - mu sync.RWMutex - keys map[string]parsedKey - lastFetch time.Time + mu sync.RWMutex + keys map[string]parsedKey + lastSuccessful time.Time // last successful JWKS load — drives TTL + lastAttempt time.Time // last refresh attempt (success or fail) — drives error backoff // lastMissRefresh is the last time we refreshed *because of* an // unknown kid and the kid was still not present after refresh. // Used to suppress stampedes of refetches for the same bad kid. @@ -85,6 +99,11 @@ func newJWKSCache(jwksURL func(ctx context.Context) (string, error), client *htt if ttl == 0 { ttl = defaultJWKSTTL } + // Only clamp the minimum when a TTL was actually requested — leaving + // callers to opt out of TTL entirely would require an explicit zero + // (which the previous branch promoted to the default). The clamp + // protects against operators setting an absurdly short value that + // would hammer the IdP. if ttl < minJWKSTTL { ttl = minJWKSTTL } @@ -93,10 +112,35 @@ func newJWKSCache(jwksURL func(ctx context.Context) (string, error), client *htt httpClient: client, ttl: ttl, refetchGrace: defaultRefetchGrace, + now: time.Now, keys: map[string]parsedKey{}, } } +// ttlExpired reports whether the cache's last successful refresh is older +// than the configured TTL. Returns true when there's no successful fetch +// yet (initial state). Must be called with at least an RLock held. +func (c *jwksCache) ttlExpired() bool { + if c.lastSuccessful.IsZero() { + return true + } + return c.now().Sub(c.lastSuccessful) > c.ttl +} + +// backoffActive reports whether we should suppress a refresh attempt +// because a recent one failed. Prevents a dead JWKS endpoint from being +// hit on every request during an outage. Must be called with at least +// an RLock held. +func (c *jwksCache) backoffActive() bool { + if c.lastErr == nil { + return false + } + if c.lastAttempt.IsZero() { + return false + } + return c.now().Sub(c.lastAttempt) < c.refetchGrace +} + // ErrKeyNotFound is returned when a `kid` is not in the cache and a // refresh did not produce it. var ErrKeyNotFound = fmt.Errorf("oidc: signing key not found") @@ -105,37 +149,62 @@ var ErrKeyNotFound = fmt.Errorf("oidc: signing key not found") // if necessary. Safe for concurrent use. // // Refresh semantics: -// - kid in cache → return immediately. -// - kid missing, never refreshed for a miss before → refresh once. -// - kid missing AND we already refreshed for a miss recently (within -// refetchGrace) → deny without re-fetching, to prevent a stream of -// bad kids from hammering the JWKS endpoint. +// - TTL expired (cache stale): refresh on the request path. Required +// for revocation to take effect — without this, a key removed from +// the IdP's JWKS would remain trusted by Forge until process restart. +// - kid in cache AND cache fresh: return immediately (hot path). +// - kid missing AND not recently refreshed for a miss: refresh once. +// - kid missing AND recently refreshed for a miss (within refetchGrace): +// deny without re-fetching — prevents bad kids hammering the endpoint. +// - last refresh failed AND within failure backoff: serve from existing +// cache without re-attempt — protects a dead IdP from request storms. func (c *jwksCache) Get(ctx context.Context, kid string) (parsedKey, error) { - // Fast path: lock-free read. + // Fast path: lock-free read. Treat a stale (TTL-expired) cache as if + // every kid were a miss, so we fall into the write path below where + // the actual refresh happens. c.mu.RLock() - if key, ok := c.keys[kid]; ok { - c.mu.RUnlock() - return key, nil + stale := c.ttlExpired() + var ( + key parsedKey + hit bool + ) + if !stale { + key, hit = c.keys[kid] } - cooldownActive := !c.lastMissRefresh.IsZero() && time.Since(c.lastMissRefresh) < c.refetchGrace + missCooldown := !c.lastMissRefresh.IsZero() && c.now().Sub(c.lastMissRefresh) < c.refetchGrace + errBackoff := c.backoffActive() c.mu.RUnlock() - if cooldownActive { + if hit { + return key, nil + } + // During error backoff we serve from the (possibly empty) cache — + // don't re-attempt a known-broken endpoint. Note that on the very + // first call (never-fetched cache + endpoint already broken), this + // path is impossible because lastErr would still be nil. + if errBackoff { + return parsedKey{}, ErrKeyNotFound + } + // Don't refresh just to look up a kid we already concluded is absent. + // TTL-expired cache always tries a refresh (security), regardless of + // the unknown-kid grace window. + if !stale && missCooldown { return parsedKey{}, ErrKeyNotFound } - // Miss — refresh. Take the write lock first so concurrent misses - // coalesce into one fetch. + // Take the write lock so concurrent refresh attempts coalesce. c.mu.Lock() defer c.mu.Unlock() - // Re-check after acquiring the write lock — another goroutine may - // have just refreshed and populated this kid. - if key, ok := c.keys[kid]; ok { - return key, nil - } - if !c.lastMissRefresh.IsZero() && time.Since(c.lastMissRefresh) < c.refetchGrace { - return parsedKey{}, ErrKeyNotFound + // Re-check under write lock — another goroutine may have refreshed + // while we were waiting. + if !c.ttlExpired() { + if key, ok := c.keys[kid]; ok { + return key, nil + } + if !c.lastMissRefresh.IsZero() && c.now().Sub(c.lastMissRefresh) < c.refetchGrace { + return parsedKey{}, ErrKeyNotFound + } } if err := c.refreshLocked(ctx); err != nil { @@ -146,29 +215,31 @@ func (c *jwksCache) Get(ctx context.Context, kid string) (parsedKey, error) { } // Refresh happened but kid is still missing — remember that so we // don't hammer JWKS on the next unknown-kid request. - c.lastMissRefresh = time.Now() + c.lastMissRefresh = c.now() return parsedKey{}, ErrKeyNotFound } // refreshLocked fetches the JWKS endpoint and replaces the cache. Must be // called with c.mu held for writing. +// refreshLocked fetches and parses the JWKS. lastAttempt is always +// updated (drives error backoff). lastSuccessful is updated ONLY when a +// full parse succeeds — failures must not extend the TTL window. func (c *jwksCache) refreshLocked(ctx context.Context) error { + c.lastAttempt = c.now() + url, err := c.jwksURL(ctx) if err != nil { c.lastErr = err - c.lastFetch = time.Now() return fmt.Errorf("oidc: resolve jwks_uri: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { c.lastErr = err - c.lastFetch = time.Now() return fmt.Errorf("oidc: build jwks request: %w", err) } resp, err := c.httpClient.Do(req) if err != nil { c.lastErr = err - c.lastFetch = time.Now() return fmt.Errorf("oidc: jwks fetch: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -177,7 +248,6 @@ func (c *jwksCache) refreshLocked(ctx context.Context) error { _, _ = io.CopyN(io.Discard, resp.Body, 1024) err := fmt.Errorf("oidc: jwks returned status %d", resp.StatusCode) c.lastErr = err - c.lastFetch = time.Now() return err } @@ -186,7 +256,6 @@ func (c *jwksCache) refreshLocked(ctx context.Context) error { } if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { c.lastErr = err - c.lastFetch = time.Now() return fmt.Errorf("oidc: jwks decode: %w", err) } @@ -218,8 +287,12 @@ func (c *jwksCache) refreshLocked(ctx context.Context) error { } c.keys = parsed - c.lastFetch = time.Now() + c.lastSuccessful = c.now() c.lastErr = nil + // Clear the unknown-kid grace window on a successful refresh so the + // next miss can re-attempt a refresh immediately (the prior miss may + // have been for a kid that's now present in the freshly-loaded JWKS). + c.lastMissRefresh = time.Time{} return nil } diff --git a/forge-core/auth/providers/oidc/provider_test.go b/forge-core/auth/providers/oidc/provider_test.go index 82fe138..d131be3 100644 --- a/forge-core/auth/providers/oidc/provider_test.go +++ b/forge-core/auth/providers/oidc/provider_test.go @@ -757,6 +757,179 @@ func TestVerify_TokenWithoutAudClaim(t *testing.T) { } } +// --- TTL-driven refresh (review finding #1) --- +// +// These tests guarantee that JWKS keys revoked at the IdP eventually stop +// being accepted by Forge. Before the fix, the cache's ttl field was dead +// code and revoked keys remained trusted until process restart. + +func TestJWKS_TTLExpiryForcesRefresh(t *testing.T) { + // Build the provider with a 1-minute TTL. We can't reach the cache's + // injectable clock from outside the package directly, so we exercise + // TTL behavior by setting a tiny TTL and advancing wall time via the + // `now` field on the provider's JWKS cache through the exported + // SetCacheClockForTest test hook below. + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.JWKSCacheTTL = 10 * time.Minute // any value above the 5-min clamp + }) + + // Drive a virtual clock so we can move past TTL without sleeping. + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("first Verify: %v", err) + } + fetchesAfterFirst := fi.jwksFetches + + // Same kid, still within TTL → no new JWKS fetch. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("second Verify within TTL: %v", err) + } + if fi.jwksFetches != fetchesAfterFirst { + t.Errorf("JWKS refetched within TTL (%d → %d) — should have served from cache", + fetchesAfterFirst, fi.jwksFetches) + } + + // Advance virtual time past TTL. + clk.advance(11 * time.Minute) + + // Same kid, but TTL expired → MUST refetch. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("third Verify after TTL: %v", err) + } + if fi.jwksFetches != fetchesAfterFirst+1 { + t.Errorf("JWKS not refetched after TTL expiry (fetches: %d, expected %d)", + fi.jwksFetches, fetchesAfterFirst+1) + } +} + +func TestJWKS_RevokedKeyEventuallyRejected(t *testing.T) { + // The security-critical case: IdP removes a key from JWKS (revocation). + // Within TTL, Forge still trusts it (acceptable — that's the TTL window). + // After TTL, Forge picks up the new JWKS without the revoked key and + // must reject tokens signed by it. + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.JWKSCacheTTL = 10 * time.Minute + }) + + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + // Add a second key the issuer rotates between. + fi.addRSAKey("key-revoked") + + // Token is signed by key-revoked. + tok := fi.SignWith("key-revoked", fi.DefaultClaims(testAudience)) + + // Initial verify — cache warms with both keys present. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("pre-revocation verify: %v", err) + } + + // IdP revokes the key (removes it from JWKS). + delete(fi.keys, "key-revoked") + for i, k := range fi.keyOrder { + if k == "key-revoked" { + fi.keyOrder = append(fi.keyOrder[:i], fi.keyOrder[i+1:]...) + break + } + } + + // Within TTL: Forge still trusts the old key (acceptable revocation lag). + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("within-TTL verify failed unexpectedly: %v", err) + } + + // Advance past TTL → refresh picks up the new JWKS without key-revoked. + clk.advance(11 * time.Minute) + + // After refresh, the token's kid is no longer in JWKS → ErrInvalidToken. + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("post-TTL verify err = %v, want ErrInvalidToken (kid removed by IdP)", err) + } +} + +func TestJWKS_FailedRefreshDoesNotExtendTTL(t *testing.T) { + // Bug guard: a failed refresh must not bump lastSuccessful. Otherwise + // an IdP outage during a TTL-refresh window would silently extend + // trust in already-cached keys. + // + // Strategy: warm the cache, then make the JWKS endpoint return 500, + // advance past TTL, attempt a Verify. The refresh fails, but the + // existing keys remain available — and the next attempt after the + // failure backoff should retry (proving lastSuccessful did NOT advance). + fi := newFakeIssuer(t) + failJWKS := atomic.Bool{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if failJWKS.Load() && r.URL.Path == "/jwks" { + w.WriteHeader(http.StatusInternalServerError) + return + } + fi.serve(w, r) + })) + defer srv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: srv.URL, + Audience: testAudience, + JWKSURL: srv.URL + "/jwks", + JWKSCacheTTL: 10 * time.Minute, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Patch the fakeIssuer's IssuerURL so claim's iss matches. + fi.server = srv + + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + // Warm the cache. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("warmup: %v", err) + } + + // Endpoint goes down; advance past TTL. + failJWKS.Store(true) + clk.advance(11 * time.Minute) + + // Refresh attempt fails. The TTL clock was NOT extended by the + // failure, so subsequent calls outside the backoff window will + // re-attempt. + _, err = p.Verify(context.Background(), tok, nil) + if err == nil { + t.Errorf("expected error during JWKS outage, got nil") + } + + // Endpoint comes back, but we're still in error-backoff window. + // Advance past refetchGrace (30s) so a retry can happen. + failJWKS.Store(false) + clk.advance(35 * time.Second) + + // Next verify should retry the refresh and succeed. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("post-recovery verify failed: %v", err) + } +} + +// virtualClock is a controllable time source for tests. +type virtualClock struct { + now time.Time +} + +func (c *virtualClock) Now() time.Time { return c.now } +func (c *virtualClock) advance(d time.Duration) { + c.now = c.now.Add(d) +} + func TestVerify_TokenWithoutExp(t *testing.T) { fi := newFakeIssuer(t) p := newProvider(t, fi, nil) From 3a7a131a78c0b6ffe69575f5c7729a92ef275f9f Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:20:00 -0400 Subject: [PATCH 09/20] fix(auth/oidc): normalize issuer trailing slash everywhere (review #2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before: newDiscovery trimmed the configured Issuer, but the discovery-doc check compared the trimmed value against the IdP's RAW emitted form, and jwt.WithIssuer received the user's RAW configured form. Result: any IdP that emits iss with a trailing slash (Auth0) was spuriously rejected against a config without slash, and vice versa. Reviewer flagged: provider.go vs discovery.go. Fix: - New Config.normalize() trims trailing slash from Issuer; called once in New() so every downstream comparison uses the same canonical form. - Discovery comparison now trims meta.Issuer before checking against the already-trimmed d.issuer. - JWT iss validation: replaced jwt.WithIssuer (strict equality, no trim option) with a manual checkIssuer that trims both the token's iss claim and the configured Issuer before comparing. Security: trimming a single trailing slash cannot collapse distinct issuers — "https://attacker.example" and "https://legit.example" remain distinct after trim. The trim is normalization, not weakening. TestIssuer_GenuineMismatchStillRejected guards this invariant. Backward compat: existing configs continue to work; new configs with trailing slashes now also work. JWT iss claim with trailing slash now also works. Tests: TestIssuer_ConfigWithoutSlash_TokenWithSlash (Auth0-shaped iss + Okta-shaped config) TestIssuer_ConfigWithSlash_TokenWithoutSlash (reverse) TestIssuer_BothWithoutSlash_StillWorks (regression — common case) TestIssuer_BothWithSlash_StillWorks (regression — Auth0 paired) TestIssuer_GenuineMismatchStillRejected (security guard) TestIssuer_TokenMissingIssClaim (no iss → ErrTokenRejected) TestDiscovery_TrailingSlashTolerated (discovery doc has slash, config doesn't) Verification: go test -race ./forge-core/auth/providers/oidc/ — all green full sweep — 47/47 packages pass golangci-lint v2.10.1 — 0 issues --- forge-core/auth/providers/oidc/discovery.go | 8 +- forge-core/auth/providers/oidc/provider.go | 61 +++++++- .../auth/providers/oidc/provider_test.go | 143 ++++++++++++++++++ 3 files changed, 207 insertions(+), 5 deletions(-) diff --git a/forge-core/auth/providers/oidc/discovery.go b/forge-core/auth/providers/oidc/discovery.go index 0537fb1..3ce6c53 100644 --- a/forge-core/auth/providers/oidc/discovery.go +++ b/forge-core/auth/providers/oidc/discovery.go @@ -57,7 +57,13 @@ func (d *discovery) EnsureLoaded(ctx context.Context) error { // Per OIDC spec, the discovery doc's `issuer` must match the issuer // the client used to fetch it. Catches misconfiguration where someone // points at a discovery doc for a different tenant. - if meta.Issuer != d.issuer { + // + // Trim trailing slashes on both sides — IdPs and operators disagree + // on whether the canonical form has one (Auth0 emits with, many + // Okta deployments without). The trim is normalization, not a + // security weakening — distinct hosts remain distinct after trim. + // See Config.normalize() in provider.go. + if strings.TrimRight(meta.Issuer, "/") != d.issuer { return fmt.Errorf("oidc: discovery issuer %q does not match configured issuer %q", meta.Issuer, d.issuer) } if meta.JWKSURI == "" { diff --git a/forge-core/auth/providers/oidc/provider.go b/forge-core/auth/providers/oidc/provider.go index 924c7df..ad3e7e0 100644 --- a/forge-core/auth/providers/oidc/provider.go +++ b/forge-core/auth/providers/oidc/provider.go @@ -34,6 +34,7 @@ import ( "fmt" "net/http" "slices" + "strings" "time" "github.com/golang-jwt/jwt/v5" @@ -99,6 +100,21 @@ func (c Config) Validate() error { return nil } +// normalize returns a copy of the Config with the Issuer in canonical +// form (trailing slash stripped). Called once at New() so every +// downstream comparison — discovery doc check, JWT iss validation, +// JWKS URL construction — uses the same form. +// +// Why this matters: OIDC IdPs are inconsistent about trailing slashes. +// Auth0 emits "iss": "https://tenant.auth0.com/", many enterprise +// Okta tenants emit without, etc. Operators paste whatever they see in +// their admin UI into forge.yaml. We trim once at the boundary so the +// mismatch never reaches a comparison. +func (c Config) normalize() Config { + c.Issuer = strings.TrimRight(c.Issuer, "/") + return c +} + // Provider implements auth.Provider for OIDC issuers. type Provider struct { cfg Config @@ -115,6 +131,10 @@ func New(cfg Config) (*Provider, error) { if err := cfg.Validate(); err != nil { return nil, err } + // Normalize trailing slash once so every downstream comparison + // (discovery doc, token iss, JWKS URL construction) uses the same + // canonical form. Review finding #2. + cfg = cfg.normalize() if cfg.ClockSkew == 0 { cfg.ClockSkew = DefaultClockSkew } @@ -133,15 +153,20 @@ func New(cfg Config) (*Provider, error) { p.jwks = newJWKSCache(p.resolveJWKSURL, client, cfg.JWKSCacheTTL) // Configure the JWT parser once. + // + // Issuer is validated manually (post-parse) instead of via + // jwt.WithIssuer — that helper does strict string equality on the + // token's iss claim, which fails when the IdP emits with a trailing + // slash but our config normalized it off (or vice versa). The manual + // check trims both sides. Same reasoning as the discovery doc + // comparison in EnsureLoaded. + // + // Audience is similarly manual because of the azp fallback below. parserOpts := []jwt.ParserOption{ jwt.WithValidMethods(allowedAlgNames()), - jwt.WithIssuer(cfg.Issuer), jwt.WithLeeway(cfg.ClockSkew), jwt.WithExpirationRequired(), } - // Audience validation handled manually so we can implement the azp - // fallback below. (jwt.WithAudience rejects when aud lacks the - // expected value without considering azp.) p.parser = jwt.NewParser(parserOpts...) return p, nil @@ -221,6 +246,14 @@ func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Hea return nil, classifyJWTError(err) } + // Issuer validation (manual — trims trailing slash on both sides so + // IdPs that emit iss with/without a trailing slash interop with + // configs that use the opposite form). See Config.normalize and + // review finding #2. + if err := p.checkIssuer(claims); err != nil { + return nil, err + } + // Audience validation (with azp fallback). if err := p.checkAudience(claims); err != nil { return nil, err @@ -229,6 +262,26 @@ func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Hea return mapClaims(p.cfg.ClaimMap, claims, headers), nil } +// checkIssuer validates the token's `iss` claim against the configured +// Issuer. Both values are trimmed of trailing slashes before comparison +// so trailing-slash drift between the IdP and the operator's configured +// value doesn't cause spurious rejections. +// +// Security note: trimming a single trailing slash cannot enable issuer +// impersonation — "https://attacker.example" and "https://legit.example" +// remain distinct. The trim is purely a normalization fix. +func (p *Provider) checkIssuer(claims jwt.MapClaims) error { + iss, _ := claims["iss"].(string) + if iss == "" { + return fmt.Errorf("%w: missing iss claim", auth.ErrTokenRejected) + } + if strings.TrimRight(iss, "/") != p.cfg.Issuer { + return fmt.Errorf("%w: issuer %q does not match configured %q", + auth.ErrTokenRejected, iss, p.cfg.Issuer) + } + return nil +} + // looksLikeJWT does a cheap structural check: exactly three non-empty // dot-separated segments. Used to yield to the next provider on non-JWT // tokens (e.g., opaque tokens that another chain entry would handle). diff --git a/forge-core/auth/providers/oidc/provider_test.go b/forge-core/auth/providers/oidc/provider_test.go index d131be3..bf24b87 100644 --- a/forge-core/auth/providers/oidc/provider_test.go +++ b/forge-core/auth/providers/oidc/provider_test.go @@ -757,6 +757,149 @@ func TestVerify_TokenWithoutAudClaim(t *testing.T) { } } +// --- Issuer trailing-slash normalization (review finding #2) --- +// +// IdPs disagree on whether the canonical issuer has a trailing slash. +// Auth0 emits "iss": "https://tenant.auth0.com/"; most Okta deployments +// do not. Operators paste whatever they see into forge.yaml. These +// tests guarantee both directions interop. + +func TestIssuer_ConfigWithoutSlash_TokenWithSlash(t *testing.T) { + fi := newFakeIssuer(t) + // Config the user-facing way: no trailing slash. + p := newProvider(t, fi, func(c *oidc.Config) { + c.Issuer = fi.IssuerURL() // already no slash + }) + + // Mint a token whose iss has a trailing slash (mimics Auth0). + claims := fi.DefaultClaims(testAudience) + claims["iss"] = fi.IssuerURL() + "/" + tok := fi.SignWith("key-1", claims) + + id, err := p.Verify(context.Background(), tok, nil) + if err != nil { + t.Fatalf("token iss with slash, config without — should accept, got: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } +} + +func TestIssuer_ConfigWithSlash_TokenWithoutSlash(t *testing.T) { + fi := newFakeIssuer(t) + // Config the other way: user pastes with trailing slash. + p := newProvider(t, fi, func(c *oidc.Config) { + c.Issuer = fi.IssuerURL() + "/" + }) + + // Token has no slash (the more common case). + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("config with slash, token without — should accept, got: %v", err) + } +} + +func TestIssuer_BothWithoutSlash_StillWorks(t *testing.T) { + // Regression: ensure the trim didn't break the most common case. + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("config without slash + token without slash should accept: %v", err) + } +} + +func TestIssuer_BothWithSlash_StillWorks(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.Issuer = fi.IssuerURL() + "/" + }) + claims := fi.DefaultClaims(testAudience) + claims["iss"] = fi.IssuerURL() + "/" + tok := fi.SignWith("key-1", claims) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("config + token both with slash should accept: %v", err) + } +} + +func TestIssuer_GenuineMismatchStillRejected(t *testing.T) { + // Security guard: trimming trailing slashes must not collapse + // genuinely-different issuers into the same comparison value. + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["iss"] = "https://attacker.example.com/" // completely different host + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("genuine issuer mismatch err = %v, want ErrTokenRejected", err) + } +} + +func TestIssuer_TokenMissingIssClaim(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + delete(claims, "iss") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("missing iss err = %v, want ErrTokenRejected", err) + } +} + +func TestDiscovery_TrailingSlashTolerated(t *testing.T) { + // Discovery side: IdP's doc has issuer with trailing slash; + // operator configured without. Must not bomb at startup-ish phase. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Deliberately use a different host than the test server's URL + // to simulate normalization-only mismatch — wait, we need it to + // be the same host, just different slash form. + })) + srv.Close() // we just needed the type + + // Build a real fake issuer and add a slash to the doc's response. + docHost := "" + dynamicSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + // Emit issuer WITH trailing slash even though docHost has none. + _, _ = w.Write([]byte(`{"issuer":"` + docHost + `/","jwks_uri":"` + docHost + `/jwks"}`)) + case "/jwks": + _, _ = w.Write([]byte(`{"keys":[]}`)) + default: + http.NotFound(w, r) + } + })) + defer dynamicSrv.Close() + docHost = dynamicSrv.URL + + p, err := oidc.New(oidc.Config{ + Issuer: docHost, // no trailing slash + Audience: testAudience, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Trigger the discovery path with a real JWT-shaped string. + // Discovery happens lazily on first JWKS lookup. We're testing that + // the comparison doesn't reject due to trailing-slash mismatch. + dummyJWT := "eyJhbGciOiJSUzI1NiIsImtpZCI6ImsxIn0.eyJpc3MiOiJ4In0.AAAA" + _, err = p.Verify(context.Background(), dummyJWT, nil) + // Expect an ErrInvalidToken (kid not found in empty JWKS) — NOT a + // "discovery issuer mismatch" error. The previous bug surfaced as + // the latter. + if err != nil && strings.Contains(err.Error(), "discovery issuer") { + t.Errorf("discovery rejected due to trailing-slash drift: %v", err) + } +} + // --- TTL-driven refresh (review finding #1) --- // // These tests guarantee that JWKS keys revoked at the IdP eventually stop From 19f72512a710440c8e022f995a0b8a1578be9e62 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:27:21 -0400 Subject: [PATCH 10/20] fix(auth): panic on nil Chain unless AllowAnonymous=true (review #3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: forge-core/auth/middleware.go silently returned a no-op wrapper when opts.Chain == nil. A misconfigured runner that forgot to wire a chain would silently serve unauthenticated requests in production. High severity — the open-prod-endpoint failure mode is the highest-impact misconfiguration in the auth subsystem. Fix: - New MiddlewareOptions.AllowAnonymous explicit opt-in (default false). - Middleware() panics at construction when Chain == nil && !AllowAnonymous. Fails loudly at startup, not silently at the first real request. - Runner updated to set AllowAnonymous=true in the two legitimate anonymous paths: 1. --no-auth flag (operator explicit opt-in) 2. No auth: block AND no --auth-url AND no channels (legacy local dev default — preserved for backward compat) Backward compat: - Existing deployments with a configured chain — unchanged. - Deployments with --no-auth — unchanged (runner sets the flag). - Deployments that ran anonymous by accident — now panic at startup with a clear, actionable message naming the flag to set. This IS a behavior change, by design. Tests (all -race clean): TestMiddleware_NilChainPanicsWithoutAllowAnonymous Asserts the panic happens AND the message mentions AllowAnonymous so operators know how to fix it. TestMiddleware_NilChainWithAllowAnonymousPassesThrough Counterpart: explicit opt-in works. TestMiddleware_NonNilChainIgnoresAllowAnonymous AllowAnonymous is only consulted when Chain is nil; a configured chain always enforces auth. The existing PR1 "nil chain passes through" test updated to use the new AllowAnonymous flag. The E2E TestE2E_NoAuthConfigured_AnonymousAccess test driver sets AllowAnonymous=true automatically when given a nil chain. Verification: go test -race ./forge-core/auth/... ./forge-cli/runtime/ — all green full sweep — 47/47 packages pass golangci-lint v2.10.1 — 0 issues --- forge-cli/runtime/auth_chain_test.go | 7 ++- forge-cli/runtime/runner.go | 19 ++++++- forge-core/auth/middleware.go | 34 +++++++++-- forge-core/auth/middleware_test.go | 85 +++++++++++++++++++++++++++- 4 files changed, 136 insertions(+), 9 deletions(-) diff --git a/forge-cli/runtime/auth_chain_test.go b/forge-cli/runtime/auth_chain_test.go index 53a5f5f..d45b9ba 100644 --- a/forge-cli/runtime/auth_chain_test.go +++ b/forge-cli/runtime/auth_chain_test.go @@ -178,11 +178,14 @@ func TestBuildLegacyAuthChain_OrgIDPropagation(t *testing.T) { // runMiddleware is a small driver that runs an http.Handler chain through // the auth.Middleware with a given chain and returns the recorded response. +// When chain is nil, the test explicitly opts the middleware into +// anonymous mode (review #3 — nil Chain without AllowAnonymous panics). func runMiddleware(t *testing.T, chain auth.Provider, method, path, authHeader string) *httptest.ResponseRecorder { t.Helper() handler := auth.Middleware(auth.MiddlewareOptions{ - Chain: chain, - SkipPaths: auth.DefaultSkipPaths(), + Chain: chain, + AllowAnonymous: chain == nil, + SkipPaths: auth.DefaultSkipPaths(), })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 82f4dfd..8cdb2b4 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -1751,7 +1751,12 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew } if r.cfg.NoAuth { - return auth.MiddlewareOptions{}, nil // nil chain → passthrough + // Operator explicitly chose anonymous via --no-auth. AllowAnonymous + // makes that choice visible at the middleware boundary (review #3). + return auth.MiddlewareOptions{ + AllowAnonymous: true, + SkipPaths: auth.DefaultSkipPaths(), + }, nil } userChain, err := r.buildUserAuthChain() @@ -1776,6 +1781,18 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew chain = auth.PrependChain(userChain, loopback) } + // No user chain AND no loopback token → legacy "no auth config, no + // channels" default. Preserve backward compat by allowing anonymous, + // but flag it explicitly so the middleware's nil-chain panic guard + // is satisfied (review #3). + if chain == nil { + return auth.MiddlewareOptions{ + AllowAnonymous: true, + SkipPaths: auth.DefaultSkipPaths(), + OnAuth: makeAuthAuditCallback(auditLogger), + }, nil + } + return auth.MiddlewareOptions{ Chain: chain, SkipPaths: auth.DefaultSkipPaths(), diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index 4cfe7ab..e5e7d57 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -30,10 +30,26 @@ func DefaultSkipPaths() map[string]bool { // MiddlewareOptions configures Middleware. type MiddlewareOptions struct { - // Chain is the provider chain that verifies bearer tokens. If nil, - // the middleware behaves as a passthrough (anonymous access). + // Chain is the provider chain that verifies bearer tokens. May only + // be nil when AllowAnonymous is true (see below). Chain Provider + // AllowAnonymous explicitly opts the middleware into running without + // authentication. Required whenever Chain is nil — otherwise + // Middleware() panics at construction. This prevents a misconfigured + // runner from silently serving unauthenticated requests because + // someone forgot to wire a chain. + // + // Set this to true when: + // - --no-auth flag is in effect (operator explicitly chose anon) + // - No auth: block AND no --auth-url AND no channels (legacy local + // dev default — preserved for backward compat) + // + // Leave this false for any production deployment that intends to + // enforce auth; a nil chain will then panic loudly at startup + // instead of running open. + AllowAnonymous bool + // SkipPaths maps "METHOD /path" keys that bypass authentication. // If nil, DefaultSkipPaths() is used. SkipPaths map[string]bool @@ -63,10 +79,20 @@ type errorString string func (e errorString) Error() string { return string(e) } // Middleware returns an http.Handler that enforces bearer token authentication -// via the provided Provider chain. If opts.Chain is nil, requests pass through -// without checks (anonymous access). +// via the provided Provider chain. +// +// Panics at construction if opts.Chain is nil and opts.AllowAnonymous is +// false. This is intentional — silently passing through requests when the +// caller forgot to wire a chain is the highest-impact misconfiguration in +// the auth subsystem (open prod endpoint). Fail-loud catches it at startup, +// not at the first request from a real user. func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { if opts.Chain == nil { + if !opts.AllowAnonymous { + panic("auth: Middleware called with nil Chain and AllowAnonymous=false. " + + "Set MiddlewareOptions.AllowAnonymous: true to explicitly allow " + + "unauthenticated access, or provide a Chain.") + } return func(next http.Handler) http.Handler { return next } } skip := opts.SkipPaths diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index 6954617..643195b 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -68,8 +68,8 @@ func TestMiddleware(t *testing.T) { wantStatus int }{ { - name: "nil chain passes through", - opts: MiddlewareOptions{Chain: nil}, + name: "nil chain with AllowAnonymous passes through", + opts: MiddlewareOptions{Chain: nil, AllowAnonymous: true}, method: "POST", path: "/", wantStatus: http.StatusOK, @@ -362,6 +362,87 @@ func TestMiddleware_TokenKindDetection(t *testing.T) { } } +// --- Nil chain panic guard (review finding #3) --- + +func TestMiddleware_NilChainPanicsWithoutAllowAnonymous(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic when Chain==nil and AllowAnonymous==false") + } + msg, _ := r.(string) + if msg == "" { + t.Fatalf("panic value = %v, want a descriptive string", r) + } + // The message should mention the option name so callers know + // how to opt in. + if !contains(msg, "AllowAnonymous") { + t.Errorf("panic message %q does not reference AllowAnonymous", msg) + } + }() + + // This call MUST panic — silent anonymous passthrough is the bug + // the AllowAnonymous flag exists to prevent. + _ = Middleware(MiddlewareOptions{Chain: nil, AllowAnonymous: false}) +} + +func TestMiddleware_NilChainWithAllowAnonymousPassesThrough(t *testing.T) { + // Counterpart: explicit opt-in does NOT panic; serves all requests. + handler := Middleware(MiddlewareOptions{ + Chain: nil, + AllowAnonymous: true, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/anything", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("status = %d, want %d (AllowAnonymous should pass through)", rr.Code, http.StatusOK) + } +} + +func TestMiddleware_NonNilChainIgnoresAllowAnonymous(t *testing.T) { + // AllowAnonymous is only consulted when Chain == nil. When a chain + // is configured, the flag has no effect — requests are still + // validated through the chain. + const goodToken = "ok" + chain := NewChainProvider(&tokenProvider{ + expected: goodToken, + identity: Identity{UserID: "u", Source: "test"}, + }) + handler := Middleware(MiddlewareOptions{ + Chain: chain, + AllowAnonymous: true, // should be ignored + SkipPaths: DefaultSkipPaths(), + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // No token → 401 (chain still enforced). + req := httptest.NewRequest("POST", "/tasks", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Errorf("non-nil chain + AllowAnonymous=true: status = %d, want 401 (chain must still enforce)", rr.Code) + } +} + +// contains is a tiny strings.Contains substitute for the panic-message test +// (kept inline to avoid pulling in strings just for the assertion). +func contains(s, sub string) bool { + if len(sub) == 0 { + return true + } + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} + func TestClassifyAuthFailure(t *testing.T) { tests := []struct { name string From b33522d2c7ad7690202c1c44638245c41abc08ef Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:31:06 -0400 Subject: [PATCH 11/20] fix(auth): cross-check --no-auth against forge.yaml auth block (review #4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: a user who configures forge.yaml with 'auth: { required: true, providers: [...] }' and then runs the agent with --no-auth gets anonymous access silently. The flag and the YAML block were treated independently. The --no-auth happy-path short-circuit returned an anonymous middleware before consulting cfg.Auth at all. This is the second-highest-impact misconfiguration class in the auth subsystem (open prod endpoint by mistake), and it's worse than the nil-Chain case (#3) because here the operator deliberately wrote 'required: true' into their config. Fix (graduated severity): auth.required == true + --no-auth → startup error (refuse) providers populated + --no-auth → warning (operator intent unclear; we proceed but make the override visible) empty auth block + --no-auth → unchanged (anonymous, no log) Error message names what's wrong AND lists the three ways to fix it: - remove --no-auth - set 'auth.required: false' - delete the 'auth:' block Warning includes provider count so dashboards can alert. Tests (all -race clean): TestResolveAuth_NoAuthWithRequiredFails auth.required=true + --no-auth → error; message contains "--no-auth conflicts with". TestResolveAuth_NoAuthWithProvidersWarns providers set, required=false + --no-auth → no error, warning emitted containing "--no-auth overrides". TestResolveAuth_NoAuthWithEmptyYAMLConfig_NoWarning Regression: --no-auth alone (no auth: block) still works without spurious warning. TestResolveAuth_NoAuthWithRequiredFalseAndProviders_WarnsNotFails Explicit Required: false is respected — warn, don't refuse. New recordLogger test helper captures Warn/Info calls so assertions can check what the runner emitted; mirrors the no-op nopLogger pattern already in use. Backward compat: - --no-auth alone (no auth: block): unchanged - --no-auth + providers + no required: warns, still anonymous - --no-auth + required: true: BREAKING — refuses to start. This is intentional; the previous behavior silently exposed a "required auth" deployment as anonymous, which the user explicitly tried to prevent in their forge.yaml. Verification: go test -race ./forge-core/auth/... ./forge-cli/runtime/ — green full sweep — 47/47 packages pass golangci-lint v2.10.1 — 0 issues --- forge-cli/runtime/auth_config_test.go | 111 ++++++++++++++++++++++++++ forge-cli/runtime/runner.go | 22 +++++ 2 files changed, 133 insertions(+) diff --git a/forge-cli/runtime/auth_config_test.go b/forge-cli/runtime/auth_config_test.go index 607f951..d076785 100644 --- a/forge-cli/runtime/auth_config_test.go +++ b/forge-cli/runtime/auth_config_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/initializ/forge/forge-core/auth" @@ -252,6 +253,116 @@ func TestBuildUserAuthChain_BothSourcesPrefersYAML(t *testing.T) { } } +// --- --no-auth vs forge.yaml auth: block (review #4) --- +// +// Phase 1 + PR3 let --no-auth and the YAML auth block be set +// independently. Review #4 cross-checks them at resolveAuth time so +// a contradictory combination fails loudly instead of silently +// granting anonymous access. + +// makeRunnerWithNoAuth returns a Runner with NoAuth=true and the given +// AuthConfig in forge.yaml. Captures logger output via a recording +// logger so warning assertions can read what the runner emitted. +func makeRunnerWithNoAuth(t *testing.T, authYAML types.AuthConfig) (*Runner, *recordLogger) { + t.Helper() + logger := &recordLogger{} + r := &Runner{logger: logger} + r.cfg.Config = &types.ForgeConfig{Auth: authYAML} + r.cfg.NoAuth = true + return r, logger +} + +func TestResolveAuth_NoAuthWithRequiredFails(t *testing.T) { + r, _ := makeRunnerWithNoAuth(t, types.AuthConfig{ + Required: true, + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }) + _, err := r.resolveAuth(nil) + if err == nil { + t.Fatal("expected error when --no-auth conflicts with auth.required: true") + } + if !strings.Contains(err.Error(), "--no-auth conflicts with") { + t.Errorf("err = %q, want descriptive --no-auth/required mismatch message", err) + } +} + +func TestResolveAuth_NoAuthWithProvidersWarns(t *testing.T) { + r, logger := makeRunnerWithNoAuth(t, types.AuthConfig{ + // Required: false — the operator hasn't asserted that auth is + // mandatory; --no-auth should still work but it surprises them + // that the YAML providers are ignored, so we warn. + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }) + opts, err := r.resolveAuth(nil) + if err != nil { + t.Fatalf("resolveAuth returned error, want warning only: %v", err) + } + if !opts.AllowAnonymous { + t.Error("expected AllowAnonymous=true under --no-auth") + } + if !logger.warnedAbout("--no-auth overrides") { + t.Errorf("expected --no-auth-overrides warning; got: %+v", logger.warnings) + } +} + +func TestResolveAuth_NoAuthWithEmptyYAMLConfig_NoWarning(t *testing.T) { + // Regression: --no-auth alone (no auth: block at all) must still + // work without spamming a warning. + r, logger := makeRunnerWithNoAuth(t, types.AuthConfig{}) + opts, err := r.resolveAuth(nil) + if err != nil { + t.Fatalf("resolveAuth: %v", err) + } + if !opts.AllowAnonymous { + t.Error("expected AllowAnonymous=true under --no-auth") + } + if len(logger.warnings) > 0 { + t.Errorf("unexpected warnings emitted: %+v", logger.warnings) + } +} + +func TestResolveAuth_NoAuthWithRequiredFalseAndProviders_WarnsNotFails(t *testing.T) { + // Explicit Required: false + Providers set — operator knows what + // they're doing; we warn but don't refuse. + r, logger := makeRunnerWithNoAuth(t, types.AuthConfig{ + Required: false, + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }) + if _, err := r.resolveAuth(nil); err != nil { + t.Fatalf("resolveAuth should not fail when Required=false: %v", err) + } + if !logger.warnedAbout("--no-auth overrides") { + t.Error("expected override warning") + } +} + +// recordLogger captures Warn/Info/Debug/Error calls for assertions. +// Mirrors the no-op nopLogger pattern in auth_config_test.go but +// records messages so tests can assert on them. +type recordLogger struct { + warnings []string + infos []string +} + +func (r *recordLogger) Debug(msg string, _ map[string]any) {} +func (r *recordLogger) Info(msg string, _ map[string]any) { r.infos = append(r.infos, msg) } +func (r *recordLogger) Warn(msg string, _ map[string]any) { r.warnings = append(r.warnings, msg) } +func (r *recordLogger) Error(msg string, _ map[string]any) {} +func (r *recordLogger) warnedAbout(substr string) bool { + for _, w := range r.warnings { + if strings.Contains(w, substr) { + return true + } + } + return false +} + func TestBuildUserAuthChain_OIDCFromYAML(t *testing.T) { // Builds an oidc provider successfully — confirms the side-effect // import wired it into the registry. Network isn't actually exercised diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 8cdb2b4..95f69ed 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -1751,6 +1751,28 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew } if r.cfg.NoAuth { + // Cross-check --no-auth against the forge.yaml auth block. The + // flag and the YAML have historically been treated independently; + // review #4 closes the gap so a misaligned pair fails loudly + // instead of silently serving anonymous traffic on what the + // operator declared a required-auth deployment. + if r.cfg.Config != nil { + authCfg := r.cfg.Config.Auth + if authCfg.Required { + return auth.MiddlewareOptions{}, fmt.Errorf( + "--no-auth conflicts with forge.yaml 'auth.required: true' — " + + "either remove --no-auth, set 'auth.required: false', or " + + "delete the 'auth:' block to confirm anonymous access is intended") + } + if len(authCfg.Providers) > 0 { + r.logger.Warn( + "--no-auth overrides forge.yaml 'auth.providers' — configured "+ + "providers will be ignored and the agent will accept anonymous "+ + "traffic. Remove --no-auth to enforce the configured chain.", + map[string]any{"providers_configured": len(authCfg.Providers)}, + ) + } + } // Operator explicitly chose anonymous via --no-auth. AllowAnonymous // makes that choice visible at the middleware boundary (review #3). return auth.MiddlewareOptions{ From d8f1f9f0a1baaaab334f42efcf03c4698d7fd1d4 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:36:21 -0400 Subject: [PATCH 12/20] test(auth/oidc): prove no-hammering on JWKS refresh failure (review #5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: when refreshLocked errors, lastMissRefresh is not updated → every subsequent unknown-kid request hammers the IdP again → a flapping IdP becomes a DDoS amplifier (N callers × broken IdP = N IdP hits per request burst). The reviewer was looking at the pre-fix-#1 code. Fix #1 (commit 379d578) addressed this concern through a DIFFERENT mechanism: it added a lastAttempt + lastErr pair and a backoffActive() guard, so any refresh error (not just unknown-kid-after-refresh) suppresses subsequent refresh attempts inside the refetchGrace window. This commit makes the no-amplifier property explicit by: 1. Adding TestJWKS_FailedRefreshDoesNotHammerIdP — fires 50 unknown-kid requests during a backoff window and asserts the JWKS endpoint receives exactly ONE hit (the initial refresh that failed). Then advances time past refetchGrace and asserts the next request triggers EXACTLY ONE retry — and the burst after that retry stays capped at one IdP call again. 2. Adding a doc comment on backoffActive() that names the DDoS-amplifier concern explicitly and references the test that pins the contract, so future maintainers don't accidentally remove the gate. No code behavior change — fix #1 already produced the right behavior; this commit makes the test coverage match. Verification: go test -race ./forge-core/auth/providers/oidc/ — green (incl. new test) full sweep — 39/39 packages pass golangci-lint v2.10.1 — 0 issues --- forge-core/auth/providers/oidc/jwks.go | 11 ++- .../auth/providers/oidc/provider_test.go | 99 +++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/forge-core/auth/providers/oidc/jwks.go b/forge-core/auth/providers/oidc/jwks.go index 5088866..d56a98b 100644 --- a/forge-core/auth/providers/oidc/jwks.go +++ b/forge-core/auth/providers/oidc/jwks.go @@ -129,8 +129,15 @@ func (c *jwksCache) ttlExpired() bool { // backoffActive reports whether we should suppress a refresh attempt // because a recent one failed. Prevents a dead JWKS endpoint from being -// hit on every request during an outage. Must be called with at least -// an RLock held. +// hit on every request during an outage — without this gate, N concurrent +// callers against a flapping IdP would amplify into N IdP hits per +// request burst (a DDoS amplifier). Must be called with at least an +// RLock held. +// +// Test coverage: TestJWKS_FailedRefreshDoesNotHammerIdP in provider_test.go +// pins the "exactly one IdP hit per refetchGrace window during failure" +// contract — fires 50 unknown-kid requests in a backoff window and asserts +// JWKS endpoint is called once. func (c *jwksCache) backoffActive() bool { if c.lastErr == nil { return false diff --git a/forge-core/auth/providers/oidc/provider_test.go b/forge-core/auth/providers/oidc/provider_test.go index bf24b87..37b7adc 100644 --- a/forge-core/auth/providers/oidc/provider_test.go +++ b/forge-core/auth/providers/oidc/provider_test.go @@ -997,6 +997,105 @@ func TestJWKS_RevokedKeyEventuallyRejected(t *testing.T) { } } +func TestJWKS_FailedRefreshDoesNotHammerIdP(t *testing.T) { + // Review #5 explicit guard: when refreshLocked fails, every + // subsequent unknown-kid (or any miss) request inside the same + // refetchGrace window must NOT trigger another IdP fetch. A + // flapping IdP otherwise turns into a DDoS amplifier — N concurrent + // callers × broken IdP = N IdP hits per request burst. + // + // The mechanism is in jwks.go's backoffActive(): on any refresh + // error, lastErr+lastAttempt are set; subsequent Get calls within + // refetchGrace short-circuit with ErrKeyNotFound, without re-calling + // refreshLocked. This test pins that behavior. + fi := newFakeIssuer(t) + var jwksHits atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/jwks" { + jwksHits.Add(1) + w.WriteHeader(http.StatusInternalServerError) + return + } + fi.serve(w, r) + })) + defer srv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: srv.URL, + Audience: testAudience, + JWKSURL: srv.URL + "/jwks", + JWKSCacheTTL: 10 * time.Minute, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + // Build a syntactically-valid JWT pointing at a kid the broken + // JWKS would never produce — guarantees we hit the refresh path + // (cache miss + TTL expired since lastSuccessful is zero). + tok := buildJWTWithKid(t, fi.keys["key-1"].priv, "unknown-kid") + + // First call attempts to refresh, fails (500), records the error + // and bumps lastAttempt → from now backoff is active. + _, _ = p.Verify(context.Background(), tok, nil) + firstWindowHits := jwksHits.Load() + if firstWindowHits != 1 { + t.Fatalf("first Verify did not attempt JWKS (hits=%d)", firstWindowHits) + } + + // Fire a burst of requests inside the backoff window. NONE of these + // should reach the JWKS endpoint — backoffActive() short-circuits. + for range 50 { + _, _ = p.Verify(context.Background(), tok, nil) + } + if got := jwksHits.Load(); got != firstWindowHits { + t.Errorf("DDoS amplifier: 50 unknown-kid requests during backoff produced %d IdP hits (want %d)", + got, firstWindowHits) + } + + // Advance past refetchGrace (default 30s). The next request should + // be allowed to re-attempt the IdP exactly once. + clk.advance(35 * time.Second) + _, _ = p.Verify(context.Background(), tok, nil) + if got := jwksHits.Load(); got != firstWindowHits+1 { + t.Errorf("post-grace retry: hits = %d, want %d (one retry per grace window)", + got, firstWindowHits+1) + } + + // And another burst still inside the new grace window — same lock-out. + for range 50 { + _, _ = p.Verify(context.Background(), tok, nil) + } + if got := jwksHits.Load(); got != firstWindowHits+1 { + t.Errorf("second backoff window leaked: hits = %d, want %d", + got, firstWindowHits+1) + } +} + +// buildJWTWithKid signs a token with the given private key but stamps an +// arbitrary kid in the header. Used by the JWKS hammer test to force the +// provider into the unknown-kid + refresh path. +func buildJWTWithKid(t *testing.T, priv any, kid string) string { + t.Helper() + claims := jwt.MapClaims{ + "iss": "x", + "aud": testAudience, + "sub": "x", + "exp": time.Now().Add(15 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + s, err := tok.SignedString(priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + return s +} + func TestJWKS_FailedRefreshDoesNotExtendTTL(t *testing.T) { // Bug guard: a failed refresh must not bump lastSuccessful. Otherwise // an IdP outage during a TTL-refresh window would silently extend From 093079e6baf47265d132272be400d6b0c0e848be Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:41:22 -0400 Subject: [PATCH 13/20] fix(auth): ErrProviderUnavailable for verifier-down scenarios (review #6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: httpverifier classified 5xx + network errors as ErrInvalidToken. Fail-closed direction was correct, but the audit signal was misleading — operators reading audit logs saw "invalid token" alerts firing when the actual problem was the IdP being down. Different response, different runbook, distinct alert. Fix: - New sentinel auth.ErrProviderUnavailable. Same chain behavior as ErrInvalidToken (stop, fail-closed, 401 response), but distinct audit signal. - httpverifier reclassifies its error mapping: HTTP 200 + valid:true → Identity (unchanged) HTTP 200 + valid:false → ErrTokenRejected (unchanged) HTTP 401 → ErrTokenRejected (unchanged) HTTP other 4xx → ErrTokenRejected (NEW: was ErrInvalidToken) HTTP 5xx → ErrProviderUnavailable (CHANGED) Network/transport error → ErrProviderUnavailable (CHANGED) Undecodable body → ErrProviderUnavailable (CHANGED; 200-but-garbage is verifier misbehavior, not a token issue) Local marshal/build err → ErrInvalidToken (unchanged; never the verifier's fault) - middleware.classifyAuthFailure: new branch returns "auth provider unavailable" as the client-visible message. - runner.authFailReason: new reason code "provider_unavailable" in the audit event. Documented in the function's reason-code table. Audit-event contract additions: auth_fail { reason: "provider_unavailable", ... } Operators dashboarding on `reason:provider_unavailable` can now alert on IdP downtime distinctly from token issues. Tests (all -race clean): TestVerify_500_ReturnsProviderUnavailable (httpverifier) TestVerify_502BadGateway_ReturnsProviderUnavailable TestVerify_NetworkError_ReturnsProviderUnavailable TestVerify_Non401_4xx_ReturnsRejected (regression: 400/403 stay rejected) TestVerify_UndecodableBody_ReturnsProviderUnavailable TestAuthAudit_EmitsAuthFailOnError (new "provider_unavailable" row) TestClassifyAuthFailure (new "auth provider unavailable" message row) Backward compat: - Chain behavior preserved (ErrProviderUnavailable fail-closed like ErrInvalidToken). - Existing client retry logic that retried on "invalid token" responses will still retry on "auth provider unavailable" (both are 401-class). - ErrInvalidToken sentinel preserved for token-side failures elsewhere (oidc malformed JWT, bad kid, etc.) — unchanged. OIDC provider not updated in this commit. Its JWKS-fetch failures currently surface via the chain's "any other error fails closed" path, mapped to authFailReason "infrastructure". Could be tightened to ErrProviderUnavailable in a follow-up, but reviewer's specific concern was the misleading "invalid token" signal from httpverifier — OIDC already produces "infrastructure" which is at least not misleading. Verification: go test -race forge-core/auth/... forge-cli/runtime/ — green full sweep — all packages pass golangci-lint v2.10.1 — 0 issues --- forge-cli/runtime/auth_audit_test.go | 3 + forge-cli/runtime/runner.go | 15 +++++ forge-core/auth/middleware.go | 5 ++ forge-core/auth/middleware_test.go | 4 ++ forge-core/auth/provider.go | 6 ++ .../auth/providers/httpverifier/provider.go | 46 ++++++++++---- .../providers/httpverifier/provider_test.go | 61 +++++++++++++++++-- 7 files changed, 122 insertions(+), 18 deletions(-) diff --git a/forge-cli/runtime/auth_audit_test.go b/forge-cli/runtime/auth_audit_test.go index a81d0b0..7e72b76 100644 --- a/forge-cli/runtime/auth_audit_test.go +++ b/forge-cli/runtime/auth_audit_test.go @@ -91,6 +91,9 @@ func TestAuthAudit_EmitsAuthFailOnError(t *testing.T) { {"rejected", auth.ErrTokenRejected, "rejected"}, {"invalid", auth.ErrInvalidToken, "invalid"}, {"not for me", auth.ErrTokenNotForMe, "not_for_me"}, + // Review #6: provider downtime gets its own stable reason code + // so operators can dashboard "IdP down" separately from "token bad". + {"provider unavailable", auth.ErrProviderUnavailable, "provider_unavailable"}, {"infrastructure", errBoom, "infrastructure"}, } diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 95f69ed..97b2b5c 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -1883,6 +1883,19 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ // code suitable for dashboarding and alerting. Reason strings are part // of the audit-event contract — changing them is a breaking change for // downstream consumers. +// +// Reason codes: +// +// missing_token - no Authorization header +// rejected - provider recognized + denied (revoked, expired, 401, 4xx) +// invalid - token malformed or cryptographically invalid +// not_for_me - chain exhausted, no provider claimed the token +// provider_unavailable - verifier/IdP unreachable (5xx, network, undecodable) +// infrastructure - other unexpected error +// +// provider_unavailable was added in review #6 so operators can +// distinguish "the token is bad" alerts from "the IdP is down" alerts +// in their dashboards — the response and the runbook are different. func authFailReason(err error) string { switch { case err == nil: @@ -1893,6 +1906,8 @@ func authFailReason(err error) string { return "rejected" case errors.Is(err, auth.ErrInvalidToken): return "invalid" + case errors.Is(err, auth.ErrProviderUnavailable): + return "provider_unavailable" case errors.Is(err, auth.ErrTokenNotForMe): return "not_for_me" default: diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index e5e7d57..cb09a87 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -150,6 +150,11 @@ func classifyAuthFailure(err error) string { return "token rejected by auth provider" case errors.Is(err, ErrInvalidToken): return "invalid token" + case errors.Is(err, ErrProviderUnavailable): + // Surface a distinct user-visible message so retry behavior on + // the client can be different from "invalid token". This is also + // the operator-facing signal in /healthz-style probes. + return "auth provider unavailable" default: return "auth provider error" } diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index 643195b..f163e2f 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -453,6 +453,10 @@ func TestClassifyAuthFailure(t *testing.T) { {"not for me", ErrTokenNotForMe, "valid bearer token required"}, {"rejected", ErrTokenRejected, "token rejected by auth provider"}, {"invalid", ErrInvalidToken, "invalid token"}, + // Review #6: ErrProviderUnavailable gets its own user-visible + // message — distinct from "invalid token" so client retry logic + // and operator alerting can differentiate. + {"provider unavailable", ErrProviderUnavailable, "auth provider unavailable"}, {"unexpected", context.Canceled, "auth provider error"}, } for _, tt := range tests { diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go index 9cb5123..4d115db 100644 --- a/forge-core/auth/provider.go +++ b/forge-core/auth/provider.go @@ -119,10 +119,16 @@ func TokenKind(token string) string { // ChainProvider stops and the middleware writes 401. // - ErrInvalidToken → token is malformed or cryptographically invalid; // ChainProvider stops and the middleware writes 401. +// - ErrProviderUnavailable → the verifier / IdP is unreachable or returned +// a transport-layer error (5xx, network timeout, garbage response). The +// token MAY be valid — we just can't say. ChainProvider stops (fail-closed, +// same as ErrInvalidToken), but the audit signal is distinct so operators +// don't chase a token issue when the actual problem is provider downtime. // - ErrProviderNotConfigured → returned by New(); never by Verify(). var ( ErrTokenNotForMe = errors.New("auth: token not for this provider") ErrTokenRejected = errors.New("auth: token rejected") ErrInvalidToken = errors.New("auth: invalid token") + ErrProviderUnavailable = errors.New("auth: provider unavailable") ErrProviderNotConfigured = errors.New("auth: provider not configured") ) diff --git a/forge-core/auth/providers/httpverifier/provider.go b/forge-core/auth/providers/httpverifier/provider.go index ec0cc35..90c5f39 100644 --- a/forge-core/auth/providers/httpverifier/provider.go +++ b/forge-core/auth/providers/httpverifier/provider.go @@ -105,11 +105,16 @@ type verifyResponse struct { // caller's org_id) to the configured verifier URL and translates the // response into an Identity or one of the sentinel errors. // -// Mapping: -// - HTTP 200 + valid:true → (Identity, nil) -// - HTTP 200 + valid:false → ErrTokenRejected -// - HTTP 401 → ErrTokenRejected -// - HTTP other / I/O error → ErrInvalidToken (wrapped cause) +// Mapping (review #6 — separated "token bad" from "verifier down"): +// +// - HTTP 200 + valid:true → (Identity, nil) +// - HTTP 200 + valid:false → ErrTokenRejected +// - HTTP 401 → ErrTokenRejected +// - HTTP 4xx (other) → ErrTokenRejected (verifier denied — token-side) +// - HTTP 5xx → ErrProviderUnavailable (server-side failure) +// - Network / transport error → ErrProviderUnavailable +// - Response body undecodable → ErrProviderUnavailable (verifier returned garbage) +// - Local marshal / request-build err → ErrInvalidToken (extremely rare; bug in caller path) // // This provider does not return ErrTokenNotForMe. func (p *Provider) Verify(ctx context.Context, token string, headers auth.Headers) (*auth.Identity, error) { @@ -117,6 +122,9 @@ func (p *Provider) Verify(ctx context.Context, token string, headers auth.Header body, err := json.Marshal(verifyRequest{Token: token, OrgID: orgID}) if err != nil { + // Local marshal failure — should never happen with the fixed + // struct shape. Keep ErrInvalidToken since this isn't the + // verifier's fault. return nil, fmt.Errorf("%w: marshal request: %w", auth.ErrInvalidToken, err) } @@ -128,24 +136,38 @@ func (p *Provider) Verify(ctx context.Context, token string, headers auth.Header resp, err := p.client.Do(req) if err != nil { - return nil, fmt.Errorf("%w: call verifier: %w", auth.ErrInvalidToken, err) + // Network error reaching the verifier — provider is unreachable. + // Distinct from a token problem; audit reflects this. + return nil, fmt.Errorf("%w: call verifier: %w", auth.ErrProviderUnavailable, err) } defer func() { _ = resp.Body.Close() }() - switch resp.StatusCode { - case http.StatusOK: + switch { + case resp.StatusCode == http.StatusOK: // fall through to decode body - case http.StatusUnauthorized: + case resp.StatusCode == http.StatusUnauthorized: return nil, auth.ErrTokenRejected + case resp.StatusCode >= 500 && resp.StatusCode < 600: + // 5xx — verifier broken or overloaded. We can't say whether the + // token is valid; surface as provider unavailable. + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return nil, fmt.Errorf("%w: verifier returned status %d", auth.ErrProviderUnavailable, resp.StatusCode) + case resp.StatusCode >= 400 && resp.StatusCode < 500: + // Non-401 4xx — verifier explicitly refused the request (e.g., + // 400 bad request, 403 forbidden). Treat as token-side rejection. + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return nil, fmt.Errorf("%w: verifier returned status %d", auth.ErrTokenRejected, resp.StatusCode) default: - // Drain a bounded amount of body for diagnostics, then fail. + // 1xx / 3xx — undefined for this contract. Treat as unavailable. _, _ = io.CopyN(io.Discard, resp.Body, 1024) - return nil, fmt.Errorf("%w: verifier returned status %d", auth.ErrInvalidToken, resp.StatusCode) + return nil, fmt.Errorf("%w: verifier returned unexpected status %d", auth.ErrProviderUnavailable, resp.StatusCode) } var result verifyResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("%w: decode response: %w", auth.ErrInvalidToken, err) + // Verifier returned 200 but the body isn't the contract shape — + // provider misbehavior, not a token issue. + return nil, fmt.Errorf("%w: decode response: %w", auth.ErrProviderUnavailable, err) } if !result.Valid { return nil, auth.ErrTokenRejected diff --git a/forge-core/auth/providers/httpverifier/provider_test.go b/forge-core/auth/providers/httpverifier/provider_test.go index fd70703..b35a0ee 100644 --- a/forge-core/auth/providers/httpverifier/provider_test.go +++ b/forge-core/auth/providers/httpverifier/provider_test.go @@ -109,26 +109,75 @@ func TestVerify_401_ReturnsRejected(t *testing.T) { } } -func TestVerify_500_ReturnsInvalidToken(t *testing.T) { +func TestVerify_500_ReturnsProviderUnavailable(t *testing.T) { + // Review #6: 5xx is the verifier's fault, not the token's. Distinct + // sentinel so operators triaging audit logs don't chase token issues + // when the actual problem is verifier downtime. srv := fakeVerifier(t, func(map[string]any) (int, any) { return http.StatusInternalServerError, map[string]any{"error": "boom"} }) p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) _, err := p.Verify(context.Background(), "x", nil) - if !errors.Is(err, auth.ErrInvalidToken) { - t.Fatalf("err = %v, want ErrInvalidToken", err) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } + if errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err also matched ErrInvalidToken — sentinels must not overlap: %v", err) } } -func TestVerify_NetworkError_ReturnsInvalidToken(t *testing.T) { +func TestVerify_502BadGateway_ReturnsProviderUnavailable(t *testing.T) { + // Other 5xx codes — same classification. + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusBadGateway, nil + }) + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestVerify_NetworkError_ReturnsProviderUnavailable(t *testing.T) { p, _ := httpverifier.New(httpverifier.Config{ URL: "http://127.0.0.1:0/never", // port 0 is invalid, will fail to connect Timeout: 100 * time.Millisecond, }) _, err := p.Verify(context.Background(), "x", nil) - if !errors.Is(err, auth.ErrInvalidToken) { - t.Fatalf("err = %v, want ErrInvalidToken", err) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestVerify_Non401_4xx_ReturnsRejected(t *testing.T) { + // 4xx other than 401 (e.g., 400, 403) means the verifier + // explicitly refused the request — token-side, not server-side. + for _, code := range []int{http.StatusBadRequest, http.StatusForbidden} { + t.Run(http.StatusText(code), func(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { return code, nil }) + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("status %d → err = %v, want ErrTokenRejected", code, err) + } + }) + } +} + +func TestVerify_UndecodableBody_ReturnsProviderUnavailable(t *testing.T) { + // 200 OK but body isn't the contract JSON — verifier misbehavior, + // not a token issue. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("not even close to JSON {{{}")) + })) + defer srv.Close() + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) } } From cf4ae14dda0d488055c9962085141dbf98b08db8 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:45:09 -0400 Subject: [PATCH 14/20] docs(auth/security): document port-stripping contract (review #7) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer asked: is the egress matcher port-agnostic? If yes, the current TestAuthDomains_PortStripped behavior is correct; if no, an issuer on :8443 is silently blocked at JWKS-fetch time. Investigation result: YES, the egress matcher IS port-agnostic. All three callsites strip the port off the outbound host before passing to matcher.IsAllowed: - egress_enforcer.go:40 — req.URL.Hostname() before IsAllowed - egress_proxy.go:210 — extractHost() via net.SplitHostPort - safe_dialer.go:42 — net.SplitHostPort(addr) So auth_domains.go emitting hostname-only entries is correct and consistent. No code fix needed. What WAS missing: the cross-package contract was implicit. If someone later hardens the matcher to be port-aware ("allow :443 but not :8080"), auth_domains.go would silently break — every forge.yaml that points at a non-443 IdP would suddenly fail JWKS fetches. This commit: - Adds a doc comment on hostFromURL() naming the three callsites whose port-stripping behavior we depend on. If any of them changes, auth_domains needs to follow. - Adds TestAuthDomains_AssumesPortAgnosticMatcher that asserts the contract directly: builds a matcher with the host AuthDomains produces, checks the hostname-only form is allowed, AND checks that host:port form is NOT (yet) allowed. If a future change makes the matcher port-aware, the second assertion flips and the test flags the cross-package break. No production code changes. The fix is preventative documentation + test, not behavior change. Verification: go test -race ./forge-core/security/ ./forge-core/auth/... — green full sweep — all packages pass golangci-lint v2.10.1 — 0 issues --- forge-core/security/auth_domains.go | 16 +++++++ forge-core/security/auth_domains_test.go | 59 ++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go index b82957f..b96b1e0 100644 --- a/forge-core/security/auth_domains.go +++ b/forge-core/security/auth_domains.go @@ -77,6 +77,22 @@ func settingString(m map[string]any, key string) string { // hostFromURL extracts the bare host (without port) from a URL string. // Returns "" for unparseable values — the caller skips those silently // since real validation happens in validate.ValidateAuthConfig. +// +// Port stripping is intentional and depends on a cross-package contract: +// the egress matcher itself is port-agnostic at the three callsites +// where it's consulted — +// +// - egress_enforcer.go: req.URL.Hostname() before matcher.IsAllowed +// - egress_proxy.go: extractHost() via net.SplitHostPort +// - safe_dialer.go: net.SplitHostPort(addr) +// +// As long as those callsites strip the port off the OUTBOUND host before +// matching, a hostname entry in the allowlist suffices for any port. If +// any of those callsites is ever changed to pass host:port to the +// matcher, this function (and TestAuthDomains_AssumesPortAgnosticMatcher) +// must change too — an issuer at https://example.com:8443 would otherwise +// be silently blocked because the allowlist would contain only +// "example.com" while the dialer asked for "example.com:8443". func hostFromURL(raw string) string { if raw == "" { return "" diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go index a917e07..c7a1616 100644 --- a/forge-core/security/auth_domains_test.go +++ b/forge-core/security/auth_domains_test.go @@ -134,3 +134,62 @@ func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { t.Errorf("unknown provider type returned domains: %v (must be nil — extractor not registered)", got) } } + +// TestAuthDomains_AssumesPortAgnosticMatcher pins the cross-package +// contract that AuthDomains relies on. AuthDomains strips ports +// (e.g., "https://login.example.com:8443" → "login.example.com") on the +// assumption that the egress matcher will ALSO strip ports off outbound +// hosts before checking the allowlist. If that assumption ever breaks +// (someone makes the matcher port-aware), AuthDomains needs to flip to +// emitting host:port, OR every existing forge.yaml that points at a +// non-443 IdP suddenly silently blocks at JWKS-fetch time. +// +// This test catches that drift early: it builds a matcher with the +// hostname-only allowlist AuthDomains would produce, then asks it +// whether the SAME hostname WITH a port is allowed. If the matcher +// answers "no", the contract is broken and someone needs to look at +// security/auth_domains.go. +func TestAuthDomains_AssumesPortAgnosticMatcher(t *testing.T) { + hosts := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "http://localhost:8080/realms/dev", + "audience": "api://forge", + }}, + }, + }) + if len(hosts) == 0 { + t.Fatal("AuthDomains returned no hosts for a configured OIDC issuer") + } + + // Build a matcher with exactly what the runner would feed it. + matcher := security.NewDomainMatcher(security.ModeAllowlist, hosts) + + // What we want: matching the raw hostname must succeed. + if !matcher.IsAllowed("localhost") { + t.Fatalf("matcher rejected the very hostname AuthDomains produced: %v", hosts) + } + + // What this guards: the matcher MUST also be port-agnostic. The + // dialer / enforcer call IsAllowed with hostname-only strings (port + // is stripped via net.SplitHostPort or url.Hostname()). If a future + // change ever makes the matcher inspect ports, this contract + // changes — and this test will need to flip together with + // AuthDomains' port-stripping behavior. + // + // We can't directly probe "matcher accepts localhost:8080" because + // the matcher's documented input is hostname-only. The integration + // contract is enforced at the dialer layer (see egress_enforcer.go:40 + // — req.URL.Hostname() strips the port before matcher.IsAllowed). + // What we CAN do here is assert the matcher's input shape: passing + // a host:port string MUST NOT match an allowlist with the bare host. + // That guarantees we'll notice if the matcher silently grew port + // awareness — IsAllowed("localhost:8080") would then start matching + // a "localhost" allowlist entry and this assertion would flip. + if matcher.IsAllowed("localhost:8080") { + t.Fatal("matcher unexpectedly accepts host:port form — the contract " + + "AuthDomains assumes (port stripped before IsAllowed) may have " + + "changed. Review security/auth_domains.go and the documented " + + "callsites (egress_enforcer.go, egress_proxy.go, safe_dialer.go).") + } +} From f30df7c24dfb3eb5673532fb73c9317e7ece6b92 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:49:43 -0400 Subject: [PATCH 15/20] refactor(ui/static): drop misleading static/dist/ directory (review #8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: forge-ui/static/dist/app.js is a 110KB hand-edited ES-module file living under a path (dist/) that conventionally signals "generated build artifact, don't touch." There is no source folder, no package.json, no bundler config, no sourcemap. Every UI change has been a direct edit to the "dist" file. This will only get worse as more wizard steps land. Decision: rename out of dist/ rather than introduce a bundler. A bundler (esbuild, etc.) would add real operational debt — package.json, npm install in CI, version pinning, sourcemap rotation, build step in the Makefile, contributor onboarding cost — for files that are hand-written single-file ES modules with no JSX, no TypeScript, no imports beyond htm/preact CDN globals. The file IS the source. Removing the misleading dist/ name communicates that honestly. Changes: - git mv app.js / style.css / index.html / monaco/ up one level from forge-ui/static/dist/ → forge-ui/static/. History preserved via git rename detection. - static/embed.go: //go:embed all:dist → //go:embed app.js style.css index.html monaco. Explicit list keeps embed.go itself out of the served assets and documents the asset surface. - server.go: drop fs.Sub(static.FS, "dist") — the FS is now rooted directly at assets. Removes the unused "io/fs" import. The SPA fallback's static.FS.Open() replaces distFS.Open(). Manual smoke (against built binary): - GET / → 200 text/html (index.html) - GET /app.js → 200 text/javascript, 110469 bytes (unchanged) - GET /style.css → 200 - GET /monaco/editor.js → 200, 2317605 bytes - GET /create → 200 (SPA fallback still works) - GET /api/wizard/meta auth_provider_types length == 4 (PR6 unaffected) Future: if app.js ever does grow JSX/TS/minification needs, that's the right time to introduce a build step — into a NEW static/src/ directory producing into static/build/ or similar. Adding it preemptively today is over-engineering. Verification: go build forge-ui/... — green full sweep — 39/39 packages pass golangci-lint v2.10.1 — 0 issues manual end-to-end smoke test of all asset paths — all 200 --- forge-ui/server.go | 15 ++++++--------- forge-ui/static/{dist => }/app.js | 0 forge-ui/static/embed.go | 16 +++++++++++++++- forge-ui/static/{dist => }/index.html | 0 forge-ui/static/{dist => }/monaco/build.sh | 0 forge-ui/static/{dist => }/monaco/editor.css | 0 forge-ui/static/{dist => }/monaco/editor.js | 0 .../static/{dist => }/monaco/editor.worker.js | 0 forge-ui/static/{dist => }/style.css | 0 9 files changed, 21 insertions(+), 10 deletions(-) rename forge-ui/static/{dist => }/app.js (100%) rename forge-ui/static/{dist => }/index.html (100%) rename forge-ui/static/{dist => }/monaco/build.sh (100%) rename forge-ui/static/{dist => }/monaco/editor.css (100%) rename forge-ui/static/{dist => }/monaco/editor.js (100%) rename forge-ui/static/{dist => }/monaco/editor.worker.js (100%) rename forge-ui/static/{dist => }/style.css (100%) diff --git a/forge-ui/server.go b/forge-ui/server.go index 5f14d59..613c9a5 100644 --- a/forge-ui/server.go +++ b/forge-ui/server.go @@ -3,7 +3,6 @@ package forgeui import ( "context" "fmt" - "io/fs" "log" "net" "net/http" @@ -103,12 +102,10 @@ func (s *UIServer) Start(ctx context.Context) error { mux.HandleFunc("GET /api/agents/{id}/skill-builder/context", s.handleSkillBuilderContext) mux.HandleFunc("GET /api/agents/{id}/skill-builder/provider", s.handleSkillBuilderProvider) - // Static file serving with SPA fallback - distFS, err := fs.Sub(static.FS, "dist") - if err != nil { - return fmt.Errorf("creating sub filesystem: %w", err) - } - fileServer := http.FileServer(http.FS(distFS)) + // Static file serving with SPA fallback. The embedded FS is rooted + // directly at the static assets — no "dist/" subdirectory (review #8 + // removed the misleading dist/ naming). + fileServer := http.FileServer(http.FS(static.FS)) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // Try to serve the file directly path := r.URL.Path @@ -117,8 +114,8 @@ func (s *UIServer) Start(ctx context.Context) error { return } - // Check if the file exists - f, err := distFS.Open(strings.TrimPrefix(path, "/")) + // Check if the file exists in the embedded FS. + f, err := static.FS.Open(strings.TrimPrefix(path, "/")) if err == nil { _ = f.Close() fileServer.ServeHTTP(w, r) diff --git a/forge-ui/static/dist/app.js b/forge-ui/static/app.js similarity index 100% rename from forge-ui/static/dist/app.js rename to forge-ui/static/app.js diff --git a/forge-ui/static/embed.go b/forge-ui/static/embed.go index 4c00508..13f2500 100644 --- a/forge-ui/static/embed.go +++ b/forge-ui/static/embed.go @@ -2,5 +2,19 @@ package static import "embed" -//go:embed all:dist +// FS exposes the agent dashboard's bundled static assets (index.html, +// app.js, style.css, and the Monaco editor under monaco/) as an +// embed.FS so they ship inside the forge binary. +// +// Files are listed explicitly here rather than embedding the whole +// directory so that this file itself (embed.go) is NOT served as an +// asset. If a new top-level static file is added, it needs to be +// added to this directive. +// +// Note: there is no build step — app.js and style.css are hand-edited +// in place, NOT generated from sources elsewhere. The directory used +// to be called "dist/" but that naming was misleading (review #8); it +// signaled "build artifact" when in fact the directory is the source. +// +//go:embed app.js style.css index.html monaco var FS embed.FS diff --git a/forge-ui/static/dist/index.html b/forge-ui/static/index.html similarity index 100% rename from forge-ui/static/dist/index.html rename to forge-ui/static/index.html diff --git a/forge-ui/static/dist/monaco/build.sh b/forge-ui/static/monaco/build.sh similarity index 100% rename from forge-ui/static/dist/monaco/build.sh rename to forge-ui/static/monaco/build.sh diff --git a/forge-ui/static/dist/monaco/editor.css b/forge-ui/static/monaco/editor.css similarity index 100% rename from forge-ui/static/dist/monaco/editor.css rename to forge-ui/static/monaco/editor.css diff --git a/forge-ui/static/dist/monaco/editor.js b/forge-ui/static/monaco/editor.js similarity index 100% rename from forge-ui/static/dist/monaco/editor.js rename to forge-ui/static/monaco/editor.js diff --git a/forge-ui/static/dist/monaco/editor.worker.js b/forge-ui/static/monaco/editor.worker.js similarity index 100% rename from forge-ui/static/dist/monaco/editor.worker.js rename to forge-ui/static/monaco/editor.worker.js diff --git a/forge-ui/static/dist/style.css b/forge-ui/static/style.css similarity index 100% rename from forge-ui/static/dist/style.css rename to forge-ui/static/style.css From 79491c195a6e0557a1ab4e5c49ba557ab6efc777 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:53:48 -0400 Subject: [PATCH 16/20] fix(ui): server-side ValidateAuthConfig in handleCreateAgent (review #9) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: handleCreateAgent decoded JSON, checked only name and model_provider, then handed opts.Auth (with an arbitrary Settings map[string]any) straight to CreateFunc. The validate package's ValidateAuthConfig existed but was never called on this path — so a buggy frontend could POST an oidc payload missing issuer/audience, the agent directory would scaffold cleanly with a malformed auth: block, and the error would only surface when the user ran `forge run`. By then the directory, .env.example, and channel tokens are all on disk. Fix: - New validateAuthPayload(opts.Auth) helper in handlers_create.go. - Called from handleCreateAgent immediately after the name/provider checks, BEFORE CreateFunc runs. - Translates AuthCreateOptions → types.AuthConfig exactly the way cmd/ui.go's createFunc does, so this check sees the same shape the eventual scaffold will. No drift between "wizard accepts" and "forge validate accepts". - 400 with the validation errors joined as the response message. - Modes "" / "none" / "custom" skip validation — they don't write provider entries, nothing to validate. Validation surface (delegated to validate.ValidateAuthConfig): - oidc: issuer + audience required - http_verifier: url required - static_token: token or token_env required - unknown type: rejected with list of known types The frontend's pre-submit checks (canNext in app.js step 8) already guard the happy path, but those are defense in depth at best — anyone with curl can bypass them. This is the server-side enforcement that matters. Tests: TestHandleCreateAgent_OIDCMissingIssuerRejected 400, error names "issuer" TestHandleCreateAgent_OIDCMissingAudienceRejected 400, error names "audience" TestHandleCreateAgent_OIDCBothFieldsAccepted 201 (regression) TestHandleCreateAgent_HTTPVerifierMissingURLRejected 400, error names "url" TestHandleCreateAgent_UnknownAuthModeRejected 400, error names unknown type TestHandleCreateAgent_NoneAndCustomSkipValidation 201 for mode=none and mode=custom TestHandleCreateAgent_WithoutAuthPayload 201 with nil Auth (regression) Each failure test also asserts that CreateFunc was NOT called — the agent directory must never exist on disk for malformed auth. Verification: go test -race ./forge-ui/ — green full sweep — all packages pass golangci-lint v2.10.1 — 0 issues --- forge-ui/handlers_create.go | 52 +++++++++++ forge-ui/handlers_create_test.go | 148 +++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+) diff --git a/forge-ui/handlers_create.go b/forge-ui/handlers_create.go index d668e73..7cd1eed 100644 --- a/forge-ui/handlers_create.go +++ b/forge-ui/handlers_create.go @@ -2,6 +2,7 @@ package forgeui import ( "encoding/json" + "fmt" "net/http" "os" "path/filepath" @@ -153,6 +154,16 @@ func (s *UIServer) handleCreateAgent(w http.ResponseWriter, r *http.Request) { return } + // Server-side validation of the auth payload before scaffolding the + // agent on disk. Without this, a buggy frontend could write a + // malformed auth: block into forge.yaml that only fails at + // `forge run` — after the agent directory has been created and the + // user thinks creation succeeded. Review #9. + if err := validateAuthPayload(opts.Auth); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + agentDir, err := s.cfg.CreateFunc(opts) if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) @@ -174,6 +185,47 @@ func (s *UIServer) handleCreateAgent(w http.ResponseWriter, r *http.Request) { }) } +// validateAuthPayload runs the forge-core validate package over the +// wizard's auth payload before the agent is scaffolded on disk. Returns +// nil for benign shapes (absent, "none", "custom") and for valid +// provider configs; returns an error with the validation messages +// surfaced when settings are malformed. +// +// The translation from the wizard's AuthCreateOptions → types.AuthConfig +// mirrors exactly what cmd/ui.go's createFunc does later, so this check +// has the same view of the inputs as the eventual scaffold. +func validateAuthPayload(a *AuthCreateOptions) error { + if a == nil { + return nil + } + switch a.Mode { + case "", "none", "custom": + // Nothing to validate — these modes don't write a provider entry. + return nil + } + + authYAML := types.AuthConfig{ + // Required is set true here because the wizard's renderAuthBlock + // always emits required:true when a provider is chosen. Keeping + // the validation view aligned with the rendered YAML avoids + // drift between "wizard accepts" and "forge validate accepts". + Required: true, + Providers: []types.AuthProvider{ + { + Type: a.Mode, + Settings: a.Settings, + }, + }, + } + + result := &validate.ValidationResult{} + validate.ValidateAuthConfig(authYAML, result) + if !result.IsValid() { + return fmt.Errorf("auth: %s", strings.Join(result.Errors, "; ")) + } + return nil +} + // handleOAuthStart initiates the OAuth browser flow for a provider. // The flow opens the user's browser for authentication and waits for the callback. func (s *UIServer) handleOAuthStart(w http.ResponseWriter, r *http.Request) { diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index 9ae8671..a220a0f 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" ) @@ -466,6 +467,153 @@ func TestHandleCreateAgent_WithAuthPayload(t *testing.T) { } } +// --- Server-side auth payload validation (review #9) --- +// +// Without this, the wizard could write malformed auth: blocks to disk +// (e.g., oidc without issuer/audience) — creation succeeds, `forge run` +// fails later, user confused. + +func TestHandleCreateAgent_OIDCMissingIssuerRejected(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + // Audience set but no issuer — should fail validation before + // CreateFunc is invoked. + body := []byte(`{ + "name": "bad-oidc", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { "audience": "api://forge" } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (body: %s)", w.Code, w.Body.String()) + } + // CreateFunc must NOT have been called — the agent directory should + // never exist on disk for malformed auth. + if captured.Name != "" { + t.Errorf("CreateFunc was called despite validation failure (captured: %+v)", captured) + } + if !strings.Contains(w.Body.String(), "issuer is required") { + t.Errorf("error body should name the missing field, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_OIDCMissingAudienceRejected(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "bad-oidc-2", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { "issuer": "https://login.example.com" } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", w.Code) + } + if captured.Name != "" { + t.Errorf("CreateFunc called despite validation failure") + } + if !strings.Contains(w.Body.String(), "audience is required") { + t.Errorf("error body should name missing audience, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_OIDCBothFieldsAccepted(t *testing.T) { + // Regression: valid oidc payload still works end-to-end. + srv, captured := setupCreateWithCapture(t) + body := []byte(`{ + "name": "good-oidc", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge" + } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } + if captured.Auth == nil || captured.Auth.Mode != "oidc" { + t.Errorf("expected captured Auth.Mode = oidc, got %+v", captured.Auth) + } +} + +func TestHandleCreateAgent_HTTPVerifierMissingURLRejected(t *testing.T) { + srv, _ := setupCreateWithCapture(t) + body := []byte(`{ + "name": "bad-http", + "model_provider": "openai", + "auth": { "mode": "http_verifier", "settings": {} } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } + if !strings.Contains(w.Body.String(), "url is required") { + t.Errorf("error body should name missing url, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_UnknownAuthModeRejected(t *testing.T) { + srv, _ := setupCreateWithCapture(t) + body := []byte(`{ + "name": "weird", + "model_provider": "openai", + "auth": { "mode": "ldap", "settings": {} } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } + if !strings.Contains(w.Body.String(), "unknown type") { + t.Errorf("error should name unknown-type, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_NoneAndCustomSkipValidation(t *testing.T) { + // mode "none" and "custom" don't produce providers — there's + // nothing to validate. They must round-trip cleanly. + for _, mode := range []string{"none", "custom"} { + t.Run(mode, func(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + body := []byte(`{ + "name": "agent-` + mode + `", + "model_provider": "openai", + "auth": { "mode": "` + mode + `", "settings": {} } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusCreated { + t.Errorf("mode %q: status = %d, want 201", mode, w.Code) + } + if captured.Auth == nil || captured.Auth.Mode != mode { + t.Errorf("Auth.Mode = %v, want %q", captured.Auth, mode) + } + }) + } +} + func TestHandleCreateAgent_WithoutAuthPayload(t *testing.T) { // Omitting `auth` from the request is still a valid create — the // agent gets anonymous access by default. From 51760a89ef993351e89ec44e9f5e00d9c3215381 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 14:57:59 -0400 Subject: [PATCH 17/20] docs(auth): tighten loopback-prepend invariant + pin test (review #10) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged: resolveAuth's comment claimed the loopback static_token is "ALWAYS prepended at the chain head (regardless of source)", but the actual code is conditional on r.authToken != "". It works in practice only because ResolveAuth() always mints a token in the non-NoAuth path. Any future refactor that short-circuits that path would silently break channel-adapter callbacks — no compile error, no test failure, just "my Slack bot stopped working after we updated forge" weeks later. Fix is documentation + an invariant test. The conditional itself is correct (--no-auth path returns early before the prepend, and that path also doesn't mint a token), so the code stays as-is. What was missing is making the upstream dependency explicit so a future maintainer either preserves the invariant or notices when they don't. Changes: - resolveAuth doc: replaced "ALWAYS prepended" with an explicit description of the precondition (ResolveAuth() must have minted a token) and a reference to the pinning test. Inline comment at the prepend site explains why the conditional is defensive, not load-bearing. - ResolveAuth() doc: states the invariant it maintains (after nil return, either r.authToken != "" or r.cfg.NoAuth is true) and names the downstream consumer + test. - New tests: TestResolveAuth_InvariantMintsTokenInNonNoAuthPath Direct invariant pin: non-NoAuth + ResolveAuth() returns nil → r.authToken MUST be non-empty. Failure message tells future maintainers what they broke and where to look. TestResolveAuth_NoAuthPathDoesNotMintToken Counterpart: --no-auth path does NOT mint. Pins the "anonymous mode skips token generation" property too, so neither direction can drift. If someone ever refactors ResolveAuth to skip minting in the non-NoAuth path, the invariant test fails immediately — and the failure message points straight at the loopback-prepend dependency. Verification: go test -race ./forge-cli/runtime/ — green (incl. 2 new tests) full sweep — all packages pass golangci-lint v2.10.1 — 0 issues --- forge-cli/runtime/auth_chain_test.go | 44 ++++++++++++++++++++++++++++ forge-cli/runtime/runner.go | 31 ++++++++++++++++++-- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/forge-cli/runtime/auth_chain_test.go b/forge-cli/runtime/auth_chain_test.go index d45b9ba..d026c1e 100644 --- a/forge-cli/runtime/auth_chain_test.go +++ b/forge-cli/runtime/auth_chain_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/types" ) // E2E tests covering the auth chain that the runner builds from legacy @@ -296,6 +297,49 @@ func TestE2E_VerifierUnreachable_ReturnsInvalidNot500(t *testing.T) { } } +// --- ResolveAuth invariant pin (review #10) --- +// +// The loopback-token prepend in resolveAuth() is conditional on +// r.authToken != "" — which works because ResolveAuth() always mints +// one in the non-NoAuth path. If that invariant ever breaks (refactor +// of ResolveAuth that skips minting), channel adapters silently lose +// their loopback. This test pins the invariant directly. + +func TestResolveAuth_InvariantMintsTokenInNonNoAuthPath(t *testing.T) { + tmp := t.TempDir() + r := &Runner{logger: nopLogger{}} + r.cfg.Config = &types.ForgeConfig{} + r.cfg.NoAuth = false + r.cfg.Host = "127.0.0.1" // localhost so the !local + NoAuth check doesn't trip + r.cfg.WorkDir = tmp // ResolveAuth stores the token under /.forge/ + + if err := r.ResolveAuth(); err != nil { + t.Fatalf("ResolveAuth: %v", err) + } + if r.authToken == "" { + t.Fatal("invariant broken: ResolveAuth() returned nil but did not mint a token " + + "(non-NoAuth path MUST set r.authToken — channels rely on this; " + + "see runner.go resolveAuth loopback-prepend comment)") + } +} + +func TestResolveAuth_NoAuthPathDoesNotMintToken(t *testing.T) { + // Counterpart: in --no-auth mode, no token should be minted. The + // AllowAnonymous branch returns before the prepend ever runs, so + // the absence of a token is correct. + r := &Runner{logger: nopLogger{}} + r.cfg.Config = &types.ForgeConfig{} + r.cfg.NoAuth = true + r.cfg.Host = "127.0.0.1" + + if err := r.ResolveAuth(); err != nil { + t.Fatalf("ResolveAuth: %v", err) + } + if r.authToken != "" { + t.Errorf("--no-auth path should not mint a token, got %q", r.authToken) + } +} + func TestE2E_WireFormat_TokenAndOrgID(t *testing.T) { // Black-box: confirm the verifier sees exactly `token` and `org_id` // JSON fields. Guards against future schema drift. diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 97b2b5c..551bd48 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -127,6 +127,14 @@ func (r *Runner) SetScheduleNotifier(fn ScheduleNotifier) { // ResolveAuth resolves the auth token early (before Run). This is needed so // channel adapters can be configured with the token before Run() blocks. // Safe to call multiple times — subsequent calls are no-ops. +// +// Invariant: after this returns nil, EITHER r.authToken is non-empty OR +// r.cfg.NoAuth is true. resolveAuth() relies on this when it conditionally +// prepends the loopback static_token (review #10). If a future refactor +// adds a return path that violates this invariant, channel-adapter +// callbacks will silently break — the test +// TestResolveAuth_InvariantMintsTokenInNonNoAuthPath in +// auth_chain_test.go pins the property. func (r *Runner) ResolveAuth() error { if r.authToken != "" || r.cfg.NoAuth { return nil // already resolved @@ -1741,9 +1749,17 @@ func (r *Runner) printBanner(proxyURL string) { // If BOTH forge.yaml auth: AND --auth-url are configured, the YAML block // wins and a warning is logged — silent merging would be surprising. // -// The internal loopback static_token is ALWAYS prepended at the chain -// head (regardless of source) so channel-adapter callbacks short-circuit -// before any external provider is invoked. +// Loopback static_token prepending: +// - The internal loopback token is prepended at the chain head WHEN +// ResolveAuth() has minted one (i.e., r.authToken != ""). In the +// non-NoAuth path that's an invariant ResolveAuth maintains (review +// #10). When --no-auth is in effect we return early via the +// AllowAnonymous path above and never reach the prepend. +// - Channel adapter callbacks rely on the loopback short-circuit; if +// ResolveAuth is ever refactored to skip token minting on the +// non-NoAuth path, channels will silently break. TestResolveAuth_ +// InvariantMintsTokenInNonNoAuthPath in auth_chain_test.go pins +// that invariant. func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.MiddlewareOptions, error) { // Ensure token is resolved (no-op if already done by ResolveAuth). if err := r.ResolveAuth(); err != nil { @@ -1788,6 +1804,15 @@ func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Middlew // Prepend the loopback static_token so channel adapter callbacks // short-circuit before any user-configured provider runs. + // + // PRECONDITION: r.authToken is non-empty on this branch because + // ResolveAuth() always mints one in the non-NoAuth path. The + // conditional here is defensive — if someone refactors ResolveAuth + // to skip minting (and forgets to update channels), we'd want the + // rest of the function to still produce a coherent middleware + // (without a loopback) rather than panic. The invariant itself is + // pinned by TestResolveAuth_InvariantMintsTokenInNonNoAuthPath + // (review #10). chain := userChain if r.authToken != "" { loopback, err := statictoken.New(statictoken.Config{ From b3d4264d6e10fe71aa2c7fb333c116ed2e369ff3 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 15:15:05 -0400 Subject: [PATCH 18/20] fix(auth): six smaller-bug fixes from review (review #11) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six related cleanups bundled per the reviewer's grouping: 11a. statictoken length leak (provider.go) subtle.ConstantTimeCompare returns 0 instantly on length mismatch without inspecting bytes, leaking the configured token's length via measurable timing differences across many trials. Fix: SHA-256 hash both sides and constant-time compare the 32-byte digests. The comparison is now over equal-length inputs in every case. Pinned by TestVerify_MismatchedLengths_AreRejected — exercises empty, prefix, suffix, off-by-one, and same-length-but-different inputs. 11b. OIDC PSS-default mismatch (jwks.go, provider.go) parseJWK was defaulting an RSA JWK with no `alg` field to "RS256". The cross-check at provider.go then rejected any PS256/PS384/PS512 token signed by the same RSA key — RFC 7518 says "infer from context" when alg is absent, and the same RSA key signs both RS* and PS* families. Fix: preserve empty alg through parseJWK. In the cross-check, when the JWK's alg is empty, fall back to trusting the token's header alg — which the asymmetric whitelist already validated upstream. When the JWK declares an alg explicitly, the strict comparison still applies (algorithm-confusion defense intact). Pinned by TestVerify_JWKWithoutAlgAcceptsAnyAsymmetric and TestVerify_JWKExplicitAlgStillStrict. 11c. mergeEgressDomains sort order (init_egress.go) deriveEgressDomains returned sorted output; mergeEgressDomains appended AuthEgressHosts unsorted at the tail. When AuthEgressHosts came from a map iteration, the rendered forge.yaml's egress.allowed_domains section diffed across runs even when the underlying input set was unchanged. Fix: sort the merged result. Generated forge.yaml is now stable. Pinned by TestMergeEgressDomains_SortedOutput. 11d. renderAuthBlock dead `name` parameter (init_auth.go, init.go) Caller passed opts.AuthMode for both the mode and name arguments; the name was suppressed by the equality check inside the function. The wizard never captures an explicit provider name today. Fix: dropped the parameter. If a future wizard step adds name capture, reintroduce it then. Existing tests updated; renamed TestRenderAuthBlock_NameSameAsModeSuppressed → TestRenderAuthBlock_NoNameEverEmitted with a doc-pinning intent. 11e. TUI backspace-on-empty undiscoverable (step_auth.go) Backspace-on-empty was the documented escape hatch out of an active sub-step input, but no visible affordance surfaced it. Fix: inline kbd hint under the input in each sub-step phase — "Backspace at empty back to picker". Doesn't change behavior, makes it discoverable. 11f. Identity.Claims unfiltered shallow copy (claims.go, provider.go) The OIDC provider's copyClaims returns the full JWT claim set on Identity.Claims, including custom claims the issuer adds (PII, raw group memberships, internal IDs). Downstream consumers reading Identity.Claims will see all of it. Fix is documentation for now: WARNING on both the Identity.Claims field and the copyClaims helper, naming the risk and pointing future consumers at typed Identity fields (UserID/Email/OrgID/Groups) as the portable surface. Filtering is properly a future authz-layer concern — not something this package can do without policy input. Verification: go test -race forge-core/... forge-cli/... — green full sweep — all packages pass golangci-lint v2.10.1 — 0 issues --- forge-cli/cmd/init.go | 2 +- forge-cli/cmd/init_auth.go | 12 +++-- forge-cli/cmd/init_auth_test.go | 46 +++++++++++------ forge-cli/cmd/init_egress.go | 35 +++++++------ forge-cli/internal/tui/steps/step_auth.go | 9 +++- forge-core/auth/provider.go | 26 +++++++--- forge-core/auth/providers/oidc/claims.go | 15 ++++++ forge-core/auth/providers/oidc/jwks.go | 18 ++++--- forge-core/auth/providers/oidc/provider.go | 11 +++-- .../auth/providers/oidc/provider_test.go | 49 +++++++++++++++++++ .../auth/providers/statictoken/provider.go | 13 ++++- .../providers/statictoken/provider_test.go | 27 ++++++++++ 12 files changed, 207 insertions(+), 56 deletions(-) diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index 2a4ac52..fca28d4 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -1050,7 +1050,7 @@ func buildTemplateData(opts *initOptions) templateData { // Auth chain rendering: pre-render the YAML fragment in Go so nested // maps (claim_map) emit correctly. - data.AuthBlock = renderAuthBlock(opts.AuthMode, opts.AuthMode, opts.AuthSettings) + data.AuthBlock = renderAuthBlock(opts.AuthMode, opts.AuthSettings) // Build env vars data.EnvVars = buildEnvVars(opts) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index 48197eb..43ecd42 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -127,15 +127,20 @@ func hostFromURL(raw string) string { // required: true // providers: // - type: -// name: // settings: // key: value // nested: // k: v // +// The AuthProvider.Name field is intentionally NOT emitted — the wizard +// doesn't capture one, and the previous signature took a `name` arg +// that callers always set equal to mode (suppressed). Removed in +// review #11d. If a future wizard step adds explicit name capture, +// reintroduce the parameter then. +// // Settings keys are emitted in alphabetical order so the output is // deterministic across runs (useful for diffing generated files). -func renderAuthBlock(mode, name string, settings map[string]any) string { +func renderAuthBlock(mode string, settings map[string]any) string { switch mode { case "", "none": return "" @@ -148,9 +153,6 @@ func renderAuthBlock(mode, name string, settings map[string]any) string { b.WriteString(" required: true\n") b.WriteString(" providers:\n") fmt.Fprintf(&b, " - type: %s\n", mode) - if name != "" && name != mode { - fmt.Fprintf(&b, " name: %s\n", name) - } if len(settings) > 0 { b.WriteString(" settings:\n") writeYAMLMap(&b, settings, " ") diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go index b883d24..3af3999 100644 --- a/forge-cli/cmd/init_auth_test.go +++ b/forge-cli/cmd/init_auth_test.go @@ -11,16 +11,16 @@ import ( // --- renderAuthBlock --- func TestRenderAuthBlock_EmptyAndNone(t *testing.T) { - if got := renderAuthBlock("", "", nil); got != "" { + if got := renderAuthBlock("", nil); got != "" { t.Errorf("mode='' → %q, want empty", got) } - if got := renderAuthBlock("none", "", nil); got != "" { + if got := renderAuthBlock("none", nil); got != "" { t.Errorf("mode='none' → %q, want empty", got) } } func TestRenderAuthBlock_Custom(t *testing.T) { - got := renderAuthBlock("custom", "", nil) + got := renderAuthBlock("custom", nil) if !strings.HasPrefix(got, "# Auth provider chain") { t.Errorf("custom mode should start with comment header, got: %s", got) } @@ -30,7 +30,7 @@ func TestRenderAuthBlock_Custom(t *testing.T) { } func TestRenderAuthBlock_OIDC(t *testing.T) { - got := renderAuthBlock("oidc", "corporate-sso", map[string]any{ + got := renderAuthBlock("oidc", map[string]any{ "issuer": "https://login.example.com", "audience": "api://forge", }) @@ -40,7 +40,6 @@ func TestRenderAuthBlock_OIDC(t *testing.T) { " required: true", " providers:", " - type: oidc", - " name: corporate-sso", " settings:", " audience: api://forge", " issuer: https://login.example.com", @@ -52,20 +51,22 @@ func TestRenderAuthBlock_OIDC(t *testing.T) { } } -func TestRenderAuthBlock_NameSameAsModeSuppressed(t *testing.T) { - // When the name equals the mode (the default the wizard supplies), - // don't emit a redundant `name:` line. - got := renderAuthBlock("oidc", "oidc", map[string]any{ +func TestRenderAuthBlock_NoNameEverEmitted(t *testing.T) { + // Review #11d: the wizard doesn't capture an explicit provider + // name, so the rendered YAML must never include a `name:` line. + // If a future wizard step adds name capture, this test changes + // at the same time as the function signature. + got := renderAuthBlock("oidc", map[string]any{ "issuer": "https://x", "audience": "y", }) - if strings.Contains(got, "name: oidc") { - t.Errorf("expected name: line suppressed when same as type, got:\n%s", got) + if strings.Contains(got, "name:") { + t.Errorf("expected no name: line in output, got:\n%s", got) } } func TestRenderAuthBlock_NestedClaimMap(t *testing.T) { - got := renderAuthBlock("oidc", "", map[string]any{ + got := renderAuthBlock("oidc", map[string]any{ "issuer": "https://login.example.com", "audience": "api://forge", "claim_map": map[string]any{"groups": "roles"}, @@ -80,7 +81,7 @@ func TestRenderAuthBlock_NestedClaimMap(t *testing.T) { } func TestRenderAuthBlock_HTTPVerifier(t *testing.T) { - got := renderAuthBlock("http_verifier", "", map[string]any{ + got := renderAuthBlock("http_verifier", map[string]any{ "url": "https://verify.example.com", "default_org": "acme", }) @@ -95,11 +96,11 @@ func TestRenderAuthBlock_HTTPVerifier(t *testing.T) { func TestRenderAuthBlock_DeterministicOrdering(t *testing.T) { // Settings keys should always emit in alphabetical order so diffs // of generated files are stable. - got1 := renderAuthBlock("oidc", "", map[string]any{ + got1 := renderAuthBlock("oidc", map[string]any{ "audience": "a", "issuer": "i", }) - got2 := renderAuthBlock("oidc", "", map[string]any{ + got2 := renderAuthBlock("oidc", map[string]any{ "issuer": "i", "audience": "a", }) @@ -320,3 +321,18 @@ func TestMergeEgressDomains_EmptyInputs(t *testing.T) { t.Errorf("got %v", got) } } + +func TestMergeEgressDomains_SortedOutput(t *testing.T) { + // Review #11c — review of #11 batch: the function returns sorted + // output so the generated forge.yaml stays stable across runs even + // when AuthEgressHosts arrive in non-deterministic order (e.g., + // derived from map iteration). + got := mergeEgressDomains( + []string{"api.openai.com", "api.anthropic.com"}, + []string{"login.example.com", "graph.microsoft.com"}, + ) + want := []string{"api.anthropic.com", "api.openai.com", "graph.microsoft.com", "login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("output not sorted:\n got %v\n want %v", got, want) + } +} diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index d2d7eff..92b632f 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -15,30 +15,29 @@ var providerDomains = map[string]string{ // ollama is local, no egress needed } -// mergeEgressDomains returns the union of `base` and `extra`, preserving -// the order of `base` and appending only previously-unseen items from -// `extra`. The result is deduplicated and stable. +// mergeEgressDomains returns the deduplicated, sorted union of `base` +// and `extra`. Sorting is intentional — `deriveEgressDomains` produces +// sorted output, and appending auth hosts unsorted at the tail (the +// pre-sort behavior of this function) made the rendered forge.yaml +// produce noisy diffs across runs when the auth host set changed. +// Review #11c. // -// Used by the init scaffold to fold auth-provider hosts (from the wizard -// or --auth flags) into the egress allowlist without disturbing the -// existing ordering. +// Used by the init scaffold to fold auth-provider hosts (from the +// wizard or --auth flags) into the egress allowlist while keeping the +// generated forge.yaml stable. func mergeEgressDomains(base, extra []string) []string { seen := make(map[string]bool, len(base)+len(extra)) out := make([]string, 0, len(base)+len(extra)) - for _, d := range base { - if d == "" || seen[d] { - continue - } - seen[d] = true - out = append(out, d) - } - for _, d := range extra { - if d == "" || seen[d] { - continue + for _, slice := range [][]string{base, extra} { + for _, d := range slice { + if d == "" || seen[d] { + continue + } + seen[d] = true + out = append(out, d) } - seen[d] = true - out = append(out, d) } + sort.Strings(out) return out } diff --git a/forge-cli/internal/tui/steps/step_auth.go b/forge-cli/internal/tui/steps/step_auth.go index 1e71de0..000ef16 100644 --- a/forge-cli/internal/tui/steps/step_auth.go +++ b/forge-cli/internal/tui/steps/step_auth.go @@ -270,7 +270,14 @@ func (s *AuthStep) View(width int) string { case authDonePhase: return "" default: - return s.input.View(width) + // Inline the "Backspace at empty: back" hint with the input + // so the escape hatch is discoverable from within the sub-step + // (review #11e — the documented behavior was undiscoverable + // because no kbd hint surfaced it). + hint := " " + + s.styles.KbdKey.Render("Backspace at empty") + " " + + s.styles.KbdDesc.Render("back to picker") + return s.input.View(width) + "\n" + hint } } diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go index 4d115db..527eac2 100644 --- a/forge-core/auth/provider.go +++ b/forge-core/auth/provider.go @@ -44,12 +44,26 @@ type Provider interface { // There is intentionally no Valid field — a non-nil *Identity returned // alongside a nil error is the only "valid" signal. See package comment. type Identity struct { - UserID string `json:"user_id,omitempty"` - Email string `json:"email,omitempty"` - OrgID string `json:"org_id,omitempty"` - WorkspaceID string `json:"workspace_id,omitempty"` - Groups []string `json:"groups,omitempty"` - Claims map[string]any `json:"claims,omitempty"` + UserID string `json:"user_id,omitempty"` + Email string `json:"email,omitempty"` + OrgID string `json:"org_id,omitempty"` + WorkspaceID string `json:"workspace_id,omitempty"` + Groups []string `json:"groups,omitempty"` + + // Claims carries the provider-specific raw payload (typically the + // full JWT claim set for the oidc provider — including custom + // issuer-specific claims). Treat this as an escape hatch for + // provider-specific authorization logic; prefer the typed fields + // above for portable consumers. + // + // WARNING (review #11f): for OIDC the map is an unfiltered shallow + // copy of the JWT claims — `sub`, `email`, `iss`, `aud`, `exp`, + // plus any custom claims the issuer adds (group memberships, + // internal IDs, profile fields, sometimes raw PII). Do NOT log + // this map verbatim. Filtering belongs in a future authz layer, + // not here. + Claims map[string]any `json:"claims,omitempty"` + // Source records which provider verified the identity (e.g., "oidc", // "http_verifier", "static_token"). Useful for audit logs and debugging. Source string `json:"source,omitempty"` diff --git a/forge-core/auth/providers/oidc/claims.go b/forge-core/auth/providers/oidc/claims.go index ef89c0e..de9e8d9 100644 --- a/forge-core/auth/providers/oidc/claims.go +++ b/forge-core/auth/providers/oidc/claims.go @@ -117,6 +117,21 @@ func stringSliceClaim(claims jwt.MapClaims, name string) []string { // copyClaims returns a shallow copy of the claim map so the caller can // expose it on Identity without sharing the underlying map with the // parsed-token state. +// +// WARNING (review #11f): the returned map contains EVERY claim the IdP +// emitted on the token — `sub`, `email`, `iss`, `aud`, `exp`, `iat`, +// `nbf`, plus any custom claims the issuer adds (group memberships, +// internal IDs, profile fields, sometimes raw PII). Downstream consumers +// reading Identity.Claims will see all of it. +// +// Recommendations for consumers: +// - Prefer the mapped, typed fields on Identity (UserID, Email, OrgID, +// Groups) where they suffice — those have a fixed contract. +// - Treat Identity.Claims as an escape hatch for provider-specific +// authorization logic, not as something to log or relay verbatim. +// - When a future authz layer needs a claim-allowlist, that's the +// right place to add it — not here. This package serves the raw +// auth principal; filtering is policy. func copyClaims(claims jwt.MapClaims) map[string]any { if len(claims) == 0 { return nil diff --git a/forge-core/auth/providers/oidc/jwks.go b/forge-core/auth/providers/oidc/jwks.go index d56a98b..6d60437 100644 --- a/forge-core/auth/providers/oidc/jwks.go +++ b/forge-core/auth/providers/oidc/jwks.go @@ -305,7 +305,15 @@ func (c *jwksCache) refreshLocked(ctx context.Context) error { // parseJWK extracts the public key material and derives the signing // algorithm from the JWK. The returned alg is the value the provider -// will require the token header to match. +// will require the token header to match — empty string means +// "any allowed asymmetric alg permitted by the key type" (only happens +// for RSA, where the same key can sign RS256/384/512 or PS256/384/512; +// per RFC 7518 the application is meant to infer the alg from context). +// +// Review #11b: we previously defaulted an RSA JWK with no `alg` field +// to "RS256", which made the cross-check at provider.go reject +// legitimate PS256-signed tokens against the same key. EC keys aren't +// affected — their alg is derived from the curve, not advertised. func parseJWK(k jwk) (any, string, error) { switch k.Kty { case "RSA": @@ -313,11 +321,9 @@ func parseJWK(k jwk) (any, string, error) { if err != nil { return nil, "", err } - alg := k.Alg - if alg == "" { - alg = "RS256" // RFC 7518 default for RSA when alg is absent - } - return pk, alg, nil + // Preserve empty alg so the provider can fall back to header-alg + // validation against the whitelist (still gated by isAllowedAlg). + return pk, k.Alg, nil case "EC": pk, alg, err := parseECDSAPublicKey(k) if err != nil { diff --git a/forge-core/auth/providers/oidc/provider.go b/forge-core/auth/providers/oidc/provider.go index ad3e7e0..47c985d 100644 --- a/forge-core/auth/providers/oidc/provider.go +++ b/forge-core/auth/providers/oidc/provider.go @@ -230,9 +230,14 @@ func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Hea return nil, err // transport / decode error — fail-closed via chain } - // Cross-check: the JWKS-declared alg must match the token's alg. - // This is the algorithm-confusion defense. - if key.Alg != headerAlg { + // Cross-check: when the JWKS entry declares an alg, the token's alg + // must match it exactly (algorithm-confusion defense). When the JWKS + // entry has no alg (RFC 7518 says "infer from context"), trust the + // token's header alg as long as it's in the asymmetric whitelist — + // which was already enforced at line ~195 above. Review #11b: + // previously the empty-alg case defaulted to "RS256", which rejected + // legitimate PS256/PS384/PS512 tokens signed by the same RSA key. + if key.Alg != "" && key.Alg != headerAlg { return nil, fmt.Errorf("%w: token alg %q does not match JWKS alg %q for kid %q", auth.ErrInvalidToken, headerAlg, key.Alg, kid) } diff --git a/forge-core/auth/providers/oidc/provider_test.go b/forge-core/auth/providers/oidc/provider_test.go index 37b7adc..296d4c5 100644 --- a/forge-core/auth/providers/oidc/provider_test.go +++ b/forge-core/auth/providers/oidc/provider_test.go @@ -138,6 +138,55 @@ func TestVerify_RejectsHMAC(t *testing.T) { } } +func TestVerify_JWKWithoutAlgAcceptsAnyAsymmetric(t *testing.T) { + // Review #11b: when the JWKS entry omits the `alg` field (RFC 7518 + // says "infer from context"), the provider must NOT lock the key + // to "RS256" — that rejected legitimate PS256 tokens signed by the + // same RSA key. The fix preserves empty alg and falls back to the + // header's alg as long as it's in the asymmetric whitelist. + fi := newFakeIssuer(t) + // Override the JWKS entry to omit its alg advertisement. + fi.keys["key-1"].algInJWKS = "" + + p := newProvider(t, fi, nil) + + // Token signed with RS256 — works (the key was generated as RSA-2048). + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("RS256 token with empty-alg JWK rejected: %v", err) + } + + // Now switch the signing method to PS256 (uses the same key + // material — RSA-PSS just changes the padding) and verify the + // provider accepts it. + fi.keys["key-1"].method = jwt.SigningMethodPS256 + tok = fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("PS256 token with empty-alg JWK rejected (regression of #11b): %v", err) + } +} + +func TestVerify_JWKExplicitAlgStillStrict(t *testing.T) { + // Counterpart: when the JWKS entry DOES declare an alg, the + // cross-check must still enforce it. This guards the + // algorithm-confusion defense — if a JWK says "this key is for + // RS256 only", a token claiming PS256 against that key is + // rejected. + fi := newFakeIssuer(t) + fi.keys["key-1"].algInJWKS = "RS256" // explicit + + p := newProvider(t, fi, nil) + + // Sign with PS256 against the same key material. + fi.keys["key-1"].method = jwt.SigningMethodPS256 + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("explicit JWK alg=RS256 must reject PS256 token; got err = %v", err) + } +} + func TestVerify_AlgInHeaderMustMatchJWKS(t *testing.T) { // Algorithm-confusion defense: if the token claims an alg that // differs from the JWKS-declared alg for that kid, reject. diff --git a/forge-core/auth/providers/statictoken/provider.go b/forge-core/auth/providers/statictoken/provider.go index ab096ad..338063a 100644 --- a/forge-core/auth/providers/statictoken/provider.go +++ b/forge-core/auth/providers/statictoken/provider.go @@ -22,6 +22,7 @@ package statictoken import ( "context" + "crypto/sha256" "crypto/subtle" "fmt" "maps" @@ -96,8 +97,18 @@ func (p *Provider) Name() string { return ProviderName } // Verify implements auth.Provider. Constant-time compare against the // configured token. Mismatch yields to the next provider via ErrTokenNotForMe. +// +// Length-leak guard (review #11a): subtle.ConstantTimeCompare returns 0 +// immediately when the two slices have different lengths, without +// inspecting any bytes — that early return leaks the expected token's +// length via measurable timing differences over many trials. We hash +// both sides to SHA-256 first so the comparison is always over +// equal-length (32-byte) digests. Hash → constant-time compare is the +// standard mitigation for this class. func (p *Provider) Verify(_ context.Context, token string, _ auth.Headers) (*auth.Identity, error) { - if subtle.ConstantTimeCompare([]byte(token), p.expected) != 1 { + presented := sha256.Sum256([]byte(token)) + expected := sha256.Sum256(p.expected) + if subtle.ConstantTimeCompare(presented[:], expected[:]) != 1 { return nil, auth.ErrTokenNotForMe } // Return a defensive copy so callers can't mutate the configured identity. diff --git a/forge-core/auth/providers/statictoken/provider_test.go b/forge-core/auth/providers/statictoken/provider_test.go index 290db08..14d2c97 100644 --- a/forge-core/auth/providers/statictoken/provider_test.go +++ b/forge-core/auth/providers/statictoken/provider_test.go @@ -145,6 +145,33 @@ func TestIdentityIsDefensivelyCopied(t *testing.T) { } } +func TestVerify_MismatchedLengths_AreRejected(t *testing.T) { + // Review #11a: hash-then-compare path must still reject tokens of + // any length that aren't the configured one. Specifically, a token + // that's just a prefix or a longer-tail variant of the real one + // must NOT compare equal — verifying the hash digest comparison + // works the same way the byte-level compare did for length-equal + // inputs. + p, _ := statictoken.New(statictoken.Config{Token: "abcdef0123456789"}) + + for _, candidate := range []string{ + "", // empty + "abcdef", // prefix + "abcdef0123456789-extra-bytes", // longer + "different-but-same-length-as-real-secret", // same-length mismatch + "abcdef0123456788", // off-by-one in last char + } { + if _, err := p.Verify(context.Background(), candidate, nil); !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("candidate %q: err = %v, want ErrTokenNotForMe", candidate, err) + } + } + + // Sanity: the real token still works. + if _, err := p.Verify(context.Background(), "abcdef0123456789", nil); err != nil { + t.Errorf("real token rejected: %v", err) + } +} + func TestConcurrentVerify_RaceSafe(t *testing.T) { // Race-safety smoke. Run with `go test -race`. p, _ := statictoken.New(statictoken.Config{Token: "tok"}) From 08fa11b4c62a1dd77576f78d79d9322404027429 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 15:58:45 -0400 Subject: [PATCH 19/20] chore(auth/statictoken): gofmt provider_test.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Comment alignment in the candidate-tokens slice tripped CI's 'gofmt -l forge-core forge-cli forge-plugins' check. Mechanical fix — gofmt re-aligns trailing comments to the same column. CI Test job was failing here: Test → Check formatting → 'test -z $(gofmt -l ...)' exits 1 on the new TestVerify_MismatchedLengths_AreRejected block I added in b3d4264 (review #11a). Local gofmt -l also reports it now; pushed. --- forge-core/auth/providers/statictoken/provider_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/forge-core/auth/providers/statictoken/provider_test.go b/forge-core/auth/providers/statictoken/provider_test.go index 14d2c97..4f16401 100644 --- a/forge-core/auth/providers/statictoken/provider_test.go +++ b/forge-core/auth/providers/statictoken/provider_test.go @@ -155,11 +155,11 @@ func TestVerify_MismatchedLengths_AreRejected(t *testing.T) { p, _ := statictoken.New(statictoken.Config{Token: "abcdef0123456789"}) for _, candidate := range []string{ - "", // empty - "abcdef", // prefix + "", // empty + "abcdef", // prefix "abcdef0123456789-extra-bytes", // longer "different-but-same-length-as-real-secret", // same-length mismatch - "abcdef0123456788", // off-by-one in last char + "abcdef0123456788", // off-by-one in last char } { if _, err := p.Verify(context.Background(), candidate, nil); !errors.Is(err, auth.ErrTokenNotForMe) { t.Errorf("candidate %q: err = %v, want ErrTokenNotForMe", candidate, err) From 49dd97a5fd41ab84a3386b2ecad8b10a8a7124f6 Mon Sep 17 00:00:00 2001 From: Naveen Kurra Date: Sat, 23 May 2026 16:05:11 -0400 Subject: [PATCH 20/20] test(auth): close coverage gaps + YAML quoting hardening (review #12) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six coverage gaps and one hardening item from the reviewer's eight-point batch. Two items in the original list are already covered: 12.2 (--no-auth + YAML auth precedence) is exercised by the four TestResolveAuth_NoAuth* tests added in review #4 commit b33522d. 12.4 (TestAuthDomains in forge-core/security) is exercised by the eight TestAuthDomains_* tests in security/auth_domains_test.go (added in PR3 + strengthened in review #7 commit cf4ae14). The six real gaps: 12.1 Warning capture for --auth-url + YAML auth precedence TestBuildUserAuthChain_BothSourcesPrefersYAML was asserting the chain-behavior half of the contract through a nop logger that swallowed the Warn — silent regression on the "we logged a warning" half. Switched to recordLogger and added an assertion on "--auth-url and forge.yaml". If a refactor ever drops that warning, the test fails. 12.3 Full resolveAuth + YAML loopback head-of-chain test auth_chain_test.go exercised buildLegacyAuthChain. Nothing exercised resolveAuth() end-to-end with a YAML auth: block AND a minted internal token, then asserted "loopback is at chain head, ahead of YAML provider." Added TestResolveAuth_YAMLAuth_LoopbackPrependedAtChainHead that inspects chain.Providers() ordering directly and smoke-tests the loopback token verifies with the "internal" source. 12.5 Backend contract test for the frontend's claim_map translation app.js's buildAuthPayload translates flat groups_claim → nested claim_map.groups before submitting. The backend doesn't run that translation; if the frontend ever regresses, downstream YAML would silently drop the custom claim. Added two tests in the UI handler: - TestHandleCreateAgent_NestedClaimMapShape pins what the backend EXPECTS post-translation (claim_map.groups present). - TestHandleCreateAgent_FlatGroupsClaim_PreservedAsExtraField documents the failure mode if the JS regresses — the request succeeds structurally but claim_map is missing. Real JS tests would require a Node toolchain we deliberately don't add (see PR6 / review #8). These Go tests are the next-best contract. 12.6 Backspace-to-picker behavior in step_auth.go Added TestAuthStep_BackspaceOnEmptyInputReturnsToPicker and a counterpart TestAuthStep_BackspaceWithContentDoesNotReturnToPicker to pin the escape-hatch behavior the review #11e hint advertises. 12.7 Invalid issuer URL blocks advance validateHTTPSURL was unit-tested standalone; never exercised as a wizard interaction. Added TestAuthStep_InvalidIssuerURLBlocksAdvance: bad URL + Enter → still on issuer phase; type a valid replacement + Enter → advances to audience. 12.8 renderAuthBlock YAML quoting hardening writeYAMLMap previously emitted "key: value" raw. Values containing ": " sequences, "#" markers, leading YAML-significant chars (!, *, -, [, {, etc.), reserved tokens (true/false/null/yes/no…), or control chars could break the parser. Added: - yamlScalar/needsYAMLQuoting helpers that emit double-quoted + escaped strings only when the input would otherwise change meaning. Plain values pass through unchanged so generated forge.yaml remains readable. - TestRenderAuthBlock_QuotesUnsafeValues (19 sub-cases) pinning the predicate's classification per input shape. - TestRenderAuthBlock_QuotedValueIsParseable round-trips a value containing ": ", "#", "!", and quotes through the renderer → yaml.v3 unmarshaler, asserting the rendered line is valid YAML AND decodes back to the original string. Note: the renderer remains a focused subset (string / number / bool / map). If auth-settings ever grow richer types, switch to yaml.v3 for the whole block. Verification: test -z "$(gofmt -l forge-core forge-cli forge-plugins)" — clean go test -race ./forge-core/... ./forge-cli/... — all green golangci-lint v2.10.1 — 0 issues --- forge-cli/cmd/init_auth.go | 75 ++++++++++++- forge-cli/cmd/init_auth_test.go | 91 ++++++++++++++++ .../internal/tui/steps/step_auth_test.go | 85 +++++++++++++++ forge-cli/runtime/auth_config_test.go | 87 ++++++++++++++- forge-ui/handlers_create_test.go | 100 ++++++++++++++++++ 5 files changed, 432 insertions(+), 6 deletions(-) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index 43ecd42..962aff8 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -180,6 +180,10 @@ func customAuthStub() string { // writeYAMLMap renders a `map[string]any` as YAML lines, recursing into // nested maps. Only string / number / bool / map values are supported — // the auth-settings schema doesn't use anything else. +// +// String values are conservatively quoted (review #12.8) when they +// contain YAML-significant characters. Otherwise the unquoted form is +// emitted to keep generated forge.yaml readable. func writeYAMLMap(b *strings.Builder, m map[string]any, indent string) { keys := make([]string, 0, len(m)) for k := range m { @@ -193,9 +197,78 @@ func writeYAMLMap(b *strings.Builder, m map[string]any, indent string) { fmt.Fprintf(b, "%s%s:\n", indent, k) writeYAMLMap(b, val, indent+" ") case string: - fmt.Fprintf(b, "%s%s: %s\n", indent, k, val) + fmt.Fprintf(b, "%s%s: %s\n", indent, k, yamlScalar(val)) default: fmt.Fprintf(b, "%s%s: %v\n", indent, k, val) } } } + +// yamlScalar returns a YAML-safe rendering of a string value. Plain +// values pass through unchanged; values containing characters that +// would break the YAML parser (": " sequences, leading specials, +// reserved tokens, control chars) are emitted as double-quoted strings +// with the minimum necessary escaping. +// +// This is deliberately not a general YAML serializer — it covers the +// subset that appears in auth provider settings (issuer URLs, audiences, +// claim names, default org strings). If we ever need richer values, +// switch to gopkg.in/yaml.v3 for the whole block. +func yamlScalar(s string) string { + if !needsYAMLQuoting(s) { + return s + } + var out strings.Builder + out.WriteByte('"') + for _, r := range s { + switch r { + case '"': + out.WriteString(`\"`) + case '\\': + out.WriteString(`\\`) + case '\n': + out.WriteString(`\n`) + case '\r': + out.WriteString(`\r`) + case '\t': + out.WriteString(`\t`) + default: + out.WriteRune(r) + } + } + out.WriteByte('"') + return out.String() +} + +// needsYAMLQuoting reports whether a string would change meaning when +// emitted unquoted in a YAML block-scalar context. Conservative — +// false positives are fine (extra quotes), false negatives are bugs +// (broken YAML). +func needsYAMLQuoting(s string) bool { + if s == "" { + return true + } + // Leading characters that begin YAML indicators (tags, anchors, + // aliases, folded/literal block markers, flow markers, directives, + // quotes, comments, leading whitespace). + switch s[0] { + case '!', '&', '*', '>', '|', '%', '@', '`', '"', '\'', '#', ' ', '\t', '[', ']', '{', '}', ',', '?', ':', '-': + return true + } + // Trailing colon (key-like form) or any unquoted ": " (mapping + // indicator inside a scalar would split into key/value). + if strings.HasSuffix(s, ":") || strings.Contains(s, ": ") || strings.Contains(s, " #") { + return true + } + // Control / newline characters. + if strings.ContainsAny(s, "\n\r\t\v\f") { + return true + } + // YAML 1.1 boolean / null literals — unquoted they decode to bool/nil. + // We keep this case-insensitive to match yaml.v3 defaults. + switch strings.ToLower(s) { + case "true", "false", "yes", "no", "on", "off", "null", "~": + return true + } + return false +} diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go index 3af3999..3c6a78e 100644 --- a/forge-cli/cmd/init_auth_test.go +++ b/forge-cli/cmd/init_auth_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/spf13/cobra" + "gopkg.in/yaml.v3" ) // --- renderAuthBlock --- @@ -93,6 +94,96 @@ func TestRenderAuthBlock_HTTPVerifier(t *testing.T) { } } +func TestRenderAuthBlock_QuotesUnsafeValues(t *testing.T) { + // Review #12.8: writeYAMLMap used to emit "key: value" raw. Values + // containing ": " or starting with YAML-significant chars could + // break the parser. The hardening adds quoting; this test pins it. + tests := []struct { + name string + value string + wantQuote bool + }{ + {"plain url", "https://login.example.com", false}, + {"audience uri", "api://forge", false}, + {"plain identifier", "my-org", false}, + {"contains ': ' (colon-space — splits as map)", "title: subtitle", true}, + {"contains ' #' (comment marker)", "value #with comment", true}, + {"trailing colon", "foo:", true}, + {"leading ! (tag)", "!secret", true}, + {"leading * (alias)", "*ref", true}, + {"leading - (sequence)", "-name", true}, + {"leading [ (flow seq)", "[item]", true}, + {"leading { (flow map)", "{key: val}", true}, + {"leading # (comment)", "#commented", true}, + {"empty", "", true}, + {"YAML 1.1 boolean 'true'", "true", true}, + {"YAML 1.1 boolean 'YES'", "YES", true}, + {"YAML null literal", "null", true}, + {"YAML tilde null", "~", true}, + {"contains newline", "line1\nline2", true}, + {"contains tab", "with\ttab", true}, + {"leading whitespace", " starts with space", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://x", // anchor that's never quoted + "audience": "y", + "probe": tt.value, + }) + line := findSettingsLine(got, "probe") + if line == "" { + t.Fatalf("no probe: line in output:\n%s", got) + } + isQuoted := strings.Contains(line, `"`) + if isQuoted != tt.wantQuote { + t.Errorf("value=%q quoted=%v, want=%v\nrendered: %q", + tt.value, isQuoted, tt.wantQuote, line) + } + }) + } +} + +func TestRenderAuthBlock_QuotedValueIsParseable(t *testing.T) { + // Round-trip check: the quoted output must parse back via yaml.v3 + // as the original string. Catches escaping bugs. + weird := `weird: value # with comment ! and "quote"` + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://x", + "audience": "y", + "probe": weird, + }) + + // Find the probe: line, extract the value, parse via yaml.v3. + line := findSettingsLine(got, "probe") + if line == "" { + t.Fatalf("no probe line:\n%s", got) + } + // " probe: \"…\"" → yaml.Unmarshal of "value: " + idx := strings.Index(line, "probe:") + docFragment := line[idx:] + var parsed map[string]string + if err := yaml.Unmarshal([]byte(docFragment), &parsed); err != nil { + t.Fatalf("rendered line %q is not valid YAML: %v", line, err) + } + if parsed["probe"] != weird { + t.Errorf("round-trip mismatch:\n got %q\n want %q", parsed["probe"], weird) + } +} + +// findSettingsLine returns the first line in `block` whose key is `key`. +// Used by the quoting tests to inspect a single rendered key/value pair. +func findSettingsLine(block, key string) string { + prefix := key + ":" + for _, line := range strings.Split(block, "\n") { + trimmed := strings.TrimLeft(line, " ") + if strings.HasPrefix(trimmed, prefix) { + return line + } + } + return "" +} + func TestRenderAuthBlock_DeterministicOrdering(t *testing.T) { // Settings keys should always emit in alphabetical order so diffs // of generated files are stable. diff --git a/forge-cli/internal/tui/steps/step_auth_test.go b/forge-cli/internal/tui/steps/step_auth_test.go index f2b0ed1..81e3ddf 100644 --- a/forge-cli/internal/tui/steps/step_auth_test.go +++ b/forge-cli/internal/tui/steps/step_auth_test.go @@ -301,6 +301,91 @@ func TestHostnameOrEmpty(t *testing.T) { } } +// --- Wizard interaction tests (review #12.6, #12.7) --- + +func TestAuthStep_BackspaceOnEmptyInputReturnsToPicker(t *testing.T) { + // Review #12.6: backspace-on-empty is the documented escape hatch + // from a sub-step input back to the provider picker, but until + // review #11e there was no visible hint AND no test. The hint was + // added; this test pins the behavior. + s := newTestAuthStep(t) + + // Walk into OIDC issuer phase. + s = press(t, s, "down") // OIDC + s = press(t, s, "enter") + if s.phase != authOIDCIssuerPhase { + t.Fatalf("setup: phase = %v, want issuer", s.phase) + } + + // Backspace on EMPTY input → returns to picker. + s = press(t, s, "backspace") + if s.phase != authSelectPhase { + t.Errorf("backspace-on-empty: phase = %v, want authSelectPhase", s.phase) + } + if s.mode != "" { + t.Errorf("backspace-on-empty: mode = %q, want \"\" (reset)", s.mode) + } +} + +func TestAuthStep_BackspaceWithContentDoesNotReturnToPicker(t *testing.T) { + // Counterpart: backspace with TEXT in the input only deletes the + // last char — it must not escape out of the sub-step. + s := newTestAuthStep(t) + s = press(t, s, "down") // OIDC + s = press(t, s, "enter") + s = typeIn(t, s, "https://example.com") + if s.phase != authOIDCIssuerPhase { + t.Fatalf("setup: phase = %v, want issuer", s.phase) + } + + s = press(t, s, "backspace") + if s.phase != authOIDCIssuerPhase { + t.Errorf("backspace-with-content escaped to phase %v, expected to stay on issuer", s.phase) + } +} + +func TestAuthStep_InvalidIssuerURLBlocksAdvance(t *testing.T) { + // Review #12.7: validateHTTPSURL is unit-tested standalone, but + // the wizard's gating on bad input was never asserted as an + // interaction. Drive the step with a bad URL and confirm Enter + // doesn't advance to the audience sub-step. + s := newTestAuthStep(t) + s = press(t, s, "down") // OIDC + s = press(t, s, "enter") + + // Type something that fails validateHTTPSURL (no scheme). + s = typeIn(t, s, "login.example.com") + s = press(t, s, "enter") + + // MUST still be on the issuer phase — the bad URL blocked advance. + if s.phase != authOIDCIssuerPhase { + t.Errorf("invalid URL advanced past issuer phase to %v", s.phase) + } + + // And the step must not be complete — Review/scaffold MUST NOT run + // for a half-configured OIDC entry. + if s.Complete() { + t.Errorf("step reported complete with invalid issuer URL") + } + + // Now type a valid replacement and confirm the gate releases. + // Clear the input first by sending backspaces until empty — easier + // than reaching into TextInput state. + for range len("login.example.com") { + s = press(t, s, "backspace") + } + if s.phase != authOIDCIssuerPhase { + // Bug guard: a long sequence of backspaces should NOT escape to + // the picker mid-way (only an Enter-after-empty would). + t.Fatalf("backspaces escaped to phase %v unexpectedly", s.phase) + } + s = typeIn(t, s, "https://login.example.com") + s = press(t, s, "enter") + if s.phase != authOIDCAudiencePhase { + t.Errorf("valid URL did not advance: phase = %v, want audience", s.phase) + } +} + func TestAppendUnique(t *testing.T) { out := appendUnique([]string{"a", "b"}, "a") if !reflect.DeepEqual(out, []string{"a", "b"}) { diff --git a/forge-cli/runtime/auth_config_test.go b/forge-cli/runtime/auth_config_test.go index d076785..199a016 100644 --- a/forge-cli/runtime/auth_config_test.go +++ b/forge-cli/runtime/auth_config_test.go @@ -226,31 +226,108 @@ func TestBuildUserAuthChain_OnlyYAMLAuth(t *testing.T) { func TestBuildUserAuthChain_BothSourcesPrefersYAML(t *testing.T) { // YAML auth points at srvYAML (accepts "yaml-only"). // Legacy --auth-url points at srvLegacy (accepts "legacy-only"). - // Expect: YAML wins, legacy verifier never consulted. + // Expect: YAML wins, legacy verifier never consulted, AND a warning + // is emitted (review #12.1 — previously asserted only the chain + // behavior; the "we logged a warning" half of the contract was + // invisible because the test used the nop logger). srvYAML := fakeVerifierFor(t, "yaml-only") srvLegacy := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Error("legacy verifier should not be called when YAML auth: is present") })) defer srvLegacy.Close() - r := runnerWith(types.AuthConfig{ + // Use a capturing logger so we can assert the warning emission too. + logger := &recordLogger{} + r := &Runner{logger: logger} + r.cfg.Config = &types.ForgeConfig{Auth: types.AuthConfig{ Providers: []types.AuthProvider{ {Type: "http_verifier", Settings: map[string]any{"url": srvYAML.URL}}, }, - }, srvLegacy.URL) + }} + r.cfg.AuthURL = srvLegacy.URL chain, err := r.buildUserAuthChain() if err != nil { t.Fatalf("buildUserAuthChain: %v", err) } - // YAML accepts "yaml-only". if _, err := chain.Verify(context.Background(), "yaml-only", nil); err != nil { t.Errorf("yaml-only rejected: %v", err) } - // Legacy token rejected because YAML verifier doesn't recognize it. if _, err := chain.Verify(context.Background(), "legacy-only", nil); !errors.Is(err, auth.ErrTokenRejected) { t.Errorf("legacy-only token err = %v, want ErrTokenRejected", err) } + + // The contract: when both YAML auth and --auth-url are configured, + // we warn that --auth-url is being ignored. Without this assertion + // a refactor that silently drops the warning would never surface. + if !logger.warnedAbout("--auth-url and forge.yaml") { + t.Errorf("expected warning when both sources configured; got: %+v", logger.warnings) + } +} + +// --- Full resolveAuth + YAML auth + loopback prepend (review #12.3) --- +// +// auth_chain_test.go exercises buildLegacyAuthChain. This test goes one +// level higher and exercises resolveAuth() end-to-end with a forge.yaml +// `auth:` block configured AND an internal token minted by +// ResolveAuth(). The assertion: the loopback static_token sits at the +// chain head, AHEAD of the user-configured providers, so channel +// adapters short-circuit before the OIDC/http_verifier roundtrip. + +func TestResolveAuth_YAMLAuth_LoopbackPrependedAtChainHead(t *testing.T) { + tmp := t.TempDir() + + // Fake verifier for the YAML-configured http_verifier provider. + srvYAML := fakeVerifierFor(t, "external-token") + + logger := &recordLogger{} + r := &Runner{logger: logger} + r.cfg.Config = &types.ForgeConfig{Auth: types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srvYAML.URL}}, + }, + }} + r.cfg.NoAuth = false + r.cfg.Host = "127.0.0.1" // local — passes the localhost gate in ResolveAuth + r.cfg.WorkDir = tmp // ResolveAuth writes runtime.token here + + opts, err := r.resolveAuth(nil) + if err != nil { + t.Fatalf("resolveAuth: %v", err) + } + if opts.Chain == nil { + t.Fatal("resolveAuth returned a nil Chain — loopback not prepended") + } + chain, ok := opts.Chain.(*auth.ChainProvider) + if !ok { + t.Fatalf("Chain is not a *ChainProvider; got %T", opts.Chain) + } + providers := chain.Providers() + if len(providers) < 2 { + t.Fatalf("chain has %d providers; want ≥ 2 (loopback + YAML)", len(providers)) + } + + // Head MUST be the loopback static_token — channel adapter + // short-circuit depends on this. If a refactor inserts a YAML + // provider before it, channels start round-tripping to the + // external verifier on every callback. + if providers[0].Name() != "static_token" { + t.Errorf("chain.providers[0].Name() = %q, want %q (loopback)", providers[0].Name(), "static_token") + } + if providers[1].Name() != "http_verifier" { + t.Errorf("chain.providers[1].Name() = %q, want %q (YAML)", providers[1].Name(), "http_verifier") + } + + // And smoke-test that the loopback token is actually the one + // ResolveAuth minted: Verify with r.authToken must succeed without + // touching the YAML verifier. + id, err := chain.Verify(context.Background(), r.authToken, nil) + if err != nil { + t.Fatalf("loopback verify failed: %v", err) + } + if id.Source != "internal" { + t.Errorf("loopback identity Source = %q, want %q", id.Source, "internal") + } } // --- --no-auth vs forge.yaml auth: block (review #4) --- diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index a220a0f..906de6d 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -614,6 +614,106 @@ func TestHandleCreateAgent_NoneAndCustomSkipValidation(t *testing.T) { } } +// TestHandleCreateAgent_NestedClaimMapShape pins the contract the +// frontend's buildAuthPayload promises (app.js): a flat `groups_claim` +// from the wizard is translated client-side into the nested shape +// `settings.claim_map.groups`. The backend doesn't run that +// translation — but if the frontend ever regresses (and app.js is +// hand-edited, so it's plausible), the request shape arriving here +// would change and downstream YAML rendering would silently drop the +// custom groups claim. +// +// This test documents what the backend EXPECTS to receive after the +// JS translation runs. If a future contributor changes app.js's +// buildAuthPayload to send the flat shape, this test still passes +// (we accept anything in Settings) but a documented JS-side test +// in app.js would catch the drift. Until JS tests exist, this is the +// canonical reference for the post-translation payload shape. +// Review #12.5. +func TestHandleCreateAgent_NestedClaimMapShape(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + // Shape AFTER the JS translation runs: groups_claim has become + // claim_map.groups. This is what the backend should see. + body := []byte(`{ + "name": "agent-with-roles", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge", + "claim_map": { "groups": "roles" } + } + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } + if captured.Auth == nil { + t.Fatal("captured Auth is nil") + } + cm, ok := captured.Auth.Settings["claim_map"].(map[string]any) + if !ok { + t.Fatalf("claim_map not present or wrong type: %v", captured.Auth.Settings["claim_map"]) + } + if cm["groups"] != "roles" { + t.Errorf("claim_map.groups = %v, want roles", cm["groups"]) + } + + // Negative half: the FLAT shape (pre-translation) should NOT + // round-trip as if it had been translated. If the frontend ever + // stops calling buildAuthPayload, the backend receives the raw + // form and the YAML will be missing claim_map entirely. + if _, hasFlat := captured.Auth.Settings["groups_claim"]; hasFlat { + t.Errorf("settings contains flat groups_claim — frontend translation may have skipped") + } +} + +func TestHandleCreateAgent_FlatGroupsClaim_PreservedAsExtraField(t *testing.T) { + // Counterpart: if the frontend sends the FLAT shape by mistake, + // the backend accepts the request (no validation rejects unknown + // keys) but the claim_map nested mapping is missing — which means + // the generated forge.yaml won't include claim_map.groups, and the + // agent runtime will fall back to the default "groups" claim. + // + // This documents the failure mode so if anyone investigates + // "my custom group claim isn't being read," this test is the + // breadcrumb. + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "agent-flat", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge", + "groups_claim": "roles" + } + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + // Request is structurally valid (issuer + audience present). It just + // won't behave as the user expected at runtime. + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201", w.Code) + } + if _, hasNested := captured.Auth.Settings["claim_map"]; hasNested { + t.Error("backend should not synthesize claim_map from groups_claim — that's the frontend's job") + } +} + func TestHandleCreateAgent_WithoutAuthPayload(t *testing.T) { // Omitting `auth` from the request is still a valid create — the // agent gets anonymous access by default.