diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index 480c971..fca28d4 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -45,6 +45,11 @@ type initOptions struct { Force bool // overwrite existing directory CustomModel string // custom provider model name AuthMethod string // "apikey" or "oauth" + + // A2A auth chain (from the wizard's Authentication step or CLI flags). + AuthMode string // "", "none", "oidc", "http_verifier", "custom" + AuthSettings map[string]any // provider-specific settings block + AuthEgressHosts []string // hosts to merge into egress allowlist } // toolEntry represents a tool parsed from a skills file. @@ -71,6 +76,12 @@ type templateData struct { EgressDomains []string EnvVars []envVarEntry HasSecrets bool + + // Auth chain rendering (see forge.yaml.tmpl). Pre-rendered as a YAML + // fragment because nested maps in the settings block (e.g. claim_map) + // would otherwise require non-trivial template helpers. Empty when + // the user picked "none" or skipped the auth step. + AuthBlock string } // fallbackTmplData holds template data for a fallback provider. @@ -121,6 +132,15 @@ func init() { initCmd.Flags().String("org-id", "", "OpenAI organization ID (enterprise)") initCmd.Flags().StringSlice("fallbacks", nil, "fallback LLM providers (e.g., openai,gemini)") initCmd.Flags().Bool("force", false, "overwrite existing directory") + + // Auth chain (PR5+). All optional. When --auth is unset or "none", no + // auth: block is written and the agent runs anonymously. + initCmd.Flags().String("auth", "", "auth mode: none, oidc, http_verifier, custom") + initCmd.Flags().String("auth-issuer", "", "OIDC issuer URL (required with --auth=oidc)") + initCmd.Flags().String("auth-audience", "", "OIDC audience (required with --auth=oidc)") + initCmd.Flags().String("auth-url", "", "verifier URL (required with --auth=http_verifier)") + initCmd.Flags().String("auth-default-org", "", "default org_id for http_verifier (optional)") + initCmd.Flags().String("auth-groups-claim", "", "claim name for groups (oidc, default: groups)") } func runInit(cmd *cobra.Command, args []string) error { @@ -155,6 +175,18 @@ func runInit(cmd *cobra.Command, args []string) error { opts.NonInteractive = nonInteractive opts.Force, _ = cmd.Flags().GetBool("force") + // Auth chain flags. + authMode, _ := cmd.Flags().GetString("auth") + if authMode != "" { + settings, hosts, err := buildAuthFromFlags(cmd, authMode) + if err != nil { + return err + } + opts.AuthMode = authMode + opts.AuthSettings = settings + opts.AuthEgressHosts = hosts + } + // TTY detection: require a terminal for interactive mode if !nonInteractive && !term.IsTerminal(int(os.Stdout.Fd())) { return fmt.Errorf("interactive mode requires a terminal; use --non-interactive") @@ -258,6 +290,7 @@ func collectInteractive(opts *initOptions) error { steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn), steps.NewSkillsStep(styles, skillInfos), steps.NewEgressStep(styles, deriveEgressFn), + steps.NewAuthStep(styles), steps.NewReviewStep(styles), // scaffold is handled by the caller after collectInteractive returns } @@ -327,11 +360,18 @@ func collectInteractive(opts *initOptions) error { opts.EnvVars["MODEL_API_KEY"] = ctx.CustomAPIKey } - // Store egress domains - if len(ctx.EgressDomains) > 0 { - opts.EnvVars["__egress_domains"] = strings.Join(ctx.EgressDomains, ",") + // Store egress domains, merging in any auth-provider hosts so the + // agent can reach its IdP / verifier at runtime. + mergedEgress := mergeEgressDomains(ctx.EgressDomains, ctx.AuthEgressHosts) + if len(mergedEgress) > 0 { + opts.EnvVars["__egress_domains"] = strings.Join(mergedEgress, ",") } + // Auth chain selection. + opts.AuthMode = ctx.AuthMode + opts.AuthSettings = ctx.AuthSettings + opts.AuthEgressHosts = ctx.AuthEgressHosts + // Check skill requirements checkSkillRequirements(opts) @@ -1002,6 +1042,16 @@ func buildTemplateData(opts *initOptions) templateData { data.EgressDomains = strings.Split(stored, ",") } + // Merge any auth-provider hosts (PR5 wizard / --auth flags) into the + // allowlist so the agent's runtime can reach its IdP/verifier. + if len(opts.AuthEgressHosts) > 0 { + data.EgressDomains = mergeEgressDomains(data.EgressDomains, opts.AuthEgressHosts) + } + + // Auth chain rendering: pre-render the YAML fragment in Go so nested + // maps (claim_map) emit correctly. + data.AuthBlock = renderAuthBlock(opts.AuthMode, opts.AuthSettings) + // Build env vars data.EnvVars = buildEnvVars(opts) diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go new file mode 100644 index 0000000..962aff8 --- /dev/null +++ b/forge-cli/cmd/init_auth.go @@ -0,0 +1,274 @@ +package cmd + +import ( + "fmt" + "net/url" + "sort" + "strings" + + "github.com/spf13/cobra" +) + +// buildAuthFromFlags reads the --auth* flag set into a settings map + +// the egress hosts the runtime will need. Validates required-field +// combinations; rejects unknown modes with a clear error. +func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]any, egressHosts []string, err error) { + switch mode { + case "none", "custom": + return nil, nil, nil + case "oidc": + issuer, _ := cmd.Flags().GetString("auth-issuer") + audience, _ := cmd.Flags().GetString("auth-audience") + groupsClaim, _ := cmd.Flags().GetString("auth-groups-claim") + if issuer == "" || audience == "" { + return nil, nil, fmt.Errorf("--auth=oidc requires --auth-issuer and --auth-audience") + } + issuer = strings.TrimRight(issuer, "/") + settings = map[string]any{ + "issuer": issuer, + "audience": audience, + } + if groupsClaim != "" && groupsClaim != "groups" { + settings["claim_map"] = map[string]any{"groups": groupsClaim} + } + if host := hostFromURL(issuer); host != "" { + egressHosts = []string{host} + } + return settings, egressHosts, nil + case "http_verifier": + verifierURL, _ := cmd.Flags().GetString("auth-url") + defaultOrg, _ := cmd.Flags().GetString("auth-default-org") + if verifierURL == "" { + return nil, nil, fmt.Errorf("--auth=http_verifier requires --auth-url") + } + settings = map[string]any{"url": verifierURL} + if defaultOrg != "" { + settings["default_org"] = defaultOrg + } + if host := hostFromURL(verifierURL); host != "" { + egressHosts = []string{host} + } + return settings, egressHosts, nil + default: + return nil, nil, fmt.Errorf("unknown --auth value %q (supported: none, oidc, http_verifier, custom)", mode) + } +} + +// authEgressHostsFromSettings extracts the outbound hosts implied by an +// auth provider's settings block. Used by the Web UI path (cmd/ui.go) to +// derive the same egress hosts the TUI wizard / --auth flags compute. +// +// Returns nil for modes that don't require outbound (none, static_token, +// custom) or for malformed/missing URLs. +func authEgressHostsFromSettings(mode string, settings map[string]any) []string { + if settings == nil { + return nil + } + var hosts []string + switch mode { + case "oidc": + if h := hostFromURL(asStringSetting(settings, "issuer")); h != "" { + hosts = append(hosts, h) + } + if h := hostFromURL(asStringSetting(settings, "jwks_url")); h != "" { + hosts = append(hosts, h) + } + case "http_verifier": + if h := hostFromURL(asStringSetting(settings, "url")); h != "" { + hosts = append(hosts, h) + } + } + return hosts +} + +// asStringSetting reads a string-valued setting, returning "" for missing +// or non-string values. +func asStringSetting(m map[string]any, key string) string { + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +// hostFromURL extracts the bare host (no port) from a URL string. Returns +// "" if the URL is malformed — validation happens in the wizard / +// buildAuthFromFlags above. +func hostFromURL(raw string) string { + if raw == "" { + return "" + } + u, perr := url.Parse(raw) + if perr != nil || u.Host == "" { + return "" + } + return u.Hostname() +} + +// renderAuthBlock returns a forge.yaml `auth:` fragment for the given +// provider type and settings. The output starts at column 0 with no +// trailing newline so the surrounding template controls spacing. +// +// Three shapes are produced: +// +// mode == "" → empty string (skip) +// mode == "none" → empty string (anonymous) +// mode == "custom" +// ↓ +// # Auth provider chain — configure here. +// # auth: +// # required: true +// # providers: [...] +// +// mode == "oidc" | "http_verifier" | "static_token" +// ↓ +// auth: +// required: true +// providers: +// - type: +// settings: +// key: value +// nested: +// k: v +// +// The AuthProvider.Name field is intentionally NOT emitted — the wizard +// doesn't capture one, and the previous signature took a `name` arg +// that callers always set equal to mode (suppressed). Removed in +// review #11d. If a future wizard step adds explicit name capture, +// reintroduce the parameter then. +// +// Settings keys are emitted in alphabetical order so the output is +// deterministic across runs (useful for diffing generated files). +func renderAuthBlock(mode string, settings map[string]any) string { + switch mode { + case "", "none": + return "" + case "custom": + return customAuthStub() + } + + var b strings.Builder + b.WriteString("auth:\n") + b.WriteString(" required: true\n") + b.WriteString(" providers:\n") + fmt.Fprintf(&b, " - type: %s\n", mode) + if len(settings) > 0 { + b.WriteString(" settings:\n") + writeYAMLMap(&b, settings, " ") + } + // Trim the final newline so the template controls spacing. + return strings.TrimRight(b.String(), "\n") +} + +// customAuthStub returns the commented-out template a user gets when they +// pick "Custom" in the wizard. +func customAuthStub() string { + return strings.Join([]string{ + "# Auth provider chain — configure here. See useforge.ai/docs/auth", + "# Supported types: oidc, http_verifier, static_token", + "# auth:", + "# required: true", + "# providers:", + "# - type: oidc", + "# settings:", + "# issuer: https://login.example.com", + "# audience: api://forge", + }, "\n") +} + +// writeYAMLMap renders a `map[string]any` as YAML lines, recursing into +// nested maps. Only string / number / bool / map values are supported — +// the auth-settings schema doesn't use anything else. +// +// String values are conservatively quoted (review #12.8) when they +// contain YAML-significant characters. Otherwise the unquoted form is +// emitted to keep generated forge.yaml readable. +func writeYAMLMap(b *strings.Builder, m map[string]any, indent string) { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + v := m[k] + switch val := v.(type) { + case map[string]any: + fmt.Fprintf(b, "%s%s:\n", indent, k) + writeYAMLMap(b, val, indent+" ") + case string: + fmt.Fprintf(b, "%s%s: %s\n", indent, k, yamlScalar(val)) + default: + fmt.Fprintf(b, "%s%s: %v\n", indent, k, val) + } + } +} + +// yamlScalar returns a YAML-safe rendering of a string value. Plain +// values pass through unchanged; values containing characters that +// would break the YAML parser (": " sequences, leading specials, +// reserved tokens, control chars) are emitted as double-quoted strings +// with the minimum necessary escaping. +// +// This is deliberately not a general YAML serializer — it covers the +// subset that appears in auth provider settings (issuer URLs, audiences, +// claim names, default org strings). If we ever need richer values, +// switch to gopkg.in/yaml.v3 for the whole block. +func yamlScalar(s string) string { + if !needsYAMLQuoting(s) { + return s + } + var out strings.Builder + out.WriteByte('"') + for _, r := range s { + switch r { + case '"': + out.WriteString(`\"`) + case '\\': + out.WriteString(`\\`) + case '\n': + out.WriteString(`\n`) + case '\r': + out.WriteString(`\r`) + case '\t': + out.WriteString(`\t`) + default: + out.WriteRune(r) + } + } + out.WriteByte('"') + return out.String() +} + +// needsYAMLQuoting reports whether a string would change meaning when +// emitted unquoted in a YAML block-scalar context. Conservative — +// false positives are fine (extra quotes), false negatives are bugs +// (broken YAML). +func needsYAMLQuoting(s string) bool { + if s == "" { + return true + } + // Leading characters that begin YAML indicators (tags, anchors, + // aliases, folded/literal block markers, flow markers, directives, + // quotes, comments, leading whitespace). + switch s[0] { + case '!', '&', '*', '>', '|', '%', '@', '`', '"', '\'', '#', ' ', '\t', '[', ']', '{', '}', ',', '?', ':', '-': + return true + } + // Trailing colon (key-like form) or any unquoted ": " (mapping + // indicator inside a scalar would split into key/value). + if strings.HasSuffix(s, ":") || strings.Contains(s, ": ") || strings.Contains(s, " #") { + return true + } + // Control / newline characters. + if strings.ContainsAny(s, "\n\r\t\v\f") { + return true + } + // YAML 1.1 boolean / null literals — unquoted they decode to bool/nil. + // We keep this case-insensitive to match yaml.v3 defaults. + switch strings.ToLower(s) { + case "true", "false", "yes", "no", "on", "off", "null", "~": + return true + } + return false +} diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go new file mode 100644 index 0000000..3c6a78e --- /dev/null +++ b/forge-cli/cmd/init_auth_test.go @@ -0,0 +1,429 @@ +package cmd + +import ( + "reflect" + "strings" + "testing" + + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +// --- renderAuthBlock --- + +func TestRenderAuthBlock_EmptyAndNone(t *testing.T) { + if got := renderAuthBlock("", nil); got != "" { + t.Errorf("mode='' → %q, want empty", got) + } + if got := renderAuthBlock("none", nil); got != "" { + t.Errorf("mode='none' → %q, want empty", got) + } +} + +func TestRenderAuthBlock_Custom(t *testing.T) { + got := renderAuthBlock("custom", nil) + if !strings.HasPrefix(got, "# Auth provider chain") { + t.Errorf("custom mode should start with comment header, got: %s", got) + } + if !strings.Contains(got, "oidc") { + t.Errorf("custom stub should mention oidc, got: %s", got) + } +} + +func TestRenderAuthBlock_OIDC(t *testing.T) { + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }) + + wantLines := []string{ + "auth:", + " required: true", + " providers:", + " - type: oidc", + " settings:", + " audience: api://forge", + " issuer: https://login.example.com", + } + for _, line := range wantLines { + if !strings.Contains(got, line) { + t.Errorf("missing line %q in:\n%s", line, got) + } + } +} + +func TestRenderAuthBlock_NoNameEverEmitted(t *testing.T) { + // Review #11d: the wizard doesn't capture an explicit provider + // name, so the rendered YAML must never include a `name:` line. + // If a future wizard step adds name capture, this test changes + // at the same time as the function signature. + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://x", + "audience": "y", + }) + if strings.Contains(got, "name:") { + t.Errorf("expected no name: line in output, got:\n%s", got) + } +} + +func TestRenderAuthBlock_NestedClaimMap(t *testing.T) { + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + "claim_map": map[string]any{"groups": "roles"}, + }) + + if !strings.Contains(got, "claim_map:") { + t.Errorf("missing claim_map:\n%s", got) + } + if !strings.Contains(got, " groups: roles") { + t.Errorf("missing nested groups: roles (check indentation):\n%s", got) + } +} + +func TestRenderAuthBlock_HTTPVerifier(t *testing.T) { + got := renderAuthBlock("http_verifier", map[string]any{ + "url": "https://verify.example.com", + "default_org": "acme", + }) + if !strings.Contains(got, "- type: http_verifier") { + t.Errorf("missing http_verifier type:\n%s", got) + } + if !strings.Contains(got, "default_org: acme") { + t.Errorf("missing default_org:\n%s", got) + } +} + +func TestRenderAuthBlock_QuotesUnsafeValues(t *testing.T) { + // Review #12.8: writeYAMLMap used to emit "key: value" raw. Values + // containing ": " or starting with YAML-significant chars could + // break the parser. The hardening adds quoting; this test pins it. + tests := []struct { + name string + value string + wantQuote bool + }{ + {"plain url", "https://login.example.com", false}, + {"audience uri", "api://forge", false}, + {"plain identifier", "my-org", false}, + {"contains ': ' (colon-space — splits as map)", "title: subtitle", true}, + {"contains ' #' (comment marker)", "value #with comment", true}, + {"trailing colon", "foo:", true}, + {"leading ! (tag)", "!secret", true}, + {"leading * (alias)", "*ref", true}, + {"leading - (sequence)", "-name", true}, + {"leading [ (flow seq)", "[item]", true}, + {"leading { (flow map)", "{key: val}", true}, + {"leading # (comment)", "#commented", true}, + {"empty", "", true}, + {"YAML 1.1 boolean 'true'", "true", true}, + {"YAML 1.1 boolean 'YES'", "YES", true}, + {"YAML null literal", "null", true}, + {"YAML tilde null", "~", true}, + {"contains newline", "line1\nline2", true}, + {"contains tab", "with\ttab", true}, + {"leading whitespace", " starts with space", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://x", // anchor that's never quoted + "audience": "y", + "probe": tt.value, + }) + line := findSettingsLine(got, "probe") + if line == "" { + t.Fatalf("no probe: line in output:\n%s", got) + } + isQuoted := strings.Contains(line, `"`) + if isQuoted != tt.wantQuote { + t.Errorf("value=%q quoted=%v, want=%v\nrendered: %q", + tt.value, isQuoted, tt.wantQuote, line) + } + }) + } +} + +func TestRenderAuthBlock_QuotedValueIsParseable(t *testing.T) { + // Round-trip check: the quoted output must parse back via yaml.v3 + // as the original string. Catches escaping bugs. + weird := `weird: value # with comment ! and "quote"` + got := renderAuthBlock("oidc", map[string]any{ + "issuer": "https://x", + "audience": "y", + "probe": weird, + }) + + // Find the probe: line, extract the value, parse via yaml.v3. + line := findSettingsLine(got, "probe") + if line == "" { + t.Fatalf("no probe line:\n%s", got) + } + // " probe: \"…\"" → yaml.Unmarshal of "value: " + idx := strings.Index(line, "probe:") + docFragment := line[idx:] + var parsed map[string]string + if err := yaml.Unmarshal([]byte(docFragment), &parsed); err != nil { + t.Fatalf("rendered line %q is not valid YAML: %v", line, err) + } + if parsed["probe"] != weird { + t.Errorf("round-trip mismatch:\n got %q\n want %q", parsed["probe"], weird) + } +} + +// findSettingsLine returns the first line in `block` whose key is `key`. +// Used by the quoting tests to inspect a single rendered key/value pair. +func findSettingsLine(block, key string) string { + prefix := key + ":" + for _, line := range strings.Split(block, "\n") { + trimmed := strings.TrimLeft(line, " ") + if strings.HasPrefix(trimmed, prefix) { + return line + } + } + return "" +} + +func TestRenderAuthBlock_DeterministicOrdering(t *testing.T) { + // Settings keys should always emit in alphabetical order so diffs + // of generated files are stable. + got1 := renderAuthBlock("oidc", map[string]any{ + "audience": "a", + "issuer": "i", + }) + got2 := renderAuthBlock("oidc", map[string]any{ + "issuer": "i", + "audience": "a", + }) + if got1 != got2 { + t.Errorf("output is not deterministic across map ordering") + } + // audience must come before issuer (alphabetical). + audIdx := strings.Index(got1, "audience:") + issIdx := strings.Index(got1, "issuer:") + if audIdx == -1 || issIdx == -1 || audIdx > issIdx { + t.Errorf("audience should appear before issuer; got:\n%s", got1) + } +} + +// --- buildAuthFromFlags --- + +func mockInitCmdWithFlags() *cobra.Command { + c := &cobra.Command{Use: "init"} + c.Flags().String("auth", "", "") + c.Flags().String("auth-issuer", "", "") + c.Flags().String("auth-audience", "", "") + c.Flags().String("auth-url", "", "") + c.Flags().String("auth-default-org", "", "") + c.Flags().String("auth-groups-claim", "", "") + return c +} + +func TestBuildAuthFromFlags_None(t *testing.T) { + c := mockInitCmdWithFlags() + settings, hosts, err := buildAuthFromFlags(c, "none") + if err != nil { + t.Fatal(err) + } + if settings != nil || hosts != nil { + t.Errorf("none should return nils, got %v / %v", settings, hosts) + } +} + +func TestBuildAuthFromFlags_OIDC_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://login.example.com/") + _ = c.Flags().Set("auth-audience", "api://forge") + + settings, hosts, err := buildAuthFromFlags(c, "oidc") + if err != nil { + t.Fatal(err) + } + if settings["issuer"] != "https://login.example.com" { + t.Errorf("trailing slash should be stripped from issuer; got %v", settings["issuer"]) + } + if settings["audience"] != "api://forge" { + t.Errorf("audience wrong: %v", settings["audience"]) + } + if _, ok := settings["claim_map"]; ok { + t.Errorf("claim_map should not be set without --auth-groups-claim") + } + if !reflect.DeepEqual(hosts, []string{"login.example.com"}) { + t.Errorf("hosts = %v, want [login.example.com]", hosts) + } +} + +func TestBuildAuthFromFlags_OIDC_CustomGroupsClaim(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://x") + _ = c.Flags().Set("auth-audience", "y") + _ = c.Flags().Set("auth-groups-claim", "roles") + + settings, _, err := buildAuthFromFlags(c, "oidc") + if err != nil { + t.Fatal(err) + } + cm, ok := settings["claim_map"].(map[string]any) + if !ok { + t.Fatalf("expected claim_map, got %v", settings["claim_map"]) + } + if cm["groups"] != "roles" { + t.Errorf("claim_map.groups = %v, want roles", cm["groups"]) + } +} + +func TestBuildAuthFromFlags_OIDC_DefaultGroupsClaimSuppressed(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://x") + _ = c.Flags().Set("auth-audience", "y") + _ = c.Flags().Set("auth-groups-claim", "groups") // same as default + + settings, _, err := buildAuthFromFlags(c, "oidc") + if err != nil { + t.Fatal(err) + } + if _, ok := settings["claim_map"]; ok { + t.Errorf("default groups claim should be suppressed") + } +} + +func TestBuildAuthFromFlags_OIDC_MissingFields(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-issuer", "https://x") + // missing audience + _, _, err := buildAuthFromFlags(c, "oidc") + if err == nil { + t.Error("expected error for missing audience") + } +} + +func TestBuildAuthFromFlags_HTTPVerifier_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-url", "https://verify.example.com/verify") + _ = c.Flags().Set("auth-default-org", "acme") + + settings, hosts, err := buildAuthFromFlags(c, "http_verifier") + if err != nil { + t.Fatal(err) + } + if settings["url"] != "https://verify.example.com/verify" { + t.Errorf("url wrong: %v", settings["url"]) + } + if settings["default_org"] != "acme" { + t.Errorf("default_org wrong: %v", settings["default_org"]) + } + if !reflect.DeepEqual(hosts, []string{"verify.example.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_HTTPVerifier_MissingURL(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "http_verifier") + if err == nil { + t.Error("expected error for missing --auth-url") + } +} + +func TestBuildAuthFromFlags_UnknownMode(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "ldap") + if err == nil { + t.Error("expected error for unknown mode") + } +} + +// --- authEgressHostsFromSettings (PR6) --- + +func TestAuthEgressHostsFromSettings_OIDC(t *testing.T) { + got := authEgressHostsFromSettings("oidc", map[string]any{ + "issuer": "https://login.example.com/realms/x", + "audience": "api://forge", + }) + if !reflect.DeepEqual(got, []string{"login.example.com"}) { + t.Errorf("got %v, want [login.example.com]", got) + } +} + +func TestAuthEgressHostsFromSettings_OIDCWithJWKSURL(t *testing.T) { + got := authEgressHostsFromSettings("oidc", map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + "jwks_url": "https://keys.example.com/jwks", + }) + want := []string{"login.example.com", "keys.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestAuthEgressHostsFromSettings_HTTPVerifier(t *testing.T) { + got := authEgressHostsFromSettings("http_verifier", map[string]any{ + "url": "https://verify.example.com/verify", + }) + if !reflect.DeepEqual(got, []string{"verify.example.com"}) { + t.Errorf("got %v, want [verify.example.com]", got) + } +} + +func TestAuthEgressHostsFromSettings_NoNetworkModes(t *testing.T) { + for _, mode := range []string{"none", "custom", "static_token", "unknown"} { + if got := authEgressHostsFromSettings(mode, map[string]any{"x": "y"}); got != nil { + t.Errorf("mode %q returned %v, want nil", mode, got) + } + } +} + +func TestAuthEgressHostsFromSettings_NilSettings(t *testing.T) { + if got := authEgressHostsFromSettings("oidc", nil); got != nil { + t.Errorf("got %v, want nil for nil settings", got) + } +} + +func TestAuthEgressHostsFromSettings_MalformedURL(t *testing.T) { + if got := authEgressHostsFromSettings("oidc", map[string]any{ + "issuer": "::not a url", + }); got != nil { + t.Errorf("got %v, want nil for malformed URL", got) + } +} + +// --- mergeEgressDomains --- + +func TestMergeEgressDomains_Dedupe(t *testing.T) { + got := mergeEgressDomains( + []string{"a", "b", "c"}, + []string{"b", "d"}, + ) + want := []string{"a", "b", "c", "d"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestMergeEgressDomains_EmptyInputs(t *testing.T) { + if got := mergeEgressDomains(nil, nil); len(got) != 0 { + t.Errorf("got %v, want empty", got) + } + if got := mergeEgressDomains([]string{"a"}, nil); !reflect.DeepEqual(got, []string{"a"}) { + t.Errorf("got %v", got) + } + if got := mergeEgressDomains(nil, []string{"b"}); !reflect.DeepEqual(got, []string{"b"}) { + t.Errorf("got %v", got) + } +} + +func TestMergeEgressDomains_SortedOutput(t *testing.T) { + // Review #11c — review of #11 batch: the function returns sorted + // output so the generated forge.yaml stays stable across runs even + // when AuthEgressHosts arrive in non-deterministic order (e.g., + // derived from map iteration). + got := mergeEgressDomains( + []string{"api.openai.com", "api.anthropic.com"}, + []string{"login.example.com", "graph.microsoft.com"}, + ) + want := []string{"api.anthropic.com", "api.openai.com", "graph.microsoft.com", "login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("output not sorted:\n got %v\n want %v", got, want) + } +} diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index 77fc237..92b632f 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -15,6 +15,32 @@ var providerDomains = map[string]string{ // ollama is local, no egress needed } +// mergeEgressDomains returns the deduplicated, sorted union of `base` +// and `extra`. Sorting is intentional — `deriveEgressDomains` produces +// sorted output, and appending auth hosts unsorted at the tail (the +// pre-sort behavior of this function) made the rendered forge.yaml +// produce noisy diffs across runs when the auth host set changed. +// Review #11c. +// +// Used by the init scaffold to fold auth-provider hosts (from the +// wizard or --auth flags) into the egress allowlist while keeping the +// generated forge.yaml stable. +func mergeEgressDomains(base, extra []string) []string { + seen := make(map[string]bool, len(base)+len(extra)) + out := make([]string, 0, len(base)+len(extra)) + for _, slice := range [][]string{base, extra} { + for _, d := range slice { + if d == "" || seen[d] { + continue + } + seen[d] = true + out = append(out, d) + } + } + sort.Strings(out) + return out +} + // deriveEgressDomains computes the full set of egress domains needed based on // the provider, channels, builtin tools, and selected registry skills. func deriveEgressDomains(opts *initOptions, skills []contract.SkillDescriptor) []string { diff --git a/forge-cli/cmd/ui.go b/forge-cli/cmd/ui.go index a77aed5..4773287 100644 --- a/forge-cli/cmd/ui.go +++ b/forge-cli/cmd/ui.go @@ -112,6 +112,15 @@ func runUI(cmd *cobra.Command, args []string) error { _ = os.Setenv("FORGE_PASSPHRASE", opts.Passphrase) } + // Web UI auth chain selection (PR6). Translate the wizard payload + // into the same fields the TUI wizard / non-interactive flags use, + // so scaffold() has a single source of truth. + if opts.Auth != nil && opts.Auth.Mode != "" { + initOpts.AuthMode = opts.Auth.Mode + initOpts.AuthSettings = opts.Auth.Settings + initOpts.AuthEgressHosts = authEgressHostsFromSettings(opts.Auth.Mode, opts.Auth.Settings) + } + storeProviderEnvVar(initOpts) checkSkillRequirements(initOpts) diff --git a/forge-cli/internal/tui/steps/step_auth.go b/forge-cli/internal/tui/steps/step_auth.go new file mode 100644 index 0000000..000ef16 --- /dev/null +++ b/forge-cli/internal/tui/steps/step_auth.go @@ -0,0 +1,403 @@ +package steps + +import ( + "fmt" + "net/url" + "slices" + "strings" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/initializ/forge/forge-cli/internal/tui" + "github.com/initializ/forge/forge-cli/internal/tui/components" +) + +// authPhase is the internal state machine of the auth step. +type authPhase int + +const ( + authSelectPhase authPhase = iota + authOIDCIssuerPhase + authOIDCAudiencePhase + authOIDCGroupsClaimPhase + authHTTPURLPhase + authHTTPOrgPhase + authDonePhase +) + +// Auth mode constants — stored in WizardContext.AuthMode and read by the +// init scaffold to decide what to render in forge.yaml. +const ( + AuthModeNone = "none" + AuthModeOIDC = "oidc" + AuthModeHTTPVerifier = "http_verifier" + AuthModeCustom = "custom" +) + +// AuthStep is the wizard step that configures the A2A server's auth chain. +// +// UX shape: +// +// ┌─ select provider type +// │ ○ None (default) +// │ ○ OIDC +// │ ○ HTTP Verifier +// │ ○ Custom (edit forge.yaml manually) +// └─ on choice: +// OIDC → issuer → audience → groups claim → done +// HTTP Verifier → url → default org → done +// None / Custom → done immediately +// +// Per-phase: backspace on empty input returns to the picker (Esc-like). +type AuthStep struct { + styles *tui.StyleSet + phase authPhase + + selector components.SingleSelect + input components.TextInput + + mode string // selected provider type + + // OIDC settings + issuer string + audience string + groupsClaim string + + // HTTP verifier settings + httpURL string + httpOrg string + + complete bool +} + +// NewAuthStep constructs the auth step. +func NewAuthStep(styles *tui.StyleSet) *AuthStep { + items := []components.SingleSelectItem{ + { + Label: "None", + Value: AuthModeNone, + Description: "Anonymous access — no auth: block in forge.yaml", + Icon: "🔓", + }, + { + Label: "OIDC (JWT)", + Value: AuthModeOIDC, + Description: "Auth0, Keycloak, Azure AD, Google, Okta-OIDC, …", + Icon: "🔐", + }, + { + Label: "HTTP Verifier", + Value: AuthModeHTTPVerifier, + Description: "Legacy — POST tokens to your own /verify endpoint", + Icon: "🔁", + }, + { + Label: "Custom", + Value: AuthModeCustom, + Description: "Write a commented stub — I'll edit forge.yaml myself", + Icon: "✏️", + }, + } + + selector := components.NewSingleSelect( + items, + styles.Theme.Accent, + styles.Theme.Primary, + styles.Theme.Secondary, + styles.Theme.Dim, + styles.Theme.Border, + styles.Theme.ActiveBorder, + styles.Theme.ActiveBg, + styles.KbdKey, + styles.KbdDesc, + ) + + return &AuthStep{ + styles: styles, + selector: selector, + } +} + +func (s *AuthStep) Title() string { return "Authentication" } +func (s *AuthStep) Icon() string { return "🔐" } + +func (s *AuthStep) Init() tea.Cmd { return s.selector.Init() } + +func (s *AuthStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { + if s.complete { + return s, nil + } + + // Forward window-size events to whichever component is active. + if wsm, ok := msg.(tea.WindowSizeMsg); ok { + if s.phase == authSelectPhase { + updated, cmd := s.selector.Update(wsm) + s.selector = updated + return s, cmd + } + updated, cmd := s.input.Update(wsm) + s.input = updated + return s, cmd + } + + switch s.phase { + case authSelectPhase: + return s.updateSelect(msg) + case authOIDCIssuerPhase, authOIDCAudiencePhase, authOIDCGroupsClaimPhase, + authHTTPURLPhase, authHTTPOrgPhase: + return s.updateInput(msg) + } + return s, nil +} + +func (s *AuthStep) updateSelect(msg tea.Msg) (tui.Step, tea.Cmd) { + updated, cmd := s.selector.Update(msg) + s.selector = updated + if !s.selector.Done() { + return s, cmd + } + + _, val := s.selector.Selected() + s.mode = val + + switch val { + case AuthModeNone, AuthModeCustom: + // No further input — finish. + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + case AuthModeOIDC: + s.phase = authOIDCIssuerPhase + s.input = s.newTextInput( + "Issuer URL", + "https://login.example.com", + validateHTTPSURL, + ) + return s, s.input.Init() + case AuthModeHTTPVerifier: + s.phase = authHTTPURLPhase + s.input = s.newTextInput( + "Verifier URL", + "https://auth.example.com/verify", + validateHTTPSURL, + ) + return s, s.input.Init() + } + return s, cmd +} + +func (s *AuthStep) updateInput(msg tea.Msg) (tui.Step, tea.Cmd) { + // Backspace on empty input returns to the picker. + if km, ok := msg.(tea.KeyMsg); ok && km.String() == "backspace" { + if s.input.Value() == "" { + s.phase = authSelectPhase + s.mode = "" + s.selector.Reset() + return s, s.selector.Init() + } + } + + updated, cmd := s.input.Update(msg) + s.input = updated + if !s.input.Done() { + return s, cmd + } + + v := strings.TrimSpace(s.input.Value()) + switch s.phase { + case authOIDCIssuerPhase: + s.issuer = strings.TrimRight(v, "/") + s.phase = authOIDCAudiencePhase + s.input = s.newTextInput("Audience", "api://forge", validateNonEmpty) + return s, s.input.Init() + + case authOIDCAudiencePhase: + s.audience = v + s.phase = authOIDCGroupsClaimPhase + s.input = s.newTextInput( + "Groups claim (default: groups — press Enter to accept)", + "groups", + nil, // optional + ) + return s, s.input.Init() + + case authOIDCGroupsClaimPhase: + s.groupsClaim = v // may be "" + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + + case authHTTPURLPhase: + s.httpURL = v + s.phase = authHTTPOrgPhase + s.input = s.newTextInput( + "Default org_id (optional — press Enter to skip)", + "", + nil, + ) + return s, s.input.Init() + + case authHTTPOrgPhase: + s.httpOrg = v + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + } + + return s, cmd +} + +func (s *AuthStep) newTextInput(label, placeholder string, validate func(string) error) components.TextInput { + return components.NewTextInput( + label, + placeholder, + false, // no slug hint + validate, + s.styles.Theme.Accent, + s.styles.AccentTxt, + s.styles.InactiveBorder, + s.styles.ErrorTxt, + s.styles.DimTxt, + s.styles.KbdKey, + s.styles.KbdDesc, + ) +} + +func (s *AuthStep) View(width int) string { + switch s.phase { + case authSelectPhase: + return s.selector.View(width) + case authDonePhase: + return "" + default: + // Inline the "Backspace at empty: back" hint with the input + // so the escape hatch is discoverable from within the sub-step + // (review #11e — the documented behavior was undiscoverable + // because no kbd hint surfaced it). + hint := " " + + s.styles.KbdKey.Render("Backspace at empty") + " " + + s.styles.KbdDesc.Render("back to picker") + return s.input.View(width) + "\n" + hint + } +} + +func (s *AuthStep) Complete() bool { return s.complete } + +func (s *AuthStep) Summary() string { + switch s.mode { + case "", AuthModeNone: + return "Anonymous (no auth)" + case AuthModeOIDC: + host := hostnameOrEmpty(s.issuer) + if host == "" { + return "OIDC" + } + return "OIDC · " + host + case AuthModeHTTPVerifier: + host := hostnameOrEmpty(s.httpURL) + if host == "" { + return "HTTP Verifier" + } + return "HTTP Verifier · " + host + case AuthModeCustom: + return "Custom (edit forge.yaml)" + } + return s.mode +} + +// Apply writes the auth selection back to the wizard context. +// +// - AuthMode is always set (defaults to "none"). +// - AuthSettings is the typed settings block ready to copy into forge.yaml. +// - AuthEgressHosts captures the outbound hosts the runner will need +// reachable; the init scaffold merges these into egress.allowed_domains. +func (s *AuthStep) Apply(ctx *tui.WizardContext) { + mode := s.mode + if mode == "" { + mode = AuthModeNone + } + ctx.AuthMode = mode + + switch mode { + case AuthModeOIDC: + settings := map[string]any{ + "issuer": s.issuer, + "audience": s.audience, + } + // Only emit claim_map if user customized it (default name = "groups"). + if s.groupsClaim != "" && s.groupsClaim != "groups" { + settings["claim_map"] = map[string]any{"groups": s.groupsClaim} + } + ctx.AuthSettings = settings + if host := hostnameOrEmpty(s.issuer); host != "" { + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, host) + } + case AuthModeHTTPVerifier: + settings := map[string]any{"url": s.httpURL} + if s.httpOrg != "" { + settings["default_org"] = s.httpOrg + } + ctx.AuthSettings = settings + if host := hostnameOrEmpty(s.httpURL); host != "" { + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, host) + } + default: + // None / Custom: nothing to attach. + ctx.AuthSettings = nil + } +} + +// --- helpers --- + +// validateHTTPSURL ensures the input parses as a URL and uses http or https. +// Errors are surfaced inline by the text input. +func validateHTTPSURL(val string) error { + if val == "" { + return fmt.Errorf("required") + } + u, err := url.Parse(val) + if err != nil { + return fmt.Errorf("not a URL: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("must start with http:// or https://") + } + if u.Host == "" { + return fmt.Errorf("missing host") + } + return nil +} + +// validateNonEmpty enforces a non-empty input. +func validateNonEmpty(val string) error { + if strings.TrimSpace(val) == "" { + return fmt.Errorf("required") + } + return nil +} + +// hostnameOrEmpty returns the bare host (no port) from a URL, or "" on +// parse failure. +func hostnameOrEmpty(raw string) string { + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil || u.Host == "" { + return "" + } + return u.Hostname() +} + +// appendUnique appends v to slice if not already present. +func appendUnique(slice []string, v string) []string { + if slices.Contains(slice, v) { + return slice + } + return append(slice, v) +} + +// doneCmd emits the wizard "step complete" message. +func doneCmd() tea.Cmd { + return func() tea.Msg { return tui.StepCompleteMsg{} } +} diff --git a/forge-cli/internal/tui/steps/step_auth_test.go b/forge-cli/internal/tui/steps/step_auth_test.go new file mode 100644 index 0000000..81e3ddf --- /dev/null +++ b/forge-cli/internal/tui/steps/step_auth_test.go @@ -0,0 +1,398 @@ +package steps + +import ( + "reflect" + "testing" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/initializ/forge/forge-cli/internal/tui" +) + +// newTestAuthStep wires up an AuthStep with a default StyleSet so we can +// drive it without spinning up a real terminal. +func newTestAuthStep(t *testing.T) *AuthStep { + t.Helper() + styles := tui.NewStyleSet(tui.DarkTheme) + return NewAuthStep(styles) +} + +// press injects a key event into the step and returns the updated step. +func press(t *testing.T, s *AuthStep, key string) *AuthStep { + t.Helper() + var msg tea.KeyMsg + switch key { + case "down": + msg = tea.KeyMsg{Type: tea.KeyDown} + case "up": + msg = tea.KeyMsg{Type: tea.KeyUp} + case "enter": + msg = tea.KeyMsg{Type: tea.KeyEnter} + case "backspace": + msg = tea.KeyMsg{Type: tea.KeyBackspace} + default: + msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)} + } + updated, _ := s.Update(msg) + return updated.(*AuthStep) +} + +// typeIn injects each rune of the string as a single KeyRunes message. +func typeIn(t *testing.T, s *AuthStep, text string) *AuthStep { + t.Helper() + for _, r := range text { + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}} + updated, _ := s.Update(msg) + s = updated.(*AuthStep) + } + return s +} + +// --- Title / Icon --- + +func TestAuthStep_TitleAndIcon(t *testing.T) { + s := newTestAuthStep(t) + if s.Title() == "" { + t.Error("Title is empty") + } + if s.Icon() == "" { + t.Error("Icon is empty") + } +} + +// --- "None" path completes immediately --- + +func TestAuthStep_NoneIsDefault(t *testing.T) { + s := newTestAuthStep(t) + s = press(t, s, "enter") // Confirm "None" (first item) + if !s.Complete() { + t.Error("expected Complete after selecting None") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeNone { + t.Errorf("AuthMode = %q, want %q", ctx.AuthMode, AuthModeNone) + } + if ctx.AuthSettings != nil { + t.Errorf("AuthSettings = %v, want nil", ctx.AuthSettings) + } + if len(ctx.AuthEgressHosts) != 0 { + t.Errorf("AuthEgressHosts = %v, want empty", ctx.AuthEgressHosts) + } +} + +// --- "Custom" path completes immediately with no settings --- + +func TestAuthStep_Custom(t *testing.T) { + s := newTestAuthStep(t) + // Move down 3 times to reach "Custom" (index 3). + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after selecting Custom") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeCustom { + t.Errorf("AuthMode = %q, want %q", ctx.AuthMode, AuthModeCustom) + } + if ctx.AuthSettings != nil { + t.Errorf("AuthSettings = %v, want nil for custom mode", ctx.AuthSettings) + } +} + +// --- OIDC full path --- + +func TestAuthStep_OIDC_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + + // Pick OIDC (index 1). + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authOIDCIssuerPhase { + t.Fatalf("phase = %v, want issuer", s.phase) + } + + // Type issuer. + s = typeIn(t, s, "https://login.example.com/") + s = press(t, s, "enter") + if s.phase != authOIDCAudiencePhase { + t.Fatalf("phase = %v, want audience", s.phase) + } + + // Type audience. + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + if s.phase != authOIDCGroupsClaimPhase { + t.Fatalf("phase = %v, want groups claim", s.phase) + } + + // Accept default groups claim. + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after groups claim") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeOIDC { + t.Errorf("AuthMode = %q, want oidc", ctx.AuthMode) + } + if ctx.AuthSettings["issuer"] != "https://login.example.com" { + t.Errorf("issuer = %v, want trimmed trailing slash", ctx.AuthSettings["issuer"]) + } + if ctx.AuthSettings["audience"] != "api://forge" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + if _, ok := ctx.AuthSettings["claim_map"]; ok { + t.Errorf("claim_map should NOT be emitted when groups claim is default") + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"login.example.com"}) { + t.Errorf("AuthEgressHosts = %v, want [login.example.com]", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_OIDC_CustomGroupsClaim(t *testing.T) { + s := newTestAuthStep(t) + s = press(t, s, "down") + s = press(t, s, "enter") // OIDC + s = typeIn(t, s, "https://login.example.com") + s = press(t, s, "enter") + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + s = typeIn(t, s, "roles") // non-default groups claim + s = press(t, s, "enter") + + ctx := tui.NewWizardContext() + s.Apply(ctx) + cm, ok := ctx.AuthSettings["claim_map"].(map[string]any) + if !ok { + t.Fatalf("claim_map not present or wrong type: %v", ctx.AuthSettings) + } + if cm["groups"] != "roles" { + t.Errorf("claim_map.groups = %v, want roles", cm["groups"]) + } +} + +// --- HTTP Verifier full path --- + +func TestAuthStep_HTTPVerifier_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + + // Pick HTTP Verifier (index 2). + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authHTTPURLPhase { + t.Fatalf("phase = %v, want url", s.phase) + } + + s = typeIn(t, s, "https://verify.example.com/verify") + s = press(t, s, "enter") + if s.phase != authHTTPOrgPhase { + t.Fatalf("phase = %v, want org", s.phase) + } + + // Skip optional default org. + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeHTTPVerifier { + t.Errorf("AuthMode = %q, want http_verifier", ctx.AuthMode) + } + if ctx.AuthSettings["url"] != "https://verify.example.com/verify" { + t.Errorf("url = %v", ctx.AuthSettings["url"]) + } + if _, ok := ctx.AuthSettings["default_org"]; ok { + t.Errorf("default_org should not be emitted when blank") + } +} + +func TestAuthStep_HTTPVerifier_WithDefaultOrg(t *testing.T) { + s := newTestAuthStep(t) + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") // HTTP Verifier + s = typeIn(t, s, "https://verify.example.com/verify") + s = press(t, s, "enter") + s = typeIn(t, s, "acme") + s = press(t, s, "enter") + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthSettings["default_org"] != "acme" { + t.Errorf("default_org = %v, want acme", ctx.AuthSettings["default_org"]) + } +} + +// --- Summary text --- + +func TestAuthStep_Summary(t *testing.T) { + tests := []struct { + mode string + want string + }{ + {"", "Anonymous (no auth)"}, + {AuthModeNone, "Anonymous (no auth)"}, + {AuthModeOIDC, "OIDC"}, + {AuthModeCustom, "Custom (edit forge.yaml)"}, + } + for _, tt := range tests { + t.Run(tt.mode, func(t *testing.T) { + s := newTestAuthStep(t) + s.mode = tt.mode + if got := s.Summary(); got != tt.want { + t.Errorf("Summary = %q, want %q", got, tt.want) + } + }) + } +} + +// --- Helpers --- + +func TestValidateHTTPSURL(t *testing.T) { + tests := []struct { + val string + ok bool + }{ + {"", false}, + {"https://example.com", true}, + {"http://localhost:8080", true}, + {"login.example.com", false}, // no scheme + {"ftp://example.com", false}, // wrong scheme + {"https://", false}, // no host + } + for _, tt := range tests { + t.Run(tt.val, func(t *testing.T) { + err := validateHTTPSURL(tt.val) + if tt.ok && err != nil { + t.Errorf("expected ok, got %v", err) + } + if !tt.ok && err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} + +func TestHostnameOrEmpty(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"https://login.example.com", "login.example.com"}, + {"http://localhost:8080/realms/dev", "localhost"}, + {"not a url", ""}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + if got := hostnameOrEmpty(tt.in); got != tt.want { + t.Errorf("hostnameOrEmpty(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +// --- Wizard interaction tests (review #12.6, #12.7) --- + +func TestAuthStep_BackspaceOnEmptyInputReturnsToPicker(t *testing.T) { + // Review #12.6: backspace-on-empty is the documented escape hatch + // from a sub-step input back to the provider picker, but until + // review #11e there was no visible hint AND no test. The hint was + // added; this test pins the behavior. + s := newTestAuthStep(t) + + // Walk into OIDC issuer phase. + s = press(t, s, "down") // OIDC + s = press(t, s, "enter") + if s.phase != authOIDCIssuerPhase { + t.Fatalf("setup: phase = %v, want issuer", s.phase) + } + + // Backspace on EMPTY input → returns to picker. + s = press(t, s, "backspace") + if s.phase != authSelectPhase { + t.Errorf("backspace-on-empty: phase = %v, want authSelectPhase", s.phase) + } + if s.mode != "" { + t.Errorf("backspace-on-empty: mode = %q, want \"\" (reset)", s.mode) + } +} + +func TestAuthStep_BackspaceWithContentDoesNotReturnToPicker(t *testing.T) { + // Counterpart: backspace with TEXT in the input only deletes the + // last char — it must not escape out of the sub-step. + s := newTestAuthStep(t) + s = press(t, s, "down") // OIDC + s = press(t, s, "enter") + s = typeIn(t, s, "https://example.com") + if s.phase != authOIDCIssuerPhase { + t.Fatalf("setup: phase = %v, want issuer", s.phase) + } + + s = press(t, s, "backspace") + if s.phase != authOIDCIssuerPhase { + t.Errorf("backspace-with-content escaped to phase %v, expected to stay on issuer", s.phase) + } +} + +func TestAuthStep_InvalidIssuerURLBlocksAdvance(t *testing.T) { + // Review #12.7: validateHTTPSURL is unit-tested standalone, but + // the wizard's gating on bad input was never asserted as an + // interaction. Drive the step with a bad URL and confirm Enter + // doesn't advance to the audience sub-step. + s := newTestAuthStep(t) + s = press(t, s, "down") // OIDC + s = press(t, s, "enter") + + // Type something that fails validateHTTPSURL (no scheme). + s = typeIn(t, s, "login.example.com") + s = press(t, s, "enter") + + // MUST still be on the issuer phase — the bad URL blocked advance. + if s.phase != authOIDCIssuerPhase { + t.Errorf("invalid URL advanced past issuer phase to %v", s.phase) + } + + // And the step must not be complete — Review/scaffold MUST NOT run + // for a half-configured OIDC entry. + if s.Complete() { + t.Errorf("step reported complete with invalid issuer URL") + } + + // Now type a valid replacement and confirm the gate releases. + // Clear the input first by sending backspaces until empty — easier + // than reaching into TextInput state. + for range len("login.example.com") { + s = press(t, s, "backspace") + } + if s.phase != authOIDCIssuerPhase { + // Bug guard: a long sequence of backspaces should NOT escape to + // the picker mid-way (only an Enter-after-empty would). + t.Fatalf("backspaces escaped to phase %v unexpectedly", s.phase) + } + s = typeIn(t, s, "https://login.example.com") + s = press(t, s, "enter") + if s.phase != authOIDCAudiencePhase { + t.Errorf("valid URL did not advance: phase = %v, want audience", s.phase) + } +} + +func TestAppendUnique(t *testing.T) { + out := appendUnique([]string{"a", "b"}, "a") + if !reflect.DeepEqual(out, []string{"a", "b"}) { + t.Errorf("appendUnique duplicate failed: %v", out) + } + out = appendUnique([]string{"a"}, "b") + if !reflect.DeepEqual(out, []string{"a", "b"}) { + t.Errorf("appendUnique new failed: %v", out) + } +} diff --git a/forge-cli/internal/tui/wizard.go b/forge-cli/internal/tui/wizard.go index 73dd0c1..d992496 100644 --- a/forge-cli/internal/tui/wizard.go +++ b/forge-cli/internal/tui/wizard.go @@ -30,6 +30,26 @@ type WizardContext struct { CustomModel string CustomAPIKey string EnvVars map[string]string + + // AuthMode is the user's selection from the auth step: + // "" — wizard step did not run + // "none" — anonymous access (no auth: block written) + // "oidc" — OIDC provider + // "http_verifier" — legacy external verifier + // "custom" — write a commented stub; user edits manually + AuthMode string + + // AuthSettings is the type-specific settings block for the chosen + // provider, ready to plug into `auth.providers[0].settings` in + // forge.yaml. Keys: "issuer", "audience", "url", "default_org", + // "claim_map" (a nested map), etc. + AuthSettings map[string]any + + // AuthEgressHosts is the list of bare hosts (no port) that should be + // auto-added to the egress allowlist because the auth step requires + // them — e.g. an OIDC issuer or http_verifier URL. The init scaffold + // merges these into the rendered forge.yaml egress.allowed_domains. + AuthEgressHosts []string } // NewWizardContext creates an initialized WizardContext. @@ -37,6 +57,7 @@ func NewWizardContext() *WizardContext { return &WizardContext{ ChannelTokens: make(map[string]string), EnvVars: make(map[string]string), + AuthSettings: make(map[string]any), } } diff --git a/forge-cli/runtime/auth_audit_test.go b/forge-cli/runtime/auth_audit_test.go new file mode 100644 index 0000000..7e72b76 --- /dev/null +++ b/forge-cli/runtime/auth_audit_test.go @@ -0,0 +1,258 @@ +package runtime + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/initializ/forge/forge-core/auth" + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// PR4 audit-event contract tests. These guarantee: +// 1. auth_verify is emitted on success with the right fields. +// 2. auth_fail is emitted on failure with a stable reason code. +// 3. Neither event ever carries PII (token bytes, email, claim payloads). + +// captureAudit decodes each NDJSON line emitted to buf into an AuditEvent. +func captureAudit(t *testing.T, buf *bytes.Buffer) []coreruntime.AuditEvent { + t.Helper() + var events []coreruntime.AuditEvent + for _, line := range strings.Split(strings.TrimSpace(buf.String()), "\n") { + if line == "" { + continue + } + var ev coreruntime.AuditEvent + if err := json.Unmarshal([]byte(line), &ev); err != nil { + t.Fatalf("decode audit line %q: %v", line, err) + } + events = append(events, ev) + } + return events +} + +func TestAuthAudit_EmitsAuthVerifyOnSuccess(t *testing.T) { + var buf bytes.Buffer + logger := coreruntime.NewAuditLogger(&buf) + + cb := makeAuthAuditCallback(logger) + if cb == nil { + t.Fatal("callback nil with non-nil logger") + } + + req := httptest.NewRequest("POST", "/tasks", nil) + id := &auth.Identity{ + UserID: "alice-123", + Email: "alice@example.com", // must NOT leak into the event + OrgID: "tenant-1", + Groups: []string{"a", "b", "c"}, + Source: "oidc", + } + cb(req, id, nil, "jwt") + + events := captureAudit(t, &buf) + if len(events) != 1 { + t.Fatalf("got %d events, want 1", len(events)) + } + ev := events[0] + + if ev.Event != coreruntime.EventAuthVerify { + t.Errorf("Event = %q, want %q", ev.Event, coreruntime.EventAuthVerify) + } + if ev.Fields["provider"] != "oidc" { + t.Errorf("provider = %v, want oidc", ev.Fields["provider"]) + } + if ev.Fields["user_id"] != "alice-123" { + t.Errorf("user_id = %v, want alice-123", ev.Fields["user_id"]) + } + if ev.Fields["org_id"] != "tenant-1" { + t.Errorf("org_id = %v, want tenant-1", ev.Fields["org_id"]) + } + // JSON unmarshal puts numbers in float64. + if g, _ := ev.Fields["groups_count"].(float64); g != 3 { + t.Errorf("groups_count = %v, want 3", ev.Fields["groups_count"]) + } + if ev.Fields["token_kind"] != "jwt" { + t.Errorf("token_kind = %v, want jwt", ev.Fields["token_kind"]) + } +} + +func TestAuthAudit_EmitsAuthFailOnError(t *testing.T) { + tests := []struct { + name string + err error + wantReason string + }{ + {"missing token", auth.ErrMissingBearer, "missing_token"}, + {"rejected", auth.ErrTokenRejected, "rejected"}, + {"invalid", auth.ErrInvalidToken, "invalid"}, + {"not for me", auth.ErrTokenNotForMe, "not_for_me"}, + // Review #6: provider downtime gets its own stable reason code + // so operators can dashboard "IdP down" separately from "token bad". + {"provider unavailable", auth.ErrProviderUnavailable, "provider_unavailable"}, + {"infrastructure", errBoom, "infrastructure"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest("POST", "/tasks", nil) + cb(req, nil, tt.err, "opaque") + + events := captureAudit(t, &buf) + if len(events) != 1 { + t.Fatalf("got %d events, want 1", len(events)) + } + ev := events[0] + if ev.Event != coreruntime.EventAuthFail { + t.Errorf("Event = %q, want %q", ev.Event, coreruntime.EventAuthFail) + } + if got := ev.Fields["reason"]; got != tt.wantReason { + t.Errorf("reason = %v, want %q", got, tt.wantReason) + } + if ev.Fields["token_kind"] != "opaque" { + t.Errorf("token_kind = %v, want opaque", ev.Fields["token_kind"]) + } + }) + } +} + +func TestAuthAudit_NoPIIInPayload(t *testing.T) { + // Strong negative test: even if the Identity carries an email and + // rich Claims, the emitted audit line MUST NOT include them. + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest("POST", "/tasks", nil) + id := &auth.Identity{ + UserID: "u", + Email: "secret@hr.example.com", + Claims: map[string]any{ + "ssn": "999-99-9999", + "secret": "deadbeef", + }, + Source: "oidc", + } + cb(req, id, nil, "jwt") + + line := buf.String() + forbidden := []string{ + "secret@hr.example.com", + "999-99-9999", + "deadbeef", + "\"email\"", + "\"claims\"", + } + for _, f := range forbidden { + if strings.Contains(line, f) { + t.Errorf("audit line leaked %q:\n%s", f, line) + } + } +} + +func TestAuthAudit_NilLoggerReturnsNilCallback(t *testing.T) { + if cb := makeAuthAuditCallback(nil); cb != nil { + t.Error("expected nil callback for nil logger") + } +} + +func TestAuthAudit_CorrelationIDPropagated(t *testing.T) { + var buf bytes.Buffer + cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf)) + + req := httptest.NewRequest("POST", "/tasks", nil) + ctx := coreruntime.WithCorrelationID(req.Context(), "corr-abc-123") + req = req.WithContext(ctx) + + cb(req, &auth.Identity{UserID: "u", Source: "oidc"}, nil, "jwt") + + events := captureAudit(t, &buf) + if events[0].CorrelationID != "corr-abc-123" { + t.Errorf("CorrelationID = %q, want corr-abc-123", events[0].CorrelationID) + } +} + +func TestAuthFailReason_UnknownError(t *testing.T) { + // Unknown error types map to "infrastructure" (catch-all for things + // like network failures, JWKS-fetch errors). + if got := authFailReason(errBoom); got != "infrastructure" { + t.Errorf("authFailReason(custom err) = %q, want infrastructure", got) + } + if got := authFailReason(nil); got != "unknown" { + t.Errorf("authFailReason(nil) = %q, want unknown", got) + } +} + +// Smoke test: making sure the OnAuth callback hooks into auth.Middleware +// end-to-end and produces audit events on real HTTP traffic. +func TestAuthAudit_E2EThroughMiddleware(t *testing.T) { + var buf bytes.Buffer + logger := coreruntime.NewAuditLogger(&buf) + + chain := auth.NewChainProvider(&testProvider{ + token: "ok", + identity: auth.Identity{UserID: "u", OrgID: "o", Source: "test"}, + }) + + handler := auth.Middleware(auth.MiddlewareOptions{ + Chain: chain, + SkipPaths: auth.DefaultSkipPaths(), + OnAuth: makeAuthAuditCallback(logger), + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // 1. Successful request. + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "Bearer ok") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // 2. Missing token. + req2 := httptest.NewRequest("POST", "/tasks", nil) + handler.ServeHTTP(httptest.NewRecorder(), req2) + + events := captureAudit(t, &buf) + if len(events) != 2 { + t.Fatalf("got %d events, want 2", len(events)) + } + if events[0].Event != coreruntime.EventAuthVerify { + t.Errorf("first event = %q, want auth_verify", events[0].Event) + } + if events[1].Event != coreruntime.EventAuthFail { + t.Errorf("second event = %q, want auth_fail", events[1].Event) + } + if events[1].Fields["reason"] != "missing_token" { + t.Errorf("missing-token reason = %v, want missing_token", events[1].Fields["reason"]) + } +} + +// --- helpers --- + +// errBoom is an opaque non-sentinel error for testing the catch-all +// "infrastructure" reason mapping. +type customErr struct{ s string } + +func (e customErr) Error() string { return e.s } + +var errBoom error = customErr{s: "network is on fire"} + +// testProvider is a tiny Provider that accepts one token. +type testProvider struct { + token string + identity auth.Identity +} + +func (t *testProvider) Name() string { return "test" } +func (t *testProvider) Verify(_ context.Context, token string, _ auth.Headers) (*auth.Identity, error) { + if token != t.token { + return nil, auth.ErrTokenNotForMe + } + id := t.identity + return &id, nil +} diff --git a/forge-cli/runtime/auth_chain_test.go b/forge-cli/runtime/auth_chain_test.go new file mode 100644 index 0000000..d026c1e --- /dev/null +++ b/forge-cli/runtime/auth_chain_test.go @@ -0,0 +1,365 @@ +package runtime + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/types" +) + +// E2E tests covering the auth chain that the runner builds from legacy +// CLI flags (--auth-token, --auth-url, --auth-org-id). These exist to +// guarantee that the PR1 refactor preserves the pre-refactor behavior +// byte-for-byte. If any of these fail, the refactor changed externally +// observable behavior — investigate before merging. + +// fakeVerifier accepts a handler and returns a httptest.Server that +// implements the http_verifier contract. +func newFakeVerifier(t *testing.T, handler func(req map[string]any) (status int, body any)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + _ = json.Unmarshal(body, &req) + status, respBody := handler(req) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if respBody != nil { + _ = json.NewEncoder(w).Encode(respBody) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func TestBuildLegacyAuthChain_NoAuth(t *testing.T) { + chain, err := buildLegacyAuthChain("", "", "") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + if chain != nil { + t.Errorf("chain = %v, want nil (no providers configured)", chain) + } +} + +func TestBuildLegacyAuthChain_TokenOnly(t *testing.T) { + chain, err := buildLegacyAuthChain("local-token", "", "") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + if chain == nil { + t.Fatal("chain is nil, want a ChainProvider with static_token") + } + + // Valid token accepted. + id, err := chain.Verify(context.Background(), "local-token", nil) + if err != nil { + t.Errorf("valid token rejected: %v", err) + } + if id == nil || id.Source != "internal" { + t.Errorf("identity = %+v, want Source=internal", id) + } + + // Wrong token yields with ErrTokenNotForMe (and at chain end, that's + // what we see — no http_verifier behind to claim it). + _, err = chain.Verify(context.Background(), "wrong", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("wrong token err = %v, want ErrTokenNotForMe", err) + } +} + +func TestBuildLegacyAuthChain_AuthURLOnly(t *testing.T) { + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + if req["token"] == "external-token" { + return http.StatusOK, map[string]any{ + "valid": true, + "user_id": "external-user", + "org_id": "external-org", + } + } + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, err := buildLegacyAuthChain("", srv.URL, "default-org") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + if chain == nil { + t.Fatal("chain is nil, want http_verifier chain") + } + + id, err := chain.Verify(context.Background(), "external-token", nil) + if err != nil { + t.Errorf("external token rejected: %v", err) + } + if id == nil || id.UserID != "external-user" || id.Source != "http_verifier" { + t.Errorf("identity = %+v", id) + } + + _, err = chain.Verify(context.Background(), "wrong", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("wrong token err = %v, want ErrTokenRejected", err) + } +} + +func TestBuildLegacyAuthChain_TokenAndAuthURL_LoopbackShortCircuits(t *testing.T) { + // This is the pre-refactor behavior: when both an internal Token AND + // AuthURL are set, the internal token is checked FIRST so loopback + // calls from channel adapters don't hit the external verifier. + var verifierCalls atomic.Int32 + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + verifierCalls.Add(1) + return http.StatusOK, map[string]any{"valid": true, "user_id": "external"} + }) + + chain, err := buildLegacyAuthChain("loopback-token", srv.URL, "") + if err != nil { + t.Fatalf("buildLegacyAuthChain: %v", err) + } + + // Loopback token: external verifier should NOT be called. + id, err := chain.Verify(context.Background(), "loopback-token", nil) + if err != nil { + t.Fatalf("loopback verify: %v", err) + } + if id.UserID != "forge-internal" || id.Source != "internal" { + t.Errorf("loopback identity = %+v, want forge-internal/internal", id) + } + if got := verifierCalls.Load(); got != 0 { + t.Errorf("external verifier called %d times during loopback (must be 0)", got) + } + + // Non-loopback token: falls through to external verifier. + id2, err := chain.Verify(context.Background(), "other-token", nil) + if err != nil { + t.Fatalf("external verify: %v", err) + } + if id2.Source != "http_verifier" { + t.Errorf("external identity Source = %q, want http_verifier", id2.Source) + } + if got := verifierCalls.Load(); got != 1 { + t.Errorf("external verifier called %d times, want 1", got) + } +} + +func TestBuildLegacyAuthChain_OrgIDPropagation(t *testing.T) { + var capturedOrgID atomic.Value + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + capturedOrgID.Store(req["org_id"]) + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "default-org") + + // No headers → config default. + if _, err := chain.Verify(context.Background(), "tok", nil); err != nil { + t.Fatalf("Verify: %v", err) + } + if got := capturedOrgID.Load(); got != "default-org" { + t.Errorf("verifier got org_id %v, want default-org", got) + } + + // X-Org-ID header overrides default. + if _, err := chain.Verify(context.Background(), "tok", auth.Headers{"X-Org-ID": "header-org"}); err != nil { + t.Fatalf("Verify: %v", err) + } + if got := capturedOrgID.Load(); got != "header-org" { + t.Errorf("verifier got org_id %v, want header-org", got) + } +} + +// --- middleware-level end-to-end --- + +// runMiddleware is a small driver that runs an http.Handler chain through +// the auth.Middleware with a given chain and returns the recorded response. +// When chain is nil, the test explicitly opts the middleware into +// anonymous mode (review #3 — nil Chain without AllowAnonymous panics). +func runMiddleware(t *testing.T, chain auth.Provider, method, path, authHeader string) *httptest.ResponseRecorder { + t.Helper() + handler := auth.Middleware(auth.MiddlewareOptions{ + Chain: chain, + AllowAnonymous: chain == nil, + SkipPaths: auth.DefaultSkipPaths(), + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + + req := httptest.NewRequest(method, path, nil) + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + return rr +} + +func TestE2E_NoAuthConfigured_AnonymousAccess(t *testing.T) { + chain, _ := buildLegacyAuthChain("", "", "") + // nil chain → middleware passthrough + rr := runMiddleware(t, chain, "POST", "/tasks/send", "") + if rr.Code != http.StatusOK { + t.Errorf("anonymous access blocked: status = %d", rr.Code) + } +} + +func TestE2E_LegacyAuthURL_HappyPath(t *testing.T) { + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + if req["token"] == "good" { + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + } + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer good") + if rr.Code != http.StatusOK { + t.Errorf("good token rejected: status = %d, body = %s", rr.Code, rr.Body.String()) + } +} + +func TestE2E_LegacyAuthURL_BadTokenReturns401(t *testing.T) { + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer bad") + if rr.Code != http.StatusUnauthorized { + t.Errorf("bad token accepted: status = %d", rr.Code) + } +} + +func TestE2E_SkipPathsBypassAuth(t *testing.T) { + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + t.Error("verifier should not be called for skip paths") + return http.StatusOK, map[string]any{"valid": false} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + + skipPaths := []struct { + method, path string + }{ + {"GET", "/"}, + {"GET", "/.well-known/agent.json"}, + {"GET", "/healthz"}, + {"GET", "/health"}, + } + + for _, sp := range skipPaths { + t.Run(sp.method+" "+sp.path, func(t *testing.T) { + rr := runMiddleware(t, chain, sp.method, sp.path, "") + if rr.Code != http.StatusOK { + t.Errorf("skip path %s %s blocked: status = %d", sp.method, sp.path, rr.Code) + } + }) + } +} + +func TestE2E_ChannelLoopback_DoesNotCallVerifier(t *testing.T) { + // Simulates a Slack/Telegram adapter calling the local A2A server + // with the internal loopback token. The external verifier must NOT + // be consulted — that would be wasteful and would break if the + // verifier is misconfigured/unreachable. + var verifierCalls atomic.Int32 + srv := newFakeVerifier(t, func(map[string]any) (int, any) { + verifierCalls.Add(1) + return http.StatusOK, map[string]any{"valid": true, "user_id": "external"} + }) + + chain, _ := buildLegacyAuthChain("loopback-secret", srv.URL, "") + + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer loopback-secret") + if rr.Code != http.StatusOK { + t.Errorf("loopback token rejected: status = %d", rr.Code) + } + if got := verifierCalls.Load(); got != 0 { + t.Errorf("external verifier called %d times for loopback (must be 0)", got) + } +} + +func TestE2E_VerifierUnreachable_ReturnsInvalidNot500(t *testing.T) { + // When the external verifier is down (port closed), the middleware + // should return 401, not 500 — preserving the pre-refactor "auth + // provider error" behavior. + chain, _ := buildLegacyAuthChain("", "http://127.0.0.1:1/never", "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer x") + if rr.Code != http.StatusUnauthorized { + t.Errorf("verifier-unreachable status = %d, want 401", rr.Code) + } +} + +// --- ResolveAuth invariant pin (review #10) --- +// +// The loopback-token prepend in resolveAuth() is conditional on +// r.authToken != "" — which works because ResolveAuth() always mints +// one in the non-NoAuth path. If that invariant ever breaks (refactor +// of ResolveAuth that skips minting), channel adapters silently lose +// their loopback. This test pins the invariant directly. + +func TestResolveAuth_InvariantMintsTokenInNonNoAuthPath(t *testing.T) { + tmp := t.TempDir() + r := &Runner{logger: nopLogger{}} + r.cfg.Config = &types.ForgeConfig{} + r.cfg.NoAuth = false + r.cfg.Host = "127.0.0.1" // localhost so the !local + NoAuth check doesn't trip + r.cfg.WorkDir = tmp // ResolveAuth stores the token under /.forge/ + + if err := r.ResolveAuth(); err != nil { + t.Fatalf("ResolveAuth: %v", err) + } + if r.authToken == "" { + t.Fatal("invariant broken: ResolveAuth() returned nil but did not mint a token " + + "(non-NoAuth path MUST set r.authToken — channels rely on this; " + + "see runner.go resolveAuth loopback-prepend comment)") + } +} + +func TestResolveAuth_NoAuthPathDoesNotMintToken(t *testing.T) { + // Counterpart: in --no-auth mode, no token should be minted. The + // AllowAnonymous branch returns before the prepend ever runs, so + // the absence of a token is correct. + r := &Runner{logger: nopLogger{}} + r.cfg.Config = &types.ForgeConfig{} + r.cfg.NoAuth = true + r.cfg.Host = "127.0.0.1" + + if err := r.ResolveAuth(); err != nil { + t.Fatalf("ResolveAuth: %v", err) + } + if r.authToken != "" { + t.Errorf("--no-auth path should not mint a token, got %q", r.authToken) + } +} + +func TestE2E_WireFormat_TokenAndOrgID(t *testing.T) { + // Black-box: confirm the verifier sees exactly `token` and `org_id` + // JSON fields. Guards against future schema drift. + var sawToken, sawOrgID atomic.Value + srv := newFakeVerifier(t, func(req map[string]any) (int, any) { + sawToken.Store(req["token"]) + sawOrgID.Store(req["org_id"]) + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + }) + + chain, _ := buildLegacyAuthChain("", srv.URL, "") + rr := runMiddleware(t, chain, "POST", "/tasks/send", "Bearer the-token") + if rr.Code != http.StatusOK { + t.Fatalf("status = %d", rr.Code) + } + if got := sawToken.Load(); got != "the-token" { + t.Errorf("verifier got token = %v, want the-token", got) + } + // No org headers sent → verifier sees empty org_id. + if got := sawOrgID.Load(); got != "" { + t.Errorf("verifier got org_id = %v, want empty", got) + } +} diff --git a/forge-cli/runtime/auth_config_test.go b/forge-cli/runtime/auth_config_test.go new file mode 100644 index 0000000..199a016 --- /dev/null +++ b/forge-cli/runtime/auth_config_test.go @@ -0,0 +1,462 @@ +package runtime + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/types" +) + +// E2E coverage for PR3's forge.yaml `auth:` block → ChainProvider wiring. + +// fakeVerifierFor is a test verifier that accepts `goodToken` and rejects +// everything else. +func fakeVerifierFor(t *testing.T, goodToken string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + _ = json.Unmarshal(body, &req) + if req["token"] == goodToken { + _ = json.NewEncoder(w).Encode(map[string]any{"valid": true, "user_id": "u"}) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{"valid": false}) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestBuildChainFromConfig_Empty(t *testing.T) { + chain, err := buildChainFromConfig(types.AuthConfig{}) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if chain != nil { + t.Errorf("chain = %v, want nil for empty config", chain) + } +} + +func TestBuildChainFromConfig_SingleHTTPVerifier(t *testing.T) { + srv := fakeVerifierFor(t, "good") + chain, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srv.URL}}, + }, + }) + if err != nil { + t.Fatalf("buildChainFromConfig: %v", err) + } + id, err := chain.Verify(context.Background(), "good", nil) + if err != nil || id == nil { + t.Errorf("good token rejected: id=%v err=%v", id, err) + } + if _, err := chain.Verify(context.Background(), "bad", nil); !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("bad token err = %v, want ErrTokenRejected", err) + } +} + +func TestBuildChainFromConfig_StaticToken(t *testing.T) { + t.Setenv("FORGE_TEST_STATIC_TOKEN", "dev-secret") + + chain, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "static_token", Settings: map[string]any{"token_env": "FORGE_TEST_STATIC_TOKEN"}}, + }, + }) + if err != nil { + t.Fatalf("buildChainFromConfig: %v", err) + } + if _, err := chain.Verify(context.Background(), "dev-secret", nil); err != nil { + t.Errorf("dev-secret rejected: %v", err) + } +} + +func TestBuildChainFromConfig_OrderedChain(t *testing.T) { + // Verify first-match-wins ordering: a static_token at position 1 + // short-circuits an http_verifier at position 2. + verifierHit := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + verifierHit = true + _ = json.NewEncoder(w).Encode(map[string]any{"valid": true, "user_id": "external"}) + })) + defer srv.Close() + + t.Setenv("FORGE_TEST_LOOPBACK_TOKEN", "loopback") + + chain, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "static_token", Settings: map[string]any{"token_env": "FORGE_TEST_LOOPBACK_TOKEN"}}, + {Type: "http_verifier", Settings: map[string]any{"url": srv.URL}}, + }, + }) + if err != nil { + t.Fatalf("buildChainFromConfig: %v", err) + } + + // Loopback token matches first provider — verifier must NOT be called. + id, err := chain.Verify(context.Background(), "loopback", nil) + if err != nil { + t.Fatalf("loopback verify: %v", err) + } + if id.Source != "static_token" { + t.Errorf("identity.Source = %q, want static_token", id.Source) + } + if verifierHit { + t.Error("http_verifier was called despite static_token match (chain ordering broken)") + } + + // Different token falls through to verifier. + if _, err := chain.Verify(context.Background(), "other", nil); err != nil { + t.Fatalf("fallthrough verify: %v", err) + } + if !verifierHit { + t.Error("http_verifier was not called for non-loopback token (chain ordering broken)") + } +} + +func TestBuildChainFromConfig_InvalidProviderType(t *testing.T) { + _, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "ldap"}, // not registered + }, + }) + if err == nil { + t.Fatal("expected error for unknown provider type") + } +} + +func TestBuildChainFromConfig_FactoryError(t *testing.T) { + _, err := buildChainFromConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc"}, // missing issuer + audience + }, + }) + if err == nil { + t.Fatal("expected error for oidc without issuer/audience") + } + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Errorf("err = %v, want wrapped ErrProviderNotConfigured", err) + } +} + +func TestBuildLegacyHTTPVerifierChain_HappyPath(t *testing.T) { + srv := fakeVerifierFor(t, "ok") + chain, err := buildLegacyHTTPVerifierChain(srv.URL, "default-org") + if err != nil { + t.Fatalf("buildLegacyHTTPVerifierChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "ok", nil); err != nil { + t.Errorf("verify failed: %v", err) + } +} + +func TestBuildLegacyHTTPVerifierChain_EmptyURLFails(t *testing.T) { + if _, err := buildLegacyHTTPVerifierChain("", ""); err == nil { + t.Error("expected error for empty URL") + } +} + +// --- buildUserAuthChain (precedence rules) --- + +// runnerWith returns a Runner whose Config has the given AuthConfig and +// CLI fields. Minimal — only what buildUserAuthChain needs. +func runnerWith(authYAML types.AuthConfig, authURL string) *Runner { + r := &Runner{logger: nopLogger{}} + r.cfg.Config = &types.ForgeConfig{Auth: authYAML} + r.cfg.AuthURL = authURL + return r +} + +// nopLogger is a Logger that does nothing — keeps the warning path out +// of test output for the prefer-YAML-over-flag test below. +type nopLogger struct{} + +func (nopLogger) Debug(string, map[string]any) {} +func (nopLogger) Info(string, map[string]any) {} +func (nopLogger) Warn(string, map[string]any) {} +func (nopLogger) Error(string, map[string]any) {} + +func TestBuildUserAuthChain_NeitherSourcePresent(t *testing.T) { + r := runnerWith(types.AuthConfig{}, "") + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if chain != nil { + t.Errorf("chain = %v, want nil when neither YAML nor flag is set", chain) + } +} + +func TestBuildUserAuthChain_OnlyLegacyURL(t *testing.T) { + srv := fakeVerifierFor(t, "ok") + r := runnerWith(types.AuthConfig{}, srv.URL) + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "ok", nil); err != nil { + t.Errorf("verify failed: %v", err) + } +} + +func TestBuildUserAuthChain_OnlyYAMLAuth(t *testing.T) { + srv := fakeVerifierFor(t, "yaml-token") + r := runnerWith(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srv.URL}}, + }, + }, "") + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "yaml-token", nil); err != nil { + t.Errorf("yaml-source verifier rejected: %v", err) + } +} + +func TestBuildUserAuthChain_BothSourcesPrefersYAML(t *testing.T) { + // YAML auth points at srvYAML (accepts "yaml-only"). + // Legacy --auth-url points at srvLegacy (accepts "legacy-only"). + // Expect: YAML wins, legacy verifier never consulted, AND a warning + // is emitted (review #12.1 — previously asserted only the chain + // behavior; the "we logged a warning" half of the contract was + // invisible because the test used the nop logger). + srvYAML := fakeVerifierFor(t, "yaml-only") + srvLegacy := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("legacy verifier should not be called when YAML auth: is present") + })) + defer srvLegacy.Close() + + // Use a capturing logger so we can assert the warning emission too. + logger := &recordLogger{} + r := &Runner{logger: logger} + r.cfg.Config = &types.ForgeConfig{Auth: types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srvYAML.URL}}, + }, + }} + r.cfg.AuthURL = srvLegacy.URL + + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if _, err := chain.Verify(context.Background(), "yaml-only", nil); err != nil { + t.Errorf("yaml-only rejected: %v", err) + } + if _, err := chain.Verify(context.Background(), "legacy-only", nil); !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("legacy-only token err = %v, want ErrTokenRejected", err) + } + + // The contract: when both YAML auth and --auth-url are configured, + // we warn that --auth-url is being ignored. Without this assertion + // a refactor that silently drops the warning would never surface. + if !logger.warnedAbout("--auth-url and forge.yaml") { + t.Errorf("expected warning when both sources configured; got: %+v", logger.warnings) + } +} + +// --- Full resolveAuth + YAML auth + loopback prepend (review #12.3) --- +// +// auth_chain_test.go exercises buildLegacyAuthChain. This test goes one +// level higher and exercises resolveAuth() end-to-end with a forge.yaml +// `auth:` block configured AND an internal token minted by +// ResolveAuth(). The assertion: the loopback static_token sits at the +// chain head, AHEAD of the user-configured providers, so channel +// adapters short-circuit before the OIDC/http_verifier roundtrip. + +func TestResolveAuth_YAMLAuth_LoopbackPrependedAtChainHead(t *testing.T) { + tmp := t.TempDir() + + // Fake verifier for the YAML-configured http_verifier provider. + srvYAML := fakeVerifierFor(t, "external-token") + + logger := &recordLogger{} + r := &Runner{logger: logger} + r.cfg.Config = &types.ForgeConfig{Auth: types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{"url": srvYAML.URL}}, + }, + }} + r.cfg.NoAuth = false + r.cfg.Host = "127.0.0.1" // local — passes the localhost gate in ResolveAuth + r.cfg.WorkDir = tmp // ResolveAuth writes runtime.token here + + opts, err := r.resolveAuth(nil) + if err != nil { + t.Fatalf("resolveAuth: %v", err) + } + if opts.Chain == nil { + t.Fatal("resolveAuth returned a nil Chain — loopback not prepended") + } + chain, ok := opts.Chain.(*auth.ChainProvider) + if !ok { + t.Fatalf("Chain is not a *ChainProvider; got %T", opts.Chain) + } + providers := chain.Providers() + if len(providers) < 2 { + t.Fatalf("chain has %d providers; want ≥ 2 (loopback + YAML)", len(providers)) + } + + // Head MUST be the loopback static_token — channel adapter + // short-circuit depends on this. If a refactor inserts a YAML + // provider before it, channels start round-tripping to the + // external verifier on every callback. + if providers[0].Name() != "static_token" { + t.Errorf("chain.providers[0].Name() = %q, want %q (loopback)", providers[0].Name(), "static_token") + } + if providers[1].Name() != "http_verifier" { + t.Errorf("chain.providers[1].Name() = %q, want %q (YAML)", providers[1].Name(), "http_verifier") + } + + // And smoke-test that the loopback token is actually the one + // ResolveAuth minted: Verify with r.authToken must succeed without + // touching the YAML verifier. + id, err := chain.Verify(context.Background(), r.authToken, nil) + if err != nil { + t.Fatalf("loopback verify failed: %v", err) + } + if id.Source != "internal" { + t.Errorf("loopback identity Source = %q, want %q", id.Source, "internal") + } +} + +// --- --no-auth vs forge.yaml auth: block (review #4) --- +// +// Phase 1 + PR3 let --no-auth and the YAML auth block be set +// independently. Review #4 cross-checks them at resolveAuth time so +// a contradictory combination fails loudly instead of silently +// granting anonymous access. + +// makeRunnerWithNoAuth returns a Runner with NoAuth=true and the given +// AuthConfig in forge.yaml. Captures logger output via a recording +// logger so warning assertions can read what the runner emitted. +func makeRunnerWithNoAuth(t *testing.T, authYAML types.AuthConfig) (*Runner, *recordLogger) { + t.Helper() + logger := &recordLogger{} + r := &Runner{logger: logger} + r.cfg.Config = &types.ForgeConfig{Auth: authYAML} + r.cfg.NoAuth = true + return r, logger +} + +func TestResolveAuth_NoAuthWithRequiredFails(t *testing.T) { + r, _ := makeRunnerWithNoAuth(t, types.AuthConfig{ + Required: true, + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }) + _, err := r.resolveAuth(nil) + if err == nil { + t.Fatal("expected error when --no-auth conflicts with auth.required: true") + } + if !strings.Contains(err.Error(), "--no-auth conflicts with") { + t.Errorf("err = %q, want descriptive --no-auth/required mismatch message", err) + } +} + +func TestResolveAuth_NoAuthWithProvidersWarns(t *testing.T) { + r, logger := makeRunnerWithNoAuth(t, types.AuthConfig{ + // Required: false — the operator hasn't asserted that auth is + // mandatory; --no-auth should still work but it surprises them + // that the YAML providers are ignored, so we warn. + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }) + opts, err := r.resolveAuth(nil) + if err != nil { + t.Fatalf("resolveAuth returned error, want warning only: %v", err) + } + if !opts.AllowAnonymous { + t.Error("expected AllowAnonymous=true under --no-auth") + } + if !logger.warnedAbout("--no-auth overrides") { + t.Errorf("expected --no-auth-overrides warning; got: %+v", logger.warnings) + } +} + +func TestResolveAuth_NoAuthWithEmptyYAMLConfig_NoWarning(t *testing.T) { + // Regression: --no-auth alone (no auth: block at all) must still + // work without spamming a warning. + r, logger := makeRunnerWithNoAuth(t, types.AuthConfig{}) + opts, err := r.resolveAuth(nil) + if err != nil { + t.Fatalf("resolveAuth: %v", err) + } + if !opts.AllowAnonymous { + t.Error("expected AllowAnonymous=true under --no-auth") + } + if len(logger.warnings) > 0 { + t.Errorf("unexpected warnings emitted: %+v", logger.warnings) + } +} + +func TestResolveAuth_NoAuthWithRequiredFalseAndProviders_WarnsNotFails(t *testing.T) { + // Explicit Required: false + Providers set — operator knows what + // they're doing; we warn but don't refuse. + r, logger := makeRunnerWithNoAuth(t, types.AuthConfig{ + Required: false, + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }) + if _, err := r.resolveAuth(nil); err != nil { + t.Fatalf("resolveAuth should not fail when Required=false: %v", err) + } + if !logger.warnedAbout("--no-auth overrides") { + t.Error("expected override warning") + } +} + +// recordLogger captures Warn/Info/Debug/Error calls for assertions. +// Mirrors the no-op nopLogger pattern in auth_config_test.go but +// records messages so tests can assert on them. +type recordLogger struct { + warnings []string + infos []string +} + +func (r *recordLogger) Debug(msg string, _ map[string]any) {} +func (r *recordLogger) Info(msg string, _ map[string]any) { r.infos = append(r.infos, msg) } +func (r *recordLogger) Warn(msg string, _ map[string]any) { r.warnings = append(r.warnings, msg) } +func (r *recordLogger) Error(msg string, _ map[string]any) {} +func (r *recordLogger) warnedAbout(substr string) bool { + for _, w := range r.warnings { + if strings.Contains(w, substr) { + return true + } + } + return false +} + +func TestBuildUserAuthChain_OIDCFromYAML(t *testing.T) { + // Builds an oidc provider successfully — confirms the side-effect + // import wired it into the registry. Network isn't actually exercised + // because we don't call Verify with a real JWT. + r := runnerWith(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }}, + }, + }, "") + chain, err := r.buildUserAuthChain() + if err != nil { + t.Fatalf("buildUserAuthChain: %v", err) + } + if chain == nil { + t.Fatal("chain is nil; OIDC factory may not be registered") + } +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index de974df..551bd48 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -3,6 +3,7 @@ package runtime import ( "context" "encoding/json" + "errors" "fmt" "net/http" "os" @@ -16,6 +17,12 @@ import ( "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/agentspec" "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/httpverifier" + // Side-effect import: registers the "oidc" provider with the auth registry + // so forge.yaml `auth: { type: oidc }` blocks construct successfully via + // auth.Build("oidc", settings). The package is not referenced directly. + _ "github.com/initializ/forge/forge-core/auth/providers/oidc" + "github.com/initializ/forge/forge-core/auth/providers/statictoken" "github.com/initializ/forge/forge-core/llm" "github.com/initializ/forge/forge-core/llm/oauth" "github.com/initializ/forge/forge-core/llm/providers" @@ -120,6 +127,14 @@ func (r *Runner) SetScheduleNotifier(fn ScheduleNotifier) { // ResolveAuth resolves the auth token early (before Run). This is needed so // channel adapters can be configured with the token before Run() blocks. // Safe to call multiple times — subsequent calls are no-ops. +// +// Invariant: after this returns nil, EITHER r.authToken is non-empty OR +// r.cfg.NoAuth is true. resolveAuth() relies on this when it conditionally +// prepends the loopback static_token (review #10). If a future refactor +// adds a return path that violates this invariant, channel-adapter +// callbacks will silently break — the test +// TestResolveAuth_InvariantMintsTokenInNonNoAuthPath in +// auth_chain_test.go pins the property. func (r *Runner) ResolveAuth() error { if r.authToken != "" || r.cfg.NoAuth { return nil // already resolved @@ -246,6 +261,10 @@ func (r *Runner) Run(ctx context.Context) error { egressDomains = append(egressDomains, expandEgressDomains(d, envVars)...) } } + // Auto-merge auth-provider issuer/verifier hosts. Without this, an + // OIDC issuer or http_verifier URL configured in forge.yaml would be + // silently blocked at runtime by the egress enforcer. + egressDomains = append(egressDomains, security.AuthDomains(r.cfg.Config.Auth)...) egressCfg, egressErr := security.Resolve( r.cfg.Config.Egress.Profile, r.cfg.Config.Egress.Mode, @@ -1719,44 +1738,299 @@ func (r *Runner) printBanner(proxyURL string) { fmt.Fprintf(os.Stderr, " Press Ctrl+C to stop\n\n") } -// resolveAuth builds the auth middleware config. Token resolution is done by -// ResolveAuth() (called early so channel adapters can use it); this method -// just wires the already-resolved token into a middleware Config with the audit callback. -func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.Config, error) { +// resolveAuth builds the auth middleware options for the A2A server. +// +// Source precedence (highest first): +// 1. --no-auth flag → nil chain, anonymous access +// 2. forge.yaml auth: → Registry.BuildChain(cfg.Auth.Providers) +// 3. --auth-url / env → legacy http_verifier chain +// 4. nothing → loopback-token-only chain (or nil if also no channels) +// +// If BOTH forge.yaml auth: AND --auth-url are configured, the YAML block +// wins and a warning is logged — silent merging would be surprising. +// +// Loopback static_token prepending: +// - The internal loopback token is prepended at the chain head WHEN +// ResolveAuth() has minted one (i.e., r.authToken != ""). In the +// non-NoAuth path that's an invariant ResolveAuth maintains (review +// #10). When --no-auth is in effect we return early via the +// AllowAnonymous path above and never reach the prepend. +// - Channel adapter callbacks rely on the loopback short-circuit; if +// ResolveAuth is ever refactored to skip token minting on the +// non-NoAuth path, channels will silently break. TestResolveAuth_ +// InvariantMintsTokenInNonNoAuthPath in auth_chain_test.go pins +// that invariant. +func (r *Runner) resolveAuth(auditLogger *coreruntime.AuditLogger) (auth.MiddlewareOptions, error) { // Ensure token is resolved (no-op if already done by ResolveAuth). if err := r.ResolveAuth(); err != nil { - return auth.Config{}, err + return auth.MiddlewareOptions{}, err } if r.cfg.NoAuth { - return auth.Config{Enabled: false}, nil + // Cross-check --no-auth against the forge.yaml auth block. The + // flag and the YAML have historically been treated independently; + // review #4 closes the gap so a misaligned pair fails loudly + // instead of silently serving anonymous traffic on what the + // operator declared a required-auth deployment. + if r.cfg.Config != nil { + authCfg := r.cfg.Config.Auth + if authCfg.Required { + return auth.MiddlewareOptions{}, fmt.Errorf( + "--no-auth conflicts with forge.yaml 'auth.required: true' — " + + "either remove --no-auth, set 'auth.required: false', or " + + "delete the 'auth:' block to confirm anonymous access is intended") + } + if len(authCfg.Providers) > 0 { + r.logger.Warn( + "--no-auth overrides forge.yaml 'auth.providers' — configured "+ + "providers will be ignored and the agent will accept anonymous "+ + "traffic. Remove --no-auth to enforce the configured chain.", + map[string]any{"providers_configured": len(authCfg.Providers)}, + ) + } + } + // Operator explicitly chose anonymous via --no-auth. AllowAnonymous + // makes that choice visible at the middleware boundary (review #3). + return auth.MiddlewareOptions{ + AllowAnonymous: true, + SkipPaths: auth.DefaultSkipPaths(), + }, nil + } + + userChain, err := r.buildUserAuthChain() + if err != nil { + return auth.MiddlewareOptions{}, fmt.Errorf("building auth chain: %w", err) + } + + // Prepend the loopback static_token so channel adapter callbacks + // short-circuit before any user-configured provider runs. + // + // PRECONDITION: r.authToken is non-empty on this branch because + // ResolveAuth() always mints one in the non-NoAuth path. The + // conditional here is defensive — if someone refactors ResolveAuth + // to skip minting (and forgets to update channels), we'd want the + // rest of the function to still produce a coherent middleware + // (without a loopback) rather than panic. The invariant itself is + // pinned by TestResolveAuth_InvariantMintsTokenInNonNoAuthPath + // (review #10). + chain := userChain + if r.authToken != "" { + loopback, err := statictoken.New(statictoken.Config{ + Token: r.authToken, + Identity: auth.Identity{ + UserID: "forge-internal", + Source: "internal", + }, + }) + if err != nil { + return auth.MiddlewareOptions{}, fmt.Errorf("loopback static_token: %w", err) + } + chain = auth.PrependChain(userChain, loopback) } - cfg := auth.Config{ - Enabled: true, - Token: r.authToken, - AuthURL: r.cfg.AuthURL, - AuthOrgID: r.cfg.AuthOrgID, + // No user chain AND no loopback token → legacy "no auth config, no + // channels" default. Preserve backward compat by allowing anonymous, + // but flag it explicitly so the middleware's nil-chain panic guard + // is satisfied (review #3). + if chain == nil { + return auth.MiddlewareOptions{ + AllowAnonymous: true, + SkipPaths: auth.DefaultSkipPaths(), + OnAuth: makeAuthAuditCallback(auditLogger), + }, nil + } + + return auth.MiddlewareOptions{ + Chain: chain, SkipPaths: auth.DefaultSkipPaths(), - OnAuth: func(req *http.Request, success bool) { - if auditLogger == nil { - return - } - event := coreruntime.AuditAuthSuccess - if !success { - event = coreruntime.AuditAuthFailure + OnAuth: makeAuthAuditCallback(auditLogger), + }, nil +} + +// makeAuthAuditCallback returns the OnAuth callback that emits structured +// auth_verify / auth_fail audit events. +// +// Fields emitted (NO PII — never email, claims, token bytes, or secrets): +// +// auth_verify: { provider, user_id, org_id, groups_count, token_kind, method, path, remote_addr } +// auth_fail: { reason, token_kind, method, path, remote_addr } +// +// Reason codes for auth_fail: +// +// missing_token → no Authorization header +// rejected → provider recognized + denied (revoked, expired, wrong iss/aud) +// invalid → token malformed or cryptographically invalid +// not_for_me → chain exhausted (no provider claimed the token) +// infrastructure → transient provider error (network, etc.) +func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Request, *auth.Identity, error, string) { + if auditLogger == nil { + return nil + } + return func(req *http.Request, id *auth.Identity, err error, tokenKind string) { + correlationID := coreruntime.CorrelationIDFromContext(req.Context()) + + if err == nil && id != nil { + // Success → auth_verify. + fields := map[string]any{ + "provider": id.Source, + "user_id": id.UserID, + "org_id": id.OrgID, + "groups_count": len(id.Groups), + "token_kind": tokenKind, + "method": req.Method, + "path": req.URL.Path, + "remote_addr": req.RemoteAddr, } auditLogger.Emit(coreruntime.AuditEvent{ - Event: event, - Fields: map[string]any{ - "method": req.Method, - "path": req.URL.Path, - "remote_addr": req.RemoteAddr, - }, + Event: coreruntime.EventAuthVerify, + CorrelationID: correlationID, + Fields: fields, }) - }, + return + } + + // Failure → auth_fail with reason code. + auditLogger.Emit(coreruntime.AuditEvent{ + Event: coreruntime.EventAuthFail, + CorrelationID: correlationID, + Fields: map[string]any{ + "reason": authFailReason(err), + "token_kind": tokenKind, + "method": req.Method, + "path": req.URL.Path, + "remote_addr": req.RemoteAddr, + }, + }) + } +} + +// authFailReason maps a chain error to a stable, low-cardinality reason +// code suitable for dashboarding and alerting. Reason strings are part +// of the audit-event contract — changing them is a breaking change for +// downstream consumers. +// +// Reason codes: +// +// missing_token - no Authorization header +// rejected - provider recognized + denied (revoked, expired, 401, 4xx) +// invalid - token malformed or cryptographically invalid +// not_for_me - chain exhausted, no provider claimed the token +// provider_unavailable - verifier/IdP unreachable (5xx, network, undecodable) +// infrastructure - other unexpected error +// +// provider_unavailable was added in review #6 so operators can +// distinguish "the token is bad" alerts from "the IdP is down" alerts +// in their dashboards — the response and the runbook are different. +func authFailReason(err error) string { + switch { + case err == nil: + return "unknown" + case errors.Is(err, auth.ErrMissingBearer): + return "missing_token" + case errors.Is(err, auth.ErrTokenRejected): + return "rejected" + case errors.Is(err, auth.ErrInvalidToken): + return "invalid" + case errors.Is(err, auth.ErrProviderUnavailable): + return "provider_unavailable" + case errors.Is(err, auth.ErrTokenNotForMe): + return "not_for_me" + default: + return "infrastructure" + } +} + +// buildUserAuthChain returns the user-facing portion of the auth chain +// (everything EXCEPT the loopback static_token, which the caller prepends). +// +// - forge.yaml auth.providers populated → Registry.BuildChain +// - --auth-url / FORGE_AUTH_URL only → legacy http_verifier +// - neither → nil (no user chain) +func (r *Runner) buildUserAuthChain() (auth.Provider, error) { + hasYAMLAuth := len(r.cfg.Config.Auth.Providers) > 0 + hasLegacyURL := r.cfg.AuthURL != "" + + if hasYAMLAuth && hasLegacyURL { + r.logger.Warn("both --auth-url and forge.yaml 'auth:' block configured; preferring 'auth:' block (--auth-url ignored)", nil) + } + + if hasYAMLAuth { + return buildChainFromConfig(r.cfg.Config.Auth) + } + if hasLegacyURL { + return buildLegacyHTTPVerifierChain(r.cfg.AuthURL, r.cfg.AuthOrgID) + } + return nil, nil +} + +// buildChainFromConfig builds a ChainProvider from a typed AuthConfig by +// delegating to the package-level registry. Each AuthProvider entry is +// constructed via its registered factory. +func buildChainFromConfig(cfg types.AuthConfig) (auth.Provider, error) { + if len(cfg.Providers) == 0 { + return nil, nil + } + providers := make([]auth.Provider, 0, len(cfg.Providers)) + for i, spec := range cfg.Providers { + p, err := auth.Build(spec.Type, spec.Settings) + if err != nil { + return nil, fmt.Errorf("auth.providers[%d] (%s): %w", i, spec.Type, err) + } + providers = append(providers, p) + } + return auth.NewChainProvider(providers...), nil +} + +// buildLegacyHTTPVerifierChain constructs the legacy --auth-url chain +// (single http_verifier provider). Kept separate from +// buildChainFromConfig so the legacy code path is obvious. +func buildLegacyHTTPVerifierChain(authURL, authOrgID string) (auth.Provider, error) { + p, err := httpverifier.New(httpverifier.Config{ + URL: authURL, + DefaultOrg: authOrgID, + }) + if err != nil { + return nil, fmt.Errorf("legacy http_verifier: %w", err) + } + return auth.NewChainProvider(p), nil +} + +// buildLegacyAuthChain is a test-friendly helper that combines the +// loopback-token prepend and the legacy http_verifier chain. Mirrors the +// production wiring in resolveAuth so PR1's e2e tests keep working +// unchanged. +// +// Order: +// 1. static_token (loopback) — if internalToken non-empty +// 2. http_verifier — if authURL non-empty +// +// Returns nil chain when both inputs are empty. +func buildLegacyAuthChain(internalToken, authURL, authOrgID string) (auth.Provider, error) { + var userChain auth.Provider + if authURL != "" { + c, err := buildLegacyHTTPVerifierChain(authURL, authOrgID) + if err != nil { + return nil, err + } + userChain = c + } + if internalToken != "" { + loopback, err := statictoken.New(statictoken.Config{ + Token: internalToken, + Identity: auth.Identity{ + UserID: "forge-internal", + Source: "internal", + }, + }) + if err != nil { + return nil, fmt.Errorf("static_token provider: %w", err) + } + if userChain == nil { + return auth.NewChainProvider(loopback), nil + } + return auth.PrependChain(userChain, loopback), nil } - return cfg, nil + return userChain, nil } // ensureGitignore makes sure .forge/ is listed in the project's .gitignore. diff --git a/forge-cli/templates/init/forge.yaml.tmpl b/forge-cli/templates/init/forge.yaml.tmpl index 6ebf375..f7e93e4 100644 --- a/forge-cli/templates/init/forge.yaml.tmpl +++ b/forge-cli/templates/init/forge.yaml.tmpl @@ -57,3 +57,7 @@ secrets: - encrypted-file - env {{- end}} +{{- if .AuthBlock}} + +{{.AuthBlock}} +{{- end}} diff --git a/forge-core/auth/chain_provider.go b/forge-core/auth/chain_provider.go new file mode 100644 index 0000000..31d3080 --- /dev/null +++ b/forge-core/auth/chain_provider.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "errors" +) + +// ChainProvider composes multiple Providers into a single Provider, calling +// them in order and returning the first non-yielding result. +// +// Behavior: +// - A Provider returning (id, nil) wins; later providers are not consulted. +// - A Provider returning ErrTokenNotForMe is skipped; the chain advances. +// - Any other error stops the chain immediately (fail-closed). The middleware +// surfaces this as 401. Falling through on non-yield errors would let an +// attacker bypass a temporarily-misbehaving provider. +// +// A ChainProvider is immutable after construction and safe for concurrent use. +type ChainProvider struct { + providers []Provider +} + +// NewChainProvider returns a chain that calls the given providers in order. +// A chain with zero providers verifies nothing — Verify returns ErrTokenNotForMe. +func NewChainProvider(providers ...Provider) *ChainProvider { + // Defensive copy so callers can't mutate the slice after construction. + cp := make([]Provider, len(providers)) + copy(cp, providers) + return &ChainProvider{providers: cp} +} + +// Name implements Provider. +func (c *ChainProvider) Name() string { return "chain" } + +// Providers returns a defensive copy of the configured providers, in order. +func (c *ChainProvider) Providers() []Provider { + out := make([]Provider, len(c.providers)) + copy(out, c.providers) + return out +} + +// Verify implements Provider with first-match-wins semantics. +func (c *ChainProvider) Verify(ctx context.Context, token string, headers Headers) (*Identity, error) { + for _, p := range c.providers { + id, err := p.Verify(ctx, token, headers) + if err == nil { + return id, nil + } + if errors.Is(err, ErrTokenNotForMe) { + continue + } + // Fail closed: non-yield error stops the chain. + return nil, err + } + return nil, ErrTokenNotForMe +} + +// PrependChain returns a new ChainProvider whose providers are `prepend` +// followed by the providers of `chain`. If chain is a *ChainProvider, its +// providers are flattened (no nested chains). If chain is a non-chain +// Provider, it is appended as a single element. If chain is nil, the result +// contains only the prepended providers. +// +// Useful for the runner to inject a loopback static_token at the chain head +// without callers having to know whether their chain already exists. +func PrependChain(chain Provider, prepend ...Provider) *ChainProvider { + var existing []Provider + switch v := chain.(type) { + case nil: + // nothing to extend + case *ChainProvider: + existing = v.providers + default: + existing = []Provider{v} + } + combined := make([]Provider, 0, len(prepend)+len(existing)) + combined = append(combined, prepend...) + combined = append(combined, existing...) + return &ChainProvider{providers: combined} +} diff --git a/forge-core/auth/context.go b/forge-core/auth/context.go new file mode 100644 index 0000000..045e581 --- /dev/null +++ b/forge-core/auth/context.go @@ -0,0 +1,26 @@ +package auth + +import "context" + +// identityKey is an unexported type used as the context key, so that no +// other package can collide with our value. +type identityKey struct{} + +// WithIdentity returns a copy of ctx that carries the given Identity. +// Storing nil is a no-op (returns ctx unchanged). +func WithIdentity(ctx context.Context, id *Identity) context.Context { + if id == nil { + return ctx + } + return context.WithValue(ctx, identityKey{}, id) +} + +// IdentityFromContext returns the Identity stored on ctx by WithIdentity, +// or nil if no Identity is present. +func IdentityFromContext(ctx context.Context) *Identity { + if ctx == nil { + return nil + } + id, _ := ctx.Value(identityKey{}).(*Identity) + return id +} diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index e94bbb3..cb09a87 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -1,86 +1,16 @@ package auth import ( - "bytes" "encoding/json" - "fmt" + "errors" "net/http" "strings" - "time" ) -// Config controls bearer-token authentication for the A2A server. -type Config struct { - // Enabled controls whether authentication is enforced. - Enabled bool - - // Token is the expected bearer token value (local validation). - Token string - - // AuthURL is an external auth provider endpoint. When set, the middleware - // forwards the bearer token to this URL via POST for validation instead - // of comparing against the local Token value. - AuthURL string - - // AuthOrgID is the default org_id sent to the external auth provider. - // Overridden per-request by the X-Org-ID header when present. - AuthOrgID string - - // SkipPaths maps "METHOD /path" keys that bypass authentication. - // Example: "GET /" → true allows unauthenticated GET on root. - SkipPaths map[string]bool - - // OnAuth is an optional callback invoked on every auth decision. - // success indicates whether the request was authenticated. - OnAuth func(r *http.Request, success bool) -} - -// authHTTPClient is a shared client with reasonable timeouts for auth requests. -var authHTTPClient = &http.Client{Timeout: 10 * time.Second} - -// externalVerifyRequest is the request body sent to the external auth URL. -type externalVerifyRequest struct { - Token string `json:"token"` - OrgID string `json:"org_id"` -} - -// externalVerifyResponse is the response from the external auth URL. -type externalVerifyResponse struct { - Valid bool `json:"valid"` - Error string `json:"error,omitempty"` - UserID string `json:"user_id,omitempty"` - OrgID string `json:"org_id,omitempty"` - Email string `json:"email,omitempty"` - WorkspaceID string `json:"workspace_id,omitempty"` -} - -// validateExternal sends the bearer token to the external auth URL and returns -// whether the token is valid. -func validateExternal(authURL, token, orgID string) (bool, error) { - body, err := json.Marshal(externalVerifyRequest{Token: token, OrgID: orgID}) - if err != nil { - return false, fmt.Errorf("marshalling auth request: %w", err) - } - - resp, err := authHTTPClient.Post(authURL, "application/json", bytes.NewReader(body)) - if err != nil { - return false, fmt.Errorf("calling auth URL: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode == http.StatusUnauthorized { - return false, nil - } - if resp.StatusCode != http.StatusOK { - return false, fmt.Errorf("auth URL returned status %d", resp.StatusCode) - } - - var result externalVerifyResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return false, fmt.Errorf("decoding auth response: %w", err) - } - - return result.Valid, nil +// errorResponse is the JSON body returned for auth failures. +type errorResponse struct { + Error string `json:"error"` + Message string `json:"message"` } // DefaultSkipPaths returns the default set of public endpoints @@ -98,115 +28,160 @@ func DefaultSkipPaths() map[string]bool { } } -// errorResponse is the JSON body returned for auth failures. -type errorResponse struct { - Error string `json:"error"` - Message string `json:"message"` +// MiddlewareOptions configures Middleware. +type MiddlewareOptions struct { + // Chain is the provider chain that verifies bearer tokens. May only + // be nil when AllowAnonymous is true (see below). + Chain Provider + + // AllowAnonymous explicitly opts the middleware into running without + // authentication. Required whenever Chain is nil — otherwise + // Middleware() panics at construction. This prevents a misconfigured + // runner from silently serving unauthenticated requests because + // someone forgot to wire a chain. + // + // Set this to true when: + // - --no-auth flag is in effect (operator explicitly chose anon) + // - No auth: block AND no --auth-url AND no channels (legacy local + // dev default — preserved for backward compat) + // + // Leave this false for any production deployment that intends to + // enforce auth; a nil chain will then panic loudly at startup + // instead of running open. + AllowAnonymous bool + + // SkipPaths maps "METHOD /path" keys that bypass authentication. + // If nil, DefaultSkipPaths() is used. + SkipPaths map[string]bool + + // OnAuth is an optional callback invoked on every auth decision. + // + // - identity is non-nil and err is nil on success. + // - identity is nil and err carries the chain error on failure + // (or auth.ErrMissingBearer when the header was absent). + // - tokenKind is "jwt", "opaque", or "empty" — structural metadata + // safe to log. The token itself is NOT passed; callers must not + // try to recover it from the request. + // + // Callbacks should be cheap — they run on the request hot path. + OnAuth func(r *http.Request, identity *Identity, err error, tokenKind string) } -// Middleware returns an http.Handler that enforces bearer token authentication. -// If cfg.Enabled is false, requests pass through without checks. -// When cfg.AuthURL is set, tokens are validated against the external provider. -func Middleware(cfg Config) func(http.Handler) http.Handler { +// ErrMissingBearer is returned (via OnAuth) when the request lacked an +// Authorization: Bearer ... header. Distinct from chain-level errors so +// callers can emit a precise "missing_token" reason code without parsing +// error strings. +var ErrMissingBearer = errorString("auth: missing bearer token") + +// errorString is a sentinel-friendly error type that compares by identity. +type errorString string + +func (e errorString) Error() string { return string(e) } + +// Middleware returns an http.Handler that enforces bearer token authentication +// via the provided Provider chain. +// +// Panics at construction if opts.Chain is nil and opts.AllowAnonymous is +// false. This is intentional — silently passing through requests when the +// caller forgot to wire a chain is the highest-impact misconfiguration in +// the auth subsystem (open prod endpoint). Fail-loud catches it at startup, +// not at the first request from a real user. +func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { + if opts.Chain == nil { + if !opts.AllowAnonymous { + panic("auth: Middleware called with nil Chain and AllowAnonymous=false. " + + "Set MiddlewareOptions.AllowAnonymous: true to explicitly allow " + + "unauthenticated access, or provide a Chain.") + } + return func(next http.Handler) http.Handler { return next } + } + skip := opts.SkipPaths + if skip == nil { + skip = DefaultSkipPaths() + } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !cfg.Enabled { - next.ServeHTTP(w, r) - return - } - - // Check if this method+path combination is public. - key := r.Method + " " + r.URL.Path - if cfg.SkipPaths[key] { + if skip[r.Method+" "+r.URL.Path] { next.ServeHTTP(w, r) return } - // Extract bearer token from Authorization header. token := extractBearerToken(r) + kind := TokenKind(token) + if token == "" { - authFail(w, r, cfg, "valid bearer token required") + notifyAuth(opts.OnAuth, r, nil, ErrMissingBearer, kind) + writeAuthError(w, "valid bearer token required") return } - // When external auth is configured, first check the internal - // token (used by channel adapter loopback calls), then fall - // through to the external auth provider for other tokens. - if cfg.AuthURL != "" { - // Accept internal token for loopback (channel adapters) - if cfg.Token != "" && ValidateToken(token, cfg.Token) { - if cfg.OnAuth != nil { - cfg.OnAuth(r, true) - } - next.ServeHTTP(w, r) - return - } - // Use org_id from request header (multiple header names), fall back to config. - orgID := extractOrgID(r, cfg.AuthOrgID) - valid, err := validateExternal(cfg.AuthURL, token, orgID) - if err != nil { - authFail(w, r, cfg, "auth provider error") - return - } - if valid { - if cfg.OnAuth != nil { - cfg.OnAuth(r, true) - } - next.ServeHTTP(w, r) - return - } - authFail(w, r, cfg, "token rejected by auth provider") + identity, err := opts.Chain.Verify(r.Context(), token, HeadersFromRequest(r)) + if err != nil || identity == nil { + notifyAuth(opts.OnAuth, r, nil, err, kind) + writeAuthError(w, classifyAuthFailure(err)) return } - // Local token validation. - if ValidateToken(token, cfg.Token) { - if cfg.OnAuth != nil { - cfg.OnAuth(r, true) - } - next.ServeHTTP(w, r) - return - } + notifyAuth(opts.OnAuth, r, identity, nil, kind) - authFail(w, r, cfg, "valid bearer token required") + ctx := WithIdentity(r.Context(), identity) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } -// authFail sends a 401 response and fires the OnAuth callback. -func authFail(w http.ResponseWriter, r *http.Request, cfg Config, msg string) { - if cfg.OnAuth != nil { - cfg.OnAuth(r, false) +// notifyAuth invokes the OnAuth callback if set, swallowing the nil check +// at the call sites so the main middleware body stays readable. +func notifyAuth(cb func(*http.Request, *Identity, error, string), r *http.Request, id *Identity, err error, kind string) { + if cb == nil { + return + } + cb(r, id, err, kind) +} + +// classifyAuthFailure maps a chain error into a user-visible message. +// Keep messages generic to avoid leaking information about which provider +// rejected the token or why. +func classifyAuthFailure(err error) string { + switch { + case err == nil, errors.Is(err, ErrTokenNotForMe): + return "valid bearer token required" + case errors.Is(err, ErrTokenRejected): + return "token rejected by auth provider" + case errors.Is(err, ErrInvalidToken): + return "invalid token" + case errors.Is(err, ErrProviderUnavailable): + // Surface a distinct user-visible message so retry behavior on + // the client can be different from "invalid token". This is also + // the operator-facing signal in /healthz-style probes. + return "auth provider unavailable" + default: + return "auth provider error" } +} + +// writeAuthError sends a 401 JSON response. The OnAuth callback is fired +// separately (via notifyAuth) so the audit-emission path can run with the +// full Identity / error context, not just a bool. +func writeAuthError(w http.ResponseWriter, msg string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(errorResponse{ //nolint:errcheck + _ = json.NewEncoder(w).Encode(errorResponse{ Error: "unauthorized", Message: msg, }) } -// extractOrgID reads the org ID from the request headers, checking multiple -// common header names: "X-Org-ID", "org-id", "org_id". Falls back to the -// provided default if none are set. -func extractOrgID(r *http.Request, fallback string) string { - for _, h := range []string{"X-Org-ID", "org-id", "org_id"} { - if v := r.Header.Get(h); v != "" { - return v - } - } - return fallback -} - // extractBearerToken extracts the token from "Authorization: Bearer ". +// Case-insensitive on the "Bearer" prefix to preserve historical behavior. func extractBearerToken(r *http.Request) string { - auth := r.Header.Get("Authorization") - if auth == "" { + header := r.Header.Get("Authorization") + if header == "" { return "" } const prefix = "Bearer " - if len(auth) > len(prefix) && strings.EqualFold(auth[:len(prefix)], prefix) { - return auth[len(prefix):] + if len(header) > len(prefix) && strings.EqualFold(header[:len(prefix)], prefix) { + return header[len(prefix):] } return "" } diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index 7107698..f163e2f 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -1,132 +1,166 @@ package auth import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "sync/atomic" "testing" ) +// tokenProvider is a minimal test-only Provider that compares against a fixed +// token. We define it here so middleware tests don't need to import the +// statictoken provider package (which would create a cycle since it imports +// this package). +type tokenProvider struct { + expected string + identity Identity +} + +func (t *tokenProvider) Name() string { return "test_token" } + +func (t *tokenProvider) Verify(_ context.Context, token string, _ Headers) (*Identity, error) { + if token != t.expected { + return nil, ErrTokenNotForMe + } + id := t.identity + return &id, nil +} + +// rejectingProvider returns ErrTokenRejected for any token, simulating an +// external verifier that recognized the token but denied it. +type rejectingProvider struct{} + +func (rejectingProvider) Name() string { return "test_reject" } +func (rejectingProvider) Verify(_ context.Context, _ string, _ Headers) (*Identity, error) { + return nil, ErrTokenRejected +} + +// brokenProvider returns ErrInvalidToken (e.g., signature failure). +type brokenProvider struct{} + +func (brokenProvider) Name() string { return "test_broken" } +func (brokenProvider) Verify(_ context.Context, _ string, _ Headers) (*Identity, error) { + return nil, ErrInvalidToken +} + func TestMiddleware(t *testing.T) { const validToken = "test-secret-token" okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("ok")) //nolint:errcheck + _, _ = w.Write([]byte("ok")) + }) + + chain := NewChainProvider(&tokenProvider{ + expected: validToken, + identity: Identity{UserID: "u1", Source: "test_token"}, }) tests := []struct { name string - cfg Config + opts MiddlewareOptions method string path string authHeader string wantStatus int }{ { - name: "disabled passes through", - cfg: Config{Enabled: false}, + name: "nil chain with AllowAnonymous passes through", + opts: MiddlewareOptions{Chain: nil, AllowAnonymous: true}, method: "POST", path: "/", wantStatus: http.StatusOK, }, { - name: "valid token accepted", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "valid token accepted", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", authHeader: "Bearer " + validToken, wantStatus: http.StatusOK, }, { - name: "missing token rejected", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "missing token rejected", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", wantStatus: http.StatusUnauthorized, }, { - name: "wrong token rejected", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "wrong token rejected", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", authHeader: "Bearer wrong-token", wantStatus: http.StatusUnauthorized, }, { - name: "GET / is public", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "GET / is public", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "GET", path: "/", wantStatus: http.StatusOK, }, { - name: "GET /.well-known/agent.json is public", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "GET /.well-known/agent.json is public", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "GET", path: "/.well-known/agent.json", wantStatus: http.StatusOK, }, { - name: "GET /healthz is public", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "GET /healthz is public", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "GET", path: "/healthz", wantStatus: http.StatusOK, }, { - name: "POST /tasks/send requires auth", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "POST /tasks/send requires auth", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/tasks/send", wantStatus: http.StatusUnauthorized, }, { - name: "case insensitive Bearer prefix", - cfg: Config{ - Enabled: true, - Token: validToken, - SkipPaths: DefaultSkipPaths(), - }, + name: "case insensitive Bearer prefix", + opts: MiddlewareOptions{Chain: chain, SkipPaths: DefaultSkipPaths()}, method: "POST", path: "/", authHeader: "bearer " + validToken, wantStatus: http.StatusOK, }, + { + name: "rejected by provider returns 401", + opts: MiddlewareOptions{Chain: NewChainProvider(rejectingProvider{}), SkipPaths: DefaultSkipPaths()}, + method: "POST", + path: "/", + authHeader: "Bearer anything", + wantStatus: http.StatusUnauthorized, + }, + { + name: "invalid token from provider returns 401", + opts: MiddlewareOptions{Chain: NewChainProvider(brokenProvider{}), SkipPaths: DefaultSkipPaths()}, + method: "POST", + path: "/", + authHeader: "Bearer malformed", + wantStatus: http.StatusUnauthorized, + }, + { + name: "nil SkipPaths uses defaults", + opts: MiddlewareOptions{Chain: chain}, + method: "GET", + path: "/healthz", + wantStatus: http.StatusOK, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mw := Middleware(tt.cfg) + mw := Middleware(tt.opts) handler := mw(okHandler) req := httptest.NewRequest(tt.method, tt.path, nil) @@ -155,17 +189,19 @@ func TestMiddleware(t *testing.T) { } } -func TestMiddlewareOnAuthCallback(t *testing.T) { +func TestMiddleware_OnAuthCallback(t *testing.T) { const token = "callback-token" var successCount, failCount atomic.Int32 - cfg := Config{ - Enabled: true, - Token: token, + opts := MiddlewareOptions{ + Chain: NewChainProvider(&tokenProvider{ + expected: token, + identity: Identity{UserID: "u", Source: "test"}, + }), SkipPaths: DefaultSkipPaths(), - OnAuth: func(r *http.Request, success bool) { - if success { + OnAuth: func(_ *http.Request, id *Identity, err error, _ string) { + if err == nil && id != nil { successCount.Add(1) } else { failCount.Add(1) @@ -173,7 +209,7 @@ func TestMiddlewareOnAuthCallback(t *testing.T) { }, } - handler := Middleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) @@ -193,3 +229,241 @@ func TestMiddlewareOnAuthCallback(t *testing.T) { t.Errorf("failure callbacks = %d, want 1", got) } } + +func TestMiddleware_IdentityIsAttachedToContext(t *testing.T) { + const token = "ctx-token" + + wantID := Identity{UserID: "ctx-user", Email: "ctx@example.com", Source: "test_token"} + + opts := MiddlewareOptions{ + Chain: NewChainProvider(&tokenProvider{expected: token, identity: wantID}), + SkipPaths: DefaultSkipPaths(), + } + + var gotID *Identity + handler := Middleware(opts)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotID = IdentityFromContext(r.Context()) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer "+token) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotID == nil { + t.Fatal("identity not attached to context") + } + if gotID.UserID != wantID.UserID || gotID.Email != wantID.Email { + t.Errorf("identity = %+v, want %+v", gotID, wantID) + } +} + +func TestMiddleware_OnAuthReceivesIdentityAndError(t *testing.T) { + const goodToken = "good" + wantID := Identity{UserID: "alice", Source: "test_token"} + + var ( + gotIDOnSuccess *Identity + gotErrOnFail error + gotKindSuccess string + gotKindFail string + ) + + opts := MiddlewareOptions{ + Chain: NewChainProvider(&tokenProvider{expected: goodToken, identity: wantID}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, id *Identity, err error, kind string) { + if err == nil && id != nil { + gotIDOnSuccess = id + gotKindSuccess = kind + } else { + gotErrOnFail = err + gotKindFail = kind + } + }, + } + + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Success path — OnAuth gets the Identity, kind = "opaque" (no dots). + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer "+goodToken) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotIDOnSuccess == nil { + t.Fatal("OnAuth did not receive Identity on success") + } + if gotIDOnSuccess.UserID != "alice" { + t.Errorf("OnAuth identity = %+v, want UserID=alice", gotIDOnSuccess) + } + if gotKindSuccess != "opaque" { + t.Errorf("token kind on success = %q, want opaque", gotKindSuccess) + } + + // Failure path — no Authorization header → ErrMissingBearer. + req2 := httptest.NewRequest("POST", "/", nil) + handler.ServeHTTP(httptest.NewRecorder(), req2) + if !errors.Is(gotErrOnFail, ErrMissingBearer) { + t.Errorf("OnAuth fail err = %v, want ErrMissingBearer", gotErrOnFail) + } + if gotKindFail != "empty" { + t.Errorf("token kind on missing-token fail = %q, want empty", gotKindFail) + } +} + +func TestMiddleware_OnAuthGetsChainError(t *testing.T) { + // When the chain rejects, OnAuth receives the precise sentinel so + // the caller can map to a stable reason code. + opts := MiddlewareOptions{ + Chain: NewChainProvider(rejectingProvider{}), + SkipPaths: DefaultSkipPaths(), + } + var gotErr error + opts.OnAuth = func(_ *http.Request, _ *Identity, err error, _ string) { + gotErr = err + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer x") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if !errors.Is(gotErr, ErrTokenRejected) { + t.Errorf("OnAuth got err = %v, want ErrTokenRejected", gotErr) + } +} + +func TestMiddleware_TokenKindDetection(t *testing.T) { + // JWT-shaped tokens are reported as "jwt" even when the chain + // doesn't recognize them — this is structural, not validity. + jwtToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ4In0.sig" + + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(rejectingProvider{}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer "+jwtToken) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "jwt" { + t.Errorf("token kind = %q, want jwt", gotKind) + } +} + +// --- Nil chain panic guard (review finding #3) --- + +func TestMiddleware_NilChainPanicsWithoutAllowAnonymous(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic when Chain==nil and AllowAnonymous==false") + } + msg, _ := r.(string) + if msg == "" { + t.Fatalf("panic value = %v, want a descriptive string", r) + } + // The message should mention the option name so callers know + // how to opt in. + if !contains(msg, "AllowAnonymous") { + t.Errorf("panic message %q does not reference AllowAnonymous", msg) + } + }() + + // This call MUST panic — silent anonymous passthrough is the bug + // the AllowAnonymous flag exists to prevent. + _ = Middleware(MiddlewareOptions{Chain: nil, AllowAnonymous: false}) +} + +func TestMiddleware_NilChainWithAllowAnonymousPassesThrough(t *testing.T) { + // Counterpart: explicit opt-in does NOT panic; serves all requests. + handler := Middleware(MiddlewareOptions{ + Chain: nil, + AllowAnonymous: true, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/anything", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("status = %d, want %d (AllowAnonymous should pass through)", rr.Code, http.StatusOK) + } +} + +func TestMiddleware_NonNilChainIgnoresAllowAnonymous(t *testing.T) { + // AllowAnonymous is only consulted when Chain == nil. When a chain + // is configured, the flag has no effect — requests are still + // validated through the chain. + const goodToken = "ok" + chain := NewChainProvider(&tokenProvider{ + expected: goodToken, + identity: Identity{UserID: "u", Source: "test"}, + }) + handler := Middleware(MiddlewareOptions{ + Chain: chain, + AllowAnonymous: true, // should be ignored + SkipPaths: DefaultSkipPaths(), + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // No token → 401 (chain still enforced). + req := httptest.NewRequest("POST", "/tasks", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Errorf("non-nil chain + AllowAnonymous=true: status = %d, want 401 (chain must still enforce)", rr.Code) + } +} + +// contains is a tiny strings.Contains substitute for the panic-message test +// (kept inline to avoid pulling in strings just for the assertion). +func contains(s, sub string) bool { + if len(sub) == 0 { + return true + } + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} + +func TestClassifyAuthFailure(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + {"nil", nil, "valid bearer token required"}, + {"not for me", ErrTokenNotForMe, "valid bearer token required"}, + {"rejected", ErrTokenRejected, "token rejected by auth provider"}, + {"invalid", ErrInvalidToken, "invalid token"}, + // Review #6: ErrProviderUnavailable gets its own user-visible + // message — distinct from "invalid token" so client retry logic + // and operator alerting can differentiate. + {"provider unavailable", ErrProviderUnavailable, "auth provider unavailable"}, + {"unexpected", context.Canceled, "auth provider error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := classifyAuthFailure(tt.err); got != tt.want { + t.Errorf("classifyAuthFailure(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } +} diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go new file mode 100644 index 0000000..527eac2 --- /dev/null +++ b/forge-core/auth/provider.go @@ -0,0 +1,148 @@ +// Package auth provides bearer-token authentication for the A2A server, +// built around a pluggable Provider chain. +// +// Each Provider claims tokens it recognizes and rejects the rest with +// ErrTokenNotForMe, allowing a ChainProvider to compose multiple +// providers in a first-match-wins fashion. The error return is the only +// signal for the verification outcome — there is intentionally no +// Identity.Valid field, because a nil-error return is the contract for +// "this token is valid." +// +// New providers live under forge-core/auth/providers// and register +// themselves via init() against the package-level registry, mirroring the +// database/sql driver pattern. +package auth + +import ( + "context" + "errors" + "net/http" + "strings" +) + +// Provider verifies a bearer token and returns the caller's Identity. +// +// Implementations must: +// - Return (id, nil) on a verified token. +// - Return (nil, ErrTokenNotForMe) when the token is not for this provider +// (so the ChainProvider can try the next provider). +// - Return (nil, ErrTokenRejected) when the token is recognized but denied +// (e.g., revoked, expired, untrusted issuer). +// - Return (nil, ErrInvalidToken) when the token is malformed or +// cryptographically invalid. +// - Return (nil, other-error) for transient failures (network, etc.). +// The ChainProvider treats these as fatal (fail-closed) — it does NOT +// fall through to the next provider on infrastructure errors, because +// doing so would allow attackers to evade a temporarily-down provider. +type Provider interface { + Name() string + Verify(ctx context.Context, token string, headers Headers) (*Identity, error) +} + +// Identity is the authenticated principal extracted by a Provider. +// +// There is intentionally no Valid field — a non-nil *Identity returned +// alongside a nil error is the only "valid" signal. See package comment. +type Identity struct { + UserID string `json:"user_id,omitempty"` + Email string `json:"email,omitempty"` + OrgID string `json:"org_id,omitempty"` + WorkspaceID string `json:"workspace_id,omitempty"` + Groups []string `json:"groups,omitempty"` + + // Claims carries the provider-specific raw payload (typically the + // full JWT claim set for the oidc provider — including custom + // issuer-specific claims). Treat this as an escape hatch for + // provider-specific authorization logic; prefer the typed fields + // above for portable consumers. + // + // WARNING (review #11f): for OIDC the map is an unfiltered shallow + // copy of the JWT claims — `sub`, `email`, `iss`, `aud`, `exp`, + // plus any custom claims the issuer adds (group memberships, + // internal IDs, profile fields, sometimes raw PII). Do NOT log + // this map verbatim. Filtering belongs in a future authz layer, + // not here. + Claims map[string]any `json:"claims,omitempty"` + + // Source records which provider verified the identity (e.g., "oidc", + // "http_verifier", "static_token"). Useful for audit logs and debugging. + Source string `json:"source,omitempty"` +} + +// Headers is a case-insensitive view over selected request headers passed +// to providers. Providers should not assume any particular casing. +type Headers map[string]string + +// Get returns the value for the given header, matched case-insensitively. +func (h Headers) Get(key string) string { + if v, ok := h[key]; ok { + return v + } + lower := strings.ToLower(key) + for k, v := range h { + if strings.ToLower(k) == lower { + return v + } + } + return "" +} + +// HeadersFromRequest extracts the well-known headers providers may use. +// Keep this list narrow — providers should be explicit about the contract. +func HeadersFromRequest(r *http.Request) Headers { + return Headers{ + "X-Org-ID": r.Header.Get("X-Org-ID"), + "X-Request-ID": r.Header.Get("X-Request-ID"), + "org-id": r.Header.Get("org-id"), + "org_id": r.Header.Get("org_id"), + } +} + +// TokenKind classifies a presented bearer token structurally — useful for +// audit logging without leaking the token itself. +// +// "jwt" → three base64url segments separated by dots +// "opaque" → anything else (Okta access tokens, custom verifier tokens, dev secrets) +// +// This is a CHEAP structural check — it does not parse or validate. +// Never log the token; this helper is safe to log. +func TokenKind(token string) string { + if token == "" { + return "empty" + } + dots := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + dots++ + if dots > 2 { + return "opaque" + } + } + } + if dots == 2 { + return "jwt" + } + return "opaque" +} + +// Sentinel errors that Providers and the ChainProvider use to signal outcomes. +// +// - ErrTokenNotForMe → provider does not recognize this token shape; +// ChainProvider should try the next provider. +// - ErrTokenRejected → provider recognized the token and denied it; +// ChainProvider stops and the middleware writes 401. +// - ErrInvalidToken → token is malformed or cryptographically invalid; +// ChainProvider stops and the middleware writes 401. +// - ErrProviderUnavailable → the verifier / IdP is unreachable or returned +// a transport-layer error (5xx, network timeout, garbage response). The +// token MAY be valid — we just can't say. ChainProvider stops (fail-closed, +// same as ErrInvalidToken), but the audit signal is distinct so operators +// don't chase a token issue when the actual problem is provider downtime. +// - ErrProviderNotConfigured → returned by New(); never by Verify(). +var ( + ErrTokenNotForMe = errors.New("auth: token not for this provider") + ErrTokenRejected = errors.New("auth: token rejected") + ErrInvalidToken = errors.New("auth: invalid token") + ErrProviderUnavailable = errors.New("auth: provider unavailable") + ErrProviderNotConfigured = errors.New("auth: provider not configured") +) diff --git a/forge-core/auth/provider_test.go b/forge-core/auth/provider_test.go new file mode 100644 index 0000000..c175b69 --- /dev/null +++ b/forge-core/auth/provider_test.go @@ -0,0 +1,485 @@ +package auth + +import ( + "context" + "errors" + "maps" + "reflect" + "sort" + "sync" + "testing" +) + +// stubProvider is an in-package test provider that lets us script the +// outcome of Verify deterministically. +type stubProvider struct { + name string + identity *Identity + err error + calls int +} + +func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) Verify(_ context.Context, _ string, _ Headers) (*Identity, error) { + s.calls++ + return s.identity, s.err +} + +// --- ChainProvider --- + +func TestChainProvider_FirstMatchWins(t *testing.T) { + idA := &Identity{UserID: "a", Source: "first"} + idB := &Identity{UserID: "b", Source: "second"} + a := &stubProvider{name: "first", identity: idA} + b := &stubProvider{name: "second", identity: idB} + + chain := NewChainProvider(a, b) + got, err := chain.Verify(context.Background(), "tok", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got != idA { + t.Errorf("identity = %+v, want %+v", got, idA) + } + if a.calls != 1 { + t.Errorf("first provider calls = %d, want 1", a.calls) + } + if b.calls != 0 { + t.Errorf("second provider should not be called, got %d calls", b.calls) + } +} + +func TestChainProvider_NotForMeYieldsToNext(t *testing.T) { + idB := &Identity{UserID: "b"} + a := &stubProvider{name: "first", err: ErrTokenNotForMe} + b := &stubProvider{name: "second", identity: idB} + + chain := NewChainProvider(a, b) + got, err := chain.Verify(context.Background(), "tok", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got != idB { + t.Errorf("identity = %+v, want %+v", got, idB) + } + if a.calls != 1 || b.calls != 1 { + t.Errorf("call counts: a=%d b=%d, want 1,1", a.calls, b.calls) + } +} + +func TestChainProvider_RejectedStopsChain(t *testing.T) { + a := &stubProvider{name: "first", err: ErrTokenRejected} + b := &stubProvider{name: "second", identity: &Identity{UserID: "b"}} + + chain := NewChainProvider(a, b) + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected", err) + } + if b.calls != 0 { + t.Errorf("second provider should not be called after rejection, got %d calls", b.calls) + } +} + +func TestChainProvider_InvalidStopsChain(t *testing.T) { + a := &stubProvider{name: "first", err: ErrInvalidToken} + b := &stubProvider{name: "second", identity: &Identity{UserID: "b"}} + + chain := NewChainProvider(a, b) + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken", err) + } + if b.calls != 0 { + t.Errorf("second provider should not be called after invalid, got %d calls", b.calls) + } +} + +func TestChainProvider_GenericErrorFailsClosed(t *testing.T) { + // Infrastructure errors (network, etc.) must NOT fall through to the + // next provider — that would let attackers evade a temporarily-down + // upstream by waiting for it to fail. + netErr := errors.New("connection refused") + a := &stubProvider{name: "first", err: netErr} + b := &stubProvider{name: "second", identity: &Identity{UserID: "b"}} + + chain := NewChainProvider(a, b) + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, netErr) { + t.Fatalf("err = %v, want network error", err) + } + if b.calls != 0 { + t.Errorf("fail-closed violated: second provider was called after generic error") + } +} + +func TestChainProvider_EmptyChainYields(t *testing.T) { + chain := NewChainProvider() + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrTokenNotForMe) { + t.Errorf("empty chain err = %v, want ErrTokenNotForMe", err) + } +} + +func TestChainProvider_AllYieldFinalErrorIsNotForMe(t *testing.T) { + a := &stubProvider{name: "a", err: ErrTokenNotForMe} + b := &stubProvider{name: "b", err: ErrTokenNotForMe} + chain := NewChainProvider(a, b) + + _, err := chain.Verify(context.Background(), "tok", nil) + if !errors.Is(err, ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestChainProvider_DefensiveSliceCopy(t *testing.T) { + a := &stubProvider{name: "a", identity: &Identity{UserID: "a"}} + b := &stubProvider{name: "b", identity: &Identity{UserID: "b"}} + srcSlice := []Provider{a, b} + + chain := NewChainProvider(srcSlice...) + // Mutate the source — chain must be unaffected. + srcSlice[0] = &stubProvider{name: "tamper", err: ErrTokenRejected} + + got, err := chain.Verify(context.Background(), "tok", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got.UserID != "a" { + t.Errorf("chain saw mutated source; identity = %+v", got) + } +} + +func TestChainProvider_ProvidersReturnsCopy(t *testing.T) { + a := &stubProvider{name: "a"} + chain := NewChainProvider(a) + + listed := chain.Providers() + // Mutate the returned slice — internal state must be unchanged. + listed[0] = &stubProvider{name: "tamper"} + + listed2 := chain.Providers() + if listed2[0].Name() != "a" { + t.Errorf("Providers() did not return a defensive copy; got %q", listed2[0].Name()) + } +} + +func TestChainProvider_NameIsChain(t *testing.T) { + if got := (&ChainProvider{}).Name(); got != "chain" { + t.Errorf("Name = %q, want chain", got) + } +} + +// --- PrependChain --- + +func TestPrependChain_FlattenedChain(t *testing.T) { + inner := NewChainProvider( + &stubProvider{name: "b"}, + &stubProvider{name: "c"}, + ) + combined := PrependChain(inner, &stubProvider{name: "a"}) + + names := providerNames(combined.Providers()) + want := []string{"a", "b", "c"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v (flatten failed)", names, want) + } +} + +func TestPrependChain_NonChainSingle(t *testing.T) { + single := &stubProvider{name: "single"} + combined := PrependChain(single, &stubProvider{name: "first"}) + names := providerNames(combined.Providers()) + want := []string{"first", "single"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v", names, want) + } +} + +func TestPrependChain_NilChain(t *testing.T) { + combined := PrependChain(nil, &stubProvider{name: "first"}) + names := providerNames(combined.Providers()) + want := []string{"first"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v", names, want) + } +} + +func TestPrependChain_NoPrependProviders(t *testing.T) { + inner := NewChainProvider(&stubProvider{name: "x"}) + combined := PrependChain(inner) + names := providerNames(combined.Providers()) + want := []string{"x"} + if !reflect.DeepEqual(names, want) { + t.Errorf("names = %v, want %v", names, want) + } +} + +// --- Headers --- + +func TestHeaders_GetCaseInsensitive(t *testing.T) { + h := Headers{"X-Org-ID": "acme"} + tests := []struct { + key string + want string + }{ + {"X-Org-ID", "acme"}, + {"x-org-id", "acme"}, + {"X-ORG-ID", "acme"}, + {"X-Other", ""}, + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + if got := h.Get(tt.key); got != tt.want { + t.Errorf("Get(%q) = %q, want %q", tt.key, got, tt.want) + } + }) + } +} + +func TestHeaders_EmptyMap(t *testing.T) { + var h Headers + if got := h.Get("anything"); got != "" { + t.Errorf("Get on nil Headers = %q, want empty", got) + } +} + +// --- Context --- + +func TestContext_RoundTrip(t *testing.T) { + want := &Identity{UserID: "alice", Source: "test"} + ctx := WithIdentity(context.Background(), want) + got := IdentityFromContext(ctx) + if got != want { + t.Errorf("IdentityFromContext = %+v, want %+v", got, want) + } +} + +func TestContext_NilIdentityIsNoOp(t *testing.T) { + parent := context.Background() + ctx := WithIdentity(parent, nil) + if ctx != parent { + t.Error("WithIdentity(nil) should return parent ctx unchanged") + } +} + +func TestContext_MissingReturnsNil(t *testing.T) { + if got := IdentityFromContext(context.Background()); got != nil { + t.Errorf("IdentityFromContext on empty ctx = %+v, want nil", got) + } +} + +func TestContext_NilContextReturnsNil(t *testing.T) { + // Intentional: confirm we don't panic on a defensive nil-context call. + var ctx context.Context //nolint:gocritic // intentional nil-context safety check + if got := IdentityFromContext(ctx); got != nil { + t.Errorf("IdentityFromContext(nil) = %+v, want nil", got) + } +} + +// --- Registry --- + +// registryGuard isolates registry mutations from the global state so the +// concurrently-running provider package tests (httpverifier, statictoken) +// don't interfere. +// +// Strategy: snapshot the registry, run the test against a wiped registry, +// then restore. +func registryGuard(t *testing.T) { + t.Helper() + registryMu.Lock() + snapshot := maps.Clone(registry) + registryMu.Unlock() + t.Cleanup(func() { + registryMu.Lock() + registry = snapshot + registryMu.Unlock() + }) + resetRegistryForTest() +} + +func TestRegister_RoundTrip(t *testing.T) { + registryGuard(t) + + Register("custom", func(_ map[string]any) (Provider, error) { + return &stubProvider{name: "custom"}, nil + }) + + p, err := Build("custom", nil) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "custom" { + t.Errorf("Name = %q, want custom", p.Name()) + } +} + +func TestRegister_DuplicatePanics(t *testing.T) { + registryGuard(t) + + Register("dup", func(_ map[string]any) (Provider, error) { + return &stubProvider{}, nil + }) + + defer func() { + if r := recover(); r == nil { + t.Error("duplicate Register did not panic") + } + }() + Register("dup", func(_ map[string]any) (Provider, error) { + return &stubProvider{}, nil + }) +} + +func TestRegister_EmptyNamePanics(t *testing.T) { + registryGuard(t) + + defer func() { + if r := recover(); r == nil { + t.Error("Register with empty name did not panic") + } + }() + Register("", func(_ map[string]any) (Provider, error) { + return &stubProvider{}, nil + }) +} + +func TestRegister_NilFactoryPanics(t *testing.T) { + registryGuard(t) + + defer func() { + if r := recover(); r == nil { + t.Error("Register with nil factory did not panic") + } + }() + Register("nil_fac", nil) +} + +func TestBuild_UnknownTypeReturnsError(t *testing.T) { + registryGuard(t) + + _, err := Build("does-not-exist", nil) + if err == nil { + t.Fatal("Build of unknown type returned nil error") + } +} + +func TestBuild_FactoryErrorIsWrapped(t *testing.T) { + registryGuard(t) + + custom := errors.New("custom failure") + Register("failing", func(_ map[string]any) (Provider, error) { + return nil, custom + }) + + _, err := Build("failing", nil) + if !errors.Is(err, custom) { + t.Fatalf("err = %v, want wrapped custom failure", err) + } +} + +func TestRegisteredTypes_SortedAndDeduplicated(t *testing.T) { + registryGuard(t) + + Register("zeta", func(_ map[string]any) (Provider, error) { return &stubProvider{}, nil }) + Register("alpha", func(_ map[string]any) (Provider, error) { return &stubProvider{}, nil }) + Register("mu", func(_ map[string]any) (Provider, error) { return &stubProvider{}, nil }) + + got := RegisteredTypes() + want := []string{"alpha", "mu", "zeta"} + if !sort.StringsAreSorted(got) { + t.Errorf("RegisteredTypes not sorted: %v", got) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("RegisteredTypes = %v, want %v", got, want) + } +} + +func TestRegistry_ConcurrentReads(t *testing.T) { + registryGuard(t) + + Register("a", func(_ map[string]any) (Provider, error) { return &stubProvider{name: "a"}, nil }) + Register("b", func(_ map[string]any) (Provider, error) { return &stubProvider{name: "b"}, nil }) + + var wg sync.WaitGroup + for range 50 { + wg.Go(func() { + _, _ = Build("a", nil) + _, _ = Build("b", nil) + _ = RegisteredTypes() + }) + } + wg.Wait() +} + +// --- UnmarshalSettings --- + +type sampleSettings struct { + Issuer string `yaml:"issuer"` + Audience string `yaml:"audience"` + Nested map[string]any `yaml:"nested,omitempty"` +} + +func TestUnmarshalSettings(t *testing.T) { + in := map[string]any{ + "issuer": "https://example.com", + "audience": "api://forge", + "nested": map[string]any{ + "k": "v", + }, + } + + var out sampleSettings + if err := UnmarshalSettings(in, &out); err != nil { + t.Fatalf("UnmarshalSettings: %v", err) + } + if out.Issuer != "https://example.com" { + t.Errorf("Issuer = %q, want https://example.com", out.Issuer) + } + if out.Audience != "api://forge" { + t.Errorf("Audience = %q, want api://forge", out.Audience) + } + if out.Nested["k"] != "v" { + t.Errorf("Nested[k] = %v, want v", out.Nested["k"]) + } +} + +func TestUnmarshalSettings_NilOut(t *testing.T) { + if err := UnmarshalSettings(map[string]any{"x": 1}, nil); err == nil { + t.Error("UnmarshalSettings with nil out returned nil error") + } +} + +// --- TokenKind --- + +func TestTokenKind(t *testing.T) { + tests := []struct { + name string + token string + want string + }{ + {"empty", "", "empty"}, + {"jwt three segments", "eyJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJ4In0.signature", "jwt"}, + {"opaque random", "abc123xyz", "opaque"}, + {"opaque one dot", "abc.def", "opaque"}, + {"opaque four segments", "a.b.c.d", "opaque"}, + {"jwt with empty segments still has 2 dots", "..", "jwt"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TokenKind(tt.token); got != tt.want { + t.Errorf("TokenKind(%q) = %q, want %q", tt.token, got, tt.want) + } + }) + } +} + +// --- helpers --- + +func providerNames(ps []Provider) []string { + out := make([]string, len(ps)) + for i, p := range ps { + out[i] = p.Name() + } + return out +} diff --git a/forge-core/auth/providers/httpverifier/provider.go b/forge-core/auth/providers/httpverifier/provider.go new file mode 100644 index 0000000..90c5f39 --- /dev/null +++ b/forge-core/auth/providers/httpverifier/provider.go @@ -0,0 +1,205 @@ +// Package httpverifier implements the legacy external auth provider: +// POST a JSON envelope to a verifier URL and trust its response. +// +// This is the provider that the historical --auth-url / FORGE_AUTH_URL flag +// configures. The request/response shape is preserved byte-for-byte so +// existing custom verifier services keep working unchanged. +// +// Request: POST {URL} +// Content-Type: application/json +// { "token": "", "org_id": "" } +// +// Response: 200 OK +// { "valid": bool, "error": "...", "user_id": "...", +// "org_id": "...", "email": "...", "workspace_id": "..." } +// +// The HTTP verifier claims every token presented to it — it never returns +// ErrTokenNotForMe. When placed in a chain, it is typically the terminator. +package httpverifier + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this provider. +const ProviderName = "http_verifier" + +// DefaultTimeout is the per-request timeout used when Config.Timeout is unset. +const DefaultTimeout = 10 * time.Second + +// Config controls the http_verifier provider. +type Config struct { + // URL is the verifier endpoint. Required. + URL string `yaml:"url"` + + // DefaultOrg is the org_id sent to the verifier when no X-Org-ID + // header (or org-id / org_id variant) is present on the request. + DefaultOrg string `yaml:"default_org,omitempty"` + + // Timeout caps each verify call. Defaults to DefaultTimeout. + Timeout time.Duration `yaml:"timeout,omitempty"` + + // HTTPClient overrides the default client. Injectable for tests. + HTTPClient *http.Client `yaml:"-"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.URL == "" { + return fmt.Errorf("%w: url required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider against a remote HTTP verifier. +type Provider struct { + cfg Config + client *http.Client +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.Timeout == 0 { + cfg.Timeout = DefaultTimeout + } + client := cfg.HTTPClient + if client == nil { + client = &http.Client{Timeout: cfg.Timeout} + } + return &Provider{cfg: cfg, client: client}, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// verifyRequest is the JSON envelope POSTed to the verifier URL. +// Preserved byte-for-byte from the pre-refactor implementation. +type verifyRequest struct { + Token string `json:"token"` + OrgID string `json:"org_id"` +} + +// verifyResponse is the JSON envelope returned by the verifier URL. +// Preserved byte-for-byte from the pre-refactor implementation. +type verifyResponse struct { + Valid bool `json:"valid"` + Error string `json:"error,omitempty"` + UserID string `json:"user_id,omitempty"` + OrgID string `json:"org_id,omitempty"` + Email string `json:"email,omitempty"` + WorkspaceID string `json:"workspace_id,omitempty"` +} + +// Verify implements auth.Provider. It POSTs the bearer token (with the +// caller's org_id) to the configured verifier URL and translates the +// response into an Identity or one of the sentinel errors. +// +// Mapping (review #6 — separated "token bad" from "verifier down"): +// +// - HTTP 200 + valid:true → (Identity, nil) +// - HTTP 200 + valid:false → ErrTokenRejected +// - HTTP 401 → ErrTokenRejected +// - HTTP 4xx (other) → ErrTokenRejected (verifier denied — token-side) +// - HTTP 5xx → ErrProviderUnavailable (server-side failure) +// - Network / transport error → ErrProviderUnavailable +// - Response body undecodable → ErrProviderUnavailable (verifier returned garbage) +// - Local marshal / request-build err → ErrInvalidToken (extremely rare; bug in caller path) +// +// This provider does not return ErrTokenNotForMe. +func (p *Provider) Verify(ctx context.Context, token string, headers auth.Headers) (*auth.Identity, error) { + orgID := resolveOrgID(headers, p.cfg.DefaultOrg) + + body, err := json.Marshal(verifyRequest{Token: token, OrgID: orgID}) + if err != nil { + // Local marshal failure — should never happen with the fixed + // struct shape. Keep ErrInvalidToken since this isn't the + // verifier's fault. + return nil, fmt.Errorf("%w: marshal request: %w", auth.ErrInvalidToken, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.URL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("%w: build request: %w", auth.ErrInvalidToken, err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(req) + if err != nil { + // Network error reaching the verifier — provider is unreachable. + // Distinct from a token problem; audit reflects this. + return nil, fmt.Errorf("%w: call verifier: %w", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + + switch { + case resp.StatusCode == http.StatusOK: + // fall through to decode body + case resp.StatusCode == http.StatusUnauthorized: + return nil, auth.ErrTokenRejected + case resp.StatusCode >= 500 && resp.StatusCode < 600: + // 5xx — verifier broken or overloaded. We can't say whether the + // token is valid; surface as provider unavailable. + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return nil, fmt.Errorf("%w: verifier returned status %d", auth.ErrProviderUnavailable, resp.StatusCode) + case resp.StatusCode >= 400 && resp.StatusCode < 500: + // Non-401 4xx — verifier explicitly refused the request (e.g., + // 400 bad request, 403 forbidden). Treat as token-side rejection. + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return nil, fmt.Errorf("%w: verifier returned status %d", auth.ErrTokenRejected, resp.StatusCode) + default: + // 1xx / 3xx — undefined for this contract. Treat as unavailable. + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return nil, fmt.Errorf("%w: verifier returned unexpected status %d", auth.ErrProviderUnavailable, resp.StatusCode) + } + + var result verifyResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + // Verifier returned 200 but the body isn't the contract shape — + // provider misbehavior, not a token issue. + return nil, fmt.Errorf("%w: decode response: %w", auth.ErrProviderUnavailable, err) + } + if !result.Valid { + return nil, auth.ErrTokenRejected + } + + return &auth.Identity{ + UserID: result.UserID, + Email: result.Email, + OrgID: result.OrgID, + WorkspaceID: result.WorkspaceID, + Source: ProviderName, + }, nil +} + +// resolveOrgID picks the org_id sent to the verifier with this precedence: +// X-Org-ID header > "org-id" > "org_id" > config default. Matches the +// pre-refactor behavior in middleware.go. +func resolveOrgID(headers auth.Headers, fallback string) string { + for _, key := range []string{"X-Org-ID", "org-id", "org_id"} { + if v := headers.Get(key); v != "" { + return v + } + } + return fallback +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/httpverifier/provider_test.go b/forge-core/auth/providers/httpverifier/provider_test.go new file mode 100644 index 0000000..b35a0ee --- /dev/null +++ b/forge-core/auth/providers/httpverifier/provider_test.go @@ -0,0 +1,312 @@ +package httpverifier_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/httpverifier" +) + +// fakeVerifier returns a test server that implements the verifier contract. +// `handler` lets each test customize the response. +func fakeVerifier(t *testing.T, handler func(req map[string]any) (status int, body any)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("verifier got method %q, want POST", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("verifier got Content-Type %q, want application/json", ct) + } + body, _ := io.ReadAll(r.Body) + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("verifier got non-JSON body: %v", err) + } + status, respBody := handler(req) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if respBody != nil { + _ = json.NewEncoder(w).Encode(respBody) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func TestNew_ValidationErrors(t *testing.T) { + t.Run("missing url", func(t *testing.T) { + _, err := httpverifier.New(httpverifier.Config{}) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } + }) +} + +func TestVerify_HappyPath(t *testing.T) { + srv := fakeVerifier(t, func(req map[string]any) (int, any) { + if req["token"] != "good-token" { + t.Errorf("verifier got token %q, want good-token", req["token"]) + } + return http.StatusOK, map[string]any{ + "valid": true, + "user_id": "user-1", + "org_id": "org-1", + "email": "user@example.com", + "workspace_id": "ws-1", + } + }) + + p, err := httpverifier.New(httpverifier.Config{URL: srv.URL}) + if err != nil { + t.Fatalf("New: %v", err) + } + + id, err := p.Verify(context.Background(), "good-token", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } + if id.UserID != "user-1" || id.Email != "user@example.com" || id.OrgID != "org-1" || id.WorkspaceID != "ws-1" { + t.Errorf("identity = %+v, fields missing", id) + } + if id.Source != httpverifier.ProviderName { + t.Errorf("identity.Source = %q, want %q", id.Source, httpverifier.ProviderName) + } +} + +func TestVerify_ValidFalse_ReturnsRejected(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusOK, map[string]any{"valid": false, "error": "bad token"} + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "bad-token", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected", err) + } +} + +func TestVerify_401_ReturnsRejected(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusUnauthorized, nil + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected", err) + } +} + +func TestVerify_500_ReturnsProviderUnavailable(t *testing.T) { + // Review #6: 5xx is the verifier's fault, not the token's. Distinct + // sentinel so operators triaging audit logs don't chase token issues + // when the actual problem is verifier downtime. + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusInternalServerError, map[string]any{"error": "boom"} + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } + if errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err also matched ErrInvalidToken — sentinels must not overlap: %v", err) + } +} + +func TestVerify_502BadGateway_ReturnsProviderUnavailable(t *testing.T) { + // Other 5xx codes — same classification. + srv := fakeVerifier(t, func(map[string]any) (int, any) { + return http.StatusBadGateway, nil + }) + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestVerify_NetworkError_ReturnsProviderUnavailable(t *testing.T) { + p, _ := httpverifier.New(httpverifier.Config{ + URL: "http://127.0.0.1:0/never", // port 0 is invalid, will fail to connect + Timeout: 100 * time.Millisecond, + }) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestVerify_Non401_4xx_ReturnsRejected(t *testing.T) { + // 4xx other than 401 (e.g., 400, 403) means the verifier + // explicitly refused the request — token-side, not server-side. + for _, code := range []int{http.StatusBadRequest, http.StatusForbidden} { + t.Run(http.StatusText(code), func(t *testing.T) { + srv := fakeVerifier(t, func(map[string]any) (int, any) { return code, nil }) + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("status %d → err = %v, want ErrTokenRejected", code, err) + } + }) + } +} + +func TestVerify_UndecodableBody_ReturnsProviderUnavailable(t *testing.T) { + // 200 OK but body isn't the contract JSON — verifier misbehavior, + // not a token issue. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("not even close to JSON {{{}")) + })) + defer srv.Close() + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + _, err := p.Verify(context.Background(), "x", nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Fatalf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestVerify_OrgID_Precedence(t *testing.T) { + tests := []struct { + name string + headers auth.Headers + want string + }{ + { + name: "header X-Org-ID wins", + headers: auth.Headers{"X-Org-ID": "header-org", "org-id": "lower-org"}, + want: "header-org", + }, + { + name: "lowercase org-id when X-Org-ID absent", + headers: auth.Headers{"org-id": "lower-org"}, + want: "lower-org", + }, + { + name: "snake_case org_id when others absent", + headers: auth.Headers{"org_id": "snake-org"}, + want: "snake-org", + }, + { + name: "config default when no headers", + headers: auth.Headers{}, + want: "default-org", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotOrgID atomic.Value + srv := fakeVerifier(t, func(req map[string]any) (int, any) { + gotOrgID.Store(req["org_id"]) + return http.StatusOK, map[string]any{"valid": true, "user_id": "u"} + }) + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL, DefaultOrg: "default-org"}) + if _, err := p.Verify(context.Background(), "tok", tt.headers); err != nil { + t.Fatalf("Verify: %v", err) + } + if got := gotOrgID.Load(); got != tt.want { + t.Errorf("verifier received org_id = %v, want %q", got, tt.want) + } + }) + } +} + +func TestVerify_WireFormat_PreservedExactly(t *testing.T) { + // Black-box check: the JSON body sent must contain exactly "token" and + // "org_id" fields — no more, no less. This guards against accidental + // schema drift from the legacy --auth-url contract. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var got map[string]any + _ = json.Unmarshal(body, &got) + if len(got) != 2 { + t.Errorf("request body has %d fields, want 2: %v", len(got), got) + } + if _, ok := got["token"]; !ok { + t.Error("request body missing 'token' field") + } + if _, ok := got["org_id"]; !ok { + t.Error("request body missing 'org_id' field") + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"valid": true}) + })) + defer srv.Close() + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL}) + if _, err := p.Verify(context.Background(), "tok", nil); err != nil { + t.Fatalf("Verify: %v", err) + } +} + +func TestVerify_RespectsContextCancellation(t *testing.T) { + // Server blocks until either the request context is cancelled OR the + // test ends. The serverDone channel guarantees srv.Close() doesn't + // hang waiting for the handler to return under -race scheduling. + serverDone := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-serverDone: + } + })) + defer func() { + close(serverDone) + srv.Close() + }() + + p, _ := httpverifier.New(httpverifier.Config{URL: srv.URL, Timeout: 5 * time.Second}) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := p.Verify(ctx, "tok", nil) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("Verify returned nil error despite context timeout") + } + if elapsed > 1*time.Second { + t.Errorf("Verify took %v, expected <1s (context should have cancelled)", elapsed) + } +} + +func TestRegisteredViaFactory(t *testing.T) { + // Confirm the package registered itself with the auth package on import. + p, err := auth.Build("http_verifier", map[string]any{"url": "https://example.com/verify"}) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "http_verifier" { + t.Errorf("Name = %q, want http_verifier", p.Name()) + } +} + +func TestFactory_UnknownSettings_AreIgnored(t *testing.T) { + // Forward-compat: unknown YAML keys must not break construction. + _, err := auth.Build("http_verifier", map[string]any{ + "url": "https://example.com/verify", + "unknown_field": "future", + }) + if err != nil { + t.Fatalf("Build with unknown field: %v", err) + } +} diff --git a/forge-core/auth/providers/oidc/claims.go b/forge-core/auth/providers/oidc/claims.go new file mode 100644 index 0000000..de9e8d9 --- /dev/null +++ b/forge-core/auth/providers/oidc/claims.go @@ -0,0 +1,142 @@ +package oidc + +import ( + "maps" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// ClaimMap configures which JWT claim names are read into each Identity +// field. Empty fields fall back to OIDC standard claim names. +type ClaimMap struct { + UserID string `yaml:"user_id,omitempty"` + Email string `yaml:"email,omitempty"` + OrgID string `yaml:"org_id,omitempty"` + WorkspaceID string `yaml:"workspace_id,omitempty"` + Groups string `yaml:"groups,omitempty"` +} + +// defaults fills empty ClaimMap fields with OIDC-standard claim names. +// Returns a copy — the input is not mutated. +func (m ClaimMap) defaults() ClaimMap { + if m.UserID == "" { + m.UserID = "sub" + } + if m.Email == "" { + m.Email = "email" + } + if m.OrgID == "" { + m.OrgID = "org_id" + } + if m.WorkspaceID == "" { + m.WorkspaceID = "workspace_id" + } + if m.Groups == "" { + m.Groups = "groups" + } + return m +} + +// mapClaims builds an Identity from JWT claims using the configured +// ClaimMap. Header overrides (X-Org-ID and variants) take precedence over +// claim-derived OrgID — this preserves the per-request tenant routing +// behavior the http_verifier provider had. +// +// The raw claim map is also copied into Identity.Claims so that downstream +// authorization layers (Phase 4+) can read provider-specific claims +// without going back through the JWT. +func mapClaims(m ClaimMap, claims jwt.MapClaims, headers auth.Headers) *auth.Identity { + cm := m.defaults() + id := &auth.Identity{ + UserID: stringClaim(claims, cm.UserID), + Email: stringClaim(claims, cm.Email), + OrgID: stringClaim(claims, cm.OrgID), + WorkspaceID: stringClaim(claims, cm.WorkspaceID), + Groups: stringSliceClaim(claims, cm.Groups), + Claims: copyClaims(claims), + Source: ProviderName, + } + + // Header override: X-Org-ID (and lowercase variants) takes precedence + // over the claim-derived OrgID. The http_verifier provider has the + // same behavior for parity. + for _, key := range []string{"X-Org-ID", "org-id", "org_id"} { + if v := headers.Get(key); v != "" { + id.OrgID = v + break + } + } + + return id +} + +// stringClaim returns the named claim as a string, or "" if missing or +// not a string. +func stringClaim(claims jwt.MapClaims, name string) string { + v, ok := claims[name] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +// stringSliceClaim returns the named claim as a string slice. Tolerates +// the common shapes IdPs use: []string, []any of strings, or a single +// string (which becomes a one-element slice). +func stringSliceClaim(claims jwt.MapClaims, name string) []string { + v, ok := claims[name] + if !ok { + return nil + } + switch val := v.(type) { + case []string: + out := make([]string, len(val)) + copy(out, val) + return out + case []any: + out := make([]string, 0, len(val)) + for _, item := range val { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return out + case string: + if val == "" { + return nil + } + return []string{val} + default: + return nil + } +} + +// copyClaims returns a shallow copy of the claim map so the caller can +// expose it on Identity without sharing the underlying map with the +// parsed-token state. +// +// WARNING (review #11f): the returned map contains EVERY claim the IdP +// emitted on the token — `sub`, `email`, `iss`, `aud`, `exp`, `iat`, +// `nbf`, plus any custom claims the issuer adds (group memberships, +// internal IDs, profile fields, sometimes raw PII). Downstream consumers +// reading Identity.Claims will see all of it. +// +// Recommendations for consumers: +// - Prefer the mapped, typed fields on Identity (UserID, Email, OrgID, +// Groups) where they suffice — those have a fixed contract. +// - Treat Identity.Claims as an escape hatch for provider-specific +// authorization logic, not as something to log or relay verbatim. +// - When a future authz layer needs a claim-allowlist, that's the +// right place to add it — not here. This package serves the raw +// auth principal; filtering is policy. +func copyClaims(claims jwt.MapClaims) map[string]any { + if len(claims) == 0 { + return nil + } + out := make(map[string]any, len(claims)) + maps.Copy(out, claims) + return out +} diff --git a/forge-core/auth/providers/oidc/discovery.go b/forge-core/auth/providers/oidc/discovery.go new file mode 100644 index 0000000..3ce6c53 --- /dev/null +++ b/forge-core/auth/providers/oidc/discovery.go @@ -0,0 +1,105 @@ +package oidc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" +) + +// providerMetadata is the subset of the OIDC discovery document we read. +// Per the OIDC spec, this document is stable for the lifetime of an issuer, +// so we cache it forever after the first successful load. +type providerMetadata struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` +} + +// discovery lazily fetches and caches the OIDC discovery document. +// +// Concurrency model: a sync.Once gates the initial load so concurrent +// requests don't race. Subsequent calls return the cached result without +// further locking once `done` is true. If the first load fails, every +// caller sees the same error — but the next call re-attempts (we don't +// want a transient startup failure to permanently brick the provider). +type discovery struct { + issuer string + client *http.Client + + mu sync.Mutex + loaded bool + meta providerMetadata +} + +func newDiscovery(issuer string, client *http.Client) *discovery { + return &discovery{ + issuer: strings.TrimRight(issuer, "/"), + client: client, + } +} + +// EnsureLoaded fetches the discovery doc on first call (or after a failed +// previous attempt). Safe for concurrent use; only one HTTP call is in +// flight per discovery instance at a time. +func (d *discovery) EnsureLoaded(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + if d.loaded { + return nil + } + meta, err := d.fetch(ctx) + if err != nil { + return err + } + // Per OIDC spec, the discovery doc's `issuer` must match the issuer + // the client used to fetch it. Catches misconfiguration where someone + // points at a discovery doc for a different tenant. + // + // Trim trailing slashes on both sides — IdPs and operators disagree + // on whether the canonical form has one (Auth0 emits with, many + // Okta deployments without). The trim is normalization, not a + // security weakening — distinct hosts remain distinct after trim. + // See Config.normalize() in provider.go. + if strings.TrimRight(meta.Issuer, "/") != d.issuer { + return fmt.Errorf("oidc: discovery issuer %q does not match configured issuer %q", meta.Issuer, d.issuer) + } + if meta.JWKSURI == "" { + return fmt.Errorf("oidc: discovery doc for %q has empty jwks_uri", d.issuer) + } + d.meta = meta + d.loaded = true + return nil +} + +// JWKSURI returns the JWKS endpoint URL from the loaded discovery doc. +// Callers must call EnsureLoaded first. +func (d *discovery) JWKSURI() string { + d.mu.Lock() + defer d.mu.Unlock() + return d.meta.JWKSURI +} + +func (d *discovery) fetch(ctx context.Context) (providerMetadata, error) { + url := d.issuer + "/.well-known/openid-configuration" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return providerMetadata{}, fmt.Errorf("oidc: build discovery request: %w", err) + } + resp, err := d.client.Do(req) + if err != nil { + return providerMetadata{}, fmt.Errorf("oidc: discovery fetch: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + return providerMetadata{}, fmt.Errorf("oidc: discovery returned status %d", resp.StatusCode) + } + var meta providerMetadata + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return providerMetadata{}, fmt.Errorf("oidc: discovery decode: %w", err) + } + return meta, nil +} diff --git a/forge-core/auth/providers/oidc/export_test.go b/forge-core/auth/providers/oidc/export_test.go new file mode 100644 index 0000000..8bd6a98 --- /dev/null +++ b/forge-core/auth/providers/oidc/export_test.go @@ -0,0 +1,13 @@ +package oidc + +import "time" + +// SetCacheClockForTest replaces the JWKS cache's time source. Exposed only +// in _test builds so tests can advance virtual time past TTL boundaries +// without sleeping. +// +// Not part of the public API. Test code in this package's external test +// file (oidc_test) uses this via the package-import boundary. +func SetCacheClockForTest(p *Provider, now func() time.Time) { + p.jwks.now = now +} diff --git a/forge-core/auth/providers/oidc/jwks.go b/forge-core/auth/providers/oidc/jwks.go new file mode 100644 index 0000000..6d60437 --- /dev/null +++ b/forge-core/auth/providers/oidc/jwks.go @@ -0,0 +1,426 @@ +package oidc + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "sync" + "time" +) + +// jwk is one entry of a JWKS document. We only model the fields we read. +// +// Reference: RFC 7517 (JWK) and RFC 7518 (JWA). +type jwk struct { + Kty string `json:"kty"` // key type ("RSA" or "EC") + Use string `json:"use,omitempty"` // "sig" expected + Alg string `json:"alg,omitempty"` // algorithm hint (advisory) + Kid string `json:"kid,omitempty"` // key ID + + // RSA fields + N string `json:"n,omitempty"` + E string `json:"e,omitempty"` + + // EC fields + Crv string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` +} + +// parsedKey is a JWK after decoding the public key material plus the +// algorithm advertised by the JWKS. The provider trusts the JWKS-declared +// algorithm — NOT the algorithm asserted in the token header. +type parsedKey struct { + Public any // *rsa.PublicKey or *ecdsa.PublicKey + Alg string // RS256, ES256, etc. +} + +// jwksCache is a TTL-bounded cache of parsed JWK keys keyed by `kid`. +// +// Refresh strategy: +// +// 1. First lookup ever: fetch JWKS. +// 2. TTL expired (now - lastSuccessful > ttl): fetch JWKS on the request +// path. This is what invalidates keys that the IdP has revoked — +// without it, a compromised key removed from the issuer's JWKS would +// remain trusted by Forge until process restart. +// 3. kid not in cache (and TTL not expired): fetch JWKS to pick up new +// keys an IdP may have added since last refresh. +// 4. kid still not in cache after refresh: return ErrKeyNotFound. +// 5. kid in cache and TTL not expired: return immediately (hot path). +// +// Two separate timestamps track refresh state so failures don't extend the +// TTL window: +// - lastSuccessful: timestamp of the most recent successful refresh. +// Drives TTL accounting; only advanced on success. +// - lastAttempt: timestamp of the most recent attempt (success or fail). +// Drives failure backpressure so a dead JWKS endpoint isn't hammered. +// +// Concurrency: a single fetch is in flight at a time; concurrent misses +// coalesce behind the write lock. +type jwksCache struct { + jwksURL func(ctx context.Context) (string, error) + httpClient *http.Client + ttl time.Duration + // refetchGrace prevents a malformed kid from triggering one JWKS fetch + // per request. Within this window, repeated unknown-kid lookups skip + // the refetch. + refetchGrace time.Duration + // now returns the current time. Injectable for tests so TTL expiry + // can be exercised without real time passing. + now func() time.Time + + mu sync.RWMutex + keys map[string]parsedKey + lastSuccessful time.Time // last successful JWKS load — drives TTL + lastAttempt time.Time // last refresh attempt (success or fail) — drives error backoff + // lastMissRefresh is the last time we refreshed *because of* an + // unknown kid and the kid was still not present after refresh. + // Used to suppress stampedes of refetches for the same bad kid. + lastMissRefresh time.Time + lastErr error // last fetch error, for diagnostics +} + +const ( + defaultJWKSTTL = 1 * time.Hour + minJWKSTTL = 5 * time.Minute + defaultRefetchGrace = 30 * time.Second +) + +func newJWKSCache(jwksURL func(ctx context.Context) (string, error), client *http.Client, ttl time.Duration) *jwksCache { + if ttl == 0 { + ttl = defaultJWKSTTL + } + // Only clamp the minimum when a TTL was actually requested — leaving + // callers to opt out of TTL entirely would require an explicit zero + // (which the previous branch promoted to the default). The clamp + // protects against operators setting an absurdly short value that + // would hammer the IdP. + if ttl < minJWKSTTL { + ttl = minJWKSTTL + } + return &jwksCache{ + jwksURL: jwksURL, + httpClient: client, + ttl: ttl, + refetchGrace: defaultRefetchGrace, + now: time.Now, + keys: map[string]parsedKey{}, + } +} + +// ttlExpired reports whether the cache's last successful refresh is older +// than the configured TTL. Returns true when there's no successful fetch +// yet (initial state). Must be called with at least an RLock held. +func (c *jwksCache) ttlExpired() bool { + if c.lastSuccessful.IsZero() { + return true + } + return c.now().Sub(c.lastSuccessful) > c.ttl +} + +// backoffActive reports whether we should suppress a refresh attempt +// because a recent one failed. Prevents a dead JWKS endpoint from being +// hit on every request during an outage — without this gate, N concurrent +// callers against a flapping IdP would amplify into N IdP hits per +// request burst (a DDoS amplifier). Must be called with at least an +// RLock held. +// +// Test coverage: TestJWKS_FailedRefreshDoesNotHammerIdP in provider_test.go +// pins the "exactly one IdP hit per refetchGrace window during failure" +// contract — fires 50 unknown-kid requests in a backoff window and asserts +// JWKS endpoint is called once. +func (c *jwksCache) backoffActive() bool { + if c.lastErr == nil { + return false + } + if c.lastAttempt.IsZero() { + return false + } + return c.now().Sub(c.lastAttempt) < c.refetchGrace +} + +// ErrKeyNotFound is returned when a `kid` is not in the cache and a +// refresh did not produce it. +var ErrKeyNotFound = fmt.Errorf("oidc: signing key not found") + +// Get returns the parsed key for the given kid, refreshing the JWKS cache +// if necessary. Safe for concurrent use. +// +// Refresh semantics: +// - TTL expired (cache stale): refresh on the request path. Required +// for revocation to take effect — without this, a key removed from +// the IdP's JWKS would remain trusted by Forge until process restart. +// - kid in cache AND cache fresh: return immediately (hot path). +// - kid missing AND not recently refreshed for a miss: refresh once. +// - kid missing AND recently refreshed for a miss (within refetchGrace): +// deny without re-fetching — prevents bad kids hammering the endpoint. +// - last refresh failed AND within failure backoff: serve from existing +// cache without re-attempt — protects a dead IdP from request storms. +func (c *jwksCache) Get(ctx context.Context, kid string) (parsedKey, error) { + // Fast path: lock-free read. Treat a stale (TTL-expired) cache as if + // every kid were a miss, so we fall into the write path below where + // the actual refresh happens. + c.mu.RLock() + stale := c.ttlExpired() + var ( + key parsedKey + hit bool + ) + if !stale { + key, hit = c.keys[kid] + } + missCooldown := !c.lastMissRefresh.IsZero() && c.now().Sub(c.lastMissRefresh) < c.refetchGrace + errBackoff := c.backoffActive() + c.mu.RUnlock() + + if hit { + return key, nil + } + // During error backoff we serve from the (possibly empty) cache — + // don't re-attempt a known-broken endpoint. Note that on the very + // first call (never-fetched cache + endpoint already broken), this + // path is impossible because lastErr would still be nil. + if errBackoff { + return parsedKey{}, ErrKeyNotFound + } + // Don't refresh just to look up a kid we already concluded is absent. + // TTL-expired cache always tries a refresh (security), regardless of + // the unknown-kid grace window. + if !stale && missCooldown { + return parsedKey{}, ErrKeyNotFound + } + + // Take the write lock so concurrent refresh attempts coalesce. + c.mu.Lock() + defer c.mu.Unlock() + + // Re-check under write lock — another goroutine may have refreshed + // while we were waiting. + if !c.ttlExpired() { + if key, ok := c.keys[kid]; ok { + return key, nil + } + if !c.lastMissRefresh.IsZero() && c.now().Sub(c.lastMissRefresh) < c.refetchGrace { + return parsedKey{}, ErrKeyNotFound + } + } + + if err := c.refreshLocked(ctx); err != nil { + return parsedKey{}, err + } + if key, ok := c.keys[kid]; ok { + return key, nil + } + // Refresh happened but kid is still missing — remember that so we + // don't hammer JWKS on the next unknown-kid request. + c.lastMissRefresh = c.now() + return parsedKey{}, ErrKeyNotFound +} + +// refreshLocked fetches the JWKS endpoint and replaces the cache. Must be +// called with c.mu held for writing. +// refreshLocked fetches and parses the JWKS. lastAttempt is always +// updated (drives error backoff). lastSuccessful is updated ONLY when a +// full parse succeeds — failures must not extend the TTL window. +func (c *jwksCache) refreshLocked(ctx context.Context) error { + c.lastAttempt = c.now() + + url, err := c.jwksURL(ctx) + if err != nil { + c.lastErr = err + return fmt.Errorf("oidc: resolve jwks_uri: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + c.lastErr = err + return fmt.Errorf("oidc: build jwks request: %w", err) + } + resp, err := c.httpClient.Do(req) + if err != nil { + c.lastErr = err + return fmt.Errorf("oidc: jwks fetch: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + _, _ = io.CopyN(io.Discard, resp.Body, 1024) + err := fmt.Errorf("oidc: jwks returned status %d", resp.StatusCode) + c.lastErr = err + return err + } + + var body struct { + Keys []jwk `json:"keys"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + c.lastErr = err + return fmt.Errorf("oidc: jwks decode: %w", err) + } + + parsed := make(map[string]parsedKey, len(body.Keys)) + for _, k := range body.Keys { + // Only signature-use keys. Empty `use` is permitted (treated as + // "sig" since `key_ops` is not modeled). + if k.Use != "" && k.Use != "sig" { + continue + } + // Reject HMAC algorithms — using shared-secret HMAC with JWKS is + // a footgun: anyone who can read the JWKS can sign tokens. + if k.Alg != "" && !isAllowedAlg(k.Alg) { + continue + } + pk, alg, err := parseJWK(k) + if err != nil { + // Skip unparseable keys rather than failing the whole load — + // real-world JWKS may include keys for algorithms we don't + // support (e.g., RSA-OAEP for encryption). + continue + } + if k.Kid == "" { + // Without kid we can't look up; skip. Tokens without kid are + // also rejected at the provider level. + continue + } + parsed[k.Kid] = parsedKey{Public: pk, Alg: alg} + } + + c.keys = parsed + c.lastSuccessful = c.now() + c.lastErr = nil + // Clear the unknown-kid grace window on a successful refresh so the + // next miss can re-attempt a refresh immediately (the prior miss may + // have been for a kid that's now present in the freshly-loaded JWKS). + c.lastMissRefresh = time.Time{} + return nil +} + +// parseJWK extracts the public key material and derives the signing +// algorithm from the JWK. The returned alg is the value the provider +// will require the token header to match — empty string means +// "any allowed asymmetric alg permitted by the key type" (only happens +// for RSA, where the same key can sign RS256/384/512 or PS256/384/512; +// per RFC 7518 the application is meant to infer the alg from context). +// +// Review #11b: we previously defaulted an RSA JWK with no `alg` field +// to "RS256", which made the cross-check at provider.go reject +// legitimate PS256-signed tokens against the same key. EC keys aren't +// affected — their alg is derived from the curve, not advertised. +func parseJWK(k jwk) (any, string, error) { + switch k.Kty { + case "RSA": + pk, err := parseRSAPublicKey(k) + if err != nil { + return nil, "", err + } + // Preserve empty alg so the provider can fall back to header-alg + // validation against the whitelist (still gated by isAllowedAlg). + return pk, k.Alg, nil + case "EC": + pk, alg, err := parseECDSAPublicKey(k) + if err != nil { + return nil, "", err + } + return pk, alg, nil + default: + return nil, "", fmt.Errorf("unsupported kty %q", k.Kty) + } +} + +func parseRSAPublicKey(k jwk) (*rsa.PublicKey, error) { + nBytes, err := base64.RawURLEncoding.DecodeString(k.N) + if err != nil { + return nil, fmt.Errorf("decode RSA n: %w", err) + } + eBytes, err := base64.RawURLEncoding.DecodeString(k.E) + if err != nil { + return nil, fmt.Errorf("decode RSA e: %w", err) + } + // e is a big-endian variable-length integer. + var e int + switch len(eBytes) { + case 0: + return nil, fmt.Errorf("empty RSA exponent") + case 1: + e = int(eBytes[0]) + case 2: + e = int(binary.BigEndian.Uint16(eBytes)) + case 3: + padded := make([]byte, 4) + copy(padded[1:], eBytes) + e = int(binary.BigEndian.Uint32(padded)) + case 4: + e = int(binary.BigEndian.Uint32(eBytes)) + default: + return nil, fmt.Errorf("RSA exponent too large: %d bytes", len(eBytes)) + } + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(nBytes), + E: e, + }, nil +} + +func parseECDSAPublicKey(k jwk) (*ecdsa.PublicKey, string, error) { + var ( + curve elliptic.Curve + alg string + ) + switch k.Crv { + case "P-256": + curve = elliptic.P256() + alg = "ES256" + case "P-384": + curve = elliptic.P384() + alg = "ES384" + case "P-521": + curve = elliptic.P521() + alg = "ES512" + default: + return nil, "", fmt.Errorf("unsupported curve %q", k.Crv) + } + xBytes, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + return nil, "", fmt.Errorf("decode EC x: %w", err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + return nil, "", fmt.Errorf("decode EC y: %w", err) + } + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(xBytes), + Y: new(big.Int).SetBytes(yBytes), + }, alg, nil +} + +// allowedAlgs lists the asymmetric signing algorithms accepted by the +// provider. Symmetric (HMAC) and "none" are NEVER accepted, even if a +// JWKS advertises them. +var allowedAlgs = map[string]struct{}{ + "RS256": {}, "RS384": {}, "RS512": {}, + "PS256": {}, "PS384": {}, "PS512": {}, + "ES256": {}, "ES384": {}, "ES512": {}, +} + +func isAllowedAlg(alg string) bool { + _, ok := allowedAlgs[alg] + return ok +} + +// allowedAlgNames returns the algorithm names in a slice — for passing to +// jwt.WithValidMethods. +func allowedAlgNames() []string { + out := make([]string, 0, len(allowedAlgs)) + for a := range allowedAlgs { + out = append(out, a) + } + return out +} diff --git a/forge-core/auth/providers/oidc/provider.go b/forge-core/auth/providers/oidc/provider.go new file mode 100644 index 0000000..47c985d --- /dev/null +++ b/forge-core/auth/providers/oidc/provider.go @@ -0,0 +1,377 @@ +// Package oidc implements an auth.Provider that verifies JWT bearer tokens +// against an OpenID Connect issuer. +// +// It works with any compliant OIDC provider — Auth0, Keycloak, Azure AD, +// Google Workspace, Ping, JumpCloud, and (in OIDC-only mode) Okta. +// +// Behavior at a glance: +// +// - On first verify, the discovery document is fetched once from +// {issuer}/.well-known/openid-configuration and cached forever (the +// issuer is stable per OIDC spec). +// +// - JWKS is fetched lazily and cached. On unknown-kid, the JWKS is +// refetched once before declaring the key unknown. +// +// - The signing algorithm is taken from the JWKS entry, NOT from the +// token header. This defends against algorithm-confusion attacks +// (e.g., a token claiming `alg: HS256` against an RSA JWKS). +// +// - alg=none and HMAC algorithms (HS256/384/512) are never accepted. +// +// - Standard claims are validated: iss exact match, aud contains +// configured Audience (with optional azp == ClientID fallback), +// exp/nbf within ClockSkew leeway. +// +// - Claim → Identity mapping is configurable via ClaimMap. The X-Org-ID +// header overrides the claim-derived OrgID for per-request tenant +// routing. +package oidc + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this +// provider in the auth registry. +const ProviderName = "oidc" + +// Default operational tunables. +const ( + DefaultClockSkew = 30 * time.Second + DefaultHTTPTimeout = 10 * time.Second +) + +// Config controls the OIDC provider. +type Config struct { + // Issuer is the full OIDC issuer URL (no trailing slash). Required. + // The token's `iss` claim must match this value exactly. + Issuer string `yaml:"issuer"` + + // Audience is the expected `aud` claim value. Required. If the token + // has multiple audiences, the configured Audience must be one of them. + Audience string `yaml:"audience"` + + // ClientID is an optional secondary audience check: if set, a token + // whose `aud` does not contain Audience is still accepted when its + // `azp` (authorized party) claim equals ClientID. + ClientID string `yaml:"client_id,omitempty"` + + // JWKSURL overrides the JWKS endpoint discovered via the OIDC + // discovery document. Most users leave this empty. + JWKSURL string `yaml:"jwks_url,omitempty"` + + // JWKSCacheTTL caps the maximum age of cached JWKS keys. Defaults to + // 1 hour. Values below 5 minutes are silently clamped up to avoid + // hammering the IdP. + JWKSCacheTTL time.Duration `yaml:"jwks_cache_ttl,omitempty"` + + // ClockSkew is the leeway applied to `exp` and `nbf` validation. + // Defaults to 30 seconds. + ClockSkew time.Duration `yaml:"clock_skew,omitempty"` + + // ClaimMap configures which JWT claim names map to Identity fields. + // Empty fields use OIDC defaults (sub, email, org_id, …). + ClaimMap ClaimMap `yaml:"claim_map,omitempty"` + + // HTTPClient overrides the default client. Injectable for tests. + HTTPClient *http.Client `yaml:"-"` +} + +// Validate returns ErrProviderNotConfigured when required fields are +// missing. +func (c Config) Validate() error { + if c.Issuer == "" { + return fmt.Errorf("%w: issuer required", auth.ErrProviderNotConfigured) + } + if c.Audience == "" { + return fmt.Errorf("%w: audience required", auth.ErrProviderNotConfigured) + } + return nil +} + +// normalize returns a copy of the Config with the Issuer in canonical +// form (trailing slash stripped). Called once at New() so every +// downstream comparison — discovery doc check, JWT iss validation, +// JWKS URL construction — uses the same form. +// +// Why this matters: OIDC IdPs are inconsistent about trailing slashes. +// Auth0 emits "iss": "https://tenant.auth0.com/", many enterprise +// Okta tenants emit without, etc. Operators paste whatever they see in +// their admin UI into forge.yaml. We trim once at the boundary so the +// mismatch never reaches a comparison. +func (c Config) normalize() Config { + c.Issuer = strings.TrimRight(c.Issuer, "/") + return c +} + +// Provider implements auth.Provider for OIDC issuers. +type Provider struct { + cfg Config + client *http.Client + discovery *discovery + jwks *jwksCache + parser *jwt.Parser + clockSkew time.Duration +} + +// New constructs a Provider after validating cfg. No network I/O happens +// here — discovery and JWKS are fetched lazily on first Verify. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + // Normalize trailing slash once so every downstream comparison + // (discovery doc, token iss, JWKS URL construction) uses the same + // canonical form. Review finding #2. + cfg = cfg.normalize() + if cfg.ClockSkew == 0 { + cfg.ClockSkew = DefaultClockSkew + } + client := cfg.HTTPClient + if client == nil { + client = &http.Client{Timeout: DefaultHTTPTimeout} + } + + p := &Provider{ + cfg: cfg, + client: client, + clockSkew: cfg.ClockSkew, + } + + p.discovery = newDiscovery(cfg.Issuer, client) + p.jwks = newJWKSCache(p.resolveJWKSURL, client, cfg.JWKSCacheTTL) + + // Configure the JWT parser once. + // + // Issuer is validated manually (post-parse) instead of via + // jwt.WithIssuer — that helper does strict string equality on the + // token's iss claim, which fails when the IdP emits with a trailing + // slash but our config normalized it off (or vice versa). The manual + // check trims both sides. Same reasoning as the discovery doc + // comparison in EnsureLoaded. + // + // Audience is similarly manual because of the azp fallback below. + parserOpts := []jwt.ParserOption{ + jwt.WithValidMethods(allowedAlgNames()), + jwt.WithLeeway(cfg.ClockSkew), + jwt.WithExpirationRequired(), + } + p.parser = jwt.NewParser(parserOpts...) + + return p, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// resolveJWKSURL returns the configured override JWKSURL if set, otherwise +// loads discovery and returns the discovered jwks_uri. +func (p *Provider) resolveJWKSURL(ctx context.Context) (string, error) { + if p.cfg.JWKSURL != "" { + return p.cfg.JWKSURL, nil + } + if err := p.discovery.EnsureLoaded(ctx); err != nil { + return "", err + } + return p.discovery.JWKSURI(), nil +} + +// Verify implements auth.Provider. +// +// The verification flow: +// 1. Parse token structurally (without signature verification yet). +// 2. Extract kid from header; reject tokens without kid. +// 3. Look up kid in JWKS cache (refreshing on miss). +// 4. Cross-check token's `alg` against JWKS-declared alg (defends against +// algorithm confusion). +// 5. Re-parse the token with the resolved key, validating iss/exp/nbf. +// 6. Validate audience (with azp fallback if ClientID is configured). +// 7. Map claims to Identity, applying header overrides. +func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Headers) (*auth.Identity, error) { + // Quick structural check before doing crypto work. JWTs are three + // base64url segments separated by dots. + if !looksLikeJWT(tokenStr) { + return nil, auth.ErrTokenNotForMe + } + + // First pass: parse the header to extract kid + alg without verifying + // the signature. + parsedHeader, _, err := jwt.NewParser().ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("%w: parse header: %w", auth.ErrInvalidToken, err) + } + kid, _ := parsedHeader.Header["kid"].(string) + if kid == "" { + return nil, fmt.Errorf("%w: token header missing kid", auth.ErrInvalidToken) + } + headerAlg, _ := parsedHeader.Header["alg"].(string) + if !isAllowedAlg(headerAlg) { + // Catches alg=none, HMAC, etc. before any further work. + return nil, fmt.Errorf("%w: disallowed alg %q", auth.ErrInvalidToken, headerAlg) + } + + // Look up the key. + key, err := p.jwks.Get(ctx, kid) + if err != nil { + if errors.Is(err, ErrKeyNotFound) { + return nil, fmt.Errorf("%w: kid %q not found in JWKS", auth.ErrInvalidToken, kid) + } + return nil, err // transport / decode error — fail-closed via chain + } + + // Cross-check: when the JWKS entry declares an alg, the token's alg + // must match it exactly (algorithm-confusion defense). When the JWKS + // entry has no alg (RFC 7518 says "infer from context"), trust the + // token's header alg as long as it's in the asymmetric whitelist — + // which was already enforced at line ~195 above. Review #11b: + // previously the empty-alg case defaulted to "RS256", which rejected + // legitimate PS256/PS384/PS512 tokens signed by the same RSA key. + if key.Alg != "" && key.Alg != headerAlg { + return nil, fmt.Errorf("%w: token alg %q does not match JWKS alg %q for kid %q", + auth.ErrInvalidToken, headerAlg, key.Alg, kid) + } + + // Second pass: verify signature + standard claims (iss/exp/nbf). + claims := jwt.MapClaims{} + _, err = p.parser.ParseWithClaims(tokenStr, claims, func(_ *jwt.Token) (any, error) { + return key.Public, nil + }) + if err != nil { + return nil, classifyJWTError(err) + } + + // Issuer validation (manual — trims trailing slash on both sides so + // IdPs that emit iss with/without a trailing slash interop with + // configs that use the opposite form). See Config.normalize and + // review finding #2. + if err := p.checkIssuer(claims); err != nil { + return nil, err + } + + // Audience validation (with azp fallback). + if err := p.checkAudience(claims); err != nil { + return nil, err + } + + return mapClaims(p.cfg.ClaimMap, claims, headers), nil +} + +// checkIssuer validates the token's `iss` claim against the configured +// Issuer. Both values are trimmed of trailing slashes before comparison +// so trailing-slash drift between the IdP and the operator's configured +// value doesn't cause spurious rejections. +// +// Security note: trimming a single trailing slash cannot enable issuer +// impersonation — "https://attacker.example" and "https://legit.example" +// remain distinct. The trim is purely a normalization fix. +func (p *Provider) checkIssuer(claims jwt.MapClaims) error { + iss, _ := claims["iss"].(string) + if iss == "" { + return fmt.Errorf("%w: missing iss claim", auth.ErrTokenRejected) + } + if strings.TrimRight(iss, "/") != p.cfg.Issuer { + return fmt.Errorf("%w: issuer %q does not match configured %q", + auth.ErrTokenRejected, iss, p.cfg.Issuer) + } + return nil +} + +// looksLikeJWT does a cheap structural check: exactly three non-empty +// dot-separated segments. Used to yield to the next provider on non-JWT +// tokens (e.g., opaque tokens that another chain entry would handle). +func looksLikeJWT(s string) bool { + if len(s) < 5 { + return false + } + segments := 1 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + segments++ + if segments > 3 { + return false + } + } + } + return segments == 3 +} + +// checkAudience validates that the token's `aud` claim contains the +// configured Audience. If ClientID is set and aud does not contain +// Audience, accepts when `azp` claim equals ClientID. +func (p *Provider) checkAudience(claims jwt.MapClaims) error { + auds := audienceSlice(claims) + if slices.Contains(auds, p.cfg.Audience) { + return nil + } + if p.cfg.ClientID != "" { + if azp, _ := claims["azp"].(string); azp == p.cfg.ClientID { + return nil + } + } + return fmt.Errorf("%w: audience %v does not match configured %q", auth.ErrTokenRejected, auds, p.cfg.Audience) +} + +// audienceSlice normalizes the `aud` claim into []string. RFC 7519 allows +// `aud` to be either a string or a string array. +func audienceSlice(claims jwt.MapClaims) []string { + switch v := claims["aud"].(type) { + case string: + return []string{v} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return out + case []string: + return v + default: + return nil + } +} + +// classifyJWTError maps a parsing/verification error from jwt/v5 to one of +// the auth-package sentinels. +func classifyJWTError(err error) error { + switch { + case errors.Is(err, jwt.ErrTokenExpired): + return fmt.Errorf("%w: token expired", auth.ErrTokenRejected) + case errors.Is(err, jwt.ErrTokenNotValidYet): + return fmt.Errorf("%w: token not yet valid", auth.ErrTokenRejected) + case errors.Is(err, jwt.ErrTokenInvalidIssuer): + return fmt.Errorf("%w: issuer mismatch", auth.ErrTokenRejected) + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + return fmt.Errorf("%w: signature invalid", auth.ErrInvalidToken) + case errors.Is(err, jwt.ErrTokenMalformed): + return fmt.Errorf("%w: token malformed", auth.ErrInvalidToken) + case errors.Is(err, jwt.ErrTokenUnverifiable): + return fmt.Errorf("%w: token unverifiable", auth.ErrInvalidToken) + case errors.Is(err, jwt.ErrTokenRequiredClaimMissing): + return fmt.Errorf("%w: required claim missing", auth.ErrInvalidToken) + default: + return fmt.Errorf("%w: %w", auth.ErrInvalidToken, err) + } +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/oidc/provider_test.go b/forge-core/auth/providers/oidc/provider_test.go new file mode 100644 index 0000000..296d4c5 --- /dev/null +++ b/forge-core/auth/providers/oidc/provider_test.go @@ -0,0 +1,1236 @@ +package oidc_test + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/oidc" +) + +const testAudience = "api://forge" + +func newProvider(t *testing.T, fi *fakeIssuer, mod func(*oidc.Config)) *oidc.Provider { + t.Helper() + cfg := oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + } + if mod != nil { + mod(&cfg) + } + p, err := oidc.New(cfg) + if err != nil { + t.Fatalf("oidc.New: %v", err) + } + return p +} + +// --- Construction / Validation --- + +func TestNew_Validation(t *testing.T) { + tests := []struct { + name string + cfg oidc.Config + wantErr error + }{ + { + name: "empty config", + cfg: oidc.Config{}, + wantErr: auth.ErrProviderNotConfigured, + }, + { + name: "missing audience", + cfg: oidc.Config{Issuer: "https://example.com"}, + wantErr: auth.ErrProviderNotConfigured, + }, + { + name: "missing issuer", + cfg: oidc.Config{Audience: "api://forge"}, + wantErr: auth.ErrProviderNotConfigured, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := oidc.New(tt.cfg) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("err = %v, want %v", err, tt.wantErr) + } + }) + } +} + +// --- Happy paths --- + +func TestVerify_RSA256_HappyPath(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["email"] = "alice@example.com" + claims["groups"] = []string{"engineers", "oncall"} + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } + if id.UserID != "user-123" { + t.Errorf("UserID = %q, want user-123", id.UserID) + } + if id.Email != "alice@example.com" { + t.Errorf("Email = %q, want alice@example.com", id.Email) + } + if len(id.Groups) != 2 || id.Groups[0] != "engineers" || id.Groups[1] != "oncall" { + t.Errorf("Groups = %v, want [engineers oncall]", id.Groups) + } + if id.Source != "oidc" { + t.Errorf("Source = %q, want oidc", id.Source) + } +} + +func TestVerify_ECDSA_HappyPath(t *testing.T) { + fi := newFakeIssuer(t) + fi.addECDSAKey("ec-1") + p := newProvider(t, fi, nil) + + tok := fi.SignWith("ec-1", fi.DefaultClaims(testAudience)) + id, err := p.Verify(context.Background(), tok, nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.UserID != "user-123" { + t.Errorf("UserID = %q, want user-123", id.UserID) + } +} + +// --- Algorithm-confusion / security --- + +func TestVerify_RejectsAlgNone(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + tok := SignUnsigned(fi.DefaultClaims(testAudience)) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (alg=none must be rejected)", err) + } +} + +func TestVerify_RejectsHMAC(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + tok := SignHMAC("key-1", fi.DefaultClaims(testAudience), []byte("any-secret")) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (HMAC must be rejected)", err) + } +} + +func TestVerify_JWKWithoutAlgAcceptsAnyAsymmetric(t *testing.T) { + // Review #11b: when the JWKS entry omits the `alg` field (RFC 7518 + // says "infer from context"), the provider must NOT lock the key + // to "RS256" — that rejected legitimate PS256 tokens signed by the + // same RSA key. The fix preserves empty alg and falls back to the + // header's alg as long as it's in the asymmetric whitelist. + fi := newFakeIssuer(t) + // Override the JWKS entry to omit its alg advertisement. + fi.keys["key-1"].algInJWKS = "" + + p := newProvider(t, fi, nil) + + // Token signed with RS256 — works (the key was generated as RSA-2048). + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("RS256 token with empty-alg JWK rejected: %v", err) + } + + // Now switch the signing method to PS256 (uses the same key + // material — RSA-PSS just changes the padding) and verify the + // provider accepts it. + fi.keys["key-1"].method = jwt.SigningMethodPS256 + tok = fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("PS256 token with empty-alg JWK rejected (regression of #11b): %v", err) + } +} + +func TestVerify_JWKExplicitAlgStillStrict(t *testing.T) { + // Counterpart: when the JWKS entry DOES declare an alg, the + // cross-check must still enforce it. This guards the + // algorithm-confusion defense — if a JWK says "this key is for + // RS256 only", a token claiming PS256 against that key is + // rejected. + fi := newFakeIssuer(t) + fi.keys["key-1"].algInJWKS = "RS256" // explicit + + p := newProvider(t, fi, nil) + + // Sign with PS256 against the same key material. + fi.keys["key-1"].method = jwt.SigningMethodPS256 + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("explicit JWK alg=RS256 must reject PS256 token; got err = %v", err) + } +} + +func TestVerify_AlgInHeaderMustMatchJWKS(t *testing.T) { + // Algorithm-confusion defense: if the token claims an alg that + // differs from the JWKS-declared alg for that kid, reject. + fi := newFakeIssuer(t) + // Replace the JWKS-declared alg for key-1 with RS512 while still + // signing the token with RS256 (the key's actual method). + fi.keys["key-1"].algInJWKS = "RS512" + + p := newProvider(t, fi, nil) + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (alg mismatch)", err) + } +} + +func TestVerify_RejectsTokenWithoutKid(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Sign manually, without kid in header. + claims := fi.DefaultClaims(testAudience) + priv := fi.keys["key-1"].priv + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signed, err := tok.SignedString(priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + _, err = p.Verify(context.Background(), signed, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (missing kid)", err) + } +} + +// --- iss / aud / azp / exp / nbf --- + +func TestVerify_RejectsWrongIssuer(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["iss"] = "https://different-issuer.example.com" + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (iss mismatch)", err) + } +} + +func TestVerify_RejectsWrongAudience(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims("api://different-service") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (aud mismatch)", err) + } +} + +func TestVerify_AcceptsAudienceInArray(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["aud"] = []string{"api://other", testAudience, "api://third"} + tok := fi.SignWith("key-1", claims) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("multi-audience token rejected: %v", err) + } +} + +func TestVerify_AcceptsAZPAsFallback(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClientID = "client-abc" + }) + + claims := fi.DefaultClaims("api://different") // aud doesn't match + claims["azp"] = "client-abc" // but azp does + tok := fi.SignWith("key-1", claims) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("azp-fallback token rejected: %v", err) + } +} + +func TestVerify_RejectsExpiredToken(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClockSkew = 1 * time.Second + }) + + claims := fi.DefaultClaims(testAudience) + claims["exp"] = time.Now().Add(-10 * time.Minute).Unix() + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (expired)", err) + } +} + +func TestVerify_ClockSkewToleratesNearExpiration(t *testing.T) { + fi := newFakeIssuer(t) + // Configure 60s leeway; token expired 30s ago — should be accepted. + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClockSkew = 60 * time.Second + }) + + claims := fi.DefaultClaims(testAudience) + claims["exp"] = time.Now().Add(-30 * time.Second).Unix() + tok := fi.SignWith("key-1", claims) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("token within ClockSkew rejected: %v", err) + } +} + +func TestVerify_RejectsNotYetValid(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClockSkew = 1 * time.Second + }) + + claims := fi.DefaultClaims(testAudience) + claims["nbf"] = time.Now().Add(10 * time.Minute).Unix() + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Fatalf("err = %v, want ErrTokenRejected (nbf in future)", err) + } +} + +// --- JWKS cache / rotation --- + +func TestJWKS_LazyLoad(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Provider construction must not fetch JWKS (lazy). + if fi.jwksFetches != 0 { + t.Errorf("JWKS fetched %d times during New (must be 0)", fi.jwksFetches) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("first Verify: %v", err) + } + if fi.jwksFetches != 1 { + t.Errorf("JWKS fetched %d times after first Verify, want 1", fi.jwksFetches) + } + + // Second Verify must reuse the cache. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("second Verify: %v", err) + } + if fi.jwksFetches != 1 { + t.Errorf("JWKS fetched %d times after second Verify, want 1 (cache miss)", fi.jwksFetches) + } +} + +func TestJWKS_RotationTriggersRefresh(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Initial token with key-1. + tok1 := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok1, nil); err != nil { + t.Fatalf("initial Verify: %v", err) + } + + // IdP rotates: adds key-2, then a token signed with key-2 arrives. + fi.addRSAKey("key-2") + tok2 := fi.SignWith("key-2", fi.DefaultClaims(testAudience)) + + if _, err := p.Verify(context.Background(), tok2, nil); err != nil { + t.Fatalf("post-rotation Verify: %v", err) + } + // Provider should have refreshed JWKS once on the unknown kid. + if fi.jwksFetches != 2 { + t.Errorf("JWKS fetched %d times after rotation, want 2 (lazy + refresh)", fi.jwksFetches) + } +} + +func TestJWKS_UnknownKidAfterRefresh(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Sign a token whose header advertises a kid that doesn't exist on + // the issuer. We do this by signing with key-1 then editing the + // header — easier: build the signed-string manually. + claims := fi.DefaultClaims(testAudience) + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = "ghost-kid" + signed, err := tok.SignedString(fi.keys["key-1"].priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + + _, err = p.Verify(context.Background(), signed, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("err = %v, want ErrInvalidToken (kid not in JWKS)", err) + } +} + +func TestJWKS_RefetchGracePreventsStampede(t *testing.T) { + // Two requests with the same unknown kid in quick succession should + // only trigger ONE JWKS fetch — the grace window should suppress + // the second. + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Warm cache with a valid request first. + if _, err := p.Verify(context.Background(), fi.SignWith("key-1", fi.DefaultClaims(testAudience)), nil); err != nil { + t.Fatalf("warmup: %v", err) + } + startFetches := fi.jwksFetches + + // Two back-to-back unknown-kid requests. + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, fi.DefaultClaims(testAudience)) + tok.Header["kid"] = "ghost" + signed, _ := tok.SignedString(fi.keys["key-1"].priv) + + _, _ = p.Verify(context.Background(), signed, nil) + _, _ = p.Verify(context.Background(), signed, nil) + + addedFetches := fi.jwksFetches - startFetches + if addedFetches > 1 { + t.Errorf("JWKS fetched %d times for two unknown-kid requests, want ≤1 (stampede prevention)", addedFetches) + } +} + +// --- ClaimMap --- + +func TestClaimMap_Defaults(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["email"] = "x@y" + claims["org_id"] = "tenant-1" + claims["workspace_id"] = "ws-2" + claims["groups"] = []string{"a"} + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.UserID != "user-123" || id.Email != "x@y" || id.OrgID != "tenant-1" || id.WorkspaceID != "ws-2" { + t.Errorf("default ClaimMap mapping wrong: %+v", id) + } +} + +func TestClaimMap_CustomNames(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.ClaimMap = oidc.ClaimMap{ + UserID: "uid", + Email: "mail", + OrgID: "tenant", + Groups: "roles", + } + }) + + claims := fi.DefaultClaims(testAudience) + claims["uid"] = "custom-user" + claims["mail"] = "x@y" + claims["tenant"] = "tenant-99" + claims["roles"] = []string{"admin"} + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.UserID != "custom-user" || id.Email != "x@y" || id.OrgID != "tenant-99" || len(id.Groups) != 1 || id.Groups[0] != "admin" { + t.Errorf("custom ClaimMap mapping wrong: %+v", id) + } +} + +func TestClaimMap_HeaderOverridesOrgID(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["org_id"] = "claim-org" + + id, err := p.Verify( + context.Background(), + fi.SignWith("key-1", claims), + auth.Headers{"X-Org-ID": "header-org"}, + ) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.OrgID != "header-org" { + t.Errorf("OrgID = %q, want header-org (header should override claim)", id.OrgID) + } +} + +func TestClaimMap_GroupsToleratesSingleString(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["groups"] = "single-group" + + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if len(id.Groups) != 1 || id.Groups[0] != "single-group" { + t.Errorf("Groups = %v, want [single-group]", id.Groups) + } +} + +// --- Non-JWT tokens yield --- + +func TestVerify_NonJWTYields(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + // Opaque-token shapes another provider might handle. + cases := []string{"", "opaque-token", "abc.def", "abc.def.ghi.jkl", "no-dots"} + for _, tok := range cases { + t.Run(tok, func(t *testing.T) { + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("token %q: err = %v, want ErrTokenNotForMe", tok, err) + } + }) + } +} + +// --- Discovery --- + +func TestDiscovery_IssuerMismatchFails(t *testing.T) { + // Build an issuer whose discovery doc returns a DIFFERENT issuer URL + // than the one we hit. Catches misconfiguration where someone copies + // an Okta discovery URL but configures the wrong issuer string. + mismatchSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{"issuer":"https://different.example.com","jwks_uri":"https://different.example.com/jwks"}`)) + })) + defer mismatchSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: mismatchSrv.URL, + Audience: testAudience, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Use a syntactically-valid JWT — header includes kid + alg so we + // reach the JWKS-fetch step where discovery is consulted. + // Header: {"alg":"RS256","kid":"k1"} Payload: {"iss":"x"} + dummyJWT := "eyJhbGciOiJSUzI1NiIsImtpZCI6ImsxIn0.eyJpc3MiOiJ4In0.AAAA" + _, err = p.Verify(context.Background(), dummyJWT, nil) + if err == nil { + t.Fatal("expected discovery mismatch error, got nil") + } + if !strings.Contains(err.Error(), "discovery issuer") { + t.Errorf("err = %v, want discovery-issuer-mismatch message", err) + } +} + +func TestDiscovery_JWKSURLOverrideSkipsDiscovery(t *testing.T) { + fi := newFakeIssuer(t) + jwksFetchesBefore := fi.jwksFetches + + // Configure JWKSURL explicitly — discovery should never be called. + // Implicitly verified by the fact that the fakeIssuer's discovery + // handler increments jwksFetches only on /jwks (not /.well-known/), + // so we check that signing still works. + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: fi.IssuerURL() + "/jwks", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("Verify: %v", err) + } + if fi.jwksFetches != jwksFetchesBefore+1 { + t.Errorf("JWKS fetches = %d, want %d", fi.jwksFetches, jwksFetchesBefore+1) + } +} + +// --- Concurrency / race safety --- + +func TestVerify_Concurrent(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + const N = 50 + errs := make(chan error, N) + for range N { + go func() { + _, err := p.Verify(context.Background(), tok, nil) + errs <- err + }() + } + for range N { + if err := <-errs; err != nil { + t.Fatalf("concurrent Verify: %v", err) + } + } +} + +// --- HTTP client / context --- + +func TestVerify_RespectsContextCancellation(t *testing.T) { + // Slow JWKS endpoint: handler blocks until the test ends. + done := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-done: + } + })) + defer func() { close(done); srv.Close() }() + + p, err := oidc.New(oidc.Config{ + Issuer: srv.URL, + Audience: testAudience, + JWKSURL: srv.URL + "/jwks", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + dummyJWT := "eyJhbGciOiJSUzI1NiIsImtpZCI6Inh4eCJ9.eyJpc3MiOiJ4In0.signature" + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + start := time.Now() + _, _ = p.Verify(ctx, dummyJWT, nil) + elapsed := time.Since(start) + + if elapsed > 2*time.Second { + t.Errorf("Verify took %v after ctx cancellation; expected <2s", elapsed) + } +} + +// --- Factory / registration --- + +func TestRegisteredViaFactory(t *testing.T) { + p, err := auth.Build("oidc", map[string]any{ + "issuer": "https://example.com", + "audience": "api://forge", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "oidc" { + t.Errorf("Name = %q, want oidc", p.Name()) + } +} + +func TestFactory_MissingIssuerErrors(t *testing.T) { + _, err := auth.Build("oidc", map[string]any{ + "audience": "api://forge", + }) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } +} + +func TestFactory_UnknownKeysAreIgnored(t *testing.T) { + // Forward compat: future YAML keys must not break construction. + _, err := auth.Build("oidc", map[string]any{ + "issuer": "https://example.com", + "audience": "api://forge", + "future_setting": "ignored", + "another_unknown": 42, + }) + if err != nil { + t.Fatalf("Build with unknown keys: %v", err) + } +} + +// --- Smoke: identity Claims is a copy --- + +func TestIdentity_ClaimsIsCopiedFromToken(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["custom"] = "value" + id, err := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Claims["custom"] != "value" { + t.Errorf("Claims[custom] = %v, want value", id.Claims["custom"]) + } + // Mutate the returned map — provider state must not be affected. + id.Claims["custom"] = "MUTATED" + + id2, _ := p.Verify(context.Background(), fi.SignWith("key-1", claims), nil) + if id2.Claims["custom"] != "value" { + t.Errorf("Claims map shared across Verify calls — got %v after mutation", id2.Claims["custom"]) + } +} + +// Make atomic referenced so the import isn't dropped if we later remove +// the concurrent-fetches counter. +var _ atomic.Int32 + +// --- Error-path coverage --- + +func TestJWKS_MalformedRSAKey_IsSkipped(t *testing.T) { + // JWKS contains a valid key AND a malformed one. The malformed one + // is silently skipped — load must succeed and the valid one usable. + fi := newFakeIssuer(t) + jwksSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{ + "keys": [ + {"kty":"RSA","kid":"broken","alg":"RS256","n":"!!!not-base64!!!","e":"AQAB"}, + {"kty":"unknown","kid":"weird"}, + {"kty":"RSA","kid":"no-kid-here","alg":"RS256","n":"","e":""} + ] + }`)) + })) + defer jwksSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: jwksSrv.URL, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + _, err = p.Verify(context.Background(), tok, nil) + // Should fail because the broken JWKS has no usable keys. + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (no valid keys in JWKS)", err) + } +} + +func TestJWKS_EndpointReturns500(t *testing.T) { + fi := newFakeIssuer(t) + jwksSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer jwksSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: jwksSrv.URL, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + _, err = p.Verify(context.Background(), tok, nil) + if err == nil { + t.Fatal("expected error when JWKS endpoint returns 500") + } +} + +func TestJWKS_HMACAlgInJWKS_IsRejected(t *testing.T) { + // If the JWKS advertises an HMAC alg, we must skip that key entirely. + fi := newFakeIssuer(t) + jwksSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{ + "keys": [ + {"kty":"oct","kid":"hmac-key","alg":"HS256","k":"some-secret"} + ] + }`)) + })) + defer jwksSrv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: fi.IssuerURL(), + Audience: testAudience, + JWKSURL: jwksSrv.URL, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Build a token claiming HS256 with kid hmac-key. + tok := SignHMAC("hmac-key", fi.DefaultClaims(testAudience), []byte("some-secret")) + _, err = p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (HMAC in JWKS must be skipped)", err) + } +} + +func TestVerify_TokenWithoutAudClaim(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + delete(claims, "aud") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (missing aud)", err) + } +} + +// --- Issuer trailing-slash normalization (review finding #2) --- +// +// IdPs disagree on whether the canonical issuer has a trailing slash. +// Auth0 emits "iss": "https://tenant.auth0.com/"; most Okta deployments +// do not. Operators paste whatever they see into forge.yaml. These +// tests guarantee both directions interop. + +func TestIssuer_ConfigWithoutSlash_TokenWithSlash(t *testing.T) { + fi := newFakeIssuer(t) + // Config the user-facing way: no trailing slash. + p := newProvider(t, fi, func(c *oidc.Config) { + c.Issuer = fi.IssuerURL() // already no slash + }) + + // Mint a token whose iss has a trailing slash (mimics Auth0). + claims := fi.DefaultClaims(testAudience) + claims["iss"] = fi.IssuerURL() + "/" + tok := fi.SignWith("key-1", claims) + + id, err := p.Verify(context.Background(), tok, nil) + if err != nil { + t.Fatalf("token iss with slash, config without — should accept, got: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } +} + +func TestIssuer_ConfigWithSlash_TokenWithoutSlash(t *testing.T) { + fi := newFakeIssuer(t) + // Config the other way: user pastes with trailing slash. + p := newProvider(t, fi, func(c *oidc.Config) { + c.Issuer = fi.IssuerURL() + "/" + }) + + // Token has no slash (the more common case). + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("config with slash, token without — should accept, got: %v", err) + } +} + +func TestIssuer_BothWithoutSlash_StillWorks(t *testing.T) { + // Regression: ensure the trim didn't break the most common case. + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("config without slash + token without slash should accept: %v", err) + } +} + +func TestIssuer_BothWithSlash_StillWorks(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.Issuer = fi.IssuerURL() + "/" + }) + claims := fi.DefaultClaims(testAudience) + claims["iss"] = fi.IssuerURL() + "/" + tok := fi.SignWith("key-1", claims) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("config + token both with slash should accept: %v", err) + } +} + +func TestIssuer_GenuineMismatchStillRejected(t *testing.T) { + // Security guard: trimming trailing slashes must not collapse + // genuinely-different issuers into the same comparison value. + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + claims["iss"] = "https://attacker.example.com/" // completely different host + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("genuine issuer mismatch err = %v, want ErrTokenRejected", err) + } +} + +func TestIssuer_TokenMissingIssClaim(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + delete(claims, "iss") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("missing iss err = %v, want ErrTokenRejected", err) + } +} + +func TestDiscovery_TrailingSlashTolerated(t *testing.T) { + // Discovery side: IdP's doc has issuer with trailing slash; + // operator configured without. Must not bomb at startup-ish phase. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Deliberately use a different host than the test server's URL + // to simulate normalization-only mismatch — wait, we need it to + // be the same host, just different slash form. + })) + srv.Close() // we just needed the type + + // Build a real fake issuer and add a slash to the doc's response. + docHost := "" + dynamicSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + // Emit issuer WITH trailing slash even though docHost has none. + _, _ = w.Write([]byte(`{"issuer":"` + docHost + `/","jwks_uri":"` + docHost + `/jwks"}`)) + case "/jwks": + _, _ = w.Write([]byte(`{"keys":[]}`)) + default: + http.NotFound(w, r) + } + })) + defer dynamicSrv.Close() + docHost = dynamicSrv.URL + + p, err := oidc.New(oidc.Config{ + Issuer: docHost, // no trailing slash + Audience: testAudience, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Trigger the discovery path with a real JWT-shaped string. + // Discovery happens lazily on first JWKS lookup. We're testing that + // the comparison doesn't reject due to trailing-slash mismatch. + dummyJWT := "eyJhbGciOiJSUzI1NiIsImtpZCI6ImsxIn0.eyJpc3MiOiJ4In0.AAAA" + _, err = p.Verify(context.Background(), dummyJWT, nil) + // Expect an ErrInvalidToken (kid not found in empty JWKS) — NOT a + // "discovery issuer mismatch" error. The previous bug surfaced as + // the latter. + if err != nil && strings.Contains(err.Error(), "discovery issuer") { + t.Errorf("discovery rejected due to trailing-slash drift: %v", err) + } +} + +// --- TTL-driven refresh (review finding #1) --- +// +// These tests guarantee that JWKS keys revoked at the IdP eventually stop +// being accepted by Forge. Before the fix, the cache's ttl field was dead +// code and revoked keys remained trusted until process restart. + +func TestJWKS_TTLExpiryForcesRefresh(t *testing.T) { + // Build the provider with a 1-minute TTL. We can't reach the cache's + // injectable clock from outside the package directly, so we exercise + // TTL behavior by setting a tiny TTL and advancing wall time via the + // `now` field on the provider's JWKS cache through the exported + // SetCacheClockForTest test hook below. + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.JWKSCacheTTL = 10 * time.Minute // any value above the 5-min clamp + }) + + // Drive a virtual clock so we can move past TTL without sleeping. + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("first Verify: %v", err) + } + fetchesAfterFirst := fi.jwksFetches + + // Same kid, still within TTL → no new JWKS fetch. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("second Verify within TTL: %v", err) + } + if fi.jwksFetches != fetchesAfterFirst { + t.Errorf("JWKS refetched within TTL (%d → %d) — should have served from cache", + fetchesAfterFirst, fi.jwksFetches) + } + + // Advance virtual time past TTL. + clk.advance(11 * time.Minute) + + // Same kid, but TTL expired → MUST refetch. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("third Verify after TTL: %v", err) + } + if fi.jwksFetches != fetchesAfterFirst+1 { + t.Errorf("JWKS not refetched after TTL expiry (fetches: %d, expected %d)", + fi.jwksFetches, fetchesAfterFirst+1) + } +} + +func TestJWKS_RevokedKeyEventuallyRejected(t *testing.T) { + // The security-critical case: IdP removes a key from JWKS (revocation). + // Within TTL, Forge still trusts it (acceptable — that's the TTL window). + // After TTL, Forge picks up the new JWKS without the revoked key and + // must reject tokens signed by it. + fi := newFakeIssuer(t) + p := newProvider(t, fi, func(c *oidc.Config) { + c.JWKSCacheTTL = 10 * time.Minute + }) + + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + // Add a second key the issuer rotates between. + fi.addRSAKey("key-revoked") + + // Token is signed by key-revoked. + tok := fi.SignWith("key-revoked", fi.DefaultClaims(testAudience)) + + // Initial verify — cache warms with both keys present. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("pre-revocation verify: %v", err) + } + + // IdP revokes the key (removes it from JWKS). + delete(fi.keys, "key-revoked") + for i, k := range fi.keyOrder { + if k == "key-revoked" { + fi.keyOrder = append(fi.keyOrder[:i], fi.keyOrder[i+1:]...) + break + } + } + + // Within TTL: Forge still trusts the old key (acceptable revocation lag). + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("within-TTL verify failed unexpectedly: %v", err) + } + + // Advance past TTL → refresh picks up the new JWKS without key-revoked. + clk.advance(11 * time.Minute) + + // After refresh, the token's kid is no longer in JWKS → ErrInvalidToken. + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Fatalf("post-TTL verify err = %v, want ErrInvalidToken (kid removed by IdP)", err) + } +} + +func TestJWKS_FailedRefreshDoesNotHammerIdP(t *testing.T) { + // Review #5 explicit guard: when refreshLocked fails, every + // subsequent unknown-kid (or any miss) request inside the same + // refetchGrace window must NOT trigger another IdP fetch. A + // flapping IdP otherwise turns into a DDoS amplifier — N concurrent + // callers × broken IdP = N IdP hits per request burst. + // + // The mechanism is in jwks.go's backoffActive(): on any refresh + // error, lastErr+lastAttempt are set; subsequent Get calls within + // refetchGrace short-circuit with ErrKeyNotFound, without re-calling + // refreshLocked. This test pins that behavior. + fi := newFakeIssuer(t) + var jwksHits atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/jwks" { + jwksHits.Add(1) + w.WriteHeader(http.StatusInternalServerError) + return + } + fi.serve(w, r) + })) + defer srv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: srv.URL, + Audience: testAudience, + JWKSURL: srv.URL + "/jwks", + JWKSCacheTTL: 10 * time.Minute, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + // Build a syntactically-valid JWT pointing at a kid the broken + // JWKS would never produce — guarantees we hit the refresh path + // (cache miss + TTL expired since lastSuccessful is zero). + tok := buildJWTWithKid(t, fi.keys["key-1"].priv, "unknown-kid") + + // First call attempts to refresh, fails (500), records the error + // and bumps lastAttempt → from now backoff is active. + _, _ = p.Verify(context.Background(), tok, nil) + firstWindowHits := jwksHits.Load() + if firstWindowHits != 1 { + t.Fatalf("first Verify did not attempt JWKS (hits=%d)", firstWindowHits) + } + + // Fire a burst of requests inside the backoff window. NONE of these + // should reach the JWKS endpoint — backoffActive() short-circuits. + for range 50 { + _, _ = p.Verify(context.Background(), tok, nil) + } + if got := jwksHits.Load(); got != firstWindowHits { + t.Errorf("DDoS amplifier: 50 unknown-kid requests during backoff produced %d IdP hits (want %d)", + got, firstWindowHits) + } + + // Advance past refetchGrace (default 30s). The next request should + // be allowed to re-attempt the IdP exactly once. + clk.advance(35 * time.Second) + _, _ = p.Verify(context.Background(), tok, nil) + if got := jwksHits.Load(); got != firstWindowHits+1 { + t.Errorf("post-grace retry: hits = %d, want %d (one retry per grace window)", + got, firstWindowHits+1) + } + + // And another burst still inside the new grace window — same lock-out. + for range 50 { + _, _ = p.Verify(context.Background(), tok, nil) + } + if got := jwksHits.Load(); got != firstWindowHits+1 { + t.Errorf("second backoff window leaked: hits = %d, want %d", + got, firstWindowHits+1) + } +} + +// buildJWTWithKid signs a token with the given private key but stamps an +// arbitrary kid in the header. Used by the JWKS hammer test to force the +// provider into the unknown-kid + refresh path. +func buildJWTWithKid(t *testing.T, priv any, kid string) string { + t.Helper() + claims := jwt.MapClaims{ + "iss": "x", + "aud": testAudience, + "sub": "x", + "exp": time.Now().Add(15 * time.Minute).Unix(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + s, err := tok.SignedString(priv) + if err != nil { + t.Fatalf("sign: %v", err) + } + return s +} + +func TestJWKS_FailedRefreshDoesNotExtendTTL(t *testing.T) { + // Bug guard: a failed refresh must not bump lastSuccessful. Otherwise + // an IdP outage during a TTL-refresh window would silently extend + // trust in already-cached keys. + // + // Strategy: warm the cache, then make the JWKS endpoint return 500, + // advance past TTL, attempt a Verify. The refresh fails, but the + // existing keys remain available — and the next attempt after the + // failure backoff should retry (proving lastSuccessful did NOT advance). + fi := newFakeIssuer(t) + failJWKS := atomic.Bool{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if failJWKS.Load() && r.URL.Path == "/jwks" { + w.WriteHeader(http.StatusInternalServerError) + return + } + fi.serve(w, r) + })) + defer srv.Close() + + p, err := oidc.New(oidc.Config{ + Issuer: srv.URL, + Audience: testAudience, + JWKSURL: srv.URL + "/jwks", + JWKSCacheTTL: 10 * time.Minute, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Patch the fakeIssuer's IssuerURL so claim's iss matches. + fi.server = srv + + clk := &virtualClock{now: time.Now()} + oidc.SetCacheClockForTest(p, clk.Now) + + tok := fi.SignWith("key-1", fi.DefaultClaims(testAudience)) + + // Warm the cache. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatalf("warmup: %v", err) + } + + // Endpoint goes down; advance past TTL. + failJWKS.Store(true) + clk.advance(11 * time.Minute) + + // Refresh attempt fails. The TTL clock was NOT extended by the + // failure, so subsequent calls outside the backoff window will + // re-attempt. + _, err = p.Verify(context.Background(), tok, nil) + if err == nil { + t.Errorf("expected error during JWKS outage, got nil") + } + + // Endpoint comes back, but we're still in error-backoff window. + // Advance past refetchGrace (30s) so a retry can happen. + failJWKS.Store(false) + clk.advance(35 * time.Second) + + // Next verify should retry the refresh and succeed. + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Errorf("post-recovery verify failed: %v", err) + } +} + +// virtualClock is a controllable time source for tests. +type virtualClock struct { + now time.Time +} + +func (c *virtualClock) Now() time.Time { return c.now } +func (c *virtualClock) advance(d time.Duration) { + c.now = c.now.Add(d) +} + +func TestVerify_TokenWithoutExp(t *testing.T) { + fi := newFakeIssuer(t) + p := newProvider(t, fi, nil) + + claims := fi.DefaultClaims(testAudience) + delete(claims, "exp") + tok := fi.SignWith("key-1", claims) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (missing exp)", err) + } +} diff --git a/forge-core/auth/providers/oidc/testing_helper_test.go b/forge-core/auth/providers/oidc/testing_helper_test.go new file mode 100644 index 0000000..3185912 --- /dev/null +++ b/forge-core/auth/providers/oidc/testing_helper_test.go @@ -0,0 +1,220 @@ +package oidc_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// fakeIssuer is a minimal in-memory OIDC issuer: serves the discovery +// document and JWKS, and helps tests sign tokens with the matching key. +type fakeIssuer struct { + t *testing.T + server *httptest.Server + // keys keyed by kid → signer + keys map[string]*signingKey + // keyOrder remembers insertion order so JWKS responses are stable. + keyOrder []string + // jwksFetches counts how many times /jwks was hit (race-safe via the + // server's serialized handler). + jwksFetches int +} + +type signingKey struct { + kid string + method jwt.SigningMethod + priv any // *rsa.PrivateKey or *ecdsa.PrivateKey + pub any // *rsa.PublicKey or *ecdsa.PublicKey + // alg is what the JWKS will declare; mismatch with `method` is what + // algorithm-confusion tests need. + algInJWKS string +} + +// newFakeIssuer builds a default RSA-keyed issuer with kid "key-1". +func newFakeIssuer(t *testing.T) *fakeIssuer { + t.Helper() + fi := &fakeIssuer{ + t: t, + keys: map[string]*signingKey{}, + } + fi.addRSAKey("key-1") + fi.server = httptest.NewServer(http.HandlerFunc(fi.serve)) + t.Cleanup(fi.server.Close) + return fi +} + +// IssuerURL is the canonical issuer URL (no trailing slash). +func (fi *fakeIssuer) IssuerURL() string { return fi.server.URL } + +// addRSAKey generates a fresh RSA key and registers it with the given kid. +func (fi *fakeIssuer) addRSAKey(kid string) *signingKey { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + fi.t.Fatalf("generate RSA key: %v", err) + } + k := &signingKey{ + kid: kid, + method: jwt.SigningMethodRS256, + priv: priv, + pub: &priv.PublicKey, + algInJWKS: "RS256", + } + fi.keys[kid] = k + fi.keyOrder = append(fi.keyOrder, kid) + return k +} + +// addECDSAKey generates a fresh P-256 ECDSA key. +func (fi *fakeIssuer) addECDSAKey(kid string) *signingKey { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + fi.t.Fatalf("generate ECDSA key: %v", err) + } + k := &signingKey{ + kid: kid, + method: jwt.SigningMethodES256, + priv: priv, + pub: &priv.PublicKey, + algInJWKS: "ES256", + } + fi.keys[kid] = k + fi.keyOrder = append(fi.keyOrder, kid) + return k +} + +// SignWith builds a signed token using the named key. Claims is the full +// claim map; helpers below build common claim sets. +func (fi *fakeIssuer) SignWith(kid string, claims jwt.MapClaims) string { + fi.t.Helper() + k, ok := fi.keys[kid] + if !ok { + fi.t.Fatalf("unknown signing key %q", kid) + } + tok := jwt.NewWithClaims(k.method, claims) + tok.Header["kid"] = kid + s, err := tok.SignedString(k.priv) + if err != nil { + fi.t.Fatalf("sign: %v", err) + } + return s +} + +// SignUnsigned returns a token with `alg: none` and no signature segment. +// jwt/v5 doesn't expose a clean way to construct these; we do it manually. +func SignUnsigned(claims jwt.MapClaims) string { + header := map[string]any{"alg": "none", "typ": "JWT", "kid": "none-kid"} + hb, _ := json.Marshal(header) + cb, _ := json.Marshal(claims) + enc := base64.RawURLEncoding.EncodeToString + return enc(hb) + "." + enc(cb) + "." +} + +// SignHMAC returns a token signed with HS256 — must be rejected by the +// provider since HMAC + JWKS is a footgun (anyone with the JWKS could +// produce one). +func SignHMAC(kid string, claims jwt.MapClaims, secret []byte) string { + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tok.Header["kid"] = kid + s, _ := tok.SignedString(secret) + return s +} + +// DefaultClaims returns a claim set that satisfies the default fakeIssuer +// + a single audience. +func (fi *fakeIssuer) DefaultClaims(audience string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": fi.IssuerURL(), + "aud": audience, + "sub": "user-123", + "iat": now.Unix(), + "exp": now.Add(15 * time.Minute).Unix(), + } +} + +// serve handles both the discovery endpoint and the JWKS endpoint. +func (fi *fakeIssuer) serve(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": fi.IssuerURL(), + "jwks_uri": fi.IssuerURL() + "/jwks", + }) + case "/jwks": + fi.jwksFetches++ + _ = json.NewEncoder(w).Encode(fi.jwksDoc()) + default: + http.NotFound(w, r) + } +} + +// jwksDoc constructs the JWKS document from the current key set. +func (fi *fakeIssuer) jwksDoc() map[string]any { + keys := make([]map[string]any, 0, len(fi.keyOrder)) + for _, kid := range fi.keyOrder { + k := fi.keys[kid] + switch pub := k.pub.(type) { + case *rsa.PublicKey: + keys = append(keys, map[string]any{ + "kty": "RSA", + "kid": kid, + "use": "sig", + "alg": k.algInJWKS, + "n": base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(intToBigEndian(pub.E)), + }) + case *ecdsa.PublicKey: + keys = append(keys, map[string]any{ + "kty": "EC", + "kid": kid, + "use": "sig", + "alg": k.algInJWKS, + "crv": ecCurveName(pub.Curve), + "x": base64.RawURLEncoding.EncodeToString(pub.X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(pub.Y.Bytes()), + }) + default: + panic(fmt.Sprintf("unsupported key type: %T", pub)) + } + } + return map[string]any{"keys": keys} +} + +func intToBigEndian(e int) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(e)) + // Strip leading zero bytes — RFC 7518 demands no leading zeros. + for len(buf) > 1 && buf[0] == 0 { + buf = buf[1:] + } + return buf +} + +func ecCurveName(c elliptic.Curve) string { + switch c { + case elliptic.P256(): + return "P-256" + case elliptic.P384(): + return "P-384" + case elliptic.P521(): + return "P-521" + default: + return "" + } +} + +// Compile-time assertion: tests reference math/big to ensure the import +// graph is exercised in case future refactors lose the dep. +var _ = big.NewInt diff --git a/forge-core/auth/providers/statictoken/provider.go b/forge-core/auth/providers/statictoken/provider.go new file mode 100644 index 0000000..338063a --- /dev/null +++ b/forge-core/auth/providers/statictoken/provider.go @@ -0,0 +1,133 @@ +// Package statictoken implements a Provider that matches the presented +// bearer token against a single expected value using constant-time +// comparison. +// +// Two intended uses: +// +// 1. Channel adapter loopback. The runner generates a random per-process +// token, configures a statictoken Provider with it (placed at the head +// of the chain), and shares the same token with Slack/Telegram adapters +// so their callbacks into the local A2A server authenticate cheaply +// without touching an upstream IdP. +// +// 2. Local dev / CI. A fixed token configured via env var lets developers +// hit a running agent with `curl -H "Authorization: Bearer $FORGE_DEV_TOKEN"` +// without setting up an IdP. +// +// Mismatch returns ErrTokenNotForMe (yield to next provider), not +// ErrTokenRejected — the loopback token is "not for me" from the +// perspective of an external client, and chain semantics require yielding +// in that case. +package statictoken + +import ( + "context" + "crypto/sha256" + "crypto/subtle" + "fmt" + "maps" + "os" + "slices" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this provider. +const ProviderName = "static_token" + +// Config controls the static_token provider. +type Config struct { + // Token is the expected bearer value (literal). Prefer TokenEnv for + // non-test use — putting a secret in YAML is a footgun. + Token string `yaml:"token,omitempty"` + + // TokenEnv names an environment variable that holds the token at + // construction time. Read once in New(); subsequent env changes do + // not affect the running provider. + TokenEnv string `yaml:"token_env,omitempty"` + + // Identity is returned on a successful match (a defensive copy is + // returned to callers). If Source is empty it defaults to "static_token". + Identity auth.Identity `yaml:"identity,omitempty"` +} + +// Validate returns ErrProviderNotConfigured when neither Token nor TokenEnv +// resolves to a non-empty value. +func (c Config) Validate() error { + if c.Token == "" && c.TokenEnv == "" { + return fmt.Errorf("%w: token or token_env required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider with a constant-time token compare. +type Provider struct { + expected []byte + identity auth.Identity +} + +// New constructs a Provider after resolving the token (TokenEnv takes +// precedence over Token literal). Returns ErrProviderNotConfigured if the +// resolved token is empty. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + token := cfg.Token + if cfg.TokenEnv != "" { + if v := os.Getenv(cfg.TokenEnv); v != "" { + token = v + } + } + if token == "" { + return nil, fmt.Errorf("%w: resolved token is empty", auth.ErrProviderNotConfigured) + } + id := cfg.Identity + if id.Source == "" { + id.Source = ProviderName + } + return &Provider{ + expected: []byte(token), + identity: id, + }, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. Constant-time compare against the +// configured token. Mismatch yields to the next provider via ErrTokenNotForMe. +// +// Length-leak guard (review #11a): subtle.ConstantTimeCompare returns 0 +// immediately when the two slices have different lengths, without +// inspecting any bytes — that early return leaks the expected token's +// length via measurable timing differences over many trials. We hash +// both sides to SHA-256 first so the comparison is always over +// equal-length (32-byte) digests. Hash → constant-time compare is the +// standard mitigation for this class. +func (p *Provider) Verify(_ context.Context, token string, _ auth.Headers) (*auth.Identity, error) { + presented := sha256.Sum256([]byte(token)) + expected := sha256.Sum256(p.expected) + if subtle.ConstantTimeCompare(presented[:], expected[:]) != 1 { + return nil, auth.ErrTokenNotForMe + } + // Return a defensive copy so callers can't mutate the configured identity. + id := p.identity + if id.Groups != nil { + id.Groups = slices.Clone(id.Groups) + } + if id.Claims != nil { + id.Claims = maps.Clone(id.Claims) + } + return &id, nil +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/statictoken/provider_test.go b/forge-core/auth/providers/statictoken/provider_test.go new file mode 100644 index 0000000..4f16401 --- /dev/null +++ b/forge-core/auth/providers/statictoken/provider_test.go @@ -0,0 +1,213 @@ +package statictoken_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/statictoken" +) + +func TestNew_ValidationErrors(t *testing.T) { + t.Run("empty config", func(t *testing.T) { + _, err := statictoken.New(statictoken.Config{}) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } + }) + + t.Run("token_env points at unset variable", func(t *testing.T) { + t.Setenv("FORGE_TEST_TOKEN_DOES_NOT_EXIST", "") + _, err := statictoken.New(statictoken.Config{TokenEnv: "FORGE_TEST_TOKEN_DOES_NOT_EXIST"}) + if !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Fatalf("err = %v, want ErrProviderNotConfigured", err) + } + }) +} + +func TestVerify_MatchReturnsIdentity(t *testing.T) { + p, err := statictoken.New(statictoken.Config{ + Token: "secret-token", + Identity: auth.Identity{ + UserID: "internal", + Email: "system@forge", + Source: "internal", + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + id, err := p.Verify(context.Background(), "secret-token", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id == nil { + t.Fatal("Verify returned nil identity") + } + if id.UserID != "internal" || id.Email != "system@forge" { + t.Errorf("identity = %+v", id) + } + if id.Source != "internal" { + t.Errorf("identity.Source = %q, want %q", id.Source, "internal") + } +} + +func TestVerify_MatchDefaultSource(t *testing.T) { + p, _ := statictoken.New(statictoken.Config{Token: "secret-token"}) + id, err := p.Verify(context.Background(), "secret-token", nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "static_token" { + t.Errorf("identity.Source = %q, want %q (default)", id.Source, "static_token") + } +} + +func TestVerify_MismatchYieldsToNext(t *testing.T) { + p, _ := statictoken.New(statictoken.Config{Token: "expected"}) + + _, err := p.Verify(context.Background(), "different", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Fatalf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestVerify_EmptyTokenYields(t *testing.T) { + p, _ := statictoken.New(statictoken.Config{Token: "expected"}) + + _, err := p.Verify(context.Background(), "", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Fatalf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestTokenEnv_PrecedenceOverLiteral(t *testing.T) { + t.Setenv("FORGE_TEST_OVERRIDE_TOKEN", "from-env") + + p, err := statictoken.New(statictoken.Config{ + Token: "from-literal", + TokenEnv: "FORGE_TEST_OVERRIDE_TOKEN", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Env value should win. + if _, err := p.Verify(context.Background(), "from-env", nil); err != nil { + t.Errorf("env token rejected: %v", err) + } + if _, err := p.Verify(context.Background(), "from-literal", nil); !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("literal token accepted (should be overridden by env): %v", err) + } +} + +func TestTokenEnv_FallsBackToLiteralWhenEnvEmpty(t *testing.T) { + t.Setenv("FORGE_TEST_UNSET_TOKEN", "") + + p, err := statictoken.New(statictoken.Config{ + Token: "from-literal", + TokenEnv: "FORGE_TEST_UNSET_TOKEN", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + if _, err := p.Verify(context.Background(), "from-literal", nil); err != nil { + t.Errorf("literal token rejected when env was empty: %v", err) + } +} + +func TestIdentityIsDefensivelyCopied(t *testing.T) { + cfg := statictoken.Config{ + Token: "tok", + Identity: auth.Identity{ + UserID: "u", + Groups: []string{"a", "b"}, + Claims: map[string]any{"role": "admin"}, + }, + } + p, _ := statictoken.New(cfg) + + id1, _ := p.Verify(context.Background(), "tok", nil) + // Mutate the returned identity. + id1.Groups[0] = "MUTATED" + id1.Claims["role"] = "MUTATED" + + id2, _ := p.Verify(context.Background(), "tok", nil) + if id2.Groups[0] != "a" { + t.Errorf("Groups[0] = %q after mutation, want %q (defensive copy failed)", id2.Groups[0], "a") + } + if id2.Claims["role"] != "admin" { + t.Errorf("Claims[role] = %v after mutation, want admin (defensive copy failed)", id2.Claims["role"]) + } +} + +func TestVerify_MismatchedLengths_AreRejected(t *testing.T) { + // Review #11a: hash-then-compare path must still reject tokens of + // any length that aren't the configured one. Specifically, a token + // that's just a prefix or a longer-tail variant of the real one + // must NOT compare equal — verifying the hash digest comparison + // works the same way the byte-level compare did for length-equal + // inputs. + p, _ := statictoken.New(statictoken.Config{Token: "abcdef0123456789"}) + + for _, candidate := range []string{ + "", // empty + "abcdef", // prefix + "abcdef0123456789-extra-bytes", // longer + "different-but-same-length-as-real-secret", // same-length mismatch + "abcdef0123456788", // off-by-one in last char + } { + if _, err := p.Verify(context.Background(), candidate, nil); !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("candidate %q: err = %v, want ErrTokenNotForMe", candidate, err) + } + } + + // Sanity: the real token still works. + if _, err := p.Verify(context.Background(), "abcdef0123456789", nil); err != nil { + t.Errorf("real token rejected: %v", err) + } +} + +func TestConcurrentVerify_RaceSafe(t *testing.T) { + // Race-safety smoke. Run with `go test -race`. + p, _ := statictoken.New(statictoken.Config{Token: "tok"}) + + var wg sync.WaitGroup + for range 100 { + wg.Go(func() { + _, _ = p.Verify(context.Background(), "tok", nil) + _, _ = p.Verify(context.Background(), "wrong", nil) + }) + } + wg.Wait() +} + +func TestRegisteredViaFactory(t *testing.T) { + p, err := auth.Build("static_token", map[string]any{ + "token": "abc", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "static_token" { + t.Errorf("Name = %q, want static_token", p.Name()) + } +} + +func TestFactory_UsesTokenEnv(t *testing.T) { + t.Setenv("FORGE_TEST_FACTORY_TOKEN", "via-env") + + p, err := auth.Build("static_token", map[string]any{ + "token_env": "FORGE_TEST_FACTORY_TOKEN", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if _, err := p.Verify(context.Background(), "via-env", nil); err != nil { + t.Errorf("env-resolved token rejected: %v", err) + } +} diff --git a/forge-core/auth/registry.go b/forge-core/auth/registry.go new file mode 100644 index 0000000..4e6ff9d --- /dev/null +++ b/forge-core/auth/registry.go @@ -0,0 +1,81 @@ +package auth + +import ( + "fmt" + "sort" + "sync" +) + +// Factory constructs a Provider from a freeform settings map (typically +// the `settings` block of an `auth.providers[]` entry in forge.yaml). +// +// A Factory should: +// - Unmarshal settings into its typed Config via UnmarshalSettings. +// - Call Config.Validate() and return clear errors. +// - Return ErrProviderNotConfigured for missing required fields. +// +// A Factory must not perform network I/O — Verify() does that lazily. +type Factory func(settings map[string]any) (Provider, error) + +var ( + registryMu sync.RWMutex + registry = map[string]Factory{} +) + +// Register adds a provider factory under the given type name. Intended to be +// called from package init() in each provider subpackage. Panics on duplicate +// registration — a duplicate is a programming error that must fail loud at +// startup, not at the first request. +func Register(typeName string, factory Factory) { + if typeName == "" { + panic("auth: Register called with empty typeName") + } + if factory == nil { + panic("auth: Register called with nil factory: " + typeName) + } + registryMu.Lock() + defer registryMu.Unlock() + if _, exists := registry[typeName]; exists { + panic("auth: provider already registered: " + typeName) + } + registry[typeName] = factory +} + +// Build constructs a Provider for the given type name using the registered +// factory. Returns an error if the type is not registered. +func Build(typeName string, settings map[string]any) (Provider, error) { + registryMu.RLock() + factory, ok := registry[typeName] + registryMu.RUnlock() + if !ok { + return nil, fmt.Errorf("auth: unknown provider type %q (registered: %v)", typeName, RegisteredTypes()) + } + p, err := factory(settings) + if err != nil { + return nil, fmt.Errorf("auth: build %q: %w", typeName, err) + } + return p, nil +} + +// RegisteredTypes returns a sorted, deduplicated slice of registered provider +// type names. Used by config validation and the wizard meta endpoint to expose +// the set of available provider types. +func RegisteredTypes() []string { + registryMu.RLock() + defer registryMu.RUnlock() + out := make([]string, 0, len(registry)) + for name := range registry { + out = append(out, name) + } + sort.Strings(out) + return out +} + +// resetRegistryForTest clears all registered providers. Test-only helper — +// not exported in non-test builds. (Kept package-private here; tests in the +// same package can call it directly.) +func resetRegistryForTest() { + registryMu.Lock() + defer registryMu.Unlock() + registry = map[string]Factory{} +} diff --git a/forge-core/auth/settings.go b/forge-core/auth/settings.go new file mode 100644 index 0000000..2c61b8d --- /dev/null +++ b/forge-core/auth/settings.go @@ -0,0 +1,27 @@ +package auth + +import ( + "fmt" + + "gopkg.in/yaml.v3" +) + +// UnmarshalSettings decodes a freeform settings map into a typed Config +// struct, honoring `yaml:"..."` tags on the destination. Implemented as a +// yaml.Marshal + yaml.Unmarshal roundtrip so each provider can keep its +// Config strongly typed while the public schema stays map[string]any. +// +// Pass a pointer to your Config struct as `out`. +func UnmarshalSettings(in map[string]any, out any) error { + if out == nil { + return fmt.Errorf("auth: UnmarshalSettings: out is nil") + } + data, err := yaml.Marshal(in) + if err != nil { + return fmt.Errorf("auth: settings marshal: %w", err) + } + if err := yaml.Unmarshal(data, out); err != nil { + return fmt.Errorf("auth: settings unmarshal: %w", err) + } + return nil +} diff --git a/forge-core/go.mod b/forge-core/go.mod index 1b2f92b..bae6d93 100644 --- a/forge-core/go.mod +++ b/forge-core/go.mod @@ -9,6 +9,7 @@ require ( ) require ( + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect golang.org/x/crypto v0.48.0 // indirect diff --git a/forge-core/go.sum b/forge-core/go.sum index f2cd70a..7e485de 100644 --- a/forge-core/go.sum +++ b/forge-core/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/forge-core/runtime/audit.go b/forge-core/runtime/audit.go index f6af274..f8e26bb 100644 --- a/forge-core/runtime/audit.go +++ b/forge-core/runtime/audit.go @@ -23,8 +23,19 @@ const ( AuditScheduleComplete = "schedule_complete" AuditScheduleSkip = "schedule_skip" AuditScheduleModify = "schedule_modify" - AuditAuthSuccess = "auth_success" - AuditAuthFailure = "auth_failure" + + // Auth events. Carry no PII (no email, no claims, no token bytes) — + // only the audit subject (UserID), tenant (OrgID), and structural + // metadata. See the auth-package middleware for the emitter. + EventAuthVerify = "auth_verify" // every successful auth decision + EventAuthFail = "auth_fail" // every failed auth decision (with reason code) + + // Deprecated: use EventAuthVerify. Kept as a string alias so any + // audit-log consumer that grep'd for "auth_success" can be migrated. + // Scheduled for removal in v0.11.0. + AuditAuthSuccess = EventAuthVerify + // Deprecated: use EventAuthFail. Same migration window as AuditAuthSuccess. + AuditAuthFailure = EventAuthFail ) // AuditEvent is a single structured audit record emitted as NDJSON. diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go new file mode 100644 index 0000000..b96b1e0 --- /dev/null +++ b/forge-core/security/auth_domains.go @@ -0,0 +1,105 @@ +package security + +import ( + "net/url" + "sort" + + "github.com/initializ/forge/forge-core/types" +) + +// AuthDomains extracts the outbound hosts that auth providers must be able +// to reach: OIDC issuers, http_verifier URLs, future Okta tenants, etc. +// +// These domains are merged into the egress allowlist BEFORE the egress +// enforcer is constructed, so configuring an OIDC provider does not +// silently fail at runtime with a network-blocked JWKS fetch. +// +// Returned hosts are deduplicated and sorted for stable test output. +// Empty/malformed URLs are skipped (validation happens elsewhere). +func AuthDomains(cfg types.AuthConfig) []string { + if len(cfg.Providers) == 0 { + return nil + } + seen := map[string]struct{}{} + for _, p := range cfg.Providers { + for _, raw := range authProviderURLs(p) { + host := hostFromURL(raw) + if host == "" { + continue + } + seen[host] = struct{}{} + } + } + if len(seen) == 0 { + return nil + } + out := make([]string, 0, len(seen)) + for h := range seen { + out = append(out, h) + } + sort.Strings(out) + return out +} + +// authProviderURLs returns the URLs that a given provider needs to reach. +// Centralizes per-provider extraction so adding a new provider type in +// the registry only requires touching one map. +func authProviderURLs(p types.AuthProvider) []string { + switch p.Type { + case "oidc": + return []string{ + settingString(p.Settings, "issuer"), + settingString(p.Settings, "jwks_url"), + } + case "http_verifier": + return []string{ + settingString(p.Settings, "url"), + } + // static_token has no outbound; not listed + // okta (Phase 3) will be added here with issuer + api domain + default: + return nil + } +} + +func settingString(m map[string]any, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +// hostFromURL extracts the bare host (without port) from a URL string. +// Returns "" for unparseable values — the caller skips those silently +// since real validation happens in validate.ValidateAuthConfig. +// +// Port stripping is intentional and depends on a cross-package contract: +// the egress matcher itself is port-agnostic at the three callsites +// where it's consulted — +// +// - egress_enforcer.go: req.URL.Hostname() before matcher.IsAllowed +// - egress_proxy.go: extractHost() via net.SplitHostPort +// - safe_dialer.go: net.SplitHostPort(addr) +// +// As long as those callsites strip the port off the OUTBOUND host before +// matching, a hostname entry in the allowlist suffices for any port. If +// any of those callsites is ever changed to pass host:port to the +// matcher, this function (and TestAuthDomains_AssumesPortAgnosticMatcher) +// must change too — an issuer at https://example.com:8443 would otherwise +// be silently blocked because the allowlist would contain only +// "example.com" while the dialer asked for "example.com:8443". +func hostFromURL(raw string) string { + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil || u.Host == "" { + return "" + } + return u.Hostname() +} diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go new file mode 100644 index 0000000..c7a1616 --- /dev/null +++ b/forge-core/security/auth_domains_test.go @@ -0,0 +1,195 @@ +package security_test + +import ( + "reflect" + "testing" + + "github.com/initializ/forge/forge-core/security" + "github.com/initializ/forge/forge-core/types" +) + +func TestAuthDomains_Empty(t *testing.T) { + if got := security.AuthDomains(types.AuthConfig{}); got != nil { + t.Errorf("AuthDomains(empty) = %v, want nil", got) + } +} + +func TestAuthDomains_OIDCIssuer(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }}, + }, + }) + want := []string{"login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_HTTPVerifier(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "http_verifier", Settings: map[string]any{ + "url": "https://verify.example.com/verify", + }}, + }, + }) + want := []string{"verify.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_OIDCWithExplicitJWKSURL(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + "jwks_url": "https://keys.example.com/.well-known/jwks.json", + }}, + }, + }) + want := []string{"keys.example.com", "login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_MultiProviderDedup(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }}, + {Type: "http_verifier", Settings: map[string]any{ + "url": "https://login.example.com/verify", + }}, + {Type: "static_token", Settings: map[string]any{"token_env": "X"}}, + }, + }) + want := []string{"login.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v (must dedup across providers)", got, want) + } +} + +func TestAuthDomains_StaticTokenContributesNothing(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "static_token", Settings: map[string]any{"token_env": "X"}}, + }, + }) + if got != nil { + t.Errorf("static_token-only config should contribute no domains, got %v", got) + } +} + +func TestAuthDomains_MalformedURLsSkipped(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "://not a url", + "audience": "api://forge", + }}, + {Type: "oidc", Settings: map[string]any{ + "issuer": "https://valid.example.com", + "audience": "api://forge", + }}, + }, + }) + // Malformed URL silently skipped (validate package handles surface errors). + want := []string{"valid.example.com"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_PortStripped(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "http://localhost:8080/realms/dev", + "audience": "api://forge", + }}, + }, + }) + want := []string{"localhost"} + if !reflect.DeepEqual(got, want) { + t.Errorf("AuthDomains = %v, want %v (port must be stripped)", got, want) + } +} + +func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "future_provider", Settings: map[string]any{"url": "https://x.example.com"}}, + }, + }) + if got != nil { + t.Errorf("unknown provider type returned domains: %v (must be nil — extractor not registered)", got) + } +} + +// TestAuthDomains_AssumesPortAgnosticMatcher pins the cross-package +// contract that AuthDomains relies on. AuthDomains strips ports +// (e.g., "https://login.example.com:8443" → "login.example.com") on the +// assumption that the egress matcher will ALSO strip ports off outbound +// hosts before checking the allowlist. If that assumption ever breaks +// (someone makes the matcher port-aware), AuthDomains needs to flip to +// emitting host:port, OR every existing forge.yaml that points at a +// non-443 IdP suddenly silently blocks at JWKS-fetch time. +// +// This test catches that drift early: it builds a matcher with the +// hostname-only allowlist AuthDomains would produce, then asks it +// whether the SAME hostname WITH a port is allowed. If the matcher +// answers "no", the contract is broken and someone needs to look at +// security/auth_domains.go. +func TestAuthDomains_AssumesPortAgnosticMatcher(t *testing.T) { + hosts := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Settings: map[string]any{ + "issuer": "http://localhost:8080/realms/dev", + "audience": "api://forge", + }}, + }, + }) + if len(hosts) == 0 { + t.Fatal("AuthDomains returned no hosts for a configured OIDC issuer") + } + + // Build a matcher with exactly what the runner would feed it. + matcher := security.NewDomainMatcher(security.ModeAllowlist, hosts) + + // What we want: matching the raw hostname must succeed. + if !matcher.IsAllowed("localhost") { + t.Fatalf("matcher rejected the very hostname AuthDomains produced: %v", hosts) + } + + // What this guards: the matcher MUST also be port-agnostic. The + // dialer / enforcer call IsAllowed with hostname-only strings (port + // is stripped via net.SplitHostPort or url.Hostname()). If a future + // change ever makes the matcher inspect ports, this contract + // changes — and this test will need to flip together with + // AuthDomains' port-stripping behavior. + // + // We can't directly probe "matcher accepts localhost:8080" because + // the matcher's documented input is hostname-only. The integration + // contract is enforced at the dialer layer (see egress_enforcer.go:40 + // — req.URL.Hostname() strips the port before matcher.IsAllowed). + // What we CAN do here is assert the matcher's input shape: passing + // a host:port string MUST NOT match an allowlist with the bare host. + // That guarantees we'll notice if the matcher silently grew port + // awareness — IsAllowed("localhost:8080") would then start matching + // a "localhost" allowlist entry and this assertion would flip. + if matcher.IsAllowed("localhost:8080") { + t.Fatal("matcher unexpectedly accepts host:port form — the contract " + + "AuthDomains assumes (port stripped before IsAllowed) may have " + + "changed. Review security/auth_domains.go and the documented " + + "callsites (egress_enforcer.go, egress_proxy.go, safe_dialer.go).") + } +} diff --git a/forge-core/types/config.go b/forge-core/types/config.go index faf90af..2294c53 100644 --- a/forge-core/types/config.go +++ b/forge-core/types/config.go @@ -22,12 +22,42 @@ type ForgeConfig struct { Skills SkillsRef `yaml:"skills,omitempty"` Memory MemoryConfig `yaml:"memory,omitempty"` Secrets SecretsConfig `yaml:"secrets,omitempty"` + Auth AuthConfig `yaml:"auth,omitempty"` Schedules []ScheduleConfig `yaml:"schedules,omitempty"` CORSOrigins []string `yaml:"cors_origins,omitempty"` Package PackageConfig `yaml:"package,omitempty"` GuardrailsPath string `yaml:"guardrails_path,omitempty"` // path to guardrails.json (default: "guardrails.json") } +// AuthConfig declares the auth provider chain for the A2A server. Mirrors +// the secrets.providers pattern: each entry is { type, settings } and the +// runner builds them in order via auth.Registry.BuildChain. +// +// Backward compatibility: if AuthConfig.Providers is empty, the legacy +// --auth-url / FORGE_AUTH_URL / FORGE_AUTH_ORG_ID flow synthesizes a +// single-element http_verifier chain (unchanged from pre-PR3 behavior). +type AuthConfig struct { + // Required indicates whether auth is mandatory. When false (default), + // the runtime treats Providers as the source of truth — operators may + // still opt out via --no-auth on localhost. Reserved for future + // TUI/UI gating logic. + Required bool `yaml:"required,omitempty"` + + // Providers is the ordered list of auth providers that compose into + // the A2A server's auth chain. First-match wins. + Providers []AuthProvider `yaml:"providers,omitempty"` +} + +// AuthProvider is one entry in AuthConfig.Providers. The Type names a +// factory registered with the auth package (e.g., "oidc", "http_verifier", +// "static_token", and — in Phase 3 — "okta"). Settings is unmarshaled +// into the provider-specific Config struct via auth.UnmarshalSettings. +type AuthProvider struct { + Type string `yaml:"type"` + Name string `yaml:"name,omitempty"` + Settings map[string]any `yaml:"settings,omitempty"` +} + // ScheduleConfig defines a recurring scheduled task in forge.yaml. type ScheduleConfig struct { ID string `yaml:"id"` diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go new file mode 100644 index 0000000..78ead5b --- /dev/null +++ b/forge-core/validate/auth.go @@ -0,0 +1,93 @@ +package validate + +import ( + "fmt" + + "github.com/initializ/forge/forge-core/types" +) + +// knownAuthProviderTypes is the closed set of auth provider types that +// validate accepts. It is intentionally NOT derived from auth.RegisteredTypes() +// — validation must work at build time (forge validate) without importing +// the runtime provider implementations, which is what allows the forge-core +// `validate` package to stay free of side-effect imports. +// +// New provider types must be added here when they ship. +var knownAuthProviderTypes = map[string]bool{ + "http_verifier": true, + "static_token": true, + "oidc": true, + // "okta": true, // Phase 3 (v0.11.0) +} + +// ValidateAuthConfig adds errors and warnings for a forge.yaml auth: block. +// Empty AuthConfig is valid (legacy --auth-url path remains the fallback). +func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { + if len(cfg.Providers) == 0 { + if cfg.Required { + r.Errors = append(r.Errors, "auth.required is true but auth.providers is empty") + } + return + } + + seenNames := map[string]int{} + for i, p := range cfg.Providers { + prefix := fmt.Sprintf("auth.providers[%d]", i) + + if p.Type == "" { + r.Errors = append(r.Errors, prefix+": type is required") + continue + } + if !knownAuthProviderTypes[p.Type] { + r.Errors = append(r.Errors, fmt.Sprintf("%s: unknown type %q (known: http_verifier, static_token, oidc)", prefix, p.Type)) + continue + } + + if p.Name != "" { + seenNames[p.Name]++ + if seenNames[p.Name] > 1 { + r.Warnings = append(r.Warnings, fmt.Sprintf("%s: duplicate name %q across providers", prefix, p.Name)) + } + } + + validateProviderSettings(prefix, p, r) + } +} + +// validateProviderSettings checks the type-specific required keys in the +// settings block. We re-validate here (the provider's own Validate() +// runs at runtime construction) so `forge validate` catches errors before +// `forge run`. +func validateProviderSettings(prefix string, p types.AuthProvider, r *ValidationResult) { + switch p.Type { + case "http_verifier": + if asString(p.Settings, "url") == "" { + r.Errors = append(r.Errors, prefix+" (http_verifier): settings.url is required") + } + case "static_token": + if asString(p.Settings, "token") == "" && asString(p.Settings, "token_env") == "" { + r.Errors = append(r.Errors, prefix+" (static_token): settings.token or settings.token_env is required") + } + if asString(p.Settings, "token") != "" { + r.Warnings = append(r.Warnings, prefix+" (static_token): literal 'token' in YAML is a footgun — prefer 'token_env'") + } + case "oidc": + if asString(p.Settings, "issuer") == "" { + r.Errors = append(r.Errors, prefix+" (oidc): settings.issuer is required") + } + if asString(p.Settings, "audience") == "" { + r.Errors = append(r.Errors, prefix+" (oidc): settings.audience is required") + } + } +} + +// asString reads a string-valued setting, returning "" for missing or +// non-string values. +func asString(m map[string]any, key string) string { + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} diff --git a/forge-core/validate/auth_test.go b/forge-core/validate/auth_test.go new file mode 100644 index 0000000..a4813fb --- /dev/null +++ b/forge-core/validate/auth_test.go @@ -0,0 +1,218 @@ +package validate + +import ( + "strings" + "testing" + + "github.com/initializ/forge/forge-core/types" +) + +func TestValidateAuthConfig_EmptyIsValid(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{}, r) + if !r.IsValid() { + t.Errorf("empty AuthConfig should be valid, got errors: %v", r.Errors) + } +} + +func TestValidateAuthConfig_RequiredWithoutProviders(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{Required: true}, r) + if r.IsValid() { + t.Error("required=true with no providers should be an error") + } + if !containsSubstr(r.Errors, "required") { + t.Errorf("missing 'required' in error message: %v", r.Errors) + } +} + +func TestValidateAuthConfig_HTTPVerifier(t *testing.T) { + tests := []struct { + name string + provider types.AuthProvider + wantErr string + }{ + { + name: "valid", + provider: types.AuthProvider{ + Type: "http_verifier", + Settings: map[string]any{"url": "https://verify.example.com"}, + }, + }, + { + name: "missing url", + provider: types.AuthProvider{ + Type: "http_verifier", + Settings: map[string]any{}, + }, + wantErr: "settings.url is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{Providers: []types.AuthProvider{tt.provider}}, r) + if tt.wantErr == "" { + if !r.IsValid() { + t.Errorf("want valid, got errors: %v", r.Errors) + } + } else { + if !containsSubstr(r.Errors, tt.wantErr) { + t.Errorf("want error containing %q, got %v", tt.wantErr, r.Errors) + } + } + }) + } +} + +func TestValidateAuthConfig_OIDC(t *testing.T) { + tests := []struct { + name string + settings map[string]any + wantErr string + }{ + { + name: "valid", + settings: map[string]any{ + "issuer": "https://login.example.com", + "audience": "api://forge", + }, + }, + { + name: "missing issuer", + settings: map[string]any{"audience": "api://forge"}, + wantErr: "settings.issuer is required", + }, + { + name: "missing audience", + settings: map[string]any{"issuer": "https://login.example.com"}, + wantErr: "settings.audience is required", + }, + { + name: "both missing", + settings: map[string]any{}, + wantErr: "issuer is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Type: "oidc", Settings: tt.settings}}, + }, r) + if tt.wantErr == "" { + if !r.IsValid() { + t.Errorf("want valid, got errors: %v", r.Errors) + } + } else { + if !containsSubstr(r.Errors, tt.wantErr) { + t.Errorf("want error containing %q, got %v", tt.wantErr, r.Errors) + } + } + }) + } +} + +func TestValidateAuthConfig_StaticToken(t *testing.T) { + tests := []struct { + name string + settings map[string]any + wantErr string + wantWarn string + }{ + { + name: "valid with token_env", + settings: map[string]any{"token_env": "FORGE_TOKEN"}, + }, + { + name: "valid with token (warns)", + settings: map[string]any{"token": "literal-secret"}, + wantWarn: "footgun", + }, + { + name: "neither token nor token_env", + settings: map[string]any{}, + wantErr: "settings.token or settings.token_env is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Type: "static_token", Settings: tt.settings}}, + }, r) + if tt.wantErr != "" && !containsSubstr(r.Errors, tt.wantErr) { + t.Errorf("want error containing %q, got %v", tt.wantErr, r.Errors) + } + if tt.wantWarn != "" && !containsSubstr(r.Warnings, tt.wantWarn) { + t.Errorf("want warning containing %q, got %v", tt.wantWarn, r.Warnings) + } + }) + } +} + +func TestValidateAuthConfig_UnknownType(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Type: "ldap"}}, + }, r) + if r.IsValid() { + t.Error("unknown provider type should fail validation") + } + if !containsSubstr(r.Errors, "unknown type") { + t.Errorf("want 'unknown type' in errors, got %v", r.Errors) + } +} + +func TestValidateAuthConfig_MissingType(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{{Settings: map[string]any{"url": "https://x"}}}, + }, r) + if r.IsValid() { + t.Error("provider without type should fail validation") + } +} + +func TestValidateAuthConfig_DuplicateNamesWarn(t *testing.T) { + r := &ValidationResult{} + ValidateAuthConfig(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc", Name: "sso", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + {Type: "oidc", Name: "sso", Settings: map[string]any{"issuer": "https://x", "audience": "y"}}, + }, + }, r) + if !containsSubstr(r.Warnings, "duplicate name") { + t.Errorf("want duplicate-name warning, got %v", r.Warnings) + } +} + +func TestValidateForgeConfig_IncludesAuth(t *testing.T) { + // Ensure the top-level ValidateForgeConfig calls ValidateAuthConfig. + cfg := &types.ForgeConfig{ + AgentID: "test-agent", + Version: "0.1.0", + Auth: types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "oidc"}, // missing issuer + audience + }, + }, + } + r := ValidateForgeConfig(cfg) + if r.IsValid() { + t.Fatal("config with invalid oidc settings should fail validation") + } + if !containsSubstr(r.Errors, "issuer is required") { + t.Errorf("want oidc-issuer error to surface via ValidateForgeConfig, got %v", r.Errors) + } +} + +// containsSubstr returns true if any string in ss contains substr. +func containsSubstr(ss []string, substr string) bool { + for _, s := range ss { + if strings.Contains(s, substr) { + return true + } + } + return false +} diff --git a/forge-core/validate/forge_config.go b/forge-core/validate/forge_config.go index 7ff5da5..e6fa43f 100644 --- a/forge-core/validate/forge_config.go +++ b/forge-core/validate/forge_config.go @@ -118,5 +118,7 @@ func ValidateForgeConfig(cfg *types.ForgeConfig) *ValidationResult { } } + ValidateAuthConfig(cfg.Auth, r) + return r } diff --git a/forge-ui/handlers_create.go b/forge-ui/handlers_create.go index 781bc36..7cd1eed 100644 --- a/forge-ui/handlers_create.go +++ b/forge-ui/handlers_create.go @@ -2,6 +2,7 @@ package forgeui import ( "encoding/json" + "fmt" "net/http" "os" "path/filepath" @@ -118,6 +119,16 @@ func (s *UIServer) handleGetWizardMeta(w http.ResponseWriter, _ *http.Request) { } } + // Auth provider types — server-driven list. Adding a new provider + // (e.g., Okta in Phase 3) means appending one entry here; the + // frontend renders the picker from this metadata. + meta.AuthProviderTypes = []AuthProviderTypeMeta{ + {Type: "none", Label: "None", Description: "Anonymous access — no auth: block written"}, + {Type: "oidc", Label: "OIDC (JWT)", Description: "Auth0, Keycloak, Azure AD, Google, Okta-OIDC, …"}, + {Type: "http_verifier", Label: "HTTP Verifier", Description: "Legacy — POST tokens to your own /verify endpoint"}, + {Type: "custom", Label: "Custom", Description: "Write a commented stub, edit forge.yaml manually"}, + } + writeJSON(w, http.StatusOK, meta) } @@ -143,6 +154,16 @@ func (s *UIServer) handleCreateAgent(w http.ResponseWriter, r *http.Request) { return } + // Server-side validation of the auth payload before scaffolding the + // agent on disk. Without this, a buggy frontend could write a + // malformed auth: block into forge.yaml that only fails at + // `forge run` — after the agent directory has been created and the + // user thinks creation succeeded. Review #9. + if err := validateAuthPayload(opts.Auth); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + agentDir, err := s.cfg.CreateFunc(opts) if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) @@ -164,6 +185,47 @@ func (s *UIServer) handleCreateAgent(w http.ResponseWriter, r *http.Request) { }) } +// validateAuthPayload runs the forge-core validate package over the +// wizard's auth payload before the agent is scaffolded on disk. Returns +// nil for benign shapes (absent, "none", "custom") and for valid +// provider configs; returns an error with the validation messages +// surfaced when settings are malformed. +// +// The translation from the wizard's AuthCreateOptions → types.AuthConfig +// mirrors exactly what cmd/ui.go's createFunc does later, so this check +// has the same view of the inputs as the eventual scaffold. +func validateAuthPayload(a *AuthCreateOptions) error { + if a == nil { + return nil + } + switch a.Mode { + case "", "none", "custom": + // Nothing to validate — these modes don't write a provider entry. + return nil + } + + authYAML := types.AuthConfig{ + // Required is set true here because the wizard's renderAuthBlock + // always emits required:true when a provider is chosen. Keeping + // the validation view aligned with the rendered YAML avoids + // drift between "wizard accepts" and "forge validate accepts". + Required: true, + Providers: []types.AuthProvider{ + { + Type: a.Mode, + Settings: a.Settings, + }, + }, + } + + result := &validate.ValidationResult{} + validate.ValidateAuthConfig(authYAML, result) + if !result.IsValid() { + return fmt.Errorf("auth: %s", strings.Join(result.Errors, "; ")) + } + return nil +} + // handleOAuthStart initiates the OAuth browser flow for a provider. // The flow opens the user's browser for authentication and waits for the callback. func (s *UIServer) handleOAuthStart(w http.ResponseWriter, r *http.Request) { diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index 031c710..906de6d 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" ) @@ -385,4 +386,348 @@ func TestHandleGetWizardMeta(t *testing.T) { if len(meta.WebSearchProviders) == 0 { t.Error("expected non-empty web_search_providers list") } + + // PR6: auth_provider_types is server-driven so the frontend doesn't + // hardcode the list. Must include the four founding types. + if len(meta.AuthProviderTypes) != 4 { + t.Errorf("auth_provider_types len = %d, want 4", len(meta.AuthProviderTypes)) + } + wantTypes := map[string]bool{"none": false, "oidc": false, "http_verifier": false, "custom": false} + for _, a := range meta.AuthProviderTypes { + if _, ok := wantTypes[a.Type]; !ok { + t.Errorf("unexpected auth type %q", a.Type) + continue + } + wantTypes[a.Type] = true + if a.Label == "" || a.Description == "" { + t.Errorf("auth type %q missing label/description", a.Type) + } + } + for typ, seen := range wantTypes { + if !seen { + t.Errorf("missing required auth type %q", typ) + } + } +} + +// setupCreateWithCapture returns a UIServer whose CreateFunc records the +// last AgentCreateOptions it received. Used to assert that opts.Auth +// round-trips through the JSON boundary. +func setupCreateWithCapture(t *testing.T) (*UIServer, *AgentCreateOptions) { + t.Helper() + root := t.TempDir() + captured := &AgentCreateOptions{} + mockCreate := func(opts AgentCreateOptions) (string, error) { + *captured = opts + return filepath.Join(root, opts.Name), nil + } + srv := NewUIServer(UIServerConfig{ + Port: 4200, + WorkDir: root, + ExePath: "/usr/bin/false", + CreateFunc: mockCreate, + AgentPort: 9100, + }) + return srv, captured +} + +func TestHandleCreateAgent_WithAuthPayload(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "auth-test-agent", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge" + } + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + } + if captured.Auth == nil { + t.Fatal("captured Auth is nil — opts.Auth did not round-trip") + } + if captured.Auth.Mode != "oidc" { + t.Errorf("Auth.Mode = %q, want oidc", captured.Auth.Mode) + } + if captured.Auth.Settings["issuer"] != "https://login.example.com" { + t.Errorf("issuer = %v, want https://login.example.com", captured.Auth.Settings["issuer"]) + } + if captured.Auth.Settings["audience"] != "api://forge" { + t.Errorf("audience = %v, want api://forge", captured.Auth.Settings["audience"]) + } +} + +// --- Server-side auth payload validation (review #9) --- +// +// Without this, the wizard could write malformed auth: blocks to disk +// (e.g., oidc without issuer/audience) — creation succeeds, `forge run` +// fails later, user confused. + +func TestHandleCreateAgent_OIDCMissingIssuerRejected(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + // Audience set but no issuer — should fail validation before + // CreateFunc is invoked. + body := []byte(`{ + "name": "bad-oidc", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { "audience": "api://forge" } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (body: %s)", w.Code, w.Body.String()) + } + // CreateFunc must NOT have been called — the agent directory should + // never exist on disk for malformed auth. + if captured.Name != "" { + t.Errorf("CreateFunc was called despite validation failure (captured: %+v)", captured) + } + if !strings.Contains(w.Body.String(), "issuer is required") { + t.Errorf("error body should name the missing field, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_OIDCMissingAudienceRejected(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "bad-oidc-2", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { "issuer": "https://login.example.com" } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", w.Code) + } + if captured.Name != "" { + t.Errorf("CreateFunc called despite validation failure") + } + if !strings.Contains(w.Body.String(), "audience is required") { + t.Errorf("error body should name missing audience, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_OIDCBothFieldsAccepted(t *testing.T) { + // Regression: valid oidc payload still works end-to-end. + srv, captured := setupCreateWithCapture(t) + body := []byte(`{ + "name": "good-oidc", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge" + } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } + if captured.Auth == nil || captured.Auth.Mode != "oidc" { + t.Errorf("expected captured Auth.Mode = oidc, got %+v", captured.Auth) + } +} + +func TestHandleCreateAgent_HTTPVerifierMissingURLRejected(t *testing.T) { + srv, _ := setupCreateWithCapture(t) + body := []byte(`{ + "name": "bad-http", + "model_provider": "openai", + "auth": { "mode": "http_verifier", "settings": {} } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } + if !strings.Contains(w.Body.String(), "url is required") { + t.Errorf("error body should name missing url, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_UnknownAuthModeRejected(t *testing.T) { + srv, _ := setupCreateWithCapture(t) + body := []byte(`{ + "name": "weird", + "model_provider": "openai", + "auth": { "mode": "ldap", "settings": {} } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", w.Code) + } + if !strings.Contains(w.Body.String(), "unknown type") { + t.Errorf("error should name unknown-type, got: %s", w.Body.String()) + } +} + +func TestHandleCreateAgent_NoneAndCustomSkipValidation(t *testing.T) { + // mode "none" and "custom" don't produce providers — there's + // nothing to validate. They must round-trip cleanly. + for _, mode := range []string{"none", "custom"} { + t.Run(mode, func(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + body := []byte(`{ + "name": "agent-` + mode + `", + "model_provider": "openai", + "auth": { "mode": "` + mode + `", "settings": {} } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + if w.Code != http.StatusCreated { + t.Errorf("mode %q: status = %d, want 201", mode, w.Code) + } + if captured.Auth == nil || captured.Auth.Mode != mode { + t.Errorf("Auth.Mode = %v, want %q", captured.Auth, mode) + } + }) + } +} + +// TestHandleCreateAgent_NestedClaimMapShape pins the contract the +// frontend's buildAuthPayload promises (app.js): a flat `groups_claim` +// from the wizard is translated client-side into the nested shape +// `settings.claim_map.groups`. The backend doesn't run that +// translation — but if the frontend ever regresses (and app.js is +// hand-edited, so it's plausible), the request shape arriving here +// would change and downstream YAML rendering would silently drop the +// custom groups claim. +// +// This test documents what the backend EXPECTS to receive after the +// JS translation runs. If a future contributor changes app.js's +// buildAuthPayload to send the flat shape, this test still passes +// (we accept anything in Settings) but a documented JS-side test +// in app.js would catch the drift. Until JS tests exist, this is the +// canonical reference for the post-translation payload shape. +// Review #12.5. +func TestHandleCreateAgent_NestedClaimMapShape(t *testing.T) { + srv, captured := setupCreateWithCapture(t) + + // Shape AFTER the JS translation runs: groups_claim has become + // claim_map.groups. This is what the backend should see. + body := []byte(`{ + "name": "agent-with-roles", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge", + "claim_map": { "groups": "roles" } + } + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } + if captured.Auth == nil { + t.Fatal("captured Auth is nil") + } + cm, ok := captured.Auth.Settings["claim_map"].(map[string]any) + if !ok { + t.Fatalf("claim_map not present or wrong type: %v", captured.Auth.Settings["claim_map"]) + } + if cm["groups"] != "roles" { + t.Errorf("claim_map.groups = %v, want roles", cm["groups"]) + } + + // Negative half: the FLAT shape (pre-translation) should NOT + // round-trip as if it had been translated. If the frontend ever + // stops calling buildAuthPayload, the backend receives the raw + // form and the YAML will be missing claim_map entirely. + if _, hasFlat := captured.Auth.Settings["groups_claim"]; hasFlat { + t.Errorf("settings contains flat groups_claim — frontend translation may have skipped") + } +} + +func TestHandleCreateAgent_FlatGroupsClaim_PreservedAsExtraField(t *testing.T) { + // Counterpart: if the frontend sends the FLAT shape by mistake, + // the backend accepts the request (no validation rejects unknown + // keys) but the claim_map nested mapping is missing — which means + // the generated forge.yaml won't include claim_map.groups, and the + // agent runtime will fall back to the default "groups" claim. + // + // This documents the failure mode so if anyone investigates + // "my custom group claim isn't being read," this test is the + // breadcrumb. + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "agent-flat", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge", + "groups_claim": "roles" + } + } + }`) + + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + // Request is structurally valid (issuer + audience present). It just + // won't behave as the user expected at runtime. + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201", w.Code) + } + if _, hasNested := captured.Auth.Settings["claim_map"]; hasNested { + t.Error("backend should not synthesize claim_map from groups_claim — that's the frontend's job") + } +} + +func TestHandleCreateAgent_WithoutAuthPayload(t *testing.T) { + // Omitting `auth` from the request is still a valid create — the + // agent gets anonymous access by default. + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{"name": "no-auth-agent", "model_provider": "openai"}`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d", w.Code, http.StatusCreated) + } + if captured.Auth != nil { + t.Errorf("Auth = %v, want nil for omitted field", captured.Auth) + } } diff --git a/forge-ui/server.go b/forge-ui/server.go index 5f14d59..613c9a5 100644 --- a/forge-ui/server.go +++ b/forge-ui/server.go @@ -3,7 +3,6 @@ package forgeui import ( "context" "fmt" - "io/fs" "log" "net" "net/http" @@ -103,12 +102,10 @@ func (s *UIServer) Start(ctx context.Context) error { mux.HandleFunc("GET /api/agents/{id}/skill-builder/context", s.handleSkillBuilderContext) mux.HandleFunc("GET /api/agents/{id}/skill-builder/provider", s.handleSkillBuilderProvider) - // Static file serving with SPA fallback - distFS, err := fs.Sub(static.FS, "dist") - if err != nil { - return fmt.Errorf("creating sub filesystem: %w", err) - } - fileServer := http.FileServer(http.FS(distFS)) + // Static file serving with SPA fallback. The embedded FS is rooted + // directly at the static assets — no "dist/" subdirectory (review #8 + // removed the misleading dist/ naming). + fileServer := http.FileServer(http.FS(static.FS)) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // Try to serve the file directly path := r.URL.Path @@ -117,8 +114,8 @@ func (s *UIServer) Start(ctx context.Context) error { return } - // Check if the file exists - f, err := distFS.Open(strings.TrimPrefix(path, "/")) + // Check if the file exists in the embedded FS. + f, err := static.FS.Open(strings.TrimPrefix(path, "/")) if err == nil { _ = f.Close() fileServer.ServeHTTP(w, r) diff --git a/forge-ui/static/dist/app.js b/forge-ui/static/app.js similarity index 92% rename from forge-ui/static/dist/app.js rename to forge-ui/static/app.js index 6eed95a..d4bc86b 100644 --- a/forge-ui/static/dist/app.js +++ b/forge-ui/static/app.js @@ -1121,7 +1121,56 @@ function loadMonaco() { // ── Create Agent Wizard ────────────────────────────────────── -const WIZARD_STEPS = ['Name', 'Provider', 'Model & Key', 'Channels', 'Tools', 'Skills', 'Fallback', 'Env Vars', 'Review']; +const WIZARD_STEPS = ['Name', 'Provider', 'Model & Key', 'Channels', 'Tools', 'Skills', 'Fallback', 'Env Vars', 'Authentication', 'Review']; + +// buildAuthPayload reshapes the wizard's form.auth into the JSON shape +// the backend (cmd/ui.go → scaffold) expects. The wizard collects +// `groups_claim` as a flat field for usability; we translate it into the +// nested `claim_map: {groups: }` shape here. Empty optional fields +// are stripped so the generated forge.yaml stays clean. +// authReviewLabel renders the one-line summary shown on the Review step. +function authReviewLabel(auth) { + const mode = auth?.mode || 'none'; + const s = auth?.settings || {}; + if (mode === 'none') return 'Anonymous (no auth)'; + if (mode === 'custom') return 'Custom (edit forge.yaml)'; + if (mode === 'oidc') { + try { + return s.issuer ? 'OIDC · ' + new URL(s.issuer).hostname : 'OIDC'; + } catch (e) { + return 'OIDC'; + } + } + if (mode === 'http_verifier') { + try { + return s.url ? 'HTTP Verifier · ' + new URL(s.url).hostname : 'HTTP Verifier'; + } catch (e) { + return 'HTTP Verifier'; + } + } + return mode; +} + +function buildAuthPayload(auth) { + const mode = auth?.mode || 'none'; + const src = auth?.settings || {}; + if (mode === 'none' || mode === 'custom') { + return { mode, settings: {} }; + } + const settings = {}; + if (mode === 'oidc') { + const issuer = (src.issuer || '').replace(/\/+$/, ''); + if (issuer) settings.issuer = issuer; + if (src.audience) settings.audience = src.audience; + if (src.groups_claim && src.groups_claim !== 'groups') { + settings.claim_map = { groups: src.groups_claim }; + } + } else if (mode === 'http_verifier') { + if (src.url) settings.url = src.url; + if (src.default_org) settings.default_org = src.default_org; + } + return { mode, settings }; +} // Provider-specific API key labels, placeholders, and env var names const PROVIDER_KEY_INFO = { @@ -1168,6 +1217,10 @@ function CreatePage() { fallbacks: [], // [{provider, api_key}] passphrase: '', env_vars: {}, + // A2A server auth chain (PR6). The wizard's Authentication step + // sets `mode` to one of {none, oidc, http_verifier, custom} and + // populates `settings` for non-trivial modes. + auth: { mode: 'none', settings: {} }, }); const [creating, setCreating] = useState(false); const [oauthLoading, setOauthLoading] = useState(false); @@ -1250,6 +1303,15 @@ function CreatePage() { } return form.model_name.trim().length > 0; } + // Authentication step: required fields must be filled when a + // non-trivial mode is chosen. None / Custom always advance. + if (step === 8) { + const mode = form.auth?.mode || 'none'; + const s = form.auth?.settings || {}; + if (mode === 'oidc') return !!(s.issuer && s.audience); + if (mode === 'http_verifier') return !!s.url; + return true; // none / custom + } return true; }, [step, form, oauthDone]); @@ -1257,7 +1319,10 @@ function CreatePage() { setCreating(true); setError(null); try { - const result = await createAgent(form); + // Reshape the auth payload before submission so the backend gets + // the typed settings the scaffold expects (e.g., claim_map nesting). + const payload = { ...form, auth: buildAuthPayload(form.auth) }; + const result = await createAgent(payload); setSuccess(result); setTimeout(() => rescanAgents().catch(() => {}), 500); } catch (err) { @@ -1750,7 +1815,104 @@ function CreatePage() { `; } - case 8: { // Review + case 8: { // Authentication (A2A server auth chain — PR6) + const authTypes = meta?.auth_provider_types || []; + const mode = form.auth?.mode || 'none'; + const settings = form.auth?.settings || {}; + + const setAuthMode = (newMode) => { + // Switching modes always resets settings so stale fields + // (e.g., OIDC issuer when user switches to HTTP Verifier) + // don't leak into the submitted payload. + updateForm('auth', { mode: newMode, settings: {} }); + }; + const setAuthSetting = (key, val) => { + updateForm('auth', { mode, settings: { ...settings, [key]: val } }); + }; + + return html` +
+
Authentication
+
+ How should the agent's A2A server authenticate incoming requests? Default: anonymous. +
+ +
+ ${authTypes.map(opt => html` + + `)} +
+ + ${mode === 'oidc' && html` +
+
+ OIDC Configuration +
+ + setAuthSetting('issuer', e.target.value)} autocomplete="off" /> + + + setAuthSetting('audience', e.target.value)} autocomplete="off" /> + + + setAuthSetting('groups_claim', e.target.value)} autocomplete="off" /> +
+ Defaults to "groups". Set to e.g. "roles" if your IdP uses a different claim name. +
+
+ `} + + ${mode === 'http_verifier' && html` +
+
+ HTTP Verifier Configuration +
+ + setAuthSetting('url', e.target.value)} autocomplete="off" /> + + + setAuthSetting('default_org', e.target.value)} autocomplete="off" /> +
+ `} + + ${mode === 'custom' && html` +
+ A commented stub will be written to forge.yaml. Edit it after creation to plug in providers we don't expose in the wizard (e.g., static_token). +
+ `} +
+ `; + } + case 9: { // Review const configuredEnvCount = Object.values(form.env_vars).filter(v => v).length; const authLabel = form.auth_method === 'oauth' ? 'OAuth' : (form.api_key ? '\u2713 API Key provided' : 'Not set'); return html` @@ -1817,6 +1979,10 @@ function CreatePage() {
Secret Encryption
${form.passphrase ? '\u2713 Passphrase set' : 'None (plaintext .env)'}
+
+
A2A Auth
+
${authReviewLabel(form.auth)}
+
`; } diff --git a/forge-ui/static/embed.go b/forge-ui/static/embed.go index 4c00508..13f2500 100644 --- a/forge-ui/static/embed.go +++ b/forge-ui/static/embed.go @@ -2,5 +2,19 @@ package static import "embed" -//go:embed all:dist +// FS exposes the agent dashboard's bundled static assets (index.html, +// app.js, style.css, and the Monaco editor under monaco/) as an +// embed.FS so they ship inside the forge binary. +// +// Files are listed explicitly here rather than embedding the whole +// directory so that this file itself (embed.go) is NOT served as an +// asset. If a new top-level static file is added, it needs to be +// added to this directive. +// +// Note: there is no build step — app.js and style.css are hand-edited +// in place, NOT generated from sources elsewhere. The directory used +// to be called "dist/" but that naming was misleading (review #8); it +// signaled "build artifact" when in fact the directory is the source. +// +//go:embed app.js style.css index.html monaco var FS embed.FS diff --git a/forge-ui/static/dist/index.html b/forge-ui/static/index.html similarity index 100% rename from forge-ui/static/dist/index.html rename to forge-ui/static/index.html diff --git a/forge-ui/static/dist/monaco/build.sh b/forge-ui/static/monaco/build.sh similarity index 100% rename from forge-ui/static/dist/monaco/build.sh rename to forge-ui/static/monaco/build.sh diff --git a/forge-ui/static/dist/monaco/editor.css b/forge-ui/static/monaco/editor.css similarity index 100% rename from forge-ui/static/dist/monaco/editor.css rename to forge-ui/static/monaco/editor.css diff --git a/forge-ui/static/dist/monaco/editor.js b/forge-ui/static/monaco/editor.js similarity index 100% rename from forge-ui/static/dist/monaco/editor.js rename to forge-ui/static/monaco/editor.js diff --git a/forge-ui/static/dist/monaco/editor.worker.js b/forge-ui/static/monaco/editor.worker.js similarity index 100% rename from forge-ui/static/dist/monaco/editor.worker.js rename to forge-ui/static/monaco/editor.worker.js diff --git a/forge-ui/static/dist/style.css b/forge-ui/static/style.css similarity index 100% rename from forge-ui/static/dist/style.css rename to forge-ui/static/style.css diff --git a/forge-ui/types.go b/forge-ui/types.go index e2c598e..d628890 100644 --- a/forge-ui/types.go +++ b/forge-ui/types.go @@ -167,6 +167,17 @@ type AgentCreateOptions struct { Passphrase string `json:"passphrase,omitempty"` EnvVars map[string]string `json:"env_vars,omitempty"` Force bool `json:"force,omitempty"` + Auth *AuthCreateOptions `json:"auth,omitempty"` // A2A server auth chain (PR6+) +} + +// AuthCreateOptions describes the auth chain selection the web wizard +// captured. Mode is one of "none", "oidc", "http_verifier", "custom". +// Settings is the provider-type-specific settings block (issuer, audience, +// url, default_org, claim_map, …). Mirrors the TUI wizard's contract so +// both surfaces feed the same scaffold path. +type AuthCreateOptions struct { + Mode string `json:"mode"` + Settings map[string]any `json:"settings,omitempty"` } // FallbackProvider describes a fallback LLM provider with its API key. @@ -250,4 +261,15 @@ type WizardMetadata struct { Skills []SkillBrowserEntry `json:"skills"` ProviderModels map[string]ProviderModels `json:"provider_models"` WebSearchProviders []WebSearchProviderOption `json:"web_search_providers"` + AuthProviderTypes []AuthProviderTypeMeta `json:"auth_provider_types"` +} + +// AuthProviderTypeMeta describes one selectable auth provider type so the +// frontend renders its picker from server-driven metadata (no hardcoded +// provider list in JavaScript). When a new provider ships (e.g., Okta in +// Phase 3), append one entry here and the wizard picks it up. +type AuthProviderTypeMeta struct { + Type string `json:"type"` // "none", "oidc", "http_verifier", "custom" + Label string `json:"label"` // human-readable label for the picker + Description string `json:"description"` // single-line description under the label } diff --git a/go.work.sum b/go.work.sum index b5df3da..903f097 100644 --- a/go.work.sum +++ b/go.work.sum @@ -3,16 +3,12 @@ github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=