diff --git a/.claude/commands/sync-docs.md b/.claude/commands/sync-docs.md index f7f050d..6c4538f 100644 --- a/.claude/commands/sync-docs.md +++ b/.claude/commands/sync-docs.md @@ -6,28 +6,33 @@ After feature work, update the affected documentation to reflect code changes. 1. **Identify changed files** — Run `git diff main --name-only` to find modified Go files. -2. **Map files to docs** — Use this mapping to determine which docs need updates: +2. **Map files to docs** — Use this mapping to determine which docs need updates. + Paths match the actual `docs/` tree (nested under `core-concepts/`, `reference/`, + `security/`, `deployment/`, `skills/`). | Changed path pattern | Affected docs | |---------------------|---------------| - | `forge-core/runtime/` | `docs/runtime.md`, `docs/hooks.md` | - | `forge-core/security/` | `docs/security/overview.md`, `docs/security/egress.md` | - | `forge-core/tools/` | `docs/tools.md` | - | `forge-core/llm/` | `docs/runtime.md` | - | `forge-core/memory/` | `docs/memory.md` | - | `forge-core/scheduler/` | `docs/scheduling.md` | - | `forge-core/secrets/` | `docs/security/secrets.md` | - | `forge-core/skills/` | `docs/skills.md` | - | `forge-core/channels/` | `docs/channels.md` | - | `forge-cli/cmd/` | `docs/commands.md` | - | `forge-cli/runtime/` | `docs/runtime.md` | - | `forge-cli/server/` | `docs/architecture.md` | - | `forge-cli/channels/` | `docs/channels.md` | - | `forge-cli/tools/` | `docs/tools.md` | - | `forge-plugins/` | `docs/channels.md`, `docs/plugins.md` | - | `forge-ui/` | `docs/dashboard.md` | - | `forge-skills/` | `docs/skills.md` | - | `forge.yaml` / `types/` | `docs/configuration.md` | + | `forge-core/auth/` | `docs/security/authentication.md`, `docs/security/audit-logging.md` | + | `forge-core/runtime/` | `docs/core-concepts/runtime-engine.md`, `docs/core-concepts/hooks.md` | + | `forge-core/security/` | `docs/security/overview.md`, `docs/security/egress-control.md` | + | `forge-core/tools/` | `docs/core-concepts/tools-and-builtins.md` | + | `forge-core/llm/` | `docs/core-concepts/runtime-engine.md` | + | `forge-core/memory/` | `docs/core-concepts/memory-system.md` | + | `forge-core/scheduler/` | `docs/core-concepts/scheduling.md` | + | `forge-core/secrets/` | `docs/security/secret-management.md` | + | `forge-core/channels/` | `docs/core-concepts/channels.md` | + | `forge-core/validate/` | `docs/reference/forge-yaml-schema.md` | + | `forge-cli/cmd/` | `docs/reference/cli-reference.md` | + | `forge-cli/runtime/` | `docs/core-concepts/runtime-engine.md` | + | `forge-cli/server/` | `docs/core-concepts/how-forge-works.md` | + | `forge-cli/channels/` | `docs/core-concepts/channels.md` | + | `forge-cli/tools/` | `docs/core-concepts/tools-and-builtins.md` | + | `forge-cli/internal/tui/` | `docs/reference/cli-reference.md` (wizard flow) | + | `forge-plugins/` | `docs/core-concepts/channels.md`, `docs/reference/framework-plugins.md` | + | `forge-ui/` | `docs/reference/web-dashboard.md` | + | `forge-skills/` | `docs/skills/writing-custom-skills.md`, `docs/skills/contributing-a-skill.md` | + | `forge-core/types/` / `forge.yaml` | `docs/reference/forge-yaml-schema.md` | + | `CHANGELOG.md` | (rendered into release notes; no per-doc sync needed) | 3. **Read the diff** — For each mapped doc, read the relevant `git diff main` output to understand what changed. diff --git a/.gitignore b/.gitignore index 8815ab0..500bc45 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,8 @@ profile.cov # OS files .DS_Store + +# Auth provider docs — kept on disk locally but not version-controlled. +# Source-of-truth docs live in the design folder; in-repo copies are +# scratch space until we decide on the doc-site delivery story. +docs/auth/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d657540 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,125 @@ +# Changelog + +## v0.11.0 — Phase 2: cloud-native auth providers (in progress) + +### Added + +- **`aws_sigv4` auth provider.** Authenticate AWS-IAM callers by reflecting + their Sigv4 signature to AWS STS `GetCallerIdentity`. No `aws-sdk-go-v2` + dependency. +- **`gcp_iap` auth provider.** Verify the JWT IAP forwards as + `X-Goog-Iap-Jwt-Assertion` when Forge sits behind a GCP HTTPS Load + Balancer with IAP enabled. +- **`azure_ad` auth provider.** Verify Microsoft Entra ID Bearer tokens + with tenant lock-in and optional Microsoft Graph group enrichment. +- Non-interactive `forge init` flags for the three new providers: + `--auth-aws-region`, `--auth-aws-allowed-principal` (repeatable), + `--auth-gcp-iap-audience`, `--auth-azure-tenant`, + `--auth-azure-multi-tenant`, `--auth-azure-groups-mode`. +- Web UI exposes the three new types via the `/api/wizard-meta` endpoint; + server-side validation rejects malformed payloads before scaffold. +- `egress_hosts` automatically extended for each new provider + (`sts..amazonaws.com`, `www.gstatic.com`, + `login.microsoftonline.com`, `graph.microsoft.com` when applicable). + +### Changed + +- Middleware now consults the auth chain **even when no Bearer token is + extracted**, so non-Bearer formats (Sigv4 `Authorization`, IAP + `X-Goog-Iap-Jwt-Assertion`) can be recognized. Existing Bearer + JWT + flows are unchanged. +- `auth.HeadersFromRequest` widened with `X-Goog-Iap-Jwt-Assertion` + for `gcp_iap`. Providers that don't consume this header are unaffected. +- `auth.TokenKind` recognizes the `forge-aws-v1.` Bearer prefix and + returns `"sigv4"`. The audit `token_kind` field now has five possible + values: `empty`, `opaque`, `jwt`, `sigv4`, `iap_jwt`. +- `validate.ValidateAuthConfig` admits the three new provider types and + enforces their per-type required keys (`aws_sigv4.region`, + `gcp_iap.audience`, `azure_ad.audience`, `azure_ad.tenant_id`-unless- + multi-tenant, `azure_ad.groups_mode` whitelist). + +### Notes for upgraders + +- **No forge.yaml changes are required** for callers continuing to use + Phase 1 providers (`static_token`, `oidc`, `http_verifier`). Phase 1 + test suite passes without modification. +- If you wrote a custom provider that inspects headers, the `Headers` + map now contains additional keys. Existing keys are unchanged. +- The `oidc` package gained an internal `SkipIssuerCheck` field carrying + `yaml:"-"` — it cannot be set via `forge.yaml` and is reachable only + from Go callers (currently only `azure_ad` multi-tenant). Operators see + no change. + +### `allowed_accounts` shortcut for whole-account trust + +For "any IAM principal in these AWS accounts" without writing +glob patterns: + +```yaml +auth: + providers: + - type: aws_sigv4 + settings: + region: us-east-1 + allowed_accounts: ["412664885516", "109887654321"] +``` + +Internally expands to the canonical glob set covering all identity +shapes (IAM users, IAM roles, STS assumed-roles, federated users) +for each account. Composes with `allowed_principals` — you can list +specific roles AND whole accounts in the same provider entry. + +For AWS-Org-wide trust without enumerating accounts, use AWS IAM +Identity Center (SSO) — SSO permission sets gate Org membership at +sign-in, and you can match Identity Center-assumed roles with the +existing `allowed_principals` globs. + +### `azure_ad.allowed_tenants` — explicit allowlist for multi-tenant mode + +```yaml +auth: + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true + allowed_tenants: + - "00000000-1111-2222-3333-444444444444" # partner A + - "55555555-6666-7777-8888-999999999999" # partner B +``` + +When `allow_multi_tenant: true`, the `tid` claim must be in +`allowed_tenants` (case-insensitive GUID match). Empty list + +multi-tenant remains the documented "any tenant globally" mode for +back-compat, but `forge validate` now emits a warning when the list +is empty to make the trade-off explicit. Non-interactive flag: +`--auth-azure-allowed-tenant` (repeatable). + +### TUI wizard supports Phase 2 providers + +`forge init`'s TUI picker now includes `AWS Sigv4 (IAM)`, +`GCP Identity-Aware Proxy`, and `Azure AD / Entra ID` entries with +step-by-step input flows. AAD is single-tenant in the TUI; +multi-tenant remains a deliberate YAML edit (security default). + +### Client experience for `aws_sigv4` + +The client side is a Bearer token with a 3-line mint: + +```python +import boto3, base64 +url = boto3.client('sts', region_name='us-east-1').generate_presigned_url( + 'get_caller_identity', ExpiresIn=900) +token = 'forge-aws-v1.' + base64.urlsafe_b64encode(url.encode()).rstrip(b'=').decode() + +requests.post(forge_url, headers={'Authorization': f'Bearer {token}'}, data=msg) +``` + +Pattern is identical to `aws-iam-authenticator` for EKS. Reference client +in `scripts/forge-aws-sign.py` — use it directly or as a template for +Go / Java / Node clients. Wire format is documented in the package +docstring of `forge-core/auth/providers/aws_sigv4/provider.go`. + +### Known deferred work + +- (none for Phase 2) diff --git a/README.md b/README.md index 35c05ff..dbd8bce 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ You write a `SKILL.md`. Forge compiles it into a secure, runnable agent with egr | Document | Description | |----------|-------------| | [Security Overview](docs/security/overview.md) | Complete security architecture | +| [Authentication](docs/security/authentication.md) | Pluggable auth providers — OIDC, AWS Sigv4, GCP IAP, Azure AD | | [Egress Security](docs/security/egress-control.md) | Egress enforcement deep dive | | [Secrets](docs/security/secret-management.md) | Encrypted secret management | | [Build Signing](docs/security/build-signing.md) | Ed25519 signing and verification | diff --git a/docs/reference/cli-reference.md b/docs/reference/cli-reference.md index 3120608..4b6a077 100644 --- a/docs/reference/cli-reference.md +++ b/docs/reference/cli-reference.md @@ -39,6 +39,23 @@ forge init [name] [flags] | `--org-id` | | | OpenAI Organization ID (enterprise) | | `--from-skills` | | | Path to a SKILL.md file for auto-configuration | | `--non-interactive` | | `false` | Skip interactive prompts | +| `--auth` | | | Auth mode: `none`, `oidc`, `http_verifier`, `aws_sigv4`, `gcp_iap`, `azure_ad`, `custom` | +| `--auth-issuer` | | | OIDC issuer URL (required with `--auth=oidc`) | +| `--auth-audience` | | | OIDC audience (required with `--auth=oidc`) | +| `--auth-url` | | | Verifier URL (required with `--auth=http_verifier`) | +| `--auth-default-org` | | | Default `org_id` for http_verifier | +| `--auth-groups-claim` | | | Custom JWT claim name for groups (oidc, default `groups`) | +| `--auth-aws-region` | | | AWS region for `aws_sigv4` (e.g. `us-east-1`) | +| `--auth-aws-audience` | | | Informational audience for `aws_sigv4` | +| `--auth-aws-allowed-principal` | | | Allowed principal glob for `aws_sigv4` (repeatable) | +| `--auth-aws-allowed-account` | | | Allowed AWS account ID for `aws_sigv4` (repeatable, 12-digit) | +| `--auth-aws-cache-ttl` | | `60s` | `aws_sigv4` identity cache TTL | +| `--auth-gcp-iap-audience` | | | Backend service ID for `gcp_iap` | +| `--auth-azure-tenant` | | | Entra tenant GUID for `azure_ad` | +| `--auth-azure-audience` | | | Audience (Application ID URI) for `azure_ad` | +| `--auth-azure-multi-tenant` | | `false` | Accept tokens from any Entra tenant | +| `--auth-azure-allowed-tenant` | | | Allowed Entra tenant GUID for multi-tenant `azure_ad` (repeatable) | +| `--auth-azure-groups-mode` | | `claim` | `azure_ad` groups mode: `claim` or `graph` | ### Generated Files @@ -84,8 +101,28 @@ forge init my-agent \ --api-key sk-... \ --org-id org-xxxxxxxxxxxxxxxxxxxxxxxx \ --non-interactive + +# AWS IAM auth (any caller in account 412664885516) +forge init my-agent \ + --model-provider ollama \ + --auth=aws_sigv4 \ + --auth-aws-region=us-east-1 \ + --auth-aws-allowed-account=412664885516 \ + --non-interactive + +# Azure AD multi-tenant with explicit partner allowlist +forge init my-agent \ + --model-provider ollama \ + --auth=azure_ad \ + --auth-azure-audience=api://forge \ + --auth-azure-multi-tenant \ + --auth-azure-allowed-tenant=00000000-1111-... \ + --auth-azure-allowed-tenant=55555555-6666-... \ + --non-interactive ``` +See [Authentication](../security/authentication.md) for the full auth provider reference. + --- ## `forge build` diff --git a/docs/reference/forge-yaml-schema.md b/docs/reference/forge-yaml-schema.md index 0407ce1..64a91bb 100644 --- a/docs/reference/forge-yaml-schema.md +++ b/docs/reference/forge-yaml-schema.md @@ -61,6 +61,51 @@ package: dest: "/usr/local/bin/custom-tool" # Install destination chmod: "0755" # File permissions +auth: # a2a HTTP-server auth chain (optional) + required: true # 401 every unauthenticated request + providers: # ordered; first match wins (fail-closed on rejection) + - type: "static_token" # local dev / shared-secret + settings: + token_env: "FORGE_AUTH_TOKEN" # env var name (preferred over literal `token:`) + - type: "oidc" # any IdP with OIDC discovery + settings: + issuer: "https://login.example.com/auth/realms/forge" + audience: "api://forge" + client_id: "" # optional azp fallback + jwks_url: "" # overrides discovery + jwks_cache_ttl: "1h" + clock_skew: "30s" + claim_map: {groups: "roles"} + - type: "http_verifier" # legacy external /verify endpoint + settings: + url: "https://auth.example.com/verify" + default_org: "acme" + timeout: "10s" + - type: "aws_sigv4" # Phase 2: AWS IAM via pre-signed STS URL + settings: + region: "us-east-1" # required + audience: "api://forge" # informational, emitted in audit Claims + allowed_accounts: # ergonomic: "anyone in these AWS accounts" + - "412664885516" + allowed_principals: # explicit globs (path.Match) + - "arn:aws:sts::412664885516:assumed-role/ci-deploy/*" + identity_cache_ttl: "60s" + max_token_expires: "15m" # caps caller's X-Amz-Expires claim + clock_skew: "5m" + - type: "gcp_iap" # Phase 2: GCP IAP-fronted Forge + settings: + audience: "/projects/PNUM/global/backendServices/BACKEND_ID" + jwks_refresh_ttl: "1h" + - type: "azure_ad" # Phase 2: Microsoft Entra ID + settings: + tenant_id: "00000000-1111-..." # required unless allow_multi_tenant + audience: "api://forge" + allow_multi_tenant: false + allowed_tenants: # required when multi-tenant + want allowlist + - "55555555-6666-..." + groups_mode: "claim" # "claim" | "graph" + graph_timeout: "5s" + secrets: providers: # Secret providers (order matters) - "encrypted-file" # AES-256-GCM encrypted file diff --git a/docs/reference/web-dashboard.md b/docs/reference/web-dashboard.md index afca134..a92f9e9 100644 --- a/docs/reference/web-dashboard.md +++ b/docs/reference/web-dashboard.md @@ -69,11 +69,42 @@ A multi-step wizard (web equivalent of `forge init`) that walks through the full | Tools | Select builtin tools; web_search shows Tavily vs Perplexity provider choice with API key input | | Skills | Browse registry skills by category with inline required/optional env var collection | | Fallback | Select backup LLM providers with API keys for automatic failover | +| Auth | Select a2a auth provider (see below) | +| Egress | Review the outbound allowlist (including auth-provider-derived hosts) | | Env & Security | Add extra env vars; set passphrase for AES-256-GCM secret encryption | | Review | Summary of all selections before creation | The wizard collects credentials inline at each step (matching the CLI TUI behavior) and supports all the same options: model selection, OAuth, web search providers, fallback chains, and encrypted secret storage. +### Auth step + +The Auth step exposes the same provider chain as `forge.yaml`'s +`auth.providers[]` block. Picker options: + +| Option | Effect | +|---|---| +| **None** | Anonymous access — no `auth:` block written | +| **OIDC (JWT)** | Generic OIDC IdP (Keycloak, Auth0, Okta, Google) — collects issuer + audience | +| **HTTP Verifier** | Legacy — POST tokens to your own `/verify` endpoint | +| **AWS Sigv4 (IAM)** | AWS-IAM-based callers — collects region + optional audience + comma-separated allowed accounts | +| **GCP Identity-Aware Proxy** | Forge behind GCP LB+IAP — collects backend service ID | +| **Azure AD / Entra ID** | Microsoft Entra ID (single-tenant in the wizard; multi-tenant requires editing `forge.yaml` as a deliberate security trade-off) | +| **Custom** | Comment stub — edit `forge.yaml` yourself | + +The Auth step runs **before** Egress, so the Egress step's review screen +displays auth-provider-derived hosts (e.g. `sts.us-east-1.amazonaws.com` +when AWS Sigv4 is selected) alongside the model / channel / tool / skill +hosts. The operator sees the full outbound surface for approval in one +place. + +The Web UI's `POST /api/create` endpoint additionally filters incoming +auth `settings` through a closed-key whitelist (defined in +`forge-core/validate/auth.go`) before scaffolding — unknown keys are +silently dropped rather than written to disk. + +See [Authentication](../security/authentication.md) for the full +provider reference. + ## Config Editor Edit `forge.yaml` for any agent with a Monaco-based YAML editor: diff --git a/docs/security/audit-logging.md b/docs/security/audit-logging.md index a850f6d..e376751 100644 --- a/docs/security/audit-logging.md +++ b/docs/security/audit-logging.md @@ -19,6 +19,8 @@ All runtime security events are emitted as structured NDJSON to stderr with corr | `egress_blocked` | Outbound request blocked (with domain, mode) | | `llm_call` | LLM API call completed (with token count) | | `guardrail_check` | Guardrail evaluation result | +| `auth_verify` | Inbound request authenticated successfully (with `provider`, `user_id`, `org_id`, `token_kind`) | +| `auth_fail` | Inbound request rejected (with `reason`, `token_kind`) | ### Example @@ -31,3 +33,81 @@ All runtime security events are emitted as structured NDJSON to stderr with corr ``` The `source` field distinguishes in-process enforcer events from subprocess proxy events. + +### Authentication events + +Every inbound request to `/tasks` emits exactly one of `auth_verify` or `auth_fail`. + +**Successful authentication:** + +```json +{ + "ts":"2026-05-24T00:50:01Z", + "event":"auth_verify", + "fields":{ + "method":"POST", + "path":"/tasks/send", + "provider":"aws_sigv4", + "user_id":"arn:aws:sts::412664885516:assumed-role/AWSReservedSSO_PowerUserAccess_.../Naveen", + "org_id":"412664885516", + "token_kind":"sigv4", + "groups_count":0, + "remote_addr":"[::1]:62297" + } +} +``` + +`user_id` is the canonical identifier the verifier returned (ARN for AWS, JWT +`sub` for OIDC/IAP/AAD). `org_id` is the AWS account, Entra tenant GUID, or +OIDC `tid`/`org_id`-mapped claim depending on the provider. + +**Failed authentication:** + +```json +{"ts":"...","event":"auth_fail","fields":{"reason":"rejected","token_kind":"sigv4","method":"POST","path":"/tasks/send","remote_addr":"[::1]:62200"}} +``` + +### Reason codes (`auth_fail.fields.reason`) + +| Reason | What it means | Operator action | +|---|---|---| +| `missing_token` | No auth-shaped headers at all | Caller forgot to authenticate | +| `not_for_me` | Bearer present but no provider claimed it | Wrong token format for the configured providers | +| `rejected` | Provider recognized + denied (allowlist miss, expired, bad sig, scope mismatch) | Check `allowed_principals` / `tenant_id` / token freshness | +| `invalid` | Token malformed (bad base64, unsupported alg, missing required field) | Token construction bug on the caller side | +| `provider_unavailable` | Verifier endpoint down (STS / JWKS / Graph 5xx, network error) | Provider-side incident; not a token issue | + +### Token kind values (`fields.token_kind`) + +Structural classification of what bytes were on the wire — safe to log: + +| Value | Shape | +|---|---| +| `empty` | No token / no auth-shaped headers | +| `opaque` | Bearer with non-JWT, non-sigv4 shape (channel adapter loopback, custom verifier tokens) | +| `jwt` | Bearer with three base64url segments (`oidc`, `azure_ad`) | +| `sigv4` | Bearer with `forge-aws-v1.` prefix (`aws_sigv4` pre-signed URL token) | +| `iap_jwt` | `X-Goog-Iap-Jwt-Assertion` header present (`gcp_iap`) — also stamped on successful verify even if Bearer was simultaneously present | + +### Audit pipeline grep recipes + +Who called my agent in the last hour, by ARN/email? + +```bash +jq -r 'select(.event=="auth_verify") | .fields.user_id' forge.log | sort | uniq -c +``` + +Why are requests failing? + +```bash +jq -r 'select(.event=="auth_fail") | .fields.reason' forge.log | sort | uniq -c +``` + +Which agents called this one (in a mesh)? + +```bash +jq -r 'select(.event=="auth_verify") | "\(.fields.user_id)"' forge.log | sort -u +``` + +See [Authentication](authentication.md) for the full provider chain and how +each provider populates these fields. diff --git a/docs/security/authentication.md b/docs/security/authentication.md new file mode 100644 index 0000000..3fd9a13 --- /dev/null +++ b/docs/security/authentication.md @@ -0,0 +1,423 @@ +--- +title: "Authentication Providers" +description: "Pluggable auth provider chain that gates Forge's /tasks endpoint — OIDC, AWS Sigv4, GCP IAP, Azure AD, and local-only static_token." +order: 6 +--- + +Forge's `a2a` HTTP server (the `/tasks` endpoint and friends) requires every +caller to authenticate through a pluggable provider chain configured in +`forge.yaml`. Each provider recognizes one token shape; the chain tries them +in order, first match wins, and the result lands in `Identity` for the +audit log and any downstream authz hook. + +## Provider matrix + +| Provider | Use case | Token format | Verifies against | Phase | +|---|---|---|---|---| +| `static_token` | Local dev, channel-adapter loopback | Shared secret | constant-time SHA-256 compare | 1 | +| `oidc` | Any IdP with OIDC discovery (Keycloak, Auth0, Okta, Google) | `Authorization: Bearer ` | Issuer's JWKS (TTL-cached, with backoff + stale-grace) | 1 | +| `http_verifier` | Custom verifier endpoint you operate | Opaque token | Your own `/verify` HTTP service | 1 | +| `aws_sigv4` | AWS-IAM-based callers (Lambda, EC2, EKS, IAM users) | `Authorization: Bearer forge-aws-v1.` | AWS STS `GetCallerIdentity` (pre-signed URL pattern) | 2 (v0.11.0) | +| `gcp_iap` | Forge behind GCP HTTPS LB + IAP | `X-Goog-Iap-Jwt-Assertion: ` | IAP's hardcoded JWKS at `www.gstatic.com` | 2 (v0.11.0) | +| `azure_ad` | Microsoft Entra ID tokens | `Authorization: Bearer ` | AAD JWKS via composed `oidc` provider + tenant gate | 2 (v0.11.0) | + +Forge holds **no IdP secrets**. All providers verify a caller-minted +credential against a third party (STS / GCP JWKS / AAD JWKS / your own +`/verify`), then stamp an `Identity` from what the verifier returned. + +## Chain semantics + +Each `Verify` returns one of: + +| Return | Meaning | Chain behavior | +|---|---|---| +| `Identity, nil` | Token accepted | Stops; chain returns this Identity | +| `nil, ErrTokenNotForMe` | "Not my format" | Continues to next provider | +| `nil, ErrTokenRejected` | "My format, but denied" | **Stops; 401** | +| `nil, ErrInvalidToken` | "Malformed" | **Stops; 401** | +| `nil, ErrProviderUnavailable` | "Can't reach my IdP" | **Stops; 401** (fail-closed) | + +The critical rule is **no fall-through on rejection**: if provider A +returns `ErrTokenRejected`, the chain does NOT try provider B. Otherwise +an attacker could downgrade by presenting a malformed token of type A and +hoping to be authenticated as type B. + +### Loopback `static_token` is auto-prepended + +Forge writes a random token to `.forge/runtime.token` (mode `0600`) on +startup and auto-prepends a `static_token` provider for it to the chain. +This is how channel adapters (Slack, Telegram, MS Teams) and the local +Web UI authenticate without you configuring anything. Anyone with read +access to `.forge/runtime.token` can call the a2a server. Treat that +file like an SSH key. + +### Non-Bearer auth headers (Phase 2) + +The middleware consults the chain **even when no `Authorization: Bearer` +was extracted**, provided a non-Bearer auth header is present +(`X-Goog-Iap-Jwt-Assertion`). When there are no auth-shaped headers at +all, the audit reason stays `missing_token` rather than widening to +`not_for_me` — operators can still distinguish "client didn't auth" from +"client tried a format we don't speak." + +## `forge.yaml` schema + +```yaml +auth: + required: true # 401 every unauthenticated request + providers: + - type: oidc | aws_sigv4 | gcp_iap | azure_ad | http_verifier | static_token + settings: + # provider-specific keys (see per-provider sections below) +``` + +Per-provider settings are validated by `forge validate`. Unknown keys +produce a warning (typo detection); the Web UI's `/api/create` endpoint +additionally filters to a closed-key whitelist before scaffolding so +malicious POST payloads can't drop arbitrary keys into `forge.yaml`. + +--- + +## `oidc` — Generic OIDC issuer + +The workhorse provider — any IdP with an OIDC discovery doc and JWKS. + +```yaml +auth: + required: true + providers: + - type: oidc + settings: + issuer: https://login.example.com/auth/realms/forge # required + audience: api://forge # required + client_id: my-spa # optional azp fallback + jwks_url: https://... # optional — overrides discovery + jwks_cache_ttl: 1h + clock_skew: 30s + claim_map: # remap claim names + groups: roles +``` + +- Algorithm whitelist: `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. `none` and HMAC are rejected before key lookup. +- JWKS is TTL-cached with backoff + stale-grace — token verification keeps working through brief JWKS outages. +- Issuer trailing-slash normalization handles the Auth0/Okta disagreement (`https://x/` vs `https://x`). + +--- + +## `aws_sigv4` — AWS IAM via pre-signed STS URL + +Authenticates callers by their AWS-IAM identity. Same pattern as +[`aws-iam-authenticator`](https://github.com/kubernetes-sigs/aws-iam-authenticator) +for EKS: caller pre-signs a `GetCallerIdentity` URL with their AWS SDK +and sends it as a Bearer token; Forge invokes that URL, STS validates +the signature against its own host and returns the canonical ARN. + +```yaml +auth: + required: true + providers: + - type: aws_sigv4 + settings: + region: us-east-1 # required + audience: api://forge # informational; in audit Claims + allowed_accounts: ["412664885516"] # ergonomic: "anyone in these accounts" + allowed_principals: # explicit globs (path.Match syntax) + - "arn:aws:sts::412664885516:assumed-role/ci-deploy/*" + identity_cache_ttl: 60s + max_token_expires: 15m # caps caller's X-Amz-Expires claim + clock_skew: 5m +``` + +### Wire format + +``` +Authorization: Bearer forge-aws-v1. +``` + +The base64-decoded payload is a complete pre-signed URL of the form: + +``` +https://sts..amazonaws.com/ + ?Action=GetCallerIdentity + &Version=2011-06-15 + &X-Amz-Algorithm=AWS4-HMAC-SHA256 + &X-Amz-Credential=///sts/aws4_request + &X-Amz-Date= + &X-Amz-Expires= + &X-Amz-SignedHeaders=host + &X-Amz-Signature= +``` + +### Client side (3 lines) + +```python +import boto3, base64, requests +from botocore.auth import SigV4QueryAuth +from botocore.awsrequest import AWSRequest + +creds = boto3.Session().get_credentials().get_frozen_credentials() +req = AWSRequest(method="GET", + url="https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15") +SigV4QueryAuth(creds, "sts", "us-east-1", expires=900).add_auth(req) +token = "forge-aws-v1." + base64.urlsafe_b64encode(req.url.encode()).rstrip(b"=").decode() + +requests.post(forge_url, headers={"Authorization": f"Bearer {token}"}, data=msg) +``` + +> `boto3.client('sts').generate_presigned_url('get_caller_identity', ...)` +> does **not** work — it signs as if the request were a POST, STS rejects +> the GET. Use the lower-level `SigV4QueryAuth` shown above. Same quirk +> `aws-iam-authenticator` works around internally. + +Reference client ships in [`scripts/forge-aws-sign.py`](../../scripts/forge-aws-sign.py). + +### `allowed_accounts` — "anyone in this account" + +The ergonomic shortcut for whole-account trust. Each 12-digit account ID +expands internally to the canonical glob set covering every STS identity +shape (IAM users, IAM roles, STS assumed-roles incl. SSO, federated +users). Composable with `allowed_principals`. + +### Org-wide trust without enumerating accounts + +There's no STS API to ask "is account X in Org Y?" — AWS deliberately +doesn't expose that. Two production paths: + +1. **AWS IAM Identity Center (SSO).** Every user's session is already an + assumed-role under `AWSReservedSSO_*`. Use a glob: + ```yaml + allowed_principals: + - "arn:aws:sts::ACCT:assumed-role/AWSReservedSSO_*/*" + ``` + Org membership is enforced by Identity Center at sign-in time. + +2. **Entry role with `aws:PrincipalOrgID` condition.** Customer creates + one IAM role in one account with a trust policy that allows anyone in + their Org to assume it. Forge's allowlist contains just that one + assumed-role ARN. The Org-membership check happens at AWS IAM, not in + Forge. + +### Security model + +- **No secret keys on Forge.** STS validates signatures. +- **SSRF guard.** Pre-signed URL host must be `sts..amazonaws.com` exactly; userinfo (`user:pass@`) and foreign hosts are rejected at parse time. +- **No HTTP redirects.** `CheckRedirect` is pinned to `ErrUseLastResponse` so a redirect off `sts.…` (e.g. MITM, TLS-inspecting proxy) can't substitute attacker bytes for the STS response. +- **Freshness gate.** Tokens claiming `X-Amz-Expires > 15min` are rejected; tokens whose `X-Amz-Date + Expires` window has lapsed (with 5min clock skew) are rejected. Bounds stolen-token replay independent of STS's own enforcement. +- **Cache bucketing on `hash(AKID, YYYYMMDD)`** — bounds stolen-key replay to one day worst-case. +- **No `aws-sdk-go-v2` dependency.** STS RPC is ~80 LOC of hand-rolled HTTP + XML. + +### Audit log shape + +```json +{ "event": "auth_verify", + "fields": { + "provider": "aws_sigv4", + "user_id": "arn:aws:sts::123456789012:assumed-role/ci-deploy/i-0abc", + "org_id": "123456789012", + "token_kind": "sigv4" + } +} +``` + +--- + +## `gcp_iap` — GCP Identity-Aware Proxy + +Verifies the JWT IAP forwards as `X-Goog-Iap-Jwt-Assertion` when Forge +sits behind a GCP HTTPS Load Balancer with IAP enabled. + +```yaml +auth: + required: true + providers: + - type: gcp_iap + settings: + audience: /projects/12345678/global/backendServices/9876543210 +``` + +`audience` is the backend service ID — find it in +**GCP Console → Security → IAP → Backend Services → Signed Header JWT Audience**. + +### Security model + +- **Hardcoded JWKS host** (`www.gstatic.com/iap/verify/public_key-jwk`). Operators cannot override — eliminates the "trust attacker's JWKS" failure mode. +- **ES256-only.** Any other alg rejected before key lookup. +- **JWKS merge-on-success.** A partial-but-valid JWKS response can't drop kids the stale-grace contract assumes are kept. +- **No HTTP redirects.** Same `ErrUseLastResponse` pin as `aws_sigv4`. +- **No GCP SDK dependency.** + +Sub `email` / `hd` (Workspace domain) flow through to `Identity.Claims` +for downstream policy. + +--- + +## `azure_ad` — Microsoft Entra ID + +Composes the Phase 1 `oidc` provider for signature verification; layers +AAD-specific concerns on top. + +### Single-tenant (the safe default) + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + tenant_id: 00000000-1111-2222-3333-444444444444 + audience: api://forge + groups_mode: claim # or "graph" +``` + +`tid` claim must equal `tenant_id`; iss is double-checked via OIDC. + +### Multi-tenant with explicit allowlist + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true + allowed_tenants: # case-insensitive GUID match + - "00000000-1111-2222-3333-444444444444" + - "55555555-6666-7777-8888-999999999999" +``` + +### Multi-tenant "any tenant globally" (high-risk) + +```yaml +auth: + required: true + providers: + - type: azure_ad + settings: + audience: api://forge + allow_multi_tenant: true + # allowed_tenants intentionally omitted +``` + +`forge validate` emits a warning so this trade-off is loud, not silent. + +### Groups overage (graph mode) + +When `groups_mode: graph` and the JWT's `groups` claim is empty (AAD +truncates at ~200 groups), Forge calls Microsoft Graph +`/me/transitiveMemberOf` using the **caller's** Bearer to fetch the +full list. Forge holds no Graph credentials of its own. Soft-fails on +Graph 5xx (returns Identity with empty Groups rather than blocking +prod traffic). + +### Security model + +- **Composition over inheritance.** No JWT verify or JWKS code in `azure_ad/` — all crypto lives in `oidc`. +- **Tenant gate.** Single-tenant: `tid == tenant_id`. Multi-tenant + allowlist: `tid ∈ allowed_tenants`. Multi-tenant + empty: no tid check (high-risk, warned). +- **Internal `skip_issuer_check` flag** carries `yaml:"-"` — unreachable from `forge.yaml`, only set by this package when `allow_multi_tenant=true`. +- **No HTTP redirects** on Graph client. Graph nextLink scheme + host both validated to prevent Bearer-downgrade via `https→http` same-host redirects. + +--- + +## `static_token` — Shared secret + +Loopback / dev use. Provider does constant-time SHA-256 comparison so +length-leak / timing attacks are blocked. + +```yaml +auth: + required: true + providers: + - type: static_token + settings: + token_env: FORGE_AUTH_TOKEN # prefer env over literal +``` + +`token:` (literal value in YAML) is also accepted but produces a warning. + +--- + +## `http_verifier` — External `/verify` endpoint + +Legacy / custom — you operate the verifier; Forge POSTs the token to it. + +```yaml +auth: + required: true + providers: + - type: http_verifier + settings: + url: https://auth.example.com/verify + default_org: acme + timeout: 10s +``` + +Same wire format as the pre-Phase-1 `--auth-url` flag. + +--- + +## Egress allowlist auto-extension + +Configuring an auth provider automatically adds the hosts it needs to +the egress allowlist: + +| Provider | Host(s) auto-added | +|---|---| +| `oidc` | ``, `` if explicit | +| `http_verifier` | `` | +| `aws_sigv4` | `sts..amazonaws.com` | +| `gcp_iap` | `www.gstatic.com` | +| `azure_ad` | `login.microsoftonline.com` (+ `graph.microsoft.com` when `groups_mode: graph`) | + +`forge init`'s wizard runs the Auth step before the Egress step, so +operators see the full outbound surface for review in one screen. + +## Wizard / CLI + +`forge init` interactive TUI: pick auth type → enter region / audience / +tenant / etc. → done. Non-interactive equivalent via flags: + +```bash +forge init --non-interactive \ + --name my-agent \ + --model-provider ollama \ + --auth=aws_sigv4 \ + --auth-aws-region=us-east-1 \ + --auth-aws-audience=api://forge \ + --auth-aws-allowed-account=412664885516 +``` + +See [CLI Reference](../reference/cli-reference.md) for the full flag set. + +--- + +## Mesh patterns (agent-to-agent) + +When an agent calls another agent, the receiver's auth provider gates +the call the same way it would for a human or CI. Two common patterns: + +**Single-account "fleet" model.** Every agent runs as a workload in +one dedicated AWS account with its own IAM role; every agent's +`forge.yaml` has `allowed_accounts: []`. Trust boundary = +the account. Onboarding a new agent = create one IAM role; no other +agent's config changes. + +**Per-pair allowlist.** Sensitive agents (touching money, PII, customer +data) override the broad account allowlist with explicit +`allowed_principals` patterns for the specific calling agents allowed. + +See [Audit Logging](audit-logging.md) for how to grep `user_id` across +audit events to map the actual call graph. + +--- + +## Related Documentation + +| Document | Description | +|----------|-------------| +| [Audit Logging](audit-logging.md) | `auth_verify` / `auth_fail` event shape, reason codes, `token_kind` values | +| [Egress Security](egress-control.md) | Auth-host auto-allowlist and how it composes with operator-set domains | +| [Trust Model](trust-model.md) | Caller → Forge trust boundary; what Forge does and doesn't trust | +| [forge.yaml Schema](../reference/forge-yaml-schema.md) | Full YAML reference including `auth:` block | +| [CLI Reference](../reference/cli-reference.md) | `forge init` auth flags | +| [Web Dashboard](../reference/web-dashboard.md) | Auth provider options in the create flow | diff --git a/docs/security/egress-control.md b/docs/security/egress-control.md index afc58b5..b1bc00b 100644 --- a/docs/security/egress-control.md +++ b/docs/security/egress-control.md @@ -212,8 +212,28 @@ The resolver (`forge-core/security/resolver.go`) combines all domain sources: - Start with explicit domains from `forge.yaml` - Add tool-inferred domains - Add capability bundle domains + - **Add auth-provider-derived domains** (see below) - Deduplicate and sort +### Auth-provider domain auto-extension + +Configuring an `auth.providers[]` entry adds the host(s) the provider needs +to reach to the allowlist automatically — operators don't have to remember +to add `sts.us-east-1.amazonaws.com` themselves when they configure +`aws_sigv4`. The `security.AuthDomains` helper centralizes the mapping: + +| Provider | Host(s) added | +|---|---| +| `oidc` | host of `issuer` URL, host of explicit `jwks_url` if set | +| `http_verifier` | host of `url` | +| `aws_sigv4` | `sts..amazonaws.com` (+ test-mode `sts_endpoint` override host) | +| `gcp_iap` | `www.gstatic.com` (hardcoded, IAP JWKS lives there) | +| `azure_ad` | `login.microsoftonline.com` (+ `graph.microsoft.com` when `groups_mode: graph`) | + +`forge init`'s wizard runs the Auth step **before** Egress so the operator +sees the full outbound surface for review in a single screen. See +[Authentication](authentication.md) for the per-provider auth model. + ## Build Artifacts The `EgressStage` generates: diff --git a/docs/security/overview.md b/docs/security/overview.md index 897f604..a2e4515 100644 --- a/docs/security/overview.md +++ b/docs/security/overview.md @@ -18,6 +18,10 @@ Forge's security is organized in layers, each addressing a different threat surf │ Global Guardrails │ │ (content filtering, PII, jailbreak) │ ├──────────────────────────────────────────────────────────────┤ +│ Authentication (a2a) │ +│ (Pluggable provider chain: OIDC, AWS Sigv4, GCP IAP, │ +│ Azure AD, http_verifier, static_token loopback) │ +├──────────────────────────────────────────────────────────────┤ │ Egress Enforcement │ │ (EgressEnforcer + EgressProxy + SafeDialer + NetworkPolicy) │ ├──────────────────────────────────────────────────────────────┤ @@ -39,6 +43,7 @@ Forge's security is organized in layers, each addressing a different threat surf ## Table of Contents - [Network Posture](#network-posture) +- [Authentication](#authentication) - [Egress Enforcement](#egress-enforcement) - [Execution Sandboxing](#execution-sandboxing) - [Secrets Management](#secrets-management) @@ -48,6 +53,32 @@ Forge's security is organized in layers, each addressing a different threat surf - [Container Security](#container-security) - [Related Documentation](#related-documentation) +## Authentication + +The `/tasks` HTTP endpoint requires every caller to authenticate through a +pluggable provider chain configured in `forge.yaml`. Forge ships six provider +types and holds no IdP secrets — every provider verifies against a third +party (STS, GCP JWKS, AAD JWKS, your custom verifier) or a local file. + +| Provider | Use case | Wire format | +|---|---|---| +| `static_token` | Loopback (channel adapters, local Web UI dashboard) | Bearer literal | +| `oidc` | Generic OIDC IdP (Keycloak, Auth0, Okta, Google) | Bearer JWT | +| `aws_sigv4` | AWS IAM identities (Lambda, EC2, EKS, SSO users) | `Bearer forge-aws-v1.` | +| `gcp_iap` | Behind GCP HTTPS LB + IAP | `X-Goog-Iap-Jwt-Assertion: ` | +| `azure_ad` | Microsoft Entra ID | Bearer JWT | +| `http_verifier` | Custom `/verify` endpoint you operate | Opaque token | + +The chain is first-match-wins with **fail-closed on rejection** — a malformed +token of type A doesn't fall through to type B. The local Web UI dashboard +and channel adapters use an auto-prepended `static_token` (the `runtime.token` +file under `.forge/`) so they keep working regardless of how external auth is +configured. + +See [Authentication Providers](authentication.md) for the complete reference, +including per-provider security model, client-side recipes, and mesh +(agent-to-agent) patterns. + --- ## Network Posture @@ -267,6 +298,7 @@ Production builds enforce: | Document | Description | |----------|-------------| +| [Authentication](authentication.md) | Pluggable auth providers (OIDC, AWS Sigv4, GCP IAP, Azure AD, etc.) gating the a2a HTTP server | | [Egress Security](egress-control.md) | Deep dive into egress enforcement: IP validation, SafeDialer, profiles, modes, domain matching, proxy architecture, NetworkPolicy | | [Secrets Management](secret-management.md) | Encrypted storage, per-agent secrets, passphrase handling | | [Build Signing & Verification](build-signing.md) | Key management, build signing, runtime verification | diff --git a/forge-cli/cmd/init.go b/forge-cli/cmd/init.go index fca28d4..ba1e71b 100644 --- a/forge-cli/cmd/init.go +++ b/forge-cli/cmd/init.go @@ -135,12 +135,27 @@ func init() { // Auth chain (PR5+). All optional. When --auth is unset or "none", no // auth: block is written and the agent runs anonymously. - initCmd.Flags().String("auth", "", "auth mode: none, oidc, http_verifier, custom") + initCmd.Flags().String("auth", "", "auth mode: none, oidc, http_verifier, aws_sigv4, gcp_iap, azure_ad, custom") initCmd.Flags().String("auth-issuer", "", "OIDC issuer URL (required with --auth=oidc)") initCmd.Flags().String("auth-audience", "", "OIDC audience (required with --auth=oidc)") initCmd.Flags().String("auth-url", "", "verifier URL (required with --auth=http_verifier)") initCmd.Flags().String("auth-default-org", "", "default org_id for http_verifier (optional)") initCmd.Flags().String("auth-groups-claim", "", "claim name for groups (oidc, default: groups)") + + // Phase 2: per-provider flags. Namespaced so help text stays grouped. + initCmd.Flags().String("auth-aws-region", "", "AWS region for aws_sigv4 (e.g. us-east-1)") + initCmd.Flags().String("auth-aws-audience", "", "informational audience for aws_sigv4") + initCmd.Flags().StringSlice("auth-aws-allowed-principal", nil, "allowed principal glob for aws_sigv4 (repeatable)") + initCmd.Flags().StringSlice("auth-aws-allowed-account", nil, "allowed AWS account ID for aws_sigv4 (repeatable; ergonomic shortcut for whole-account access)") + initCmd.Flags().String("auth-aws-cache-ttl", "", "aws_sigv4 identity cache TTL (default 60s)") + + initCmd.Flags().String("auth-gcp-iap-audience", "", "audience for gcp_iap (backend service ID)") + + initCmd.Flags().String("auth-azure-tenant", "", "Entra tenant GUID (required unless --auth-azure-multi-tenant)") + initCmd.Flags().String("auth-azure-audience", "", "Entra application ID URI / audience") + initCmd.Flags().Bool("auth-azure-multi-tenant", false, "accept tokens from any Entra tenant (default false)") + initCmd.Flags().StringSlice("auth-azure-allowed-tenant", nil, "allowed Entra tenant GUID for multi-tenant mode (repeatable; empty list = any tenant globally)") + initCmd.Flags().String("auth-azure-groups-mode", "", "azure_ad groups mode: claim (default) or graph") } func runInit(cmd *cobra.Command, args []string) error { @@ -254,13 +269,25 @@ func collectInteractive(opts *initOptions) error { } } - // Build the egress derivation callback (avoids circular import) - deriveEgressFn := func(provider string, channels, tools, selectedSkills []string, envVars map[string]string) []string { + // Build the egress derivation callback (avoids circular import). + // Note: the auth params get the auth step's choice forwarded into + // deriveEgressDomains, which delegates to authEgressHostsFromSettings + // — the same translation used by `forge init --auth=…` so TUI and + // non-interactive paths produce identical egress lists. + deriveEgressFn := func( + provider string, + channels, tools, selectedSkills []string, + envVars map[string]string, + authMode string, + authSettings map[string]any, + ) []string { tmpOpts := &initOptions{ ModelProvider: provider, Channels: channels, BuiltinTools: tools, EnvVars: envVars, + AuthMode: authMode, + AuthSettings: authSettings, } selectedInfos := lookupSelectedSkills(selectedSkills) return deriveEgressDomains(tmpOpts, selectedInfos) @@ -281,7 +308,14 @@ func collectInteractive(opts *initOptions) error { return validateWebSearchKey(provider, key) } - // Build step list + // Build step list. + // + // Auth comes BEFORE Egress so the operator's auth choice — and the + // hosts it requires (sts..amazonaws.com, login.microsoftonline.com, + // etc.) — appear in the Egress review for confirmation. Without this + // ordering, the egress list would miss auth hosts and a Forge instance + // could fail at runtime because its OIDC discovery or STS call gets + // blocked by the very allowlist the wizard just rendered. wizardSteps := []tui.Step{ steps.NewNameStep(styles, opts.Name), steps.NewProviderStep(styles, validateKeyFn, oauthFlowFn), @@ -289,8 +323,8 @@ func collectInteractive(opts *initOptions) error { steps.NewChannelStep(styles), steps.NewToolsStep(styles, toolInfos, validateWebSearchKeyFn), steps.NewSkillsStep(styles, skillInfos), - steps.NewEgressStep(styles, deriveEgressFn), steps.NewAuthStep(styles), + steps.NewEgressStep(styles, deriveEgressFn), steps.NewReviewStep(styles), // scaffold is handled by the caller after collectInteractive returns } diff --git a/forge-cli/cmd/init_auth.go b/forge-cli/cmd/init_auth.go index 962aff8..e38e41e 100644 --- a/forge-cli/cmd/init_auth.go +++ b/forge-cli/cmd/init_auth.go @@ -4,6 +4,7 @@ import ( "fmt" "net/url" "sort" + "strconv" "strings" "github.com/spf13/cobra" @@ -49,8 +50,74 @@ func buildAuthFromFlags(cmd *cobra.Command, mode string) (settings map[string]an egressHosts = []string{host} } return settings, egressHosts, nil + case "aws_sigv4": + region, _ := cmd.Flags().GetString("auth-aws-region") + audience, _ := cmd.Flags().GetString("auth-aws-audience") + allowedPrincipals, _ := cmd.Flags().GetStringSlice("auth-aws-allowed-principal") + allowedAccounts, _ := cmd.Flags().GetStringSlice("auth-aws-allowed-account") + cacheTTL, _ := cmd.Flags().GetString("auth-aws-cache-ttl") + if region == "" { + return nil, nil, fmt.Errorf("--auth=aws_sigv4 requires --auth-aws-region") + } + settings = map[string]any{"region": region} + if audience != "" { + settings["audience"] = audience + } + if len(allowedPrincipals) > 0 { + settings["allowed_principals"] = allowedPrincipals + } + if len(allowedAccounts) > 0 { + settings["allowed_accounts"] = allowedAccounts + } + if cacheTTL != "" { + settings["identity_cache_ttl"] = cacheTTL + } + egressHosts = []string{"sts." + region + ".amazonaws.com"} + return settings, egressHosts, nil + case "gcp_iap": + audience, _ := cmd.Flags().GetString("auth-gcp-iap-audience") + if audience == "" { + return nil, nil, fmt.Errorf("--auth=gcp_iap requires --auth-gcp-iap-audience") + } + settings = map[string]any{"audience": audience} + // IAP JWKS host is hardcoded (decision §9.4). + egressHosts = []string{"www.gstatic.com"} + return settings, egressHosts, nil + case "azure_ad": + tenant, _ := cmd.Flags().GetString("auth-azure-tenant") + audience, _ := cmd.Flags().GetString("auth-azure-audience") + multiTenant, _ := cmd.Flags().GetBool("auth-azure-multi-tenant") + allowedTenants, _ := cmd.Flags().GetStringSlice("auth-azure-allowed-tenant") + groupsMode, _ := cmd.Flags().GetString("auth-azure-groups-mode") + if audience == "" { + return nil, nil, fmt.Errorf("--auth=azure_ad requires --auth-azure-audience") + } + if !multiTenant && tenant == "" { + return nil, nil, fmt.Errorf("--auth=azure_ad requires --auth-azure-tenant unless --auth-azure-multi-tenant=true") + } + if !multiTenant && len(allowedTenants) > 0 { + return nil, nil, fmt.Errorf("--auth-azure-allowed-tenant is only meaningful with --auth-azure-multi-tenant=true") + } + settings = map[string]any{"audience": audience} + if tenant != "" { + settings["tenant_id"] = tenant + } + if multiTenant { + settings["allow_multi_tenant"] = true + } + if len(allowedTenants) > 0 { + settings["allowed_tenants"] = allowedTenants + } + if groupsMode != "" { + settings["groups_mode"] = groupsMode + } + egressHosts = []string{"login.microsoftonline.com"} + if groupsMode == "graph" { + egressHosts = append(egressHosts, "graph.microsoft.com") + } + return settings, egressHosts, nil default: - return nil, nil, fmt.Errorf("unknown --auth value %q (supported: none, oidc, http_verifier, custom)", mode) + return nil, nil, fmt.Errorf("unknown --auth value %q (supported: none, oidc, http_verifier, aws_sigv4, gcp_iap, azure_ad, custom)", mode) } } @@ -77,6 +144,19 @@ func authEgressHostsFromSettings(mode string, settings map[string]any) []string if h := hostFromURL(asStringSetting(settings, "url")); h != "" { hosts = append(hosts, h) } + case "aws_sigv4": + region := asStringSetting(settings, "region") + if region != "" { + hosts = append(hosts, "sts."+region+".amazonaws.com") + } + case "gcp_iap": + // Decision §9.4: hardcoded JWKS host. + hosts = append(hosts, "www.gstatic.com") + case "azure_ad": + hosts = append(hosts, "login.microsoftonline.com") + if asStringSetting(settings, "groups_mode") == "graph" { + hosts = append(hosts, "graph.microsoft.com") + } } return hosts } @@ -196,6 +276,25 @@ func writeYAMLMap(b *strings.Builder, m map[string]any, indent string) { case map[string]any: fmt.Fprintf(b, "%s%s:\n", indent, k) writeYAMLMap(b, val, indent+" ") + case []string: + // Phase 2: aws_sigv4's allowed_principals is the only []string + // currently in the auth settings schema. ARNs frequently contain + // `:` so each entry goes through yamlScalar for proper quoting. + fmt.Fprintf(b, "%s%s:\n", indent, k) + for _, item := range val { + fmt.Fprintf(b, "%s - %s\n", indent, yamlScalar(item)) + } + case []any: + // Defensive: when settings come from YAML unmarshal, lists land + // as []any. Coerce per-element and reuse the string path. + fmt.Fprintf(b, "%s%s:\n", indent, k) + for _, item := range val { + if s, ok := item.(string); ok { + fmt.Fprintf(b, "%s - %s\n", indent, yamlScalar(s)) + } else { + fmt.Fprintf(b, "%s - %v\n", indent, item) + } + } case string: fmt.Fprintf(b, "%s%s: %s\n", indent, k, yamlScalar(val)) default: @@ -270,5 +369,119 @@ func needsYAMLQuoting(s string) bool { case "true", "false", "yes", "no", "on", "off", "null", "~": return true } + // Anything that resembles a YAML 1.1 / 1.2 number must be quoted + // to preserve string semantics. The auth-settings schema rarely + // produces these shapes (account IDs ARE all-digit; others are + // theoretical), but the docstring says "false negatives are bugs" + // and the Web UI POST path can supply arbitrary strings. + // + // Covers: + // - All-digit: "412664885516" (AWS account) + // - Leading-zero: "010" (YAML 1.1 octal) + // - Hex: "0x1A" (YAML 1.1 hex) + // - Octal: "0o17" (YAML 1.2 octal) + // - Binary: "0b10" (YAML 1.2 binary) + // - Scientific notation: "1e10", "1.5E-3" + // - Decimal float: "3.14" + // - Signed: "-5", "+7" + // - Special floats: ".inf", ".nan", "-.inf" + if looksNumeric(s) { + return true + } return false } + +// looksNumeric reports whether s would parse as a YAML number under any +// of YAML 1.1 / 1.2 / yaml.v3's relaxed coercion rules. Conservative — +// false positives just produce extra quotes. +func looksNumeric(s string) bool { + if s == "" { + return false + } + // YAML special floats: case-insensitive. + switch strings.ToLower(s) { + case ".inf", ".nan", "-.inf", "+.inf": + return true + } + // Strip a leading sign once for the remaining shape tests. + body := s + if body[0] == '+' || body[0] == '-' { + body = body[1:] + if body == "" { + return false + } + } + // Hex / Octal / Binary prefixes. + if len(body) >= 2 && body[0] == '0' { + switch body[1] { + case 'x', 'X': + return allHexDigits(body[2:]) + case 'o', 'O': + return allOctalDigits(body[2:]) + case 'b', 'B': + return allBinaryDigits(body[2:]) + } + } + // Plain decimal integer (incl. leading-zero like "010", which + // YAML 1.1 would parse as octal). + if isAllDigits(body) { + return true + } + // Decimal float / scientific notation: rely on Go's strconv.ParseFloat + // — if it accepts the bytes as a float, YAML will too. + if _, err := strconv.ParseFloat(body, 64); err == nil { + return true + } + return false +} + +func allHexDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + c := s[i] + ok := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') + if !ok { + return false + } + } + return true +} + +func allOctalDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] < '0' || s[i] > '7' { + return false + } + } + return true +} + +func allBinaryDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] != '0' && s[i] != '1' { + return false + } + } + return true +} + +// isAllDigits returns true if s is non-empty and consists only of ASCII digits. +func isAllDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] < '0' || s[i] > '9' { + return false + } + } + return true +} diff --git a/forge-cli/cmd/init_auth_test.go b/forge-cli/cmd/init_auth_test.go index 3c6a78e..88288a3 100644 --- a/forge-cli/cmd/init_auth_test.go +++ b/forge-cli/cmd/init_auth_test.go @@ -216,6 +216,17 @@ func mockInitCmdWithFlags() *cobra.Command { c.Flags().String("auth-url", "", "") c.Flags().String("auth-default-org", "", "") c.Flags().String("auth-groups-claim", "", "") + // Phase 2 flags: + c.Flags().String("auth-aws-region", "", "") + c.Flags().String("auth-aws-audience", "", "") + c.Flags().StringSlice("auth-aws-allowed-principal", nil, "") + c.Flags().String("auth-aws-cache-ttl", "", "") + c.Flags().String("auth-gcp-iap-audience", "", "") + c.Flags().String("auth-azure-tenant", "", "") + c.Flags().String("auth-azure-audience", "", "") + c.Flags().Bool("auth-azure-multi-tenant", false, "") + c.Flags().StringSlice("auth-azure-allowed-tenant", nil, "") + c.Flags().String("auth-azure-groups-mode", "", "") return c } @@ -427,3 +438,223 @@ func TestMergeEgressDomains_SortedOutput(t *testing.T) { t.Errorf("output not sorted:\n got %v\n want %v", got, want) } } + +// --- Phase 2 renderer + flag tests --- + +func TestRenderAuthBlock_AWSSigv4_Minimal(t *testing.T) { + got := renderAuthBlock("aws_sigv4", map[string]any{"region": "us-east-1"}) + if !strings.Contains(got, "type: aws_sigv4") { + t.Errorf("missing type line:\n%s", got) + } + if !strings.Contains(got, "region: us-east-1") { + t.Errorf("missing region:\n%s", got) + } + if strings.Contains(got, "audience:") { + t.Errorf("audience should not be emitted when unset:\n%s", got) + } +} + +func TestRenderAuthBlock_AWSSigv4_ARNsList(t *testing.T) { + wantARNs := []string{ + "arn:aws:sts::123:assumed-role/ci-deploy/*", + "arn:aws:sts::123:assumed-role/forge-runner/*", + } + got := renderAuthBlock("aws_sigv4", map[string]any{ + "region": "us-east-1", + "allowed_principals": wantARNs, + }) + if !strings.Contains(got, "allowed_principals:") { + t.Errorf("missing allowed_principals header:\n%s", got) + } + // ARN contains ":" but not ": " (colon-space), so YAML doesn't need + // to quote it — verify round-trip parse instead of insisting on quotes. + parsed := parseAuthBlockYAML(t, got) + provs, _ := parsed["auth"].(map[string]any)["providers"].([]any) + settings := provs[0].(map[string]any)["settings"].(map[string]any) + gotARNs := toStringSlice(settings["allowed_principals"]) + if !reflect.DeepEqual(gotARNs, wantARNs) { + t.Errorf("ARN round-trip mismatch:\n got %v\n want %v", gotARNs, wantARNs) + } +} + +func TestRenderAuthBlock_GCPIAP(t *testing.T) { + wantAud := "/projects/12345/global/backendServices/67890" + got := renderAuthBlock("gcp_iap", map[string]any{ + "audience": wantAud, + }) + if !strings.Contains(got, "type: gcp_iap") { + t.Errorf("missing type:\n%s", got) + } + parsed := parseAuthBlockYAML(t, got) + provs, _ := parsed["auth"].(map[string]any)["providers"].([]any) + gotAud := provs[0].(map[string]any)["settings"].(map[string]any)["audience"] + if gotAud != wantAud { + t.Errorf("audience round-trip mismatch:\n got %v\n want %v", gotAud, wantAud) + } +} + +func parseAuthBlockYAML(t *testing.T, block string) map[string]any { + t.Helper() + var out map[string]any + if err := yaml.Unmarshal([]byte(block), &out); err != nil { + t.Fatalf("renderAuthBlock output is not valid YAML: %v\n%s", err, block) + } + return out +} + +func toStringSlice(v any) []string { + arr, ok := v.([]any) + if !ok { + return nil + } + out := make([]string, len(arr)) + for i, x := range arr { + out[i], _ = x.(string) + } + return out +} + +func TestRenderAuthBlock_AzureAD_SingleTenant(t *testing.T) { + got := renderAuthBlock("azure_ad", map[string]any{ + "tenant_id": "00000000-1111-2222-3333-444444444444", + "audience": "api://forge", + }) + if !strings.Contains(got, "tenant_id: 00000000-1111-2222-3333-444444444444") { + t.Errorf("missing tenant_id:\n%s", got) + } + if strings.Contains(got, "allow_multi_tenant") { + t.Errorf("allow_multi_tenant should be omitted when not set:\n%s", got) + } +} + +func TestRenderAuthBlock_AzureAD_MultiTenant(t *testing.T) { + got := renderAuthBlock("azure_ad", map[string]any{ + "audience": "api://forge", + "allow_multi_tenant": true, + }) + if !strings.Contains(got, "allow_multi_tenant: true") { + t.Errorf("missing allow_multi_tenant true:\n%s", got) + } + if strings.Contains(got, "tenant_id") { + t.Errorf("tenant_id should not appear when multi-tenant:\n%s", got) + } +} + +func TestRenderAuthBlock_AzureAD_GroupsModeGraph(t *testing.T) { + got := renderAuthBlock("azure_ad", map[string]any{ + "tenant_id": "abc", + "audience": "api://forge", + "groups_mode": "graph", + }) + if !strings.Contains(got, "groups_mode: graph") { + t.Errorf("missing groups_mode:\n%s", got) + } +} + +func TestBuildAuthFromFlags_AWSSigv4_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-aws-region", "us-east-1") + _ = c.Flags().Set("auth-aws-audience", "api://forge") + _ = c.Flags().Set("auth-aws-allowed-principal", "arn:aws:sts::123:assumed-role/x/*") + + settings, hosts, err := buildAuthFromFlags(c, "aws_sigv4") + if err != nil { + t.Fatal(err) + } + if settings["region"] != "us-east-1" { + t.Errorf("region = %v", settings["region"]) + } + if !reflect.DeepEqual(hosts, []string{"sts.us-east-1.amazonaws.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_AWSSigv4_MissingRegion(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "aws_sigv4") + if err == nil { + t.Fatal("expected error when region missing") + } +} + +func TestBuildAuthFromFlags_GCPIAP_HappyPath(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-gcp-iap-audience", "/projects/12345/...") + settings, hosts, err := buildAuthFromFlags(c, "gcp_iap") + if err != nil { + t.Fatal(err) + } + if settings["audience"] == "" { + t.Error("audience not captured") + } + if !reflect.DeepEqual(hosts, []string{"www.gstatic.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_GCPIAP_MissingAudience(t *testing.T) { + c := mockInitCmdWithFlags() + _, _, err := buildAuthFromFlags(c, "gcp_iap") + if err == nil { + t.Fatal("expected error when audience missing") + } +} + +func TestBuildAuthFromFlags_AzureAD_RequiresTenant(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-azure-audience", "api://forge") + _, _, err := buildAuthFromFlags(c, "azure_ad") + if err == nil { + t.Fatal("expected error when tenant + non-multi-tenant") + } +} + +func TestBuildAuthFromFlags_AzureAD_MultiTenantOK(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-azure-audience", "api://forge") + _ = c.Flags().Set("auth-azure-multi-tenant", "true") + settings, hosts, err := buildAuthFromFlags(c, "azure_ad") + if err != nil { + t.Fatal(err) + } + if settings["allow_multi_tenant"] != true { + t.Errorf("allow_multi_tenant not set: %v", settings) + } + if !reflect.DeepEqual(hosts, []string{"login.microsoftonline.com"}) { + t.Errorf("hosts = %v", hosts) + } +} + +func TestBuildAuthFromFlags_AzureAD_GraphModeAddsGraphHost(t *testing.T) { + c := mockInitCmdWithFlags() + _ = c.Flags().Set("auth-azure-audience", "api://forge") + _ = c.Flags().Set("auth-azure-tenant", "abc-tid") + _ = c.Flags().Set("auth-azure-groups-mode", "graph") + _, hosts, err := buildAuthFromFlags(c, "azure_ad") + if err != nil { + t.Fatal(err) + } + want := []string{"login.microsoftonline.com", "graph.microsoft.com"} + if !reflect.DeepEqual(hosts, want) { + t.Errorf("hosts = %v, want %v", hosts, want) + } +} + +func TestAuthEgressHostsFromSettings_Phase2(t *testing.T) { + cases := []struct { + mode string + settings map[string]any + want []string + }{ + {"aws_sigv4", map[string]any{"region": "us-east-1"}, []string{"sts.us-east-1.amazonaws.com"}}, + {"gcp_iap", map[string]any{"audience": "x"}, []string{"www.gstatic.com"}}, + {"azure_ad", map[string]any{"audience": "x"}, []string{"login.microsoftonline.com"}}, + {"azure_ad", map[string]any{"audience": "x", "groups_mode": "graph"}, []string{"login.microsoftonline.com", "graph.microsoft.com"}}, + } + for _, tc := range cases { + got := authEgressHostsFromSettings(tc.mode, tc.settings) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("%s %v: got %v, want %v", tc.mode, tc.settings, got, tc.want) + } + } +} diff --git a/forge-cli/cmd/init_egress.go b/forge-cli/cmd/init_egress.go index 92b632f..fac10b7 100644 --- a/forge-cli/cmd/init_egress.go +++ b/forge-cli/cmd/init_egress.go @@ -93,6 +93,18 @@ func deriveEgressDomains(opts *initOptions, skills []contract.SkillDescriptor) [ } } + // 5. Auth-provider domains — same translation the non-interactive + // --auth=… path uses, so TUI and CLI render identical egress lists. + // Examples: + // oidc → host extracted from issuer URL + // aws_sigv4 → sts..amazonaws.com + // gcp_iap → www.gstatic.com (hardcoded §9.4) + // azure_ad → login.microsoftonline.com (+ graph.microsoft.com + // when groups_mode=graph) + for _, h := range authEgressHostsFromSettings(opts.AuthMode, opts.AuthSettings) { + add(h) + } + sort.Strings(domains) return domains } diff --git a/forge-cli/cmd/init_test.go b/forge-cli/cmd/init_test.go index 279b3b8..8bcb4e2 100644 --- a/forge-cli/cmd/init_test.go +++ b/forge-cli/cmd/init_test.go @@ -513,6 +513,115 @@ func TestDeriveEgressDomains_Empty(t *testing.T) { } } +// TestDeriveEgressDomains_AuthProviderHostsMerged confirms the wizard +// re-order invariant: a chosen Auth provider contributes its required +// hosts (STS for aws_sigv4, AAD authority for azure_ad, etc.) to the +// same egress list a user reviews in the Egress step. Pins the contract +// that the operator never has to add auth hosts manually after the wizard. +func TestDeriveEgressDomains_AuthProviderHostsMerged(t *testing.T) { + cases := []struct { + name string + mode string + set map[string]any + want []string + }{ + { + name: "aws_sigv4 us-east-1", + mode: "aws_sigv4", + set: map[string]any{"region": "us-east-1"}, + want: []string{"sts.us-east-1.amazonaws.com"}, + }, + { + name: "aws_sigv4 eu-west-2", + mode: "aws_sigv4", + set: map[string]any{"region": "eu-west-2"}, + want: []string{"sts.eu-west-2.amazonaws.com"}, + }, + { + name: "gcp_iap", + mode: "gcp_iap", + set: map[string]any{"audience": "/projects/x/global/backendServices/y"}, + want: []string{"www.gstatic.com"}, + }, + { + name: "azure_ad claim", + mode: "azure_ad", + set: map[string]any{"audience": "api://forge", "tenant_id": "abc"}, + want: []string{"login.microsoftonline.com"}, + }, + { + name: "azure_ad graph mode adds graph host", + mode: "azure_ad", + set: map[string]any{"audience": "api://forge", "tenant_id": "abc", "groups_mode": "graph"}, + want: []string{"graph.microsoft.com", "login.microsoftonline.com"}, + }, + { + name: "oidc issuer host", + mode: "oidc", + set: map[string]any{"issuer": "https://login.example.com", "audience": "api://forge"}, + want: []string{"login.example.com"}, + }, + { + name: "none → no auth hosts added", + mode: "none", + set: nil, + want: nil, + }, + { + name: "custom → no auth hosts added", + mode: "custom", + set: nil, + want: nil, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + opts := &initOptions{ + ModelProvider: "ollama", // no provider hosts → exposes only auth hosts + EnvVars: map[string]string{}, + AuthMode: tc.mode, + AuthSettings: tc.set, + } + got := deriveEgressDomains(opts, nil) + if len(got) != len(tc.want) { + t.Fatalf("len(got) = %d (%v), want %d (%v)", len(got), got, len(tc.want), tc.want) + } + for i, d := range tc.want { + if got[i] != d { + t.Errorf("got[%d] = %q, want %q (full got: %v)", i, got[i], d, got) + } + } + }) + } +} + +// TestDeriveEgressDomains_AuthHostsMergeNotOverwrite confirms auth hosts +// coexist with provider/channel/skill hosts — the auth pass is additive, +// not exclusive. +func TestDeriveEgressDomains_AuthHostsMergeNotOverwrite(t *testing.T) { + opts := &initOptions{ + ModelProvider: "openai", + Channels: []string{"slack"}, + EnvVars: map[string]string{}, + AuthMode: "aws_sigv4", + AuthSettings: map[string]any{"region": "us-east-1"}, + } + got := deriveEgressDomains(opts, nil) + have := map[string]bool{} + for _, d := range got { + have[d] = true + } + for _, want := range []string{ + "api.openai.com", // model provider + "slack.com", // channel + "sts.us-east-1.amazonaws.com", // auth + } { + if !have[want] { + t.Errorf("missing %q in merged egress list: %v", want, got) + } + } +} + func TestBuildEnvVars(t *testing.T) { opts := &initOptions{ ModelProvider: "openai", diff --git a/forge-cli/internal/tui/steps/egress_step.go b/forge-cli/internal/tui/steps/egress_step.go index 93739bd..592ac16 100644 --- a/forge-cli/internal/tui/steps/egress_step.go +++ b/forge-cli/internal/tui/steps/egress_step.go @@ -10,7 +10,19 @@ import ( ) // DeriveEgressFunc computes egress domains from wizard context. -type DeriveEgressFunc func(provider string, channels, tools, skills []string, envVars map[string]string) []string +// +// authMode + authSettings are the user's choice from the preceding Auth step. +// Auth-derived hosts (e.g. sts..amazonaws.com for aws_sigv4, +// login.microsoftonline.com for azure_ad) are merged into the egress list so +// the Egress review displays the FULL outbound surface — operators see and +// confirm the auth hosts alongside provider/channel/tool/skill hosts. +type DeriveEgressFunc func( + provider string, + channels, tools, skills []string, + envVars map[string]string, + authMode string, + authSettings map[string]any, +) []string // EgressStep handles egress domain review. type EgressStep struct { @@ -32,6 +44,10 @@ func NewEgressStep(styles *tui.StyleSet, deriveFn DeriveEgressFunc) *EgressStep } // Prepare computes egress domains using the accumulated wizard context. +// +// The Auth step runs BEFORE Egress in the wizard order (see init.go), so by +// the time Prepare runs, ctx.AuthMode and ctx.AuthSettings reflect the +// operator's choice and we can compute auth-derived hosts. func (s *EgressStep) Prepare(ctx *tui.WizardContext) { var channels []string if ctx.Channel != "" && ctx.Channel != "none" { @@ -40,7 +56,7 @@ func (s *EgressStep) Prepare(ctx *tui.WizardContext) { s.domains = nil if s.deriveFn != nil { - s.domains = s.deriveFn(ctx.Provider, channels, ctx.BuiltinTools, ctx.Skills, ctx.EnvVars) + s.domains = s.deriveFn(ctx.Provider, channels, ctx.BuiltinTools, ctx.Skills, ctx.EnvVars, ctx.AuthMode, ctx.AuthSettings) } s.empty = len(s.domains) == 0 @@ -126,6 +142,13 @@ func (s *EgressStep) Apply(ctx *tui.WizardContext) { // inferSource guesses the source of an egress domain based on context. func inferSource(domain string, ctx *tui.WizardContext) string { + // Auth provider domains — checked FIRST so an OIDC issuer host like + // "login.example.com" is correctly attributed to "oidc auth" rather + // than falling through to a generic "configured" label. + if src := authProviderForDomain(domain, ctx); src != "" { + return src + } + // Provider domains providerDomains := map[string]string{ "api.openai.com": "model provider", @@ -170,3 +193,85 @@ func inferSource(domain string, ctx *tui.WizardContext) string { return "configured" } + +// authProviderForDomain returns a human-friendly label when `domain` is +// known to be required by the operator's chosen auth provider, or "" when +// the domain wasn't sourced from auth. +// +// The matching is intentionally narrow: we compare against the hosts each +// provider actually contributes (computed elsewhere via authEgressHostsFromSettings). +// For oidc/http_verifier, the host is dynamic (issuer/url-derived) so we +// match by exact string against what ctx.AuthSettings says the issuer +// resolves to. +func authProviderForDomain(domain string, ctx *tui.WizardContext) string { + if ctx == nil || ctx.AuthMode == "" || ctx.AuthMode == "none" || ctx.AuthMode == "custom" { + return "" + } + switch ctx.AuthMode { + case "aws_sigv4": + // sts..amazonaws.com is the only contributed host. + region, _ := ctx.AuthSettings["region"].(string) + if region != "" && domain == "sts."+region+".amazonaws.com" { + return "aws_sigv4 auth" + } + case "gcp_iap": + if domain == "www.gstatic.com" { + return "gcp_iap auth" + } + case "azure_ad": + if domain == "login.microsoftonline.com" { + return "azure_ad auth" + } + if domain == "graph.microsoft.com" { + return "azure_ad auth (graph)" + } + case "oidc": + // Best-effort: match the configured issuer's host. + issuer, _ := ctx.AuthSettings["issuer"].(string) + if h := hostOf(issuer); h != "" && h == domain { + return "oidc auth" + } + jwks, _ := ctx.AuthSettings["jwks_url"].(string) + if h := hostOf(jwks); h != "" && h == domain { + return "oidc auth (jwks)" + } + case "http_verifier": + url, _ := ctx.AuthSettings["url"].(string) + if h := hostOf(url); h != "" && h == domain { + return "http_verifier auth" + } + } + return "" +} + +// hostOf extracts the host portion of a URL, returning "" on parse failure +// or missing host. Kept local to avoid pulling net/url into a hot iteration +// path elsewhere. +func hostOf(raw string) string { + if raw == "" { + return "" + } + // Cheap manual split; matches what `(*url.URL).Hostname()` does for the + // well-formed inputs the wizard collects. + const sep = "://" + i := indexOf(raw, sep) + if i < 0 { + return "" + } + rest := raw[i+len(sep):] + for k := 0; k < len(rest); k++ { + if rest[k] == '/' || rest[k] == ':' { + return rest[:k] + } + } + return rest +} + +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} diff --git a/forge-cli/internal/tui/steps/step_auth.go b/forge-cli/internal/tui/steps/step_auth.go index 000ef16..82321d0 100644 --- a/forge-cli/internal/tui/steps/step_auth.go +++ b/forge-cli/internal/tui/steps/step_auth.go @@ -22,6 +22,16 @@ const ( authOIDCGroupsClaimPhase authHTTPURLPhase authHTTPOrgPhase + // Phase 2: aws_sigv4 + authAWSRegionPhase + authAWSAudiencePhase + authAWSAccountsPhase + // Phase 2: gcp_iap + authGCPIAPAudiencePhase + // Phase 2: azure_ad (single-tenant only; multi-tenant requires YAML + // edit so it's a deliberate choice rather than an accidental toggle) + authAADTenantPhase + authAADAudiencePhase authDonePhase ) @@ -31,6 +41,9 @@ const ( AuthModeNone = "none" AuthModeOIDC = "oidc" AuthModeHTTPVerifier = "http_verifier" + AuthModeAWSSigv4 = "aws_sigv4" + AuthModeGCPIAP = "gcp_iap" + AuthModeAzureAD = "azure_ad" AuthModeCustom = "custom" ) @@ -67,6 +80,20 @@ type AuthStep struct { httpURL string httpOrg string + // aws_sigv4 settings + awsRegion string + awsAudience string + awsAccounts []string // 12-digit AWS account IDs + + // gcp_iap settings + gcpAudience string + + // azure_ad settings (TUI is single-tenant only; multi-tenant + // requires editing forge.yaml directly so it stays a deliberate + // security decision) + aadTenant string + aadAudience string + complete bool } @@ -91,6 +118,24 @@ func NewAuthStep(styles *tui.StyleSet) *AuthStep { Description: "Legacy — POST tokens to your own /verify endpoint", Icon: "🔁", }, + { + Label: "AWS Sigv4 (IAM)", + Value: AuthModeAWSSigv4, + Description: "Verify AWS-IAM callers via STS GetCallerIdentity (Phase 2)", + Icon: "🅰️", + }, + { + Label: "GCP Identity-Aware Proxy", + Value: AuthModeGCPIAP, + Description: "Forge behind a GCP HTTPS LB+IAP (Phase 2)", + Icon: "🇬", + }, + { + Label: "Azure AD / Entra ID", + Value: AuthModeAzureAD, + Description: "Single-tenant Entra tokens (Phase 2 — multi-tenant via YAML)", + Icon: "🇦", + }, { Label: "Custom", Value: AuthModeCustom, @@ -144,7 +189,10 @@ func (s *AuthStep) Update(msg tea.Msg) (tui.Step, tea.Cmd) { case authSelectPhase: return s.updateSelect(msg) case authOIDCIssuerPhase, authOIDCAudiencePhase, authOIDCGroupsClaimPhase, - authHTTPURLPhase, authHTTPOrgPhase: + authHTTPURLPhase, authHTTPOrgPhase, + authAWSRegionPhase, authAWSAudiencePhase, authAWSAccountsPhase, + authGCPIAPAudiencePhase, + authAADTenantPhase, authAADAudiencePhase: return s.updateInput(msg) } return s, nil @@ -182,6 +230,30 @@ func (s *AuthStep) updateSelect(msg tea.Msg) (tui.Step, tea.Cmd) { validateHTTPSURL, ) return s, s.input.Init() + case AuthModeAWSSigv4: + s.phase = authAWSRegionPhase + s.input = s.newTextInput( + "AWS region", + "us-east-1", + validateAWSRegion, + ) + return s, s.input.Init() + case AuthModeGCPIAP: + s.phase = authGCPIAPAudiencePhase + s.input = s.newTextInput( + "IAP audience (backend service ID from GCP console)", + "/projects/PNUM/global/backendServices/BACKEND_ID", + validateNonEmpty, + ) + return s, s.input.Init() + case AuthModeAzureAD: + s.phase = authAADTenantPhase + s.input = s.newTextInput( + "Entra tenant ID (GUID)", + "00000000-0000-0000-0000-000000000000", + validateNonEmpty, + ) + return s, s.input.Init() } return s, cmd } @@ -242,6 +314,57 @@ func (s *AuthStep) updateInput(msg tea.Msg) (tui.Step, tea.Cmd) { s.complete = true s.phase = authDonePhase return s, doneCmd() + + // --- aws_sigv4 --- + case authAWSRegionPhase: + s.awsRegion = v + s.phase = authAWSAudiencePhase + s.input = s.newTextInput( + "Audience (informational; press Enter to skip)", + "api://forge", + nil, + ) + return s, s.input.Init() + + case authAWSAudiencePhase: + s.awsAudience = v + s.phase = authAWSAccountsPhase + s.input = s.newTextInput( + "Allowed AWS accounts (comma-separated 12-digit IDs; Enter to skip)", + "412664885516,109887654321", + validateAccountList, + ) + return s, s.input.Init() + + case authAWSAccountsPhase: + s.awsAccounts = parseAccountList(v) + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + + // --- gcp_iap --- + case authGCPIAPAudiencePhase: + s.gcpAudience = v + s.complete = true + s.phase = authDonePhase + return s, doneCmd() + + // --- azure_ad (single-tenant) --- + case authAADTenantPhase: + s.aadTenant = v + s.phase = authAADAudiencePhase + s.input = s.newTextInput( + "Audience (Application ID URI)", + "api://forge", + validateNonEmpty, + ) + return s, s.input.Init() + + case authAADAudiencePhase: + s.aadAudience = v + s.complete = true + s.phase = authDonePhase + return s, doneCmd() } return s, cmd @@ -299,6 +422,15 @@ func (s *AuthStep) Summary() string { return "HTTP Verifier" } return "HTTP Verifier · " + host + case AuthModeAWSSigv4: + if len(s.awsAccounts) > 0 { + return fmt.Sprintf("AWS Sigv4 · %s · %d account(s)", s.awsRegion, len(s.awsAccounts)) + } + return "AWS Sigv4 · " + s.awsRegion + case AuthModeGCPIAP: + return "GCP IAP" + case AuthModeAzureAD: + return "Azure AD · single-tenant" case AuthModeCustom: return "Custom (edit forge.yaml)" } @@ -341,6 +473,25 @@ func (s *AuthStep) Apply(ctx *tui.WizardContext) { if host := hostnameOrEmpty(s.httpURL); host != "" { ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, host) } + case AuthModeAWSSigv4: + settings := map[string]any{"region": s.awsRegion} + if s.awsAudience != "" { + settings["audience"] = s.awsAudience + } + if len(s.awsAccounts) > 0 { + settings["allowed_accounts"] = s.awsAccounts + } + ctx.AuthSettings = settings + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, "sts."+s.awsRegion+".amazonaws.com") + case AuthModeGCPIAP: + ctx.AuthSettings = map[string]any{"audience": s.gcpAudience} + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, "www.gstatic.com") + case AuthModeAzureAD: + ctx.AuthSettings = map[string]any{ + "tenant_id": s.aadTenant, + "audience": s.aadAudience, + } + ctx.AuthEgressHosts = appendUnique(ctx.AuthEgressHosts, "login.microsoftonline.com") default: // None / Custom: nothing to attach. ctx.AuthSettings = nil @@ -376,6 +527,60 @@ func validateNonEmpty(val string) error { return nil } +// validateAWSRegion checks for the canonical AWS region shape +// (xx-direction-N). Doesn't enumerate regions because AWS adds them +// often; STS at startup fails fast if the host doesn't resolve. +func validateAWSRegion(val string) error { + v := strings.TrimSpace(val) + if v == "" { + return fmt.Errorf("required") + } + // xx-direction-N, e.g. us-east-1, eu-west-2, ap-southeast-1 + parts := strings.Split(v, "-") + if len(parts) < 3 { + return fmt.Errorf("expected AWS region shape like 'us-east-1'") + } + return nil +} + +// validateAccountList accepts a comma-separated list of 12-digit AWS +// account IDs. Empty input is OK (the field is optional). Each entry +// is validated; one bad entry blocks the whole field. +func validateAccountList(val string) error { + v := strings.TrimSpace(val) + if v == "" { + return nil + } + for raw := range strings.SplitSeq(v, ",") { + acct := strings.TrimSpace(raw) + if len(acct) != 12 { + return fmt.Errorf("account %q: expected 12 digits", acct) + } + for _, c := range acct { + if c < '0' || c > '9' { + return fmt.Errorf("account %q: must be digits only", acct) + } + } + } + return nil +} + +// parseAccountList splits a comma-separated input into a trimmed slice. +// Empty input returns nil so the caller can omit the YAML key. +func parseAccountList(val string) []string { + v := strings.TrimSpace(val) + if v == "" { + return nil + } + var out []string + for raw := range strings.SplitSeq(v, ",") { + if s := strings.TrimSpace(raw); s != "" { + out = append(out, s) + } + } + return out +} + // hostnameOrEmpty returns the bare host (no port) from a URL, or "" on // parse failure. func hostnameOrEmpty(raw string) string { diff --git a/forge-cli/internal/tui/steps/step_auth_test.go b/forge-cli/internal/tui/steps/step_auth_test.go index 81e3ddf..f313b6a 100644 --- a/forge-cli/internal/tui/steps/step_auth_test.go +++ b/forge-cli/internal/tui/steps/step_auth_test.go @@ -86,10 +86,11 @@ func TestAuthStep_NoneIsDefault(t *testing.T) { func TestAuthStep_Custom(t *testing.T) { s := newTestAuthStep(t) - // Move down 3 times to reach "Custom" (index 3). - s = press(t, s, "down") - s = press(t, s, "down") - s = press(t, s, "down") + // Picker order: 0=None 1=OIDC 2=HTTPVerifier 3=AWS 4=GCP 5=AAD 6=Custom. + // Six downs to reach Custom at index 6. + for range 6 { + s = press(t, s, "down") + } s = press(t, s, "enter") if !s.Complete() { t.Fatal("expected Complete after selecting Custom") @@ -396,3 +397,228 @@ func TestAppendUnique(t *testing.T) { t.Errorf("appendUnique new failed: %v", out) } } + +// --- Phase 2: aws_sigv4 / gcp_iap / azure_ad --- + +func TestAuthStep_AWSSigv4_FullFlow_WithAccounts(t *testing.T) { + s := newTestAuthStep(t) + + // Picker order: 0=None 1=OIDC 2=HTTPVerifier 3=AWS 4=GCP 5=AAD 6=Custom + // Navigate to AWS (3 downs). + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authAWSRegionPhase { + t.Fatalf("phase = %v, want AWS region", s.phase) + } + + // Region + s = typeIn(t, s, "us-east-1") + s = press(t, s, "enter") + if s.phase != authAWSAudiencePhase { + t.Fatalf("phase = %v, want AWS audience", s.phase) + } + + // Audience (optional — press Enter to skip) + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + if s.phase != authAWSAccountsPhase { + t.Fatalf("phase = %v, want AWS accounts", s.phase) + } + + // Accounts (comma-separated) + s = typeIn(t, s, "412664885516, 109887654321") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after accounts") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeAWSSigv4 { + t.Errorf("AuthMode = %q, want aws_sigv4", ctx.AuthMode) + } + if ctx.AuthSettings["region"] != "us-east-1" { + t.Errorf("region = %v", ctx.AuthSettings["region"]) + } + if ctx.AuthSettings["audience"] != "api://forge" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + accts, _ := ctx.AuthSettings["allowed_accounts"].([]string) + if !reflect.DeepEqual(accts, []string{"412664885516", "109887654321"}) { + t.Errorf("allowed_accounts = %v", accts) + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"sts.us-east-1.amazonaws.com"}) { + t.Errorf("egress hosts = %v", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_AWSSigv4_AllowAudienceSkip(t *testing.T) { + // Skipping the audience field (pressing Enter on empty input) and + // the accounts field should still complete cleanly. + s := newTestAuthStep(t) + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") // pick AWS + s = typeIn(t, s, "us-east-1") + s = press(t, s, "enter") // region → audience + s = press(t, s, "enter") // audience (empty) → accounts + s = press(t, s, "enter") // accounts (empty) → done + if !s.Complete() { + t.Fatal("expected Complete after skipping optional fields") + } + ctx := tui.NewWizardContext() + s.Apply(ctx) + if _, ok := ctx.AuthSettings["audience"]; ok { + t.Errorf("audience should be omitted when empty: %v", ctx.AuthSettings) + } + if _, ok := ctx.AuthSettings["allowed_accounts"]; ok { + t.Errorf("allowed_accounts should be omitted when empty: %v", ctx.AuthSettings) + } +} + +func TestAuthStep_AWSSigv4_RejectsMalformedAccount(t *testing.T) { + // validateAccountList enforces the 12-digit shape before letting + // the user advance past the accounts field. + if err := validateAccountList("notanaccount"); err == nil { + t.Error("expected error on malformed account ID") + } + if err := validateAccountList("412664885516,bad,109887654321"); err == nil { + t.Error("expected error when any account in the list is bad") + } + if err := validateAccountList(""); err != nil { + t.Errorf("empty input should be allowed (optional field), got %v", err) + } + if err := validateAccountList("412664885516 , 109887654321 "); err != nil { + t.Errorf("trimmed whitespace should be tolerated, got %v", err) + } +} + +func TestAuthStep_GCPIAP_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + // 4 downs to reach GCP IAP. + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "down") + s = press(t, s, "enter") + if s.phase != authGCPIAPAudiencePhase { + t.Fatalf("phase = %v, want GCP audience", s.phase) + } + + s = typeIn(t, s, "/projects/12345/global/backendServices/67890") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after audience") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeGCPIAP { + t.Errorf("AuthMode = %q, want gcp_iap", ctx.AuthMode) + } + if ctx.AuthSettings["audience"] != "/projects/12345/global/backendServices/67890" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"www.gstatic.com"}) { + t.Errorf("egress hosts = %v, want [www.gstatic.com]", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_AzureAD_FullFlow(t *testing.T) { + s := newTestAuthStep(t) + // 5 downs to reach Azure AD. + for range 5 { + s = press(t, s, "down") + } + s = press(t, s, "enter") + if s.phase != authAADTenantPhase { + t.Fatalf("phase = %v, want AAD tenant", s.phase) + } + + s = typeIn(t, s, "00000000-1111-2222-3333-444444444444") + s = press(t, s, "enter") + if s.phase != authAADAudiencePhase { + t.Fatalf("phase = %v, want AAD audience", s.phase) + } + + s = typeIn(t, s, "api://forge") + s = press(t, s, "enter") + if !s.Complete() { + t.Fatal("expected Complete after audience") + } + + ctx := tui.NewWizardContext() + s.Apply(ctx) + if ctx.AuthMode != AuthModeAzureAD { + t.Errorf("AuthMode = %q", ctx.AuthMode) + } + if ctx.AuthSettings["tenant_id"] != "00000000-1111-2222-3333-444444444444" { + t.Errorf("tenant_id = %v", ctx.AuthSettings["tenant_id"]) + } + if ctx.AuthSettings["audience"] != "api://forge" { + t.Errorf("audience = %v", ctx.AuthSettings["audience"]) + } + if _, ok := ctx.AuthSettings["allow_multi_tenant"]; ok { + t.Errorf("TUI must NOT set allow_multi_tenant — multi-tenant requires YAML edit (security)") + } + if !reflect.DeepEqual(ctx.AuthEgressHosts, []string{"login.microsoftonline.com"}) { + t.Errorf("egress hosts = %v", ctx.AuthEgressHosts) + } +} + +func TestAuthStep_Summary_Phase2(t *testing.T) { + cases := []struct { + setup func(*AuthStep) + want string + }{ + {setup: func(s *AuthStep) { s.mode = AuthModeAWSSigv4; s.awsRegion = "us-east-1" }, want: "AWS Sigv4 · us-east-1"}, + {setup: func(s *AuthStep) { + s.mode = AuthModeAWSSigv4 + s.awsRegion = "us-east-1" + s.awsAccounts = []string{"A", "B"} + }, want: "AWS Sigv4 · us-east-1 · 2 account(s)"}, + {setup: func(s *AuthStep) { s.mode = AuthModeGCPIAP }, want: "GCP IAP"}, + {setup: func(s *AuthStep) { s.mode = AuthModeAzureAD }, want: "Azure AD · single-tenant"}, + } + for _, tc := range cases { + s := newTestAuthStep(t) + tc.setup(s) + if got := s.Summary(); got != tc.want { + t.Errorf("Summary = %q, want %q", got, tc.want) + } + } +} + +func TestValidateAWSRegion(t *testing.T) { + ok := []string{"us-east-1", "eu-west-2", "ap-southeast-1", "ca-central-1"} + for _, s := range ok { + if err := validateAWSRegion(s); err != nil { + t.Errorf("validateAWSRegion(%q) = %v, want nil", s, err) + } + } + bad := []string{"", "us", "us-east", "useast1"} + for _, s := range bad { + if err := validateAWSRegion(s); err == nil { + t.Errorf("validateAWSRegion(%q) returned nil", s) + } + } +} + +func TestParseAccountList(t *testing.T) { + cases := map[string][]string{ + "": nil, + " ": nil, + "412664885516": {"412664885516"}, + "a, b ,c": {"a", "b", "c"}, + "a,,b": {"a", "b"}, + } + for in, want := range cases { + got := parseAccountList(in) + if !reflect.DeepEqual(got, want) { + t.Errorf("parseAccountList(%q) = %v, want %v", in, got, want) + } + } +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 551bd48..b3b6a89 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -17,10 +17,15 @@ import ( "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/agentspec" "github.com/initializ/forge/forge-core/auth" + // Side-effect imports: each provider sub-package registers its factory + // with the auth registry via init() so forge.yaml `auth.providers[]` + // blocks construct successfully via auth.Build("", settings). + // Listed here even when the package is also referenced directly + // (httpverifier, statictoken) for grep-ability. + _ "github.com/initializ/forge/forge-core/auth/providers/aws_sigv4" + _ "github.com/initializ/forge/forge-core/auth/providers/azure_ad" + _ "github.com/initializ/forge/forge-core/auth/providers/gcp_iap" "github.com/initializ/forge/forge-core/auth/providers/httpverifier" - // Side-effect import: registers the "oidc" provider with the auth registry - // so forge.yaml `auth: { type: oidc }` blocks construct successfully via - // auth.Build("oidc", settings). The package is not referenced directly. _ "github.com/initializ/forge/forge-core/auth/providers/oidc" "github.com/initializ/forge/forge-core/auth/providers/statictoken" "github.com/initializ/forge/forge-core/llm" diff --git a/forge-core/auth/middleware.go b/forge-core/auth/middleware.go index cb09a87..13af495 100644 --- a/forge-core/auth/middleware.go +++ b/forge-core/auth/middleware.go @@ -59,9 +59,9 @@ type MiddlewareOptions struct { // - identity is non-nil and err is nil on success. // - identity is nil and err carries the chain error on failure // (or auth.ErrMissingBearer when the header was absent). - // - tokenKind is "jwt", "opaque", or "empty" — structural metadata - // safe to log. The token itself is NOT passed; callers must not - // try to recover it from the request. + // - tokenKind is "jwt", "opaque", "sigv4", "iap_jwt", or "empty" — + // structural metadata safe to log. The token itself is NOT + // passed; callers must not try to recover it from the request. // // Callbacks should be cheap — they run on the request hot path. OnAuth func(r *http.Request, identity *Identity, err error, tokenKind string) @@ -109,7 +109,18 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { token := extractBearerToken(r) kind := TokenKind(token) - if token == "" { + // Phase 2: gcp_iap doesn't use a Bearer token — it reads + // X-Goog-Iap-Jwt-Assertion. Surface that in the audit kind + // and let the chain run even on empty Bearer when IAP is + // the format in play. aws_sigv4 (Phase 2 pre-signed URL + // pattern) DOES use a Bearer token, so no special-case here. + iapHeader := r.Header.Get("X-Goog-Iap-Jwt-Assertion") + if kind == "empty" && iapHeader != "" { + kind = "iap_jwt" + } + hasNonBearerAuth := token == "" && iapHeader != "" + + if token == "" && !hasNonBearerAuth { notifyAuth(opts.OnAuth, r, nil, ErrMissingBearer, kind) writeAuthError(w, "valid bearer token required") return @@ -122,6 +133,16 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { return } + // Phase 2 (Review M4): refine token_kind from the structural + // shape to the actual provider that matched. The structural + // kind says "what bytes were on the wire"; the post-verify + // kind says "which auth path succeeded." A request with both + // a Bearer JWT AND an X-Goog-Iap-Jwt-Assertion would record + // kind="jwt" under the structural rule even though gcp_iap + // was the verifier — that mis-attributes IAP-fronted traffic + // in audit dashboards. + kind = refineTokenKind(kind, identity.Source) + notifyAuth(opts.OnAuth, r, identity, nil, kind) ctx := WithIdentity(r.Context(), identity) @@ -130,6 +151,28 @@ func Middleware(opts MiddlewareOptions) func(http.Handler) http.Handler { } } +// refineTokenKind upgrades the audit token_kind from the pre-verify +// structural classification to one that reflects which auth path +// actually succeeded. +// +// Today the only refinement is gcp_iap → "iap_jwt": that provider +// reads X-Goog-Iap-Jwt-Assertion (not the Bearer slot), so the +// structural rule cannot detect it when a Bearer is ALSO present. +// Refining post-verify keeps the audit signal clean even when an +// IAP-fronted Forge instance ALSO carries a Bearer JWT for app-level +// auth chaining. +// +// Other providers (oidc, azure_ad, aws_sigv4, http_verifier, +// static_token) don't need refinement: their structural kind already +// matches the auth path (aws_sigv4 has its own "forge-aws-v1." prefix +// → "sigv4"; everything else is just "jwt"/"opaque"). +func refineTokenKind(structural, providerSource string) string { + if providerSource == "gcp_iap" { + return "iap_jwt" + } + return structural +} + // notifyAuth invokes the OnAuth callback if set, swallowing the nil check // at the call sites so the main middleware body stays readable. func notifyAuth(cb func(*http.Request, *Identity, error, string), r *http.Request, id *Identity, err error, kind string) { diff --git a/forge-core/auth/middleware_test.go b/forge-core/auth/middleware_test.go index f163e2f..238d0b0 100644 --- a/forge-core/auth/middleware_test.go +++ b/forge-core/auth/middleware_test.go @@ -443,6 +443,258 @@ func contains(s, sub string) bool { return false } +// --- Phase 2: non-Bearer auth formats reach the chain --- + +// headerCapturingProvider records the headers it sees and lets the test +// script the response. Used by Phase 2 middleware tests to assert that +// providers consuming non-Bearer formats (Sigv4, IAP) actually receive +// what they need. +type headerCapturingProvider struct { + sawToken string + sawHeaders Headers + identity *Identity + err error +} + +func (p *headerCapturingProvider) Name() string { return "test_capture" } +func (p *headerCapturingProvider) Verify(_ context.Context, token string, h Headers) (*Identity, error) { + p.sawToken = token + p.sawHeaders = h + return p.identity, p.err +} + +func TestMiddleware_Sigv4BearerReachesChain(t *testing.T) { + // Phase 2: aws_sigv4 uses the pre-signed URL Bearer-token pattern. + // The chain receives the Bearer token directly via the standard path + // — no non-Bearer-header handling needed. + spy := &headerCapturingProvider{err: ErrTokenNotForMe} + opts := MiddlewareOptions{ + Chain: NewChainProvider(spy), + SkipPaths: DefaultSkipPaths(), + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "Bearer forge-aws-v1.aHR0cHM6Ly9zdHM") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if spy.sawToken == "" || !startsWithLocal(spy.sawToken, "forge-aws-v1.") { + t.Errorf("chain saw token = %q, want forge-aws-v1 prefixed", spy.sawToken) + } +} + +func TestMiddleware_IAPHeaderReachesChain(t *testing.T) { + // Phase 2 change: gcp_iap's X-Goog-Iap-Jwt-Assertion is enough — no + // Authorization header at all is needed for the chain to be consulted. + spy := &headerCapturingProvider{err: ErrTokenNotForMe} + opts := MiddlewareOptions{ + Chain: NewChainProvider(spy), + SkipPaths: DefaultSkipPaths(), + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJ.eyJ.sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if spy.sawHeaders == nil { + t.Fatal("chain was never invoked — middleware short-circuited on empty Bearer") + } + if got := spy.sawHeaders.Get("X-Goog-Iap-Jwt-Assertion"); got != "eyJ.eyJ.sig" { + t.Errorf("chain saw IAP header = %q, want eyJ.eyJ.sig", got) + } +} + +func TestMiddleware_NoAuthHeaders_PreservesMissingTokenReason(t *testing.T) { + // Review #4 contract regression check: when the caller did NOT attempt + // auth at all (no Bearer, no Sigv4 Authorization, no IAP header), the + // audit reason MUST stay ErrMissingBearer — not be widened to "not_for_me". + // This lets ops dashboards still differentiate "client didn't auth" from + // "client tried a format we don't speak." + spyCalled := false + chain := NewChainProvider(&headerCapturingProvider{ + err: ErrTokenNotForMe, + identity: nil, + }) + + var gotErr error + opts := MiddlewareOptions{ + Chain: chain, + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, err error, _ string) { + spyCalled = true + gotErr = err + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Request with NO auth-shaped headers at all. + req := httptest.NewRequest("POST", "/tasks", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if !spyCalled { + t.Fatal("OnAuth callback never fired") + } + if !errors.Is(gotErr, ErrMissingBearer) { + t.Errorf("OnAuth err = %v, want ErrMissingBearer (review #4 regression)", gotErr) + } + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestMiddleware_TokenKind_Sigv4FromForgeAwsToken(t *testing.T) { + // Phase 2: aws_sigv4 uses the "forge-aws-v1." Bearer token + // pattern. TokenKind on that prefix returns "sigv4" so audit dashboards + // can count Sigv4 traffic distinctly from generic Bearer. + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "Bearer forge-aws-v1.aHR0cHM6Ly9zdHMudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "sigv4" { + t.Errorf("token kind = %q, want sigv4", gotKind) + } +} + +func TestMiddleware_TokenKind_RefinedToIapJwtWhenGCPIAPVerifies(t *testing.T) { + // Review M4: structural kind is "jwt" when a Bearer is present, but + // if gcp_iap was actually the verifier (because it read the IAP + // header instead of the Bearer), the audit kind must be "iap_jwt" + // so dashboards count IAP-fronted traffic correctly. + idFromIAP := &Identity{UserID: "iap-user", Source: "gcp_iap"} + chain := NewChainProvider(&headerCapturingProvider{ + identity: idFromIAP, + }) + + var gotKind string + opts := MiddlewareOptions{ + Chain: chain, + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, id *Identity, _ error, kind string) { + if id != nil && id.Source == "gcp_iap" { + gotKind = kind + } + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + // Both a Bearer JWT (structural kind="jwt") AND an IAP header present. + // The chain (stubbed to return Source="gcp_iap") simulates gcp_iap + // being the actual verifier. + req.Header.Set("Authorization", "Bearer eyJ.eyJ.sig") + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJ.eyJ.iap-sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "iap_jwt" { + t.Errorf("token kind = %q, want iap_jwt (refined post-verify from provider Source)", gotKind) + } +} + +func TestMiddleware_TokenKind_JWT_NotRefinedForOIDCProviders(t *testing.T) { + // Counter-test: when oidc / azure_ad is the verifier (Source != + // gcp_iap), the structural "jwt" kind stays — we don't over-refine. + id := &Identity{UserID: "alice", Source: "oidc"} + chain := NewChainProvider(&headerCapturingProvider{identity: id}) + + var gotKind string + handler := Middleware(MiddlewareOptions{ + Chain: chain, + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, id *Identity, _ error, kind string) { + if id != nil { + gotKind = kind + } + }, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("Authorization", "Bearer eyJ.eyJ.sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "jwt" { + t.Errorf("token kind = %q, want jwt (no refinement for oidc Source)", gotKind) + } +} + +func TestMiddleware_TokenKind_IapJwtOnIAPHeader(t *testing.T) { + // Phase 2: when the only auth header is X-Goog-Iap-Jwt-Assertion, + // audit emits token_kind="iap_jwt" (not "empty"). Pinning this so + // the GCP IAP audit signal is distinguishable from no-auth requests. + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJ.eyJ.sig") + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "iap_jwt" { + t.Errorf("token kind = %q, want iap_jwt", gotKind) + } +} + +func TestMiddleware_TokenKind_EmptyWhenTrulyNoAuth(t *testing.T) { + // Counterpart to the test above: when the caller didn't attempt auth + // at all, token_kind should be "empty" — not silently widened. + var gotKind string + opts := MiddlewareOptions{ + Chain: NewChainProvider(&headerCapturingProvider{err: ErrTokenNotForMe}), + SkipPaths: DefaultSkipPaths(), + OnAuth: func(_ *http.Request, _ *Identity, _ error, kind string) { + gotKind = kind + }, + } + handler := Middleware(opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/tasks", nil) + handler.ServeHTTP(httptest.NewRecorder(), req) + + if gotKind != "empty" { + t.Errorf("token kind = %q, want empty", gotKind) + } +} + +// startsWithLocal is the middleware_test.go equivalent of provider_test.go's +// startsWith helper. Kept local so the two test files don't depend on each +// other's helpers. +func startsWithLocal(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + func TestClassifyAuthFailure(t *testing.T) { tests := []struct { name string diff --git a/forge-core/auth/provider.go b/forge-core/auth/provider.go index 527eac2..f394ff4 100644 --- a/forge-core/auth/provider.go +++ b/forge-core/auth/provider.go @@ -89,20 +89,31 @@ func (h Headers) Get(key string) string { // HeadersFromRequest extracts the well-known headers providers may use. // Keep this list narrow — providers should be explicit about the contract. +// +// X-Goog-Iap-Jwt-Assertion is included for gcp_iap, which doesn't use a +// Bearer token. All other Phase 2 providers (aws_sigv4 with the pre-signed +// URL pattern, azure_ad) ride the standard Bearer path and don't need +// extra header surface here. func HeadersFromRequest(r *http.Request) Headers { return Headers{ - "X-Org-ID": r.Header.Get("X-Org-ID"), - "X-Request-ID": r.Header.Get("X-Request-ID"), - "org-id": r.Header.Get("org-id"), - "org_id": r.Header.Get("org_id"), + "X-Org-ID": r.Header.Get("X-Org-ID"), + "X-Request-ID": r.Header.Get("X-Request-ID"), + "org-id": r.Header.Get("org-id"), + "org_id": r.Header.Get("org_id"), + "X-Goog-Iap-Jwt-Assertion": r.Header.Get("X-Goog-Iap-Jwt-Assertion"), } } // TokenKind classifies a presented bearer token structurally — useful for // audit logging without leaking the token itself. // -// "jwt" → three base64url segments separated by dots -// "opaque" → anything else (Okta access tokens, custom verifier tokens, dev secrets) +// "empty" → empty token +// "sigv4" → forge-aws-v1. (AWS Sigv4 via pre-signed URL pattern; +// +// the magic prefix mirrors aws-iam-authenticator's "k8s-aws-v1.") +// +// "jwt" → three base64url segments separated by dots +// "opaque" → anything else (custom verifier tokens, dev secrets, etc.) // // This is a CHEAP structural check — it does not parse or validate. // Never log the token; this helper is safe to log. @@ -110,6 +121,9 @@ func TokenKind(token string) string { if token == "" { return "empty" } + if strings.HasPrefix(token, "forge-aws-v1.") { + return "sigv4" + } dots := 0 for i := 0; i < len(token); i++ { if token[i] == '.' { diff --git a/forge-core/auth/provider_test.go b/forge-core/auth/provider_test.go index c175b69..a4958e6 100644 --- a/forge-core/auth/provider_test.go +++ b/forge-core/auth/provider_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "maps" + "net/http" "reflect" "sort" "sync" @@ -464,6 +465,15 @@ func TestTokenKind(t *testing.T) { {"opaque one dot", "abc.def", "opaque"}, {"opaque four segments", "a.b.c.d", "opaque"}, {"jwt with empty segments still has 2 dots", "..", "jwt"}, + // Phase 2: aws_sigv4 uses the pre-signed URL pattern, encoded + // as a Bearer token with the "forge-aws-v1." prefix (mirrors + // aws-iam-authenticator's "k8s-aws-v1." convention). + {"sigv4 forge-aws-v1 token", "forge-aws-v1.aHR0cHM6Ly9zdHMudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20", "sigv4"}, + {"sigv4 prefix only", "forge-aws-v1.", "sigv4"}, + // Defensive: a token that looks similar but with the wrong + // version suffix or missing the period must NOT be sigv4. + {"wrong version prefix", "forge-aws-v2.something", "opaque"}, + {"prefix missing trailing dot", "forge-aws-v1xyz", "opaque"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -474,6 +484,59 @@ func TestTokenKind(t *testing.T) { } } +// --- HeadersFromRequest --- + +func TestHeadersFromRequest_Phase1Headers(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("X-Org-ID", "acme") + req.Header.Set("X-Request-ID", "req-1") + req.Header.Set("org-id", "lower") + req.Header.Set("org_id", "snake") + + h := HeadersFromRequest(req) + if h["X-Org-ID"] != "acme" { + t.Errorf("X-Org-ID = %q, want acme", h["X-Org-ID"]) + } + if h["X-Request-ID"] != "req-1" { + t.Errorf("X-Request-ID = %q, want req-1", h["X-Request-ID"]) + } + if h["org-id"] != "lower" { + t.Errorf("org-id = %q, want lower", h["org-id"]) + } + if h["org_id"] != "snake" { + t.Errorf("org_id = %q, want snake", h["org_id"]) + } +} + +func TestHeadersFromRequest_Phase2Headers(t *testing.T) { + // Phase 2: HeadersFromRequest only widens for gcp_iap. aws_sigv4 + // rides the Bearer path (pre-signed URL pattern); azure_ad rides + // the Bearer path (standard JWT). The widened header surface is + // intentionally narrow. + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("X-Goog-Iap-Jwt-Assertion", "eyJabc.eyJdef.sig") + + h := HeadersFromRequest(req) + + if got := h.Get("X-Goog-Iap-Jwt-Assertion"); got != "eyJabc.eyJdef.sig" { + t.Errorf("X-Goog-Iap-Jwt-Assertion = %q, want eyJabc.eyJdef.sig", got) + } +} + +func TestHeadersFromRequest_AbsentHeadersAreEmpty(t *testing.T) { + // Providers must not assume any header is present — absence is normal + // and means "this format isn't here, yield to the next provider." + req, _ := http.NewRequest("POST", "/", nil) + h := HeadersFromRequest(req) + for _, k := range []string{ + "X-Goog-Iap-Jwt-Assertion", + } { + if got := h.Get(k); got != "" { + t.Errorf("%s should be empty on request with no headers, got %q", k, got) + } + } +} + // --- helpers --- func providerNames(ps []Provider) []string { diff --git a/forge-core/auth/providers/aws_sigv4/arn_matcher.go b/forge-core/auth/providers/aws_sigv4/arn_matcher.go new file mode 100644 index 0000000..2a57ef2 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/arn_matcher.go @@ -0,0 +1,83 @@ +package aws_sigv4 + +import ( + "errors" + "fmt" + "path" +) + +// ArnMatcher checks a caller's ARN against a list of shell-style globs. +// Empty list = allow any IAM principal. Invalid patterns fail at startup, +// never at request time. +// +// Decision §9.3: shell-style globs via path.Match. Supports * and ? but +// NOT regex syntax — simpler mental model, less footgun. +type ArnMatcher struct { + patterns []string +} + +// NewArnMatcher validates each pattern via path.Match("", pattern). A +// malformed pattern is a config bug; we want it surfaced at Factory time. +func NewArnMatcher(patterns []string) (*ArnMatcher, error) { + for _, p := range patterns { + if _, err := path.Match(p, ""); err != nil { + return nil, fmt.Errorf("invalid glob %q: %w", p, err) + } + } + return &ArnMatcher{patterns: patterns}, nil +} + +// Match returns true if the ARN matches any pattern, or if the matcher +// has no patterns (which means "allow any principal"). +// +// NOTE: STS GetCallerIdentity returns the ASSUMED-ROLE ARN form +// ("arn:aws:sts::ACCOUNT:assumed-role/RoleName/SessionName"), not the +// role's own ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). Operators must +// write patterns against the assumed-role form — PR6 docs spell this out. +func (m *ArnMatcher) Match(arn string) bool { + if len(m.patterns) == 0 { + return true + } + for _, p := range m.patterns { + if ok, _ := path.Match(p, arn); ok { + return true + } + } + return false +} + +// validateAccountID checks for an AWS account ID — 12 ASCII digits. +// Catches typos (region names, role ARNs pasted by mistake) at Factory +// time so a misconfigured allowed_accounts entry doesn't silently +// become an unreachable pattern. +func validateAccountID(s string) error { + if len(s) != 12 { + return fmt.Errorf("account ID %q: expected 12 digits, got %d chars", s, len(s)) + } + for i := 0; i < len(s); i++ { + c := s[i] + if c < '0' || c > '9' { + return errors.New("account ID must be 12 digits") + } + } + return nil +} + +// expandAccountGlobs returns the canonical ARN glob set that covers +// every STS identity shape in a given account. The four shapes: +// +// arn:aws:iam:::user/ — direct IAM user +// arn:aws:iam:::role/ — direct IAM role (rare; usually only EC2/Lambda) +// arn:aws:sts:::assumed-role// — SSO, AssumeRole, IRSA +// arn:aws:sts:::federated-user/ — SAML/web-identity federation +// +// path.Match's `*` doesn't cross `/`, so the assumed-role glob needs +// two `*` segments to span both RoleName and SessionName. +func expandAccountGlobs(acct string) []string { + return []string{ + "arn:aws:iam::" + acct + ":user/*", + "arn:aws:iam::" + acct + ":role/*", + "arn:aws:sts::" + acct + ":assumed-role/*/*", + "arn:aws:sts::" + acct + ":federated-user/*", + } +} diff --git a/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go b/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go new file mode 100644 index 0000000..22b02a6 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/arn_matcher_test.go @@ -0,0 +1,122 @@ +package aws_sigv4 + +import ( + "testing" +) + +func TestArnMatcher_EmptyAllowsAll(t *testing.T) { + m, err := NewArnMatcher(nil) + if err != nil { + t.Fatalf("NewArnMatcher: %v", err) + } + if !m.Match("arn:aws:iam::123:role/anything") { + t.Error("empty matcher should allow any ARN") + } +} + +func TestArnMatcher_Glob(t *testing.T) { + m, err := NewArnMatcher([]string{ + "arn:aws:iam::123:role/forge-*", + "arn:aws:sts::123:assumed-role/ci-deploy/*", + }) + if err != nil { + t.Fatalf("NewArnMatcher: %v", err) + } + cases := map[string]bool{ + "arn:aws:iam::123:role/forge-deploy": true, + "arn:aws:iam::123:role/forge-runner": true, + "arn:aws:iam::123:role/forge": false, + "arn:aws:iam::123:role/other-deploy": false, + "arn:aws:iam::456:role/forge-deploy": false, + "arn:aws:sts::123:assumed-role/ci-deploy/session-1": true, + "arn:aws:sts::123:assumed-role/ci-deploy/another-sess": true, + "arn:aws:sts::123:assumed-role/wrong-role/session-1": false, + } + for in, want := range cases { + if got := m.Match(in); got != want { + t.Errorf("Match(%q) = %v, want %v", in, got, want) + } + } +} + +func TestArnMatcher_InvalidPatternFailsAtStartup(t *testing.T) { + if _, err := NewArnMatcher([]string{"["}); err == nil { + t.Fatal("expected error for malformed glob") + } +} + +func TestArnMatcher_AssumedRoleNeedsTwoStars(t *testing.T) { + // Reminder for operators: STS returns the assumed-role ARN, not the + // IAM role ARN. A pattern that matches "arn:aws:iam::ACCT:role/X-*" + // will NOT match "arn:aws:sts::ACCT:assumed-role/X-foo/session". + // PR6 docs spell this out. + m, _ := NewArnMatcher([]string{"arn:aws:iam::123:role/forge-*"}) + if m.Match("arn:aws:sts::123:assumed-role/forge-deploy/session") { + t.Error("IAM role pattern should NOT match STS assumed-role ARN — docs invariant") + } +} + +func TestValidateAccountID(t *testing.T) { + good := []string{"412664885516", "000000000000", "999999999999"} + for _, s := range good { + if err := validateAccountID(s); err != nil { + t.Errorf("validateAccountID(%q) = %v, want nil", s, err) + } + } + bad := []string{ + "", // empty + "1234", // too short + "4126648855161", // 13 chars + "us-east-1", // not digits + "arn:aws:iam::1:", // ARN form + "412 664885516", // space + } + for _, s := range bad { + if err := validateAccountID(s); err == nil { + t.Errorf("validateAccountID(%q) returned nil, want error", s) + } + } +} + +func TestExpandAccountGlobs(t *testing.T) { + got := expandAccountGlobs("412664885516") + want := map[string]bool{ + "arn:aws:iam::412664885516:user/*": true, + "arn:aws:iam::412664885516:role/*": true, + "arn:aws:sts::412664885516:assumed-role/*/*": true, + "arn:aws:sts::412664885516:federated-user/*": true, + } + if len(got) != len(want) { + t.Fatalf("got %d patterns, want %d", len(got), len(want)) + } + for _, g := range got { + if !want[g] { + t.Errorf("unexpected pattern %q", g) + } + } +} + +// End-to-end: each expanded pattern matches the realistic ARN it's +// supposed to cover. +func TestExpandedAccountPatterns_MatchRealisticARNs(t *testing.T) { + m, err := NewArnMatcher(expandAccountGlobs("412664885516")) + if err != nil { + t.Fatalf("NewArnMatcher: %v", err) + } + cases := map[string]bool{ + // Should match (in-account): + "arn:aws:iam::412664885516:user/alice": true, + "arn:aws:iam::412664885516:role/ec2-instance-role": true, + "arn:aws:sts::412664885516:assumed-role/AWSReservedSSO_PowerUserAccess_abc/naveen": true, + "arn:aws:sts::412664885516:assumed-role/ci-deploy/session-id": true, + "arn:aws:sts::412664885516:federated-user/saml-jane": true, + // Should NOT match (different account): + "arn:aws:iam::999999999999:user/eve": false, + "arn:aws:sts::999999999999:assumed-role/SomeRole/session": false, + } + for arn, want := range cases { + if got := m.Match(arn); got != want { + t.Errorf("Match(%q) = %v, want %v", arn, got, want) + } + } +} diff --git a/forge-core/auth/providers/aws_sigv4/identity_cache.go b/forge-core/auth/providers/aws_sigv4/identity_cache.go new file mode 100644 index 0000000..72ff56b --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/identity_cache.go @@ -0,0 +1,73 @@ +package aws_sigv4 + +import ( + "sync" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// IdentityCache holds verified Identities keyed by hash(AKID|YYYYMMDD). +// Bounds the stolen-key window: a leaked AKID is honored until either +// IdentityCacheTTL elapses OR midnight UTC (date bucket rolls over), +// whichever is sooner. +// +// Never caches rejections or errors — Phase 1 invariant: errors are +// not entities. +type IdentityCache struct { + ttl time.Duration + mu sync.RWMutex + data map[string]cacheEntry + now func() time.Time +} + +type cacheEntry struct { + id *auth.Identity + expireAt time.Time +} + +// NewIdentityCache returns an empty cache with the given TTL. time.Now is +// the default clock; tests can swap it via the exported NowFunc field. +func NewIdentityCache(ttl time.Duration) *IdentityCache { + return &IdentityCache{ + ttl: ttl, + data: make(map[string]cacheEntry), + now: time.Now, + } +} + +// Get returns the cached identity if present and not expired. +func (c *IdentityCache) Get(key string) (*auth.Identity, bool) { + c.mu.RLock() + e, ok := c.data[key] + c.mu.RUnlock() + if !ok || c.now().After(e.expireAt) { + return nil, false + } + return e.id, true +} + +// Put stores the identity under key with a fresh TTL. Overwriting an +// existing entry does NOT extend the previous expiry — it replaces it. +// (Refusing to extend prevents a "refresh-just-before-expiry" attack +// from holding a stolen credential alive indefinitely.) +// +// Opportunistic eviction sweeps expired entries when the map grows past +// 10k. Bounds memory under sustained miss without needing a background +// goroutine. +func (c *IdentityCache) Put(key string, id *auth.Identity) { + c.mu.Lock() + c.data[key] = cacheEntry{id: id, expireAt: c.now().Add(c.ttl)} + if len(c.data) > 10_000 { + now := c.now() + for k, e := range c.data { + if now.After(e.expireAt) { + delete(c.data, k) + } + } + } + c.mu.Unlock() +} + +// setNow is a test-only hook for swapping the clock. +func (c *IdentityCache) setNow(fn func() time.Time) { c.now = fn } diff --git a/forge-core/auth/providers/aws_sigv4/identity_cache_test.go b/forge-core/auth/providers/aws_sigv4/identity_cache_test.go new file mode 100644 index 0000000..7e05e2f --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/identity_cache_test.go @@ -0,0 +1,92 @@ +package aws_sigv4 + +import ( + "strconv" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +func TestIdentityCache_HitMiss(t *testing.T) { + c := NewIdentityCache(time.Minute) + if _, ok := c.Get("k1"); ok { + t.Error("empty cache returned hit") + } + id := &auth.Identity{UserID: "arn1"} + c.Put("k1", id) + got, ok := c.Get("k1") + if !ok || got != id { + t.Errorf("Get hit = (%v, %v), want id, true", got, ok) + } +} + +func TestIdentityCache_Expiry(t *testing.T) { + now := time.Unix(1_700_000_000, 0) + c := NewIdentityCache(60 * time.Second) + c.setNow(func() time.Time { return now }) + + c.Put("k", &auth.Identity{UserID: "arn"}) + + if _, ok := c.Get("k"); !ok { + t.Fatal("expected hit before expiry") + } + + now = now.Add(61 * time.Second) + if _, ok := c.Get("k"); ok { + t.Error("expected miss after TTL expiry") + } +} + +func TestIdentityCache_PutDoesNotExtendExpiry(t *testing.T) { + // Defense against refresh-just-before-expiry holding a stolen + // credential alive indefinitely. Put MUST replace the entry's + // expireAt, not extend it. + now := time.Unix(1_700_000_000, 0) + c := NewIdentityCache(60 * time.Second) + c.setNow(func() time.Time { return now }) + + c.Put("k", &auth.Identity{UserID: "arn"}) + // Refresh at t+50s — expireAt becomes t+110s. + now = now.Add(50 * time.Second) + c.Put("k", &auth.Identity{UserID: "arn"}) + + // At t+120s, original-expiry+TTL would still be valid; replacement + // makes the cap t+110s. Verify it's expired. + now = now.Add(70 * time.Second) // t+120s + if _, ok := c.Get("k"); ok { + t.Error("Put extended TTL beyond the bound — refresh-extends-stolen-key bug") + } +} + +func TestIdentityCache_OpportunisticEviction(t *testing.T) { + // Force the map past the 10_000 threshold with expired entries; one + // more Put should trigger a sweep that drops them. + now := time.Unix(1_700_000_000, 0) + c := NewIdentityCache(time.Second) + c.setNow(func() time.Time { return now }) + + id := &auth.Identity{UserID: "arn"} + for i := range 10_001 { + // strconv.Itoa(i), NOT string(rune(i)): the rune form maps + // surrogate code points to U+FFFD (the replacement char), so + // all of {0xD800..0xDFFF} would collide on one cache key and + // the map never actually reaches 10_001 entries — the + // threshold-eviction test would silently no-op. (Review NIT.) + c.Put(strconv.Itoa(i), id) + } + now = now.Add(2 * time.Second) // expire everything + + c.Put("fresh", id) // triggers sweep + + if got, ok := c.Get("fresh"); !ok || got != id { + t.Fatal("fresh entry lost during sweep") + } + // Map should have been pruned — at most a few entries remain + c.mu.RLock() + size := len(c.data) + c.mu.RUnlock() + if size > 10 { + t.Errorf("eviction did not run; cache size = %d", size) + } +} diff --git a/forge-core/auth/providers/aws_sigv4/provider.go b/forge-core/auth/providers/aws_sigv4/provider.go new file mode 100644 index 0000000..1714ceb --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/provider.go @@ -0,0 +1,362 @@ +// Package aws_sigv4 authenticates AWS-IAM callers using the pre-signed +// URL pattern. Same approach as aws-iam-authenticator (the EKS auth bridge): +// the caller uses their AWS SDK to compute a pre-signed STS +// GetCallerIdentity URL, wraps it as a Bearer token, and sends it to +// Forge. Forge invokes the URL on STS, which validates the signature +// (it was signed for STS's host) and returns the canonical +// ARN / Account / UserID. Forge stamps an Identity from that response. +// +// Forge never possesses the caller's AWS secret key. The cryptographic +// work happens once on the caller side via standard SDK calls. +// +// # Client-side contract (3 lines) +// +// # Python / boto3 +// import boto3, base64 +// url = boto3.client('sts', region_name='us-east-1').generate_presigned_url( +// 'get_caller_identity', ExpiresIn=900) +// token = 'forge-aws-v1.' + base64.urlsafe_b64encode(url.encode()).rstrip(b'=').decode() +// requests.post(forge_url, headers={'Authorization': f'Bearer {token}'}, data=msg) +// +// Reference client in scripts/forge-aws-sign.py. +// +// # Wire format +// +// Authorization: Bearer forge-aws-v1. +// +// The base64-decoded payload is a complete pre-signed URL of the form: +// +// https://sts..amazonaws.com/ +// ?Action=GetCallerIdentity +// &Version=2011-06-15 +// &X-Amz-Algorithm=AWS4-HMAC-SHA256 +// &X-Amz-Credential=///sts/aws4_request +// &X-Amz-Date= +// &X-Amz-Expires= +// &X-Amz-SignedHeaders=host +// &X-Amz-Signature= +// +// # SSRF guard +// +// Before invoking the URL, Forge validates the host matches +// sts..amazonaws.com exactly. A token whose URL +// points anywhere else is rejected — the token must not be usable to +// coerce Forge into calling an arbitrary internal endpoint. +// +// # Caching +// +// Verified identities are cached for IdentityCacheTTL keyed on +// hash(AKID, YYYYMMDD), extracted from the token's X-Amz-Credential. +// Rotating AKID or rolling past midnight UTC invalidates the bucket. +// Errors are never cached. +// +// # Decisions +// +// §9.1 — no aws-sdk-go-v2 dependency. The STS RPC is hand-rolled HTTP + +// XML; trade-off is smaller attack surface and no transitive deps. +// +// §9.3 — allowed_principals are shell-style globs (path.Match). +// +// # Audit reason codes (Phase 1 contract) +// +// rejected — STS 4xx (expired/bad sig), ARN allowlist miss, +// URL host mismatch, region scope mismatch +// provider_unavailable — STS 5xx, network failure, parse failure +// invalid — token format malformed, base64 fails, URL fails, +// missing required query params +// not_for_me — token didn't start with "forge-aws-v1." +package aws_sigv4 + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the type name used to register and reference this provider. +const ProviderName = "aws_sigv4" + +// Defaults. +const ( + defaultIdentityCacheTTL = 60 * time.Second + defaultHTTPTimeout = 5 * time.Second + + // Cap the token's self-declared lifetime (X-Amz-Expires). AWS allows + // up to 7 days for some services; STS GetCallerIdentity in practice + // is signed for 15 min by all standard SDKs. Anything longer is + // suspect. + defaultMaxTokenExpires = 15 * time.Minute + + // Clock skew tolerance between Forge and the caller. Generous enough + // for typical NTP drift but tight enough to bound stolen-token replay. + defaultClockSkew = 5 * time.Minute +) + +// Config controls the aws_sigv4 provider. +type Config struct { + // Region is the AWS region whose STS endpoint validates signatures. + // REQUIRED. The pre-signed URL's host MUST match + // sts..amazonaws.com exactly. + Region string `yaml:"region"` + + // Audience is informational only — emitted in the audit log's Claims + // payload. STS itself doesn't enforce it. + Audience string `yaml:"audience,omitempty"` + + // AllowedPrincipals is an optional list of shell-style globs (§9.3) + // matched against the STS-returned ARN. Empty list means "allow any + // IAM principal that has a valid AWS key" — fine for single-tenant + // dev, never appropriate for production. + // + // Patterns match the STS assumed-role ARN form + // ("arn:aws:sts::ACCOUNT:assumed-role/RoleName/SessionName"), NOT the + // IAM role ARN ("arn:aws:iam::ACCOUNT:role/RoleName"). + AllowedPrincipals []string `yaml:"allowed_principals,omitempty"` + + // AllowedAccounts is an ergonomic shortcut for the common case of + // "anyone in these AWS accounts." Each entry is an account ID + // (12 digits); New() expands each into the canonical glob set: + // + // arn:aws:iam:::user/* + // arn:aws:iam:::role/* + // arn:aws:sts:::assumed-role/*/* + // arn:aws:sts:::federated-user/* + // + // covering every shape STS returns. The expansion is appended to + // AllowedPrincipals — operators can mix the two: list specific + // roles in AllowedPrincipals AND whole accounts in AllowedAccounts. + // + // For AWS-Org-wide trust without enumerating accounts, see the + // docstring section on AWS Identity Center / SSO. + AllowedAccounts []string `yaml:"allowed_accounts,omitempty"` + + // IdentityCacheTTL bounds how long a verified Identity is reused + // without re-checking with STS. Defaults to 60s. + IdentityCacheTTL time.Duration `yaml:"identity_cache_ttl,omitempty"` + + // MaxTokenExpires caps the X-Amz-Expires value the caller can stamp + // into the pre-signed URL. Defaults to 15min (matches what all + // standard AWS SDKs produce for GetCallerIdentity). Belt-and-braces + // gate on top of STS's own freshness enforcement. + MaxTokenExpires time.Duration `yaml:"max_token_expires,omitempty"` + + // ClockSkew is the tolerance window for clock drift between Forge + // and the caller when evaluating token freshness. Defaults to 5min. + ClockSkew time.Duration `yaml:"clock_skew,omitempty"` + + // STSEndpoint is a TEST-ONLY override that changes the expected + // pre-signed URL host (and relaxes the https requirement). Production + // should leave this empty. + STSEndpoint string `yaml:"sts_endpoint,omitempty"` + + // HTTPTimeout caps each STS call. Defaults to 5s. + HTTPTimeout time.Duration `yaml:"http_timeout,omitempty"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.Region == "" { + return fmt.Errorf("%w: region required", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider for AWS-IAM callers. +type Provider struct { + cfg Config + expectedHost string // computed once: sts..amazonaws.com or test override + requireHTTPS bool // false only when STSEndpoint test override is in use + cache *IdentityCache + sts *STSClient + matcher *ArnMatcher + + // now is injectable for tests; defaults to time.Now in New(). + now func() time.Time +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.IdentityCacheTTL == 0 { + cfg.IdentityCacheTTL = defaultIdentityCacheTTL + } + if cfg.HTTPTimeout == 0 { + cfg.HTTPTimeout = defaultHTTPTimeout + } + if cfg.MaxTokenExpires == 0 { + cfg.MaxTokenExpires = defaultMaxTokenExpires + } + if cfg.ClockSkew == 0 { + cfg.ClockSkew = defaultClockSkew + } + // Expand the AllowedAccounts shortcut into canonical globs and merge + // with operator-supplied AllowedPrincipals. We expand here (rather + // than at Match time) so the patterns are validated once at Factory. + expanded := append([]string(nil), cfg.AllowedPrincipals...) + for _, acct := range cfg.AllowedAccounts { + if err := validateAccountID(acct); err != nil { + return nil, fmt.Errorf("aws_sigv4: allowed_accounts: %w", err) + } + expanded = append(expanded, expandAccountGlobs(acct)...) + } + matcher, err := NewArnMatcher(expanded) + if err != nil { + return nil, fmt.Errorf("aws_sigv4: allowed_principals: %w", err) + } + + expectedHost := fmt.Sprintf("sts.%s.amazonaws.com", cfg.Region) + requireHTTPS := true + if cfg.STSEndpoint != "" { + h, scheme := hostAndSchemeOf(cfg.STSEndpoint) + if h != "" { + expectedHost = h + } + // Test override: allow plain http for httptest servers. + if scheme == "http" { + requireHTTPS = false + } + } + + return &Provider{ + cfg: cfg, + expectedHost: expectedHost, + requireHTTPS: requireHTTPS, + cache: NewIdentityCache(cfg.IdentityCacheTTL), + sts: NewSTSClient(cfg.Region, "", cfg.HTTPTimeout), + matcher: matcher, + now: time.Now, + }, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. +// +// The middleware extracts the Bearer token and passes it here. If the +// token doesn't start with "forge-aws-v1." we yield to the next chain +// entry. Otherwise we validate the embedded pre-signed URL, GET it on +// STS, and stamp an Identity from the returned ARN. +func (p *Provider) Verify(ctx context.Context, token string, _ auth.Headers) (*auth.Identity, error) { + parsed, err := ParseToken(token, p.expectedHost, p.requireHTTPS) + if err != nil { + // Only the prefix check distinguishes "this isn't my token" + // from "this IS my token but malformed." Use a sentinel so the + // chain can fall through on the former but stop on the latter. + if err.Error() == "missing forge-aws-v1 prefix" { + return nil, auth.ErrTokenNotForMe + } + return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + + // Region in the credential scope must match our configured region. + // Defends against cross-region replay (e.g. a token pre-signed for + // us-east-1 hitting a Forge instance configured for eu-west-1). + if parsed.Region != p.cfg.Region { + return nil, fmt.Errorf("%w: token region %q != configured %s", auth.ErrTokenRejected, parsed.Region, p.cfg.Region) + } + + // Parser-side freshness check (Review M2). Belt-and-braces on top of + // STS's own ~15min enforcement: caps the X-Amz-Expires that any + // caller can claim, and rejects already-expired tokens before we + // ever round-trip to STS. This bounds the stolen-token replay + // window to MaxTokenExpires + ClockSkew, independent of our cache TTL. + if err := parsed.CheckFreshness(p.now(), p.cfg.MaxTokenExpires, p.cfg.ClockSkew); err != nil { + return nil, fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) + } + + cacheKey := dateBucketKey(parsed.AKID, parsed.Date) + if id, ok := p.cache.Get(cacheKey); ok { + return id, nil + } + + // IMPORTANT: pass RawURL (the byte-for-byte original from the token), + // NOT parsed.URL.String(). The latter would re-encode query params + // via net/url's rules and invalidate the signature. + caller, err := p.sts.GetCallerIdentity(ctx, parsed.RawURL) + if err != nil { + return nil, err // STSClient already wraps with the right sentinel + } + + if !p.matcher.Match(caller.Arn) { + return nil, fmt.Errorf("%w: ARN %q not in allowed_principals", auth.ErrTokenRejected, caller.Arn) + } + + id := &auth.Identity{ + UserID: caller.Arn, + OrgID: caller.Account, + Source: ProviderName, + Claims: map[string]any{ + "user_id": caller.UserID, + "arn": caller.Arn, + "account": caller.Account, + "audience": p.cfg.Audience, + }, + } + p.cache.Put(cacheKey, id) + return id, nil +} + +// dateBucketKey hashes (AKID, YYYYMMDD) so two requests from the same +// AKID on the same day collapse to a single STS call per +// IdentityCacheTTL window. Hashing protects against length-leak / log-scan +// reads of the cache key. +func dateBucketKey(akid, date string) string { + bucket := date + if len(bucket) > 8 { + bucket = bucket[:8] // YYYYMMDD + } + sum := sha256.Sum256([]byte(akid + "|" + bucket)) + return hex.EncodeToString(sum[:]) +} + +// hostAndSchemeOf is a forgiving parser used only at Factory time for the +// STSEndpoint test override. Returns (host, scheme) or empty strings on +// parse failure. +func hostAndSchemeOf(raw string) (host, scheme string) { + // Accept both bare "host:port" and full "scheme://host:port/path" forms. + // For the test override we only care about the host portion (for + // matching) and whether scheme is http (so we can relax the https check). + if i := indexOf(raw, "://"); i >= 0 { + scheme = raw[:i] + rest := raw[i+3:] + for k := 0; k < len(rest); k++ { + if rest[k] == '/' { + return rest[:k], scheme + } + } + return rest, scheme + } + // No scheme: assume host:port form + for k := 0; k < len(raw); k++ { + if raw[k] == '/' { + return raw[:k], "" + } + } + return raw, "" +} + +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/aws_sigv4/provider_test.go b/forge-core/auth/providers/aws_sigv4/provider_test.go new file mode 100644 index 0000000..305a56c --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/provider_test.go @@ -0,0 +1,456 @@ +package aws_sigv4 + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// fixedTestTime is the wall-clock the test helpers pretend it is. +// Tokens minted by tokenFor() are signed at this instant (X-Amz-Date) +// and Provider.now is pinned to this instant via newTestProvider, so +// CheckFreshness sees an in-window token by default. Tests that want +// to exercise expiry/skew override Provider.now after construction. +var fixedTestTime = time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC) + +// tokenFor builds a forge-aws-v1 token whose embedded URL points at the +// test STS server. AKID + date + region + signature are placeholders; +// the fake STS doesn't validate them. X-Amz-Date is pinned to +// fixedTestTime so the freshness check passes by default. +func tokenFor(stsURL, akid, dateYYYYMMDD, region string) string { + q := url.Values{} + q.Set("Action", "GetCallerIdentity") + q.Set("Version", "2011-06-15") + q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") + q.Set("X-Amz-Credential", fmt.Sprintf("%s/%s/%s/sts/aws4_request", akid, dateYYYYMMDD, region)) + // Build X-Amz-Date from the (possibly day-rolled) dateYYYYMMDD plus + // fixedTestTime's HHMMSS so date-bucket rollover tests still work. + q.Set("X-Amz-Date", dateYYYYMMDD+"T"+fixedTestTime.UTC().Format("150405")+"Z") + q.Set("X-Amz-Expires", "900") + q.Set("X-Amz-SignedHeaders", "host") + q.Set("X-Amz-Signature", "fakesig"+akid) + full := stsURL + "/?" + q.Encode() + return TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(full)) +} + +func defaultToken(stsURL string) string { + return tokenFor(stsURL, "AKIAIOSFODNN7EXAMPLE", "20260524", "us-east-1") +} + +func newTestProvider(t *testing.T, sts http.Handler, opts ...func(*Config)) (*Provider, string) { + t.Helper() + srv := httptest.NewServer(sts) + t.Cleanup(srv.Close) + + cfg := Config{ + Region: "us-east-1", + Audience: "api://forge", + STSEndpoint: srv.URL, + IdentityCacheTTL: 60 * time.Second, + HTTPTimeout: 5 * time.Second, + } + for _, fn := range opts { + fn(&cfg) + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + // Pin the provider's clock to fixedTestTime so tokens minted by + // tokenFor() pass the M2 freshness check. + p.now = func() time.Time { return fixedTestTime } + return p, srv.URL +} + +func happySTS() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, happySTSXML) + }) +} + +func TestProvider_Name(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + if p.Name() != "aws_sigv4" { + t.Errorf("Name = %q, want aws_sigv4", p.Name()) + } +} + +func TestProvider_New_RequiresRegion(t *testing.T) { + _, err := New(Config{}) + if err == nil || !errors.Is(err, auth.ErrProviderNotConfigured) { + t.Errorf("err = %v, want wrapped ErrProviderNotConfigured", err) + } +} + +func TestProvider_New_RejectsInvalidGlob(t *testing.T) { + _, err := New(Config{Region: "us-east-1", AllowedPrincipals: []string{"["}}) + if err == nil { + t.Fatal("expected error for malformed glob") + } +} + +func TestProvider_NoPrefix_YieldsToChain(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "Bearer some.opaque.token", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestProvider_EmptyToken_YieldsToChain(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), "", nil) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestProvider_MalformedToken_Invalid(t *testing.T) { + p, _ := newTestProvider(t, happySTS()) + _, err := p.Verify(context.Background(), TokenPrefix+"!!!not-base64!!!", nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_HappyPath_ReturnsIdentity(t *testing.T) { + p, stsURL := newTestProvider(t, happySTS()) + id, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "aws_sigv4" { + t.Errorf("Source = %q", id.Source) + } + if id.UserID != "arn:aws:sts::123456789012:assumed-role/ci-deploy/session" { + t.Errorf("UserID = %q", id.UserID) + } + if id.OrgID != "123456789012" { + t.Errorf("OrgID = %q", id.OrgID) + } + if id.Claims["audience"] != "api://forge" { + t.Errorf("Claims[audience] = %v", id.Claims["audience"]) + } +} + +func TestProvider_RegionMismatch_Rejected(t *testing.T) { + // Token's credential scope says eu-west-1, provider configured us-east-1. + // Defends against cross-region token replay (the same AKID may be + // valid in either region, but the operator's allowlist applies only + // to the configured region). + p, stsURL := newTestProvider(t, happySTS()) + tok := tokenFor(stsURL, "AKIAEXAMPLE", "20260524", "eu-west-1") + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (cross-region)", err) + } +} + +func TestProvider_ForeignHost_Invalid(t *testing.T) { + // SSRF guard: a token whose URL points anywhere other than the + // expected STS host is rejected before we ever issue a request. + p, _ := newTestProvider(t, happySTS()) + hostile := "https://evil.example.com/?Action=GetCallerIdentity" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Credential=AKIA/20260524/us-east-1/sts/aws4_request" + + "&X-Amz-Signature=x" + tok := TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(hostile)) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (SSRF guard)", err) + } +} + +func TestProvider_AllowlistMiss_Rejected(t *testing.T) { + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedPrincipals = []string{"arn:aws:iam::999:role/*"} + }) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (allowlist miss)", err) + } +} + +func TestProvider_AllowlistHit_Succeeds(t *testing.T) { + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedPrincipals = []string{"arn:aws:sts::123456789012:assumed-role/ci-deploy/*"} + }) + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("Verify: %v", err) + } +} + +func TestProvider_AllowedAccounts_AllowsAnyIdentityInAccount(t *testing.T) { + // The fake STS returns assumed-role ARN for account 123456789012. + // AllowedAccounts=[123456789012] expands to globs covering all + // identity shapes in that account → the assumed-role ARN matches. + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedAccounts = []string{"123456789012"} + }) + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("Verify: %v", err) + } +} + +func TestProvider_AllowedAccounts_DifferentAccountRejected(t *testing.T) { + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedAccounts = []string{"999999999999"} + }) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_AllowedAccounts_RejectsMalformedAtFactory(t *testing.T) { + _, err := New(Config{ + Region: "us-east-1", + AllowedAccounts: []string{"not-an-account"}, + }) + if err == nil { + t.Fatal("expected error on malformed account ID") + } +} + +func TestProvider_AllowedAccounts_MergesWithAllowedPrincipals(t *testing.T) { + // Mix: account-wide grant for 123456789012 (covers the test STS + // response) + a specific role pattern for some other account. + // Verify the account-wide entry takes precedence. + p, stsURL := newTestProvider(t, happySTS(), func(c *Config) { + c.AllowedPrincipals = []string{"arn:aws:iam::999:role/specific"} + c.AllowedAccounts = []string{"123456789012"} + }) + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("Verify: %v", err) + } +} + +func TestProvider_STSDown_Unavailable(t *testing.T) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestProvider_STSRejects_Rejected(t *testing.T) { + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, "SignatureDoesNotMatch") + })) + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_CacheHit_AvoidsSTSCall(t *testing.T) { + var calls atomic.Int32 + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + _, _ = io.WriteString(w, happySTSXML) + })) + tok := defaultToken(stsURL) + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatal(err) + } + if _, err := p.Verify(context.Background(), tok, nil); err != nil { + t.Fatal(err) + } + if got := calls.Load(); got != 1 { + t.Errorf("STS calls = %d, want 1 (cache must hit)", got) + } +} + +func TestProvider_RejectedRequest_DoesNotPoisonCache(t *testing.T) { + var calls atomic.Int32 + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + if n == 1 { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, "") + return + } + _, _ = io.WriteString(w, happySTSXML) + })) + tok := defaultToken(stsURL) + _, err1 := p.Verify(context.Background(), tok, nil) + if !errors.Is(err1, auth.ErrTokenRejected) { + t.Fatalf("first Verify err = %v, want ErrTokenRejected", err1) + } + _, err2 := p.Verify(context.Background(), tok, nil) + if err2 != nil { + t.Fatalf("second Verify err = %v, want nil", err2) + } + if got := calls.Load(); got != 2 { + t.Errorf("STS calls = %d, want 2 (rejection must not poison cache)", got) + } +} + +func TestProvider_DateBucketRollover_TriggersFreshSTSCall(t *testing.T) { + var calls atomic.Int32 + p, stsURL := newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + _, _ = io.WriteString(w, happySTSXML) + })) + // Day 1: use the default fixedTestTime clock. + if _, err := p.Verify(context.Background(), tokenFor(stsURL, "AKIA", "20260524", "us-east-1"), nil); err != nil { + t.Fatal(err) + } + // Day 2: advance clock by 24h so the day-2 token is fresh under + // CheckFreshness (post-M2 freshness gate). + p.now = func() time.Time { return fixedTestTime.Add(24 * time.Hour) } + if _, err := p.Verify(context.Background(), tokenFor(stsURL, "AKIA", "20260525", "us-east-1"), nil); err != nil { + t.Fatal(err) + } + if got := calls.Load(); got != 2 { + t.Errorf("STS calls = %d, want 2 (date bucket rolled)", got) + } +} + +// --- Review M2: parser-side freshness --- + +func TestProvider_RejectsExpiredToken(t *testing.T) { + // Token's X-Amz-Date + Expires window is in the past relative to + // Provider.now. STS would also reject this, but we belt-and-brace. + p, stsURL := newTestProvider(t, happySTS()) + p.now = func() time.Time { + return fixedTestTime.Add(30 * time.Minute) // far past the 15min lifetime + 5min skew + } + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } + if err != nil && !strings.Contains(err.Error(), "expired") { + t.Errorf("err should mention 'expired'; got %v", err) + } +} + +func TestProvider_RejectsTokenFromFuture(t *testing.T) { + // Token signed beyond skew tolerance in the future. + p, stsURL := newTestProvider(t, happySTS()) + p.now = func() time.Time { + return fixedTestTime.Add(-1 * time.Hour) // way before the token's signing instant + } + _, err := p.Verify(context.Background(), defaultToken(stsURL), nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } + if err != nil && !strings.Contains(err.Error(), "future") { + t.Errorf("err should mention 'future'; got %v", err) + } +} + +func TestProvider_RejectsOverlyLongExpiresClaim(t *testing.T) { + // Caller crafted a token with X-Amz-Expires > 15min. STS would + // also reject the signature, but we belt-and-brace at parse-side. + p, stsURL := newTestProvider(t, happySTS()) + + // Build the token pointing at the test STS host so we pass the + // SSRF-guard host check; the freshness check then catches the + // oversized X-Amz-Expires. + q := url.Values{} + q.Set("Action", "GetCallerIdentity") + q.Set("Version", "2011-06-15") + q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") + q.Set("X-Amz-Credential", "AKIA/20260524/us-east-1/sts/aws4_request") + q.Set("X-Amz-Date", "20260524T010000Z") + q.Set("X-Amz-Expires", "3600") // 1 hour — exceeds our 15min cap + q.Set("X-Amz-SignedHeaders", "host") + q.Set("X-Amz-Signature", "abc") + full := stsURL + "/?" + q.Encode() + tok := TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(full)) + + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } + if err != nil && !strings.Contains(err.Error(), "X-Amz-Expires") { + t.Errorf("err should mention X-Amz-Expires cap; got %v", err) + } +} + +func TestProvider_AcceptsTokenAtEdgeOfSkewWindow(t *testing.T) { + // Just barely fresh: signed at fixedTestTime, expires after 15min, + // and now is fixedTestTime + 15min + 4min skew (still within 5min). + p, stsURL := newTestProvider(t, happySTS()) + p.now = func() time.Time { + return fixedTestTime.Add(15*time.Minute + 4*time.Minute) + } + if _, err := p.Verify(context.Background(), defaultToken(stsURL), nil); err != nil { + t.Errorf("token within skew window should pass, got %v", err) + } +} + +func TestProvider_RegisteredInRegistry(t *testing.T) { + p, err := auth.Build("aws_sigv4", map[string]any{ + "region": "us-east-1", + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "aws_sigv4" { + t.Errorf("Name = %q", p.Name()) + } +} + +func TestProvider_FactoryRejectsMissingRegion(t *testing.T) { + _, err := auth.Build("aws_sigv4", map[string]any{}) + if err == nil { + t.Fatal("expected error from factory when region is missing") + } +} + +func TestProvider_TokenPointingAtForeignSTSRegion_Invalid(t *testing.T) { + // Token URL says sts.eu-west-1.amazonaws.com — provider expects + // sts.us-east-1.amazonaws.com. The pre-validation host check should + // catch this before any STS call. + p, _ := newTestProvider(t, happySTS(), func(c *Config) { + // Use defaults so expectedHost stays sts.us-east-1.amazonaws.com. + // Drop STSEndpoint override to exercise the real host path. + c.STSEndpoint = "" + }) + // Force the provider to compare against the real sts host. + hostile := "https://sts.eu-west-1.amazonaws.com/?Action=GetCallerIdentity" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Credential=AKIA/20260524/eu-west-1/sts/aws4_request" + + "&X-Amz-Signature=x" + tok := TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(hostile)) + _, err := p.Verify(context.Background(), tok, nil) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (cross-region URL host)", err) + } +} + +// Sanity: with no STSEndpoint override the expectedHost matches AWS's +// real STS endpoint for the configured region — protects against +// accidental regressions to the prod host derivation. +func TestProvider_DefaultExpectedHost(t *testing.T) { + p, err := New(Config{Region: "us-east-1"}) + if err != nil { + t.Fatal(err) + } + if p.expectedHost != "sts.us-east-1.amazonaws.com" { + t.Errorf("expectedHost = %q", p.expectedHost) + } + if !p.requireHTTPS { + t.Error("requireHTTPS should be true by default") + } +} + +// startsWithBearer is a tiny helper for the token_kind tests in middleware +// (kept local so a test-only constant doesn't leak into the production package). +var _ = strings.HasPrefix diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go new file mode 100644 index 0000000..ebf6f31 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser.go @@ -0,0 +1,226 @@ +package aws_sigv4 + +import ( + "encoding/base64" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + "time" +) + +// TokenPrefix is the magic token-type marker that distinguishes +// forge-aws-v1 tokens from JWTs / opaque tokens in the Bearer slot. +// Mirrors the "k8s-aws-v1." convention from aws-iam-authenticator. +const TokenPrefix = "forge-aws-v1." + +// PresignedToken is the parsed view of a forge-aws-v1 Bearer token. +// The token wraps a pre-signed STS GetCallerIdentity URL; AKID + Date are +// extracted from the URL's X-Amz-Credential query param so the provider's +// identity cache can key on them without re-deriving for every Verify call. +// +// RawURL holds the original URL byte-for-byte as it appeared in the +// decoded token payload. Forge MUST use RawURL when invoking STS — round- +// tripping through Go's net/url package re-encodes query parameters in +// ways that differ from how the AWS SDK emitted them (e.g., percent- +// encoding of "/" in X-Amz-Credential, "+" in X-Amz-Security-Token), and +// any such re-encoding invalidates the signature. +type PresignedToken struct { + RawURL string // the exact URL the AWS SDK produced — preserve as-is + URL *url.URL // parsed view, for host validation and query inspection only + AKID string // for IdentityCache bucket key + Date string // YYYYMMDD scope date — for IdentityCache bucket key + Region string // from the credential scope (we cross-check against cfg.Region) + SigTime time.Time // parsed X-Amz-Date — used by CheckFreshness, not by ParseToken itself + Expires time.Duration // parsed X-Amz-Expires — used by CheckFreshness +} + +// ParseToken validates a forge-aws-v1 Bearer token end-to-end and returns +// the URL Forge should invoke on STS. +// +// expectedHost is sts..amazonaws.com for prod, or the test-mode +// override host (Config.STSEndpoint) for integration tests. +// +// Validation gates (in order — fail-fast): +// +// 1. Token starts with the TokenPrefix. +// 2. Body decodes as base64url. +// 3. Decoded payload parses as a URL. +// 4. URL scheme is https (or http when STSEndpoint test override is in use). +// 5. URL host matches expectedHost. +// 6. URL query has Action=GetCallerIdentity. +// 7. URL query has X-Amz-Algorithm=AWS4-HMAC-SHA256. +// 8. URL query has a non-empty X-Amz-Signature. +// 9. URL query has X-Amz-Credential parseable as AKID/YYYYMMDD/region/sts/aws4_request. +// +// Returns ErrTokenNotForMe only when (1) fails — the prefix is the only +// "shape" check; everything else is a malformed / rejected token from +// our perspective, classified as ErrInvalidToken or ErrTokenRejected +// by the caller. +func ParseToken(token, expectedHost string, requireHTTPS bool) (*PresignedToken, error) { + if !strings.HasPrefix(token, TokenPrefix) { + return nil, errors.New("missing forge-aws-v1 prefix") + } + encoded := strings.TrimPrefix(token, TokenPrefix) + if encoded == "" { + return nil, errors.New("forge-aws-v1 token has empty payload") + } + + // base64url decode — accept both padded and unpadded forms because + // SDKs disagree on whether to emit "=" padding. + raw, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + // fallback to standard base64url with padding + raw, err = base64.URLEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("base64url decode: %w", err) + } + } + + u, err := url.Parse(string(raw)) + if err != nil { + return nil, fmt.Errorf("decoded payload is not a URL: %w", err) + } + + if requireHTTPS && u.Scheme != "https" { + return nil, fmt.Errorf("URL scheme %q is not https", u.Scheme) + } + if u.Scheme != "https" && u.Scheme != "http" { + return nil, fmt.Errorf("URL scheme %q is not http(s)", u.Scheme) + } + // Reject userinfo BEFORE the host check. RFC 3986 separates + // userinfo from host, so net/url parses + // "https://user:pass@sts.us-east-1.amazonaws.com" into + // (u.User="user:pass", u.Host="sts.us-east-1.amazonaws.com") — + // the host check alone would let that token through. Then + // http.Client.Do would synthesize Authorization: Basic from + // u.User and ship attacker-controlled bytes to STS. STS ignores + // Basic (it uses the X-Amz-Signature query param), but we still + // don't want attacker bytes leaving the box. (Review M1.) + if u.User != nil { + return nil, errors.New("URL must not contain userinfo (RFC 3986 user:pass@ section)") + } + if !strings.EqualFold(u.Host, expectedHost) { + return nil, fmt.Errorf("URL host %q does not match expected %q (SSRF guard)", u.Host, expectedHost) + } + + q := u.Query() + if q.Get("Action") != "GetCallerIdentity" { + return nil, fmt.Errorf("URL Action=%q, want GetCallerIdentity", q.Get("Action")) + } + if q.Get("X-Amz-Algorithm") != "AWS4-HMAC-SHA256" { + return nil, fmt.Errorf("URL X-Amz-Algorithm=%q, want AWS4-HMAC-SHA256", q.Get("X-Amz-Algorithm")) + } + if q.Get("X-Amz-Signature") == "" { + return nil, errors.New("URL missing X-Amz-Signature") + } + + akid, date, region, err := parseCredentialScope(q.Get("X-Amz-Credential")) + if err != nil { + return nil, fmt.Errorf("X-Amz-Credential: %w", err) + } + + sigTime, err := parseAmzDate(q.Get("X-Amz-Date")) + if err != nil { + return nil, fmt.Errorf("X-Amz-Date: %w", err) + } + expires, err := parseAmzExpires(q.Get("X-Amz-Expires")) + if err != nil { + return nil, fmt.Errorf("X-Amz-Expires: %w", err) + } + + return &PresignedToken{ + RawURL: string(raw), + URL: u, + AKID: akid, + Date: date, + Region: region, + SigTime: sigTime, + Expires: expires, + }, nil +} + +// CheckFreshness rejects tokens whose self-declared lifetime exceeds +// maxExpires OR whose validity window has already lapsed (with skew +// for clock drift between Forge and the caller). This is defense in +// depth on top of STS's own ~15min enforcement: if STS ever accepts +// a stale token, our IdentityCache would happily serve the cached +// Identity for its full TTL. Parser-side freshness closes that gap. +// +// Caller passes `now` and the limits so this is unit-testable without +// time monkey-patching. Provider supplies them from its Config. +func (t *PresignedToken) CheckFreshness(now time.Time, maxExpires, skew time.Duration) error { + if t.Expires > maxExpires { + return fmt.Errorf("X-Amz-Expires=%s exceeds cap %s", t.Expires, maxExpires) + } + // Token's own self-declared expiry passed already (with skew tolerance). + if now.After(t.SigTime.Add(t.Expires).Add(skew)) { + return fmt.Errorf("token expired: signed at %s + %s lifetime + %s skew, now %s", + t.SigTime.UTC().Format(time.RFC3339), t.Expires, skew, now.UTC().Format(time.RFC3339)) + } + // Token from the future beyond our skew tolerance — either a wildly + // skewed client OR a malicious signer trying to extend the validity + // window. STS itself catches this; we belt-and-brace. + if t.SigTime.Sub(now) > skew { + return fmt.Errorf("token signed in the future: %s vs now %s (skew %s)", + t.SigTime.UTC().Format(time.RFC3339), now.UTC().Format(time.RFC3339), skew) + } + return nil +} + +// parseAmzDate parses an X-Amz-Date timestamp in its standard form +// "YYYYMMDDTHHMMSSZ" (e.g. "20260524T150405Z"). UTC by definition. +func parseAmzDate(s string) (time.Time, error) { + if s == "" { + return time.Time{}, errors.New("missing X-Amz-Date") + } + t, err := time.Parse("20060102T150405Z", s) + if err != nil { + return time.Time{}, fmt.Errorf("malformed %q: %v", s, err) + } + return t, nil +} + +// parseAmzExpires parses the X-Amz-Expires query value (seconds, as a +// decimal integer string). AWS SDKs constrain this to [1, 604800] +// (1s to 7 days) at signing time; we additionally cap at CheckFreshness +// time per the operator's maxExpires. +func parseAmzExpires(s string) (time.Duration, error) { + if s == "" { + return 0, errors.New("missing X-Amz-Expires") + } + n, err := strconv.Atoi(s) + if err != nil { + return 0, fmt.Errorf("not an integer: %q", s) + } + if n <= 0 { + return 0, fmt.Errorf("must be positive, got %d", n) + } + return time.Duration(n) * time.Second, nil +} + +// parseCredentialScope splits X-Amz-Credential into its five segments: +// +// AKID/YYYYMMDD/region/service/aws4_request +// +// Service MUST be "sts" and the tail MUST be "aws4_request". +func parseCredentialScope(cred string) (akid, date, region string, err error) { + if cred == "" { + return "", "", "", errors.New("missing X-Amz-Credential") + } + segs := strings.Split(cred, "/") + if len(segs) != 5 { + return "", "", "", fmt.Errorf("expected 5 /-separated parts, got %d", len(segs)) + } + if segs[3] != "sts" { + return "", "", "", fmt.Errorf("scope service=%q, want sts", segs[3]) + } + if segs[4] != "aws4_request" { + return "", "", "", fmt.Errorf("scope tail=%q, want aws4_request", segs[4]) + } + if segs[0] == "" || segs[1] == "" || segs[2] == "" { + return "", "", "", errors.New("empty AKID/date/region segment") + } + return segs[0], segs[1], segs[2], nil +} diff --git a/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go new file mode 100644 index 0000000..4ffdfb1 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sigv4_parser_test.go @@ -0,0 +1,301 @@ +package aws_sigv4 + +import ( + "encoding/base64" + "errors" + "strings" + "testing" + "time" +) + +// makeToken builds a forge-aws-v1 token from a complete URL. Helper for +// tests — production tokens come from the AWS SDK's Presign(). +func makeToken(rawURL string) string { + return TokenPrefix + base64.RawURLEncoding.EncodeToString([]byte(rawURL)) +} + +const validPresignedURL = "https://sts.us-east-1.amazonaws.com/" + + "?Action=GetCallerIdentity" + + "&Version=2011-06-15" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20260524%2Fus-east-1%2Fsts%2Faws4_request" + + "&X-Amz-Date=20260524T010000Z" + + "&X-Amz-Expires=900" + + "&X-Amz-SignedHeaders=host" + + "&X-Amz-Signature=abcd1234" + +const validHost = "sts.us-east-1.amazonaws.com" + +func TestParseToken_HappyPath(t *testing.T) { + tok := makeToken(validPresignedURL) + parsed, err := ParseToken(tok, validHost, true) + if err != nil { + t.Fatalf("ParseToken: %v", err) + } + if parsed.AKID != "AKIAIOSFODNN7EXAMPLE" { + t.Errorf("AKID = %q", parsed.AKID) + } + if parsed.Date != "20260524" { + t.Errorf("Date = %q", parsed.Date) + } + if parsed.Region != "us-east-1" { + t.Errorf("Region = %q", parsed.Region) + } + if parsed.URL == nil || parsed.URL.Host != validHost { + t.Errorf("URL host = %v", parsed.URL) + } +} + +func TestParseToken_MissingPrefix_NotForMe(t *testing.T) { + cases := []string{ + "", + "Bearer foo", + "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ4In0.sig", // JWT-shaped + "forge-aws-v0.something", // wrong version prefix + "AWS4-HMAC-SHA256 Credential=AKIA...", // old format + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + _, err := ParseToken(in, validHost, true) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "missing forge-aws-v1 prefix") { + t.Errorf("err = %v, want missing-prefix error (so caller maps to ErrTokenNotForMe)", err) + } + }) + } +} + +func TestParseToken_EmptyPayload(t *testing.T) { + _, err := ParseToken(TokenPrefix, validHost, true) + if err == nil { + t.Fatal("expected error on empty payload") + } +} + +func TestParseToken_RejectsUserinfo(t *testing.T) { + // Review M1: net/url parses "https://user:pass@sts.us-east-1.amazonaws.com" + // into u.User != nil and u.Host == "sts.us-east-1.amazonaws.com" — the + // host check alone passes. We must reject userinfo explicitly, otherwise + // http.Client.Do would synthesize Authorization: Basic and ship + // attacker bytes to STS. + hostile := strings.Replace( + validPresignedURL, + "https://sts.us-east-1.amazonaws.com/", + "https://attacker:secret@sts.us-east-1.amazonaws.com/", + 1, + ) + _, err := ParseToken(makeToken(hostile), validHost, true) + if err == nil { + t.Fatal("expected error on URL with userinfo") + } + if !strings.Contains(err.Error(), "userinfo") { + t.Errorf("err should mention userinfo; got %v", err) + } +} + +func TestParseToken_RejectsForeignHost(t *testing.T) { + // SSRF guard — even if base64 decodes to a syntactically valid URL, + // any non-STS host is rejected. + hostile := "https://evil.example.com/" + strings.Replace(validPresignedURL, "https://sts.us-east-1.amazonaws.com/", "", 1) + _, err := ParseToken(makeToken(hostile), validHost, true) + if err == nil || !strings.Contains(err.Error(), "SSRF") { + t.Errorf("err = %v, want SSRF-guard rejection", err) + } +} + +func TestParseToken_RejectsHTTPScheme_InProdMode(t *testing.T) { + httpURL := strings.Replace(validPresignedURL, "https://", "http://", 1) + _, err := ParseToken(makeToken(httpURL), validHost, true) + if err == nil || !strings.Contains(err.Error(), "scheme") { + t.Errorf("err = %v, want https-required rejection", err) + } +} + +func TestParseToken_AcceptsHTTP_InTestMode(t *testing.T) { + httpURL := strings.Replace(validPresignedURL, "https://", "http://", 1) + _, err := ParseToken(makeToken(httpURL), validHost, false) + if err != nil { + t.Errorf("test-mode http should be accepted, got %v", err) + } +} + +func TestParseToken_RejectsWrongAction(t *testing.T) { + u := strings.Replace(validPresignedURL, "Action=GetCallerIdentity", "Action=ListUsers", 1) + _, err := ParseToken(makeToken(u), validHost, true) + if err == nil || !strings.Contains(err.Error(), "Action") { + t.Errorf("err = %v, want Action-mismatch rejection", err) + } +} + +func TestParseToken_RejectsMissingSignature(t *testing.T) { + u := strings.Replace(validPresignedURL, "&X-Amz-Signature=abcd1234", "", 1) + _, err := ParseToken(makeToken(u), validHost, true) + if err == nil || !strings.Contains(err.Error(), "X-Amz-Signature") { + t.Errorf("err = %v, want missing-signature rejection", err) + } +} + +func TestParseToken_RejectsWrongAlgorithm(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Algorithm=AWS4-HMAC-SHA256", "X-Amz-Algorithm=AWS3-MD5", 1) + _, err := ParseToken(makeToken(u), validHost, true) + if err == nil || !strings.Contains(err.Error(), "X-Amz-Algorithm") { + t.Errorf("err = %v, want algorithm-mismatch rejection", err) + } +} + +func TestParseToken_RejectsBadBase64(t *testing.T) { + _, err := ParseToken(TokenPrefix+"!!!not-base64!!!", validHost, true) + if err == nil { + t.Fatal("expected error on malformed base64") + } +} + +// --- Review M2: parser surface for freshness --- + +func TestParseToken_PopulatesSigTimeAndExpires(t *testing.T) { + parsed, err := ParseToken(makeToken(validPresignedURL), validHost, true) + if err != nil { + t.Fatalf("ParseToken: %v", err) + } + wantTime := time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC) + if !parsed.SigTime.Equal(wantTime) { + t.Errorf("SigTime = %v, want %v", parsed.SigTime, wantTime) + } + if parsed.Expires != 900*time.Second { + t.Errorf("Expires = %v, want 900s", parsed.Expires) + } +} + +func TestParseToken_RejectsMissingAmzDate(t *testing.T) { + u := strings.Replace(validPresignedURL, "&X-Amz-Date=20260524T010000Z", "", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on missing X-Amz-Date") + } +} + +func TestParseToken_RejectsMalformedAmzDate(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Date=20260524T010000Z", "X-Amz-Date=not-a-date", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on malformed X-Amz-Date") + } +} + +func TestParseToken_RejectsMissingAmzExpires(t *testing.T) { + u := strings.Replace(validPresignedURL, "&X-Amz-Expires=900", "", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on missing X-Amz-Expires") + } +} + +func TestParseToken_RejectsNonNumericExpires(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Expires=900", "X-Amz-Expires=forever", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on non-numeric X-Amz-Expires") + } +} + +func TestParseToken_RejectsNonPositiveExpires(t *testing.T) { + u := strings.Replace(validPresignedURL, "X-Amz-Expires=900", "X-Amz-Expires=0", 1) + if _, err := ParseToken(makeToken(u), validHost, true); err == nil { + t.Fatal("expected error on zero X-Amz-Expires") + } +} + +func TestCheckFreshness_Expired(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC), + Expires: 15 * time.Minute, + } + now := time.Date(2026, 5, 24, 1, 25, 0, 0, time.UTC) // 25min later, beyond 15min+5min skew + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err == nil { + t.Fatal("expected expired error") + } +} + +func TestCheckFreshness_FromTheFuture(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 2, 0, 0, 0, time.UTC), + Expires: 15 * time.Minute, + } + now := time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC) // 1h before token's sign time + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err == nil { + t.Fatal("expected future-token error") + } +} + +func TestCheckFreshness_ExceedsExpiresCap(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC), + Expires: 1 * time.Hour, // exceeds the 15min cap + } + now := tok.SigTime + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err == nil { + t.Fatal("expected cap error") + } +} + +func TestCheckFreshness_HappyPathInsideSkew(t *testing.T) { + tok := &PresignedToken{ + SigTime: time.Date(2026, 5, 24, 1, 0, 0, 0, time.UTC), + Expires: 15 * time.Minute, + } + // 19 min later: past Expires but within skew tolerance. + now := tok.SigTime.Add(19 * time.Minute) + if err := tok.CheckFreshness(now, 15*time.Minute, 5*time.Minute); err != nil { + t.Errorf("token within skew window should pass, got %v", err) + } +} + +func TestParseCredentialScope_HappyPath(t *testing.T) { + akid, date, region, err := parseCredentialScope("AKIA123/20260524/us-east-1/sts/aws4_request") + if err != nil { + t.Fatalf("parseCredentialScope: %v", err) + } + if akid != "AKIA123" || date != "20260524" || region != "us-east-1" { + t.Errorf("got %q/%q/%q", akid, date, region) + } +} + +func TestParseCredentialScope_Malformed(t *testing.T) { + cases := map[string]string{ + "empty": "", + "too few segments": "AKIA/20260524/us-east-1/sts", + "too many segments": "AKIA/20260524/us-east-1/sts/aws4_request/extra", + "wrong service": "AKIA/20260524/us-east-1/s3/aws4_request", + "wrong tail": "AKIA/20260524/us-east-1/sts/aws3_request", + "empty AKID segment": "/20260524/us-east-1/sts/aws4_request", + } + for name, in := range cases { + t.Run(name, func(t *testing.T) { + if _, _, _, err := parseCredentialScope(in); err == nil { + t.Errorf("expected error for %q", in) + } + }) + } +} + +// FuzzParseToken — pure decoder must never panic, regardless of input. +func FuzzParseToken(f *testing.F) { + f.Add(makeToken(validPresignedURL)) + f.Add("") + f.Add(TokenPrefix) + f.Add(strings.Repeat(TokenPrefix, 10)) + f.Add(TokenPrefix + "AAAA====") + f.Fuzz(func(_ *testing.T, in string) { + _, _ = ParseToken(in, validHost, true) + }) +} + +// assert we map missing-prefix to a stable error message because the +// caller (provider.go) string-matches it to convert to ErrTokenNotForMe. +func TestParseToken_MissingPrefixErrorMessageIsStable(t *testing.T) { + _, err := ParseToken("garbage", validHost, true) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, err) || err.Error() != "missing forge-aws-v1 prefix" { + t.Errorf("err.Error() = %q, want exact string", err.Error()) + } +} diff --git a/forge-core/auth/providers/aws_sigv4/sts_client.go b/forge-core/auth/providers/aws_sigv4/sts_client.go new file mode 100644 index 0000000..7e33d18 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sts_client.go @@ -0,0 +1,132 @@ +package aws_sigv4 + +import ( + "context" + "encoding/xml" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// STSClient invokes a pre-signed STS GetCallerIdentity URL produced by the +// caller's AWS SDK. Because the signature is in the URL's query parameters +// (not headers tied to a destination host), the request validates against +// STS no matter who relays it — that's the design property that lets +// Forge act as a verifier without holding any AWS secrets. +// +// ~80 LOC of hand-rolled HTTP + XML. No aws-sdk-go-v2 dependency +// (decision §9.1). +type STSClient struct { + http *http.Client +} + +// NewSTSClient builds a client that GETs the URL the caller pre-signed. +// `region` is informational here (the URL itself carries the region in +// its credential scope and host); we keep the arg for symmetry with the +// pre-rewrite API and for future per-region tuning. +// +// CheckRedirect is pinned to ErrUseLastResponse. STS never legitimately +// issues 3xx; the parser-side host gate (sigv4_parser.go's expectedHost) +// only validates the FIRST hop. If we let Go's default policy auto-follow +// a 302, an attacker (MITM with a valid cert, TLS-inspecting corporate +// proxy, DNS hijack) could redirect us to a foreign URL whose body becomes +// the parsed STS XML — and that XML controls Identity.UserID/OrgID/Arn. +// Refuse redirects outright so the same-host guard actually holds. +func NewSTSClient(_ /* region */ string, _ /* legacyOverrideUnused */ string, timeout time.Duration) *STSClient { + return &STSClient{ + http: &http.Client{ + Timeout: timeout, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + } +} + +// CallerIdentity is the parsed STS response — the canonical identifiers +// Forge stamps into auth.Identity. +type CallerIdentity struct { + UserID string // e.g. "AROAJ...:session-name" + Arn string // e.g. "arn:aws:sts::123:assumed-role/ci-deploy/session" + Account string // e.g. "123456789012" +} + +// GetCallerIdentity GETs the pre-signed URL and parses the XML response. +// +// Error classification: +// +// 200 OK → CallerIdentity, nil +// 4xx → auth.ErrTokenRejected (caller's signature didn't validate; +// most often "SignatureDoesNotMatch", "ExpiredToken", or +// "InvalidClientTokenId") +// 5xx / network → auth.ErrProviderUnavailable (review #6 contract) +// parse failure → auth.ErrProviderUnavailable (unexpected response shape) +func (c *STSClient) GetCallerIdentity(ctx context.Context, presignedURL string) (*CallerIdentity, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, presignedURL, nil) + if err != nil { + return nil, fmt.Errorf("%w: build STS request: %v", auth.ErrProviderUnavailable, err) + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: STS RPC: %v", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + + // Real GetCallerIdentity responses are ~1 KiB; cap at 64 KiB. + raw, err := io.ReadAll(io.LimitReader(resp.Body, 64<<10)) + if err != nil { + return nil, fmt.Errorf("%w: read STS body: %v", auth.ErrProviderUnavailable, err) + } + + switch { + case resp.StatusCode == http.StatusOK: + return parseGetCallerIdentityResponse(raw) + case resp.StatusCode >= 400 && resp.StatusCode < 500: + return nil, fmt.Errorf("%w: STS rejected signature: %s", auth.ErrTokenRejected, summarize(raw)) + case resp.StatusCode >= 500: + return nil, fmt.Errorf("%w: STS HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + default: + return nil, fmt.Errorf("%w: STS unexpected status %d", auth.ErrProviderUnavailable, resp.StatusCode) + } +} + +// parseGetCallerIdentityResponse extracts the three canonical fields. +// STS always emits all three; absence is treated as a malformed reply. +func parseGetCallerIdentityResponse(raw []byte) (*CallerIdentity, error) { + var resp struct { + XMLName xml.Name `xml:"GetCallerIdentityResponse"` + Result struct { + UserID string `xml:"UserId"` + Account string `xml:"Account"` + Arn string `xml:"Arn"` + } `xml:"GetCallerIdentityResult"` + } + if err := xml.Unmarshal(raw, &resp); err != nil { + return nil, fmt.Errorf("%w: parse STS XML: %v", auth.ErrProviderUnavailable, err) + } + if resp.Result.Arn == "" || resp.Result.Account == "" || resp.Result.UserID == "" { + return nil, fmt.Errorf("%w: STS XML missing required fields", auth.ErrProviderUnavailable) + } + return &CallerIdentity{ + UserID: resp.Result.UserID, + Arn: resp.Result.Arn, + Account: resp.Result.Account, + }, nil +} + +// summarize returns a short, single-line, log-safe rendering of an STS +// error body. Caps at 200 chars and strips newlines so STS error text +// (which can echo the caller's headers in some shapes) never propagates +// to logs verbatim. +func summarize(raw []byte) string { + s := string(raw) + if len(s) > 200 { + s = s[:200] + "…" + } + return strings.ReplaceAll(s, "\n", " ") +} diff --git a/forge-core/auth/providers/aws_sigv4/sts_client_test.go b/forge-core/auth/providers/aws_sigv4/sts_client_test.go new file mode 100644 index 0000000..87e5127 --- /dev/null +++ b/forge-core/auth/providers/aws_sigv4/sts_client_test.go @@ -0,0 +1,188 @@ +package aws_sigv4 + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +const happySTSXML = ` + + AROAJ123:session + 123456789012 + arn:aws:sts::123456789012:assumed-role/ci-deploy/session + + req-id +` + +func TestSTSClient_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("STS got method %q, want GET", r.Method) + } + w.Header().Set("Content-Type", "text/xml") + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + id, err := c.GetCallerIdentity(context.Background(), srv.URL+"/?Action=GetCallerIdentity") + if err != nil { + t.Fatalf("GetCallerIdentity: %v", err) + } + if id.Arn != "arn:aws:sts::123456789012:assumed-role/ci-deploy/session" { + t.Errorf("Arn = %q", id.Arn) + } + if id.Account != "123456789012" { + t.Errorf("Account = %q", id.Account) + } + if id.UserID != "AROAJ123:session" { + t.Errorf("UserID = %q", id.UserID) + } +} + +func TestSTSClient_403_Rejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, "SignatureDoesNotMatch") + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestSTSClient_500_Unavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_NetworkError_Unavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + url := srv.URL + srv.Close() + + c := NewSTSClient("us-east-1", "", 1*time.Second) + _, err := c.GetCallerIdentity(context.Background(), url) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_BodyCap(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "") + _, _ = io.WriteString(w, strings.Repeat("A", 128<<10)) + _, _ = io.WriteString(w, "") + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error on oversized STS body") + } + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_MissingFieldsRejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, ` + x123 + `) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestSTSClient_RequestCount(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + c := NewSTSClient("us-east-1", "", 5*time.Second) + for range 3 { + if _, err := c.GetCallerIdentity(context.Background(), srv.URL); err != nil { + t.Fatalf("call: %v", err) + } + } + if calls.Load() != 3 { + t.Errorf("STS calls = %d, want 3", calls.Load()) + } +} + +func TestSTSClient_DoesNotFollowRedirects(t *testing.T) { + // Review B3: Go's default http.Client follows redirects up to 10 + // hops. The parser-side host gate only validates the first hop — + // auto-following a 302 to attacker-controlled bytes would let + // those bytes become the parsed STS XML and control the stamped + // Identity. Pin: any 3xx is treated as STS-unavailable (we never + // follow), so the same-host guard actually holds. + var hits atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits.Add(1) + w.Header().Set("Location", "https://attacker.example.com/") + w.WriteHeader(http.StatusFound) // 302 + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + _, err := c.GetCallerIdentity(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error on 302; client must not follow") + } + // 3xx falls into the "unexpected status" arm → unavailable. + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if hits.Load() != 1 { + t.Errorf("STS hit %d times, want exactly 1 (redirect was followed)", hits.Load()) + } +} + +func TestSTSClient_PreservesURLQueryString(t *testing.T) { + // The pre-signed URL carries the signature in query params; the + // client MUST send those verbatim to STS or STS will reject. + var capturedQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedQuery = r.URL.RawQuery + _, _ = io.WriteString(w, happySTSXML) + })) + defer srv.Close() + + c := NewSTSClient("us-east-1", "", 5*time.Second) + wantQuery := "Action=GetCallerIdentity&X-Amz-Signature=abc123" + _, _ = c.GetCallerIdentity(context.Background(), srv.URL+"/?"+wantQuery) + if capturedQuery != wantQuery { + t.Errorf("STS received query %q, want %q", capturedQuery, wantQuery) + } +} diff --git a/forge-core/auth/providers/azure_ad/graph_cache.go b/forge-core/auth/providers/azure_ad/graph_cache.go new file mode 100644 index 0000000..de67224 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_cache.go @@ -0,0 +1,61 @@ +package azure_ad + +import ( + "sync" + "time" +) + +// GraphCache holds enriched group memberships keyed by user ID, with a +// short TTL. Bounds how long a stale "removed from group" state stays +// cached after AAD's reality changes. +type GraphCache struct { + ttl time.Duration + mu sync.RWMutex + data map[string]graphEntry + now func() time.Time +} + +type graphEntry struct { + groups []string + expireAt time.Time +} + +// NewGraphCache builds an empty cache. +func NewGraphCache(ttl time.Duration) *GraphCache { + return &GraphCache{ + ttl: ttl, + data: make(map[string]graphEntry), + now: time.Now, + } +} + +// Get returns the cached groups for userID, or (nil, false) on miss/expiry. +// +// The returned slice is a defensive copy — callers that subsequently mutate +// their Identity.Groups (the auth.Identity layer treats Groups as a freely- +// mutable field) MUST NOT corrupt the cache. (Review NIT.) +func (c *GraphCache) Get(userID string) ([]string, bool) { + c.mu.RLock() + e, ok := c.data[userID] + c.mu.RUnlock() + if !ok || c.now().After(e.expireAt) { + return nil, false + } + return append([]string(nil), e.groups...), true +} + +// Put stores the groups under userID with a fresh TTL. Overwrites any +// prior entry (does not extend). +// +// Stores a defensive copy so subsequent caller mutations of the input +// slice don't reach back through cache hits. +func (c *GraphCache) Put(userID string, groups []string) { + c.mu.Lock() + c.data[userID] = graphEntry{ + groups: append([]string(nil), groups...), + expireAt: c.now().Add(c.ttl), + } + c.mu.Unlock() +} + +func (c *GraphCache) setNow(fn func() time.Time) { c.now = fn } diff --git a/forge-core/auth/providers/azure_ad/graph_cache_test.go b/forge-core/auth/providers/azure_ad/graph_cache_test.go new file mode 100644 index 0000000..847d00a --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_cache_test.go @@ -0,0 +1,73 @@ +package azure_ad + +import ( + "testing" + "time" +) + +func TestGraphCache_HitMiss(t *testing.T) { + c := NewGraphCache(time.Minute) + if _, ok := c.Get("user-1"); ok { + t.Error("empty cache returned hit") + } + c.Put("user-1", []string{"g1", "g2"}) + groups, ok := c.Get("user-1") + if !ok || len(groups) != 2 { + t.Errorf("Get = (%v, %v)", groups, ok) + } +} + +func TestGraphCache_GetReturnsDefensiveCopy(t *testing.T) { + // Caller mutating Identity.Groups must NOT corrupt the cache + // (Review NIT). + c := NewGraphCache(time.Minute) + c.Put("u", []string{"g1", "g2", "g3"}) + + got, ok := c.Get("u") + if !ok { + t.Fatal("expected hit") + } + got[0] = "tampered" + //nolint:staticcheck // intentional caller-side append; the result is + // discarded because we're testing that mutations of `got` don't + // reach back into the cache. + _ = append(got, "extra") + + again, _ := c.Get("u") + if again[0] != "g1" { + t.Errorf("cache was mutated by caller: %v", again) + } + if len(again) != 3 { + t.Errorf("cache slice length changed: %d", len(again)) + } +} + +func TestGraphCache_PutStoresDefensiveCopy(t *testing.T) { + // Caller mutating the input slice after Put must NOT bleed + // through into future Gets. + c := NewGraphCache(time.Minute) + src := []string{"g1", "g2"} + c.Put("u", src) + src[0] = "tampered" + + got, _ := c.Get("u") + if got[0] != "g1" { + t.Errorf("Put didn't copy: %v", got) + } +} + +func TestGraphCache_Expiry(t *testing.T) { + now := time.Unix(1_700_000_000, 0) + c := NewGraphCache(60 * time.Second) + c.setNow(func() time.Time { return now }) + + c.Put("user-1", []string{"g1"}) + + if _, ok := c.Get("user-1"); !ok { + t.Fatal("expected hit before expiry") + } + now = now.Add(61 * time.Second) + if _, ok := c.Get("user-1"); ok { + t.Error("expected miss after expiry") + } +} diff --git a/forge-core/auth/providers/azure_ad/graph_client.go b/forge-core/auth/providers/azure_ad/graph_client.go new file mode 100644 index 0000000..c4f4a7e --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_client.go @@ -0,0 +1,181 @@ +package azure_ad + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// GraphClient calls Microsoft Graph /me/transitiveMemberOf to enrich +// group memberships when the JWT's groups claim overflows (AAD truncates +// groups when the user is in more than ~200 of them). +// +// Forge holds NO Graph credentials of its own — the caller's Bearer +// token is reflected to Graph, which authorizes the read against the +// user's delegated permission (GroupMember.Read.All). +type GraphClient struct { + endpoint string // initial page URL + endpointHost string // pre-parsed for cheap per-page nextLink validation + endpointScheme string // pre-parsed for the same — guards against http://-downgrade nextLinks + http *http.Client +} + +const graphBaseURL = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id&$top=100" + +// NewGraphClient builds a client pointed at the real Graph endpoint. +func NewGraphClient(timeout time.Duration) *GraphClient { + return newGraphClientFor(graphBaseURL, timeout) +} + +// NewGraphClientWithEndpoint is a TEST-ONLY constructor for pointing at +// a fake Graph server. +func NewGraphClientWithEndpoint(endpoint string, timeout time.Duration) *GraphClient { + return newGraphClientFor(endpoint, timeout) +} + +func newGraphClientFor(endpoint string, timeout time.Duration) *GraphClient { + host, scheme := "", "" + if u, err := url.Parse(endpoint); err == nil { + host = u.Host + scheme = u.Scheme + } + return &GraphClient{ + endpoint: endpoint, + endpointHost: host, + endpointScheme: scheme, + http: &http.Client{ + Timeout: timeout, + // Reject HTTP redirects unconditionally. Graph paginates via + // application-layer @odata.nextLink (which we validate against + // the configured host + scheme); it does NOT need transport- + // layer 301/302/307s. Allowing the default redirect-follow + // policy would bypass our same-host guard — an attacker + // returning a 302 to a foreign URL would have the bearer + // reflected there. ErrUseLastResponse returns the 3xx as-is; + // our status-code switch then classifies it as unavailable. + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + } +} + +// TransitiveMemberOf walks the paginated response and returns the full +// list of (transitive) group object IDs the caller belongs to. The +// authHeader is reflected verbatim — Forge does not authenticate to Graph +// independently. +// +// Error classification: +// +// 401 / 403 → auth.ErrTokenRejected (caller's token missing +// GroupMember.Read.All consent) +// 5xx / network → auth.ErrProviderUnavailable +// @odata.nextLink pointing at a foreign host → error (never followed) +func (c *GraphClient) TransitiveMemberOf(ctx context.Context, _ string, authHeader string) ([]string, error) { + if authHeader == "" { + return nil, fmt.Errorf("%w: graph enrichment needs a forwardable Bearer", auth.ErrInvalidToken) + } + out := []string{} + next := c.endpoint + for next != "" { + if err := ensureGraphHost(c.endpointHost, c.endpointScheme, next); err != nil { + return nil, fmt.Errorf("%w: graph nextLink host: %v", auth.ErrProviderUnavailable, err) + } + page, nextURL, err := c.fetchPage(ctx, next, authHeader) + if err != nil { + return nil, err + } + out = append(out, page...) + next = nextURL + if len(out) > 5000 { + return nil, errors.New("graph response exceeds 5000 groups (likely misconfiguration)") + } + } + return out, nil +} + +func (c *GraphClient) fetchPage(ctx context.Context, u, authHeader string) (ids []string, next string, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, "", fmt.Errorf("%w: graph build request: %v", auth.ErrProviderUnavailable, err) + } + req.Header.Set("Authorization", authHeader) + req.Header.Set("Accept", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + return nil, "", fmt.Errorf("%w: graph fetch: %v", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1 MiB cap per page + if err != nil { + return nil, "", fmt.Errorf("%w: graph read: %v", auth.ErrProviderUnavailable, err) + } + + switch { + case resp.StatusCode == http.StatusOK: + // fall through to parse + case resp.StatusCode == http.StatusUnauthorized, resp.StatusCode == http.StatusForbidden: + return nil, "", fmt.Errorf("%w: graph %d (likely missing GroupMember.Read.All consent)", auth.ErrTokenRejected, resp.StatusCode) + case resp.StatusCode >= 500: + return nil, "", fmt.Errorf("%w: graph HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + default: + return nil, "", fmt.Errorf("%w: graph HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + } + + var page struct { + Value []struct { + ID string `json:"id"` + } `json:"value"` + NextLink string `json:"@odata.nextLink"` + } + if err := json.Unmarshal(body, &page); err != nil { + return nil, "", fmt.Errorf("%w: graph parse: %v", auth.ErrProviderUnavailable, err) + } + ids = make([]string, 0, len(page.Value)) + for _, g := range page.Value { + ids = append(ids, g.ID) + } + return ids, page.NextLink, nil +} + +// ensureGraphHost rejects @odata.nextLink values that point at a foreign +// host OR downgrade the scheme. Both checks matter: +// +// - Host: Graph paginates within graph.microsoft.com. A foreign-host +// nextLink would coerce Forge into sending the caller's Bearer to +// an attacker. +// - Scheme: Go's http.Client strips Authorization on cross-host +// redirects but NOT on cross-scheme (https→http) downgrades to the +// same host. A `nextLink: "http://graph.microsoft.com/..."` would +// pass the host check and then leak the Bearer in plaintext to +// anyone able to MITM the connection. Require scheme match too. +// +// `configuredHost` / `configuredScheme` are pre-parsed from +// GraphClient.endpoint at construction time so we don't reparse it for +// every paginated request. For tests, the configured scheme is http +// (httptest servers) and that's intentionally allowed. +func ensureGraphHost(configuredHost, configuredScheme, candidate string) error { + if candidate == "" { + return nil + } + got, err := url.Parse(candidate) + if err != nil { + return err + } + if !strings.EqualFold(configuredHost, got.Host) { + return fmt.Errorf("nextLink host %q does not match configured %q", got.Host, configuredHost) + } + if !strings.EqualFold(configuredScheme, got.Scheme) { + return fmt.Errorf("nextLink scheme %q does not match configured %q (Bearer-downgrade guard)", got.Scheme, configuredScheme) + } + return nil +} diff --git a/forge-core/auth/providers/azure_ad/graph_client_test.go b/forge-core/auth/providers/azure_ad/graph_client_test.go new file mode 100644 index 0000000..0ee2e6e --- /dev/null +++ b/forge-core/auth/providers/azure_ad/graph_client_test.go @@ -0,0 +1,211 @@ +package azure_ad + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +func TestGraph_HappyPath_Paginated(t *testing.T) { + var calls int + mux := http.NewServeMux() + var graphURL string + + mux.HandleFunc("/page1", func(w http.ResponseWriter, _ *http.Request) { + calls++ + _, _ = io.WriteString(w, `{"value":[{"id":"g1"},{"id":"g2"}],"@odata.nextLink":"`+graphURL+`/page2"}`) + }) + mux.HandleFunc("/page2", func(w http.ResponseWriter, _ *http.Request) { + calls++ + _, _ = io.WriteString(w, `{"value":[{"id":"g3"}]}`) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + graphURL = srv.URL + + c := NewGraphClientWithEndpoint(srv.URL+"/page1", 5*time.Second) + out, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if err != nil { + t.Fatalf("TransitiveMemberOf: %v", err) + } + if len(out) != 3 || out[0] != "g1" || out[2] != "g3" { + t.Errorf("groups = %v", out) + } + if calls != 2 { + t.Errorf("pages fetched = %d, want 2", calls) + } +} + +func TestGraph_401_Rejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestGraph_403_Rejected(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestGraph_500_Unavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestGraph_NoAuthHeader_Invalid(t *testing.T) { + c := NewGraphClientWithEndpoint("http://does.not.matter", 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "") + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestGraph_AuthHeaderReflected(t *testing.T) { + var captured string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Get("Authorization") + _, _ = io.WriteString(w, `{"value":[]}`) + })) + defer srv.Close() + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, _ = c.TransitiveMemberOf(context.Background(), "user-1", "Bearer the-token") + if captured != "Bearer the-token" { + t.Errorf("Graph got Authorization = %q", captured) + } +} + +func TestGraph_DefensivePaginationCap(t *testing.T) { + // Server keeps emitting @odata.nextLink to itself; expect cap to fire. + var srvURL string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, fmt.Sprintf(`{"value":[%s],"@odata.nextLink":"%s"}`, manyIDs(100), srvURL)) + })) + defer srv.Close() + srvURL = srv.URL + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer x") + if err == nil { + t.Fatal("expected defensive cap error") + } +} + +func TestGraphClient_DoesNotFollowRedirects(t *testing.T) { + // Review B2: ensureGraphHost only validates @odata.nextLink (the + // application-layer paginator). HTTP 301/302/307 from Graph were + // being auto-followed by Go's default policy, bypassing the host + // guard. An attacker (MITM, TLS-inspecting proxy, DNS hijack) + // returning a 302 to a foreign URL would have the response body + // JSON-unmarshalled as `value: [{id: ...}]` — those IDs become + // Identity.Groups, which the future Phase 4 authz layer will + // trust. Pin CheckRedirect = ErrUseLastResponse. + var hits int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits++ + w.Header().Set("Location", "https://attacker.example.com/value") + w.WriteHeader(http.StatusFound) + })) + defer srv.Close() + + c := NewGraphClientWithEndpoint(srv.URL, 5*time.Second) + _, err := c.TransitiveMemberOf(context.Background(), "user-1", "Bearer token") + if err == nil { + t.Fatal("expected error on 302; client must not follow") + } + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if hits != 1 { + t.Errorf("graph hit %d times, want 1 (redirect was followed)", hits) + } +} + +func TestEnsureGraphHost_RejectsForeignHost(t *testing.T) { + err := ensureGraphHost("graph.microsoft.com", "https", "https://evil.example.com/me/next") + if err == nil { + t.Fatal("expected error on foreign host") + } +} + +func TestEnsureGraphHost_AcceptsSameHost(t *testing.T) { + err := ensureGraphHost( + "graph.microsoft.com", "https", + "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$skiptoken=abc", + ) + if err != nil { + t.Errorf("same-host nextLink rejected: %v", err) + } +} + +func TestEnsureGraphHost_EmptyOK(t *testing.T) { + if err := ensureGraphHost("graph.microsoft.com", "https", ""); err != nil { + t.Errorf("empty nextLink should be ok, got %v", err) + } +} + +func TestEnsureGraphHost_RejectsSchemeDowngrade(t *testing.T) { + // Review B1: http://-on-same-host MUST be rejected. Go's http.Client + // keeps Authorization across same-host redirects, so a downgrade + // would leak the caller's Bearer in plaintext to anyone with a + // network position to MITM the response. + err := ensureGraphHost( + "graph.microsoft.com", "https", + "http://graph.microsoft.com/v1.0/me/transitiveMemberOf?$skiptoken=abc", + ) + if err == nil { + t.Fatal("expected error on http downgrade") + } + if !strings.Contains(err.Error(), "scheme") { + t.Errorf("err should mention scheme; got %v", err) + } +} + +func TestEnsureGraphHost_TestModeHTTPOK(t *testing.T) { + // When the configured endpoint itself is http:// (httptest servers + // in unit/integration tests), http://-same-host nextLinks must + // still validate. + err := ensureGraphHost( + "127.0.0.1:54321", "http", + "http://127.0.0.1:54321/page2", + ) + if err != nil { + t.Errorf("test-mode http nextLink should be accepted, got %v", err) + } +} + +// manyIDs returns a JSON snippet for `count` entries — used for the +// pagination cap test. +func manyIDs(count int) string { + parts := make([]string, count) + for i := range count { + parts[i] = fmt.Sprintf(`{"id":"g-%d"}`, i) + } + return strings.Join(parts, ",") +} diff --git a/forge-core/auth/providers/azure_ad/provider.go b/forge-core/auth/providers/azure_ad/provider.go new file mode 100644 index 0000000..95108fa --- /dev/null +++ b/forge-core/auth/providers/azure_ad/provider.go @@ -0,0 +1,287 @@ +// Package azure_ad authenticates Microsoft Entra ID (Azure AD) tokens. +// Composes the Phase 1 oidc.Provider (decision §9.2) for the heavy +// lifting — signature verify and base claim validation — and layers +// AAD-specific concerns on top: +// +// - Tenant lock-in via the `tid` claim +// - Optional Microsoft Graph group enrichment when the JWT's groups +// claim overflows (AAD truncates at ~200 groups) +// - Correct issuer template for single- vs. multi-tenant +// +// Decision §9.5: standard Bearer flow; no widened-header use. +// +// Audit reason codes (Phase 1 contract): +// +// rejected — bad signature, expired, tid mismatch, +// aud mismatch, Graph 401/403 +// invalid — missing tid, malformed claims, +// unsupported alg +// provider_unavailable — AAD JWKS down, Graph 5xx +// not_for_me — empty Bearer (delegates to OIDC's looksLikeJWT) +package azure_ad + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/oidc" +) + +// ProviderName is the registry name. +const ProviderName = "azure_ad" + +const ( + aadAuthorityBase = "https://login.microsoftonline.com" + + defaultGraphTimeout = 5 * time.Second + defaultJWKSCacheTTL = time.Hour + defaultGraphCacheTTL = 5 * time.Minute +) + +// Config controls the azure_ad provider. +type Config struct { + // TenantID is the Entra tenant GUID. REQUIRED unless AllowMultiTenant + // is true. + TenantID string `yaml:"tenant_id"` + + // Audience is REQUIRED. Typically the Application ID URI from the + // app registration (e.g. "api://forge"). + Audience string `yaml:"audience"` + + // AllowMultiTenant enables accepting tokens from Entra tenants other + // than the one in TenantID. Defaults to false (single-tenant — safe + // choice). When true: + // - the composed oidc.Provider's issuer-equality check is + // suppressed (the "common" issuer template has a {tenantid} + // placeholder that string-equality can't satisfy) + // - tenancy enforcement moves to AllowedTenants (below); see + // CHANGELOG for the security implications + AllowMultiTenant bool `yaml:"allow_multi_tenant,omitempty"` + + // AllowedTenants is an optional allowlist of Entra tenant GUIDs, + // matched against the JWT's `tid` claim. Only meaningful when + // AllowMultiTenant=true; ignored in single-tenant mode (TenantID + // is the gate there). + // + // Empty list + AllowMultiTenant=true = "any tenant globally" — + // the documented but high-risk shape. Set this list for the safer + // "these specific tenants only" semantic. + // + // Effort to set: customers know their partner tenants; operators + // just copy GUIDs in. There is no API to enumerate them. + AllowedTenants []string `yaml:"allowed_tenants,omitempty"` + + // GroupsMode is "claim" (default — uses the in-JWT groups/roles + // claim) or "graph" (queries Microsoft Graph when groups are missing, + // i.e. AAD overage). + GroupsMode string `yaml:"groups_mode,omitempty"` + + // GraphTimeout caps each Graph call. Default 5s. Only used when + // GroupsMode == "graph". + GraphTimeout time.Duration `yaml:"graph_timeout,omitempty"` + + // JWKSCacheTTL bounds the JWKS cache age. Defaults to 1h. + JWKSCacheTTL time.Duration `yaml:"jwks_cache_ttl,omitempty"` + + // GraphEndpoint is a TEST-ONLY override pointing at a fake Graph + // server. Empty in production. + GraphEndpoint string `yaml:"-"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.Audience == "" { + return fmt.Errorf("%w: audience required (e.g. api://forge)", auth.ErrProviderNotConfigured) + } + if !c.AllowMultiTenant && c.TenantID == "" { + return fmt.Errorf("%w: tenant_id required unless allow_multi_tenant=true", auth.ErrProviderNotConfigured) + } + if c.GroupsMode != "" && c.GroupsMode != "claim" && c.GroupsMode != "graph" { + return fmt.Errorf("%w: groups_mode must be 'claim' or 'graph', got %q", auth.ErrProviderNotConfigured, c.GroupsMode) + } + // allowed_tenants only makes sense with multi-tenant. Reject the + // combination at factory time so a typo'd config doesn't silently + // degrade to single-tenant behavior the operator didn't intend. + if !c.AllowMultiTenant && len(c.AllowedTenants) > 0 { + return fmt.Errorf("%w: allowed_tenants is only meaningful when allow_multi_tenant=true (single-tenant mode uses tenant_id directly)", auth.ErrProviderNotConfigured) + } + return nil +} + +// ExtractTenantID returns the "tid" claim, or "" if it's missing / +// non-string. The empty-return form lets callers distinguish "missing" +// from "wrong tenant" without a typed error. +func ExtractTenantID(claims map[string]any) string { + tid, _ := claims["tid"].(string) + return tid +} + +// Provider implements auth.Provider for AAD callers. +type Provider struct { + cfg Config + oidc *oidc.Provider // composition (decision §9.2) + graph *GraphClient // nil unless GroupsMode == "graph" + cache *GraphCache // nil unless GroupsMode == "graph" +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.GroupsMode == "" { + cfg.GroupsMode = "claim" + } + if cfg.GraphTimeout == 0 { + cfg.GraphTimeout = defaultGraphTimeout + } + if cfg.JWKSCacheTTL == 0 { + cfg.JWKSCacheTTL = defaultJWKSCacheTTL + } + + inner, err := oidc.New(oidc.Config{ + Issuer: resolveIssuer(cfg), + Audience: cfg.Audience, + JWKSCacheTTL: cfg.JWKSCacheTTL, + SkipIssuerCheck: cfg.AllowMultiTenant, + }) + if err != nil { + return nil, fmt.Errorf("azure_ad: composing oidc provider: %w", err) + } + + p := &Provider{cfg: cfg, oidc: inner} + if cfg.GroupsMode == "graph" { + if cfg.GraphEndpoint != "" { + p.graph = NewGraphClientWithEndpoint(cfg.GraphEndpoint, cfg.GraphTimeout) + } else { + p.graph = NewGraphClient(cfg.GraphTimeout) + } + p.cache = NewGraphCache(defaultGraphCacheTTL) + } + return p, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. +func (p *Provider) Verify(ctx context.Context, token string, headers auth.Headers) (*auth.Identity, error) { + id, err := p.oidc.Verify(ctx, token, headers) + if err != nil { + return nil, err + } + + // Tenant gate. Three modes: + // + // - single-tenant (AllowMultiTenant=false): tid MUST equal TenantID + // - multi-tenant + AllowedTenants set: tid MUST be in the list + // - multi-tenant + AllowedTenants empty: no tid check at all + // ("any tenant globally" + // — the documented high-risk + // shape, explicitly opted into + // by leaving the list empty) + if !p.cfg.AllowMultiTenant { + tid := ExtractTenantID(id.Claims) + if tid == "" { + return nil, fmt.Errorf("%w: AAD token missing tid claim", auth.ErrInvalidToken) + } + if tid != p.cfg.TenantID { + return nil, fmt.Errorf("%w: tid mismatch", auth.ErrTokenRejected) + } + } else if len(p.cfg.AllowedTenants) > 0 { + // Multi-tenant with an explicit allowlist (Review M3): the + // composed oidc.Provider's iss check is suppressed and the + // single-tenant arm above is skipped, but operators who set + // AllowedTenants want explicit per-tenant trust — enforce it. + tid := ExtractTenantID(id.Claims) + if tid == "" { + return nil, fmt.Errorf("%w: AAD token missing tid claim", auth.ErrInvalidToken) + } + if !tenantInAllowlist(tid, p.cfg.AllowedTenants) { + return nil, fmt.Errorf("%w: tid %q not in allowed_tenants", auth.ErrTokenRejected, tid) + } + } + + // Optional Graph enrichment. + if p.cfg.GroupsMode == "graph" && needsEnrichment(id.Groups) { + if enriched, err := p.enrichGroups(ctx, id, headers); err == nil { + id.Groups = enriched + } + // Soft-fail on Graph failure: leave Groups empty rather than + // blocking prod traffic on a transient outage. Hard-fail mode + // (graph_required: true) is out of scope for v0.11. + } + + id.Source = ProviderName // overwrite oidc's "oidc" stamp + return id, nil +} + +// resolveIssuer picks the issuer URL passed to the composed OIDC +// provider. +// +// For SINGLE-TENANT (AllowMultiTenant=false): the full per-tenant +// authority URL. oidc.Provider's iss-equality check is in force, AND +// Verify() additionally enforces tid == TenantID. Double-gate. +// +// For MULTI-TENANT (AllowMultiTenant=true): the "common" endpoint, +// which serves JWKS for all Entra tenants. oidc.Provider's iss check +// is suppressed via SkipIssuerCheck because "common"'s issuer template +// (`https://login.microsoftonline.com/{tenantid}/v2.0`) cannot be +// satisfied by string equality. Tenancy gating then depends on +// AllowedTenants: +// - non-empty list: Verify() enforces tid ∈ AllowedTenants +// - empty list: no tid check anywhere — ANY Entra tenant in the +// world is accepted ("any-tenant" mode, opted into +// by deliberately omitting the list) +func resolveIssuer(cfg Config) string { + if cfg.AllowMultiTenant { + return aadAuthorityBase + "/common/v2.0" + } + return fmt.Sprintf("%s/%s/v2.0", aadAuthorityBase, cfg.TenantID) +} + +// tenantInAllowlist reports whether tid is one of the configured +// AllowedTenants entries. Match is case-insensitive because Entra +// emits GUIDs in lowercase but operators commonly paste them in +// either case from the Azure portal. +func tenantInAllowlist(tid string, allowed []string) bool { + for _, a := range allowed { + if strings.EqualFold(tid, a) { + return true + } + } + return false +} + +// needsEnrichment returns true when Graph should be consulted. AAD +// emits no `groups` claim (or a `_claim_names` indicator) when a user +// is in too many groups. Phase 1 OIDC surfaces empty Groups in that +// case — treating empty as "enrich" catches it. +func needsEnrichment(groups []string) bool { + return len(groups) == 0 +} + +func (p *Provider) enrichGroups(ctx context.Context, id *auth.Identity, headers auth.Headers) ([]string, error) { + if cached, ok := p.cache.Get(id.UserID); ok { + return cached, nil + } + groups, err := p.graph.TransitiveMemberOf(ctx, id.UserID, headers.Get("Authorization")) + if err != nil { + return nil, err + } + p.cache.Put(id.UserID, groups) + return groups, nil +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/azure_ad/provider_test.go b/forge-core/auth/providers/azure_ad/provider_test.go new file mode 100644 index 0000000..fcdf7a1 --- /dev/null +++ b/forge-core/auth/providers/azure_ad/provider_test.go @@ -0,0 +1,540 @@ +package azure_ad + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" + "github.com/initializ/forge/forge-core/auth/providers/oidc" +) + +// fakeAAD is a tiny in-memory AAD: serves an OIDC discovery doc + RS256 +// JWKS, lets tests sign tokens with the matching key. +type fakeAAD struct { + t *testing.T + priv *rsa.PrivateKey + pub *rsa.PublicKey + kid string + srv *httptest.Server +} + +func newFakeAAD(t *testing.T) *fakeAAD { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + f := &fakeAAD{t: t, priv: priv, pub: &priv.PublicKey, kid: "kid-1"} + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", f.serveDiscovery) + mux.HandleFunc("/keys", f.serveJWKS) + f.srv = httptest.NewServer(mux) + t.Cleanup(f.srv.Close) + return f +} + +func (f *fakeAAD) issuerURL() string { return f.srv.URL } + +func (f *fakeAAD) serveDiscovery(w http.ResponseWriter, _ *http.Request) { + doc := map[string]any{ + "issuer": f.srv.URL, + "jwks_uri": f.srv.URL + "/keys", + } + _ = json.NewEncoder(w).Encode(doc) +} + +func (f *fakeAAD) serveJWKS(w http.ResponseWriter, _ *http.Request) { + n := base64.RawURLEncoding.EncodeToString(f.pub.N.Bytes()) + eBytes := make([]byte, 8) + binary.BigEndian.PutUint64(eBytes, uint64(f.pub.E)) + i := 0 + for i < len(eBytes)-1 && eBytes[i] == 0 { + i++ + } + e := base64.RawURLEncoding.EncodeToString(eBytes[i:]) + _ = json.NewEncoder(w).Encode(map[string]any{ + "keys": []map[string]any{ + {"kty": "RSA", "kid": f.kid, "alg": "RS256", "use": "sig", "n": n, "e": e}, + }, + }) +} + +func (f *fakeAAD) sign(claims jwt.MapClaims) string { + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = f.kid + str, err := tok.SignedString(f.priv) + if err != nil { + panic(err) + } + return str +} + +func validAADClaims(iss, tenant, audience string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": iss, + "aud": audience, + "tid": tenant, + "sub": "user-1", + "email": "alice@example.com", + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + "nbf": now.Unix(), + } +} + +// newWithIssuer builds an azure_ad Provider whose composed OIDC points +// at a test-supplied issuer URL (rather than the AAD authority host). +// Production code uses the AAD authority template via New(); this helper +// lets tests substitute a fakeAAD without piping a TEST_ONLY field +// through the public Config surface. +func newWithIssuer(cfg Config, issuer string) (*Provider, error) { + if cfg.Audience == "" { + return nil, fmt.Errorf("audience required") + } + if cfg.GroupsMode == "" { + cfg.GroupsMode = "claim" + } + if cfg.GraphTimeout == 0 { + cfg.GraphTimeout = defaultGraphTimeout + } + if cfg.JWKSCacheTTL == 0 { + cfg.JWKSCacheTTL = defaultJWKSCacheTTL + } + inner, err := oidc.New(oidc.Config{ + Issuer: issuer, + Audience: cfg.Audience, + JWKSCacheTTL: cfg.JWKSCacheTTL, + SkipIssuerCheck: cfg.AllowMultiTenant, + }) + if err != nil { + return nil, fmt.Errorf("compose oidc: %w", err) + } + p := &Provider{cfg: cfg, oidc: inner} + if cfg.GroupsMode == "graph" { + if cfg.GraphEndpoint != "" { + p.graph = NewGraphClientWithEndpoint(cfg.GraphEndpoint, cfg.GraphTimeout) + } else { + p.graph = NewGraphClient(cfg.GraphTimeout) + } + p.cache = NewGraphCache(defaultGraphCacheTTL) + } + return p, nil +} + +func newTestProviderAADSingleTenant(t *testing.T, f *fakeAAD, tid string) *Provider { + t.Helper() + p, err := newWithIssuer(Config{ + Audience: "api://forge", + TenantID: tid, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + return p +} + +func newTestProviderAADMultiTenant(t *testing.T, f *fakeAAD) *Provider { + t.Helper() + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + return p +} + +func newTestProviderAADGraphMode(t *testing.T, f *fakeAAD, tid, graphURL string) *Provider { + t.Helper() + p, err := newWithIssuer(Config{ + Audience: "api://forge", + TenantID: tid, + GroupsMode: "graph", + GraphEndpoint: graphURL, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + return p +} + +// --- Tests --- + +func TestFactory_RejectsMissingAudience(t *testing.T) { + if _, err := auth.Build("azure_ad", map[string]any{ + "tenant_id": "00000000-0000-0000-0000-000000000000", + }); err == nil { + t.Fatal("expected error when audience is missing") + } +} + +func TestFactory_RejectsMissingTenantUnlessMultiTenant(t *testing.T) { + if _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + }); err == nil { + t.Fatal("expected error when tenant_id missing and not multi-tenant") + } +} + +func TestFactory_AcceptsMultiTenantWithoutTenant(t *testing.T) { + if _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "allow_multi_tenant": true, + }); err != nil { + t.Errorf("expected success, got %v", err) + } +} + +func TestFactory_RejectsBadGroupsMode(t *testing.T) { + _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "tenant_id": "00000000-0000-0000-0000-000000000000", + "groups_mode": "bogus", + }) + if err == nil { + t.Fatal("expected error for invalid groups_mode") + } +} + +func TestProvider_HappyPath_ClaimMode(t *testing.T) { + f := newFakeAAD(t) + wantTID := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADSingleTenant(t, f, wantTID) + + claims := validAADClaims(f.issuerURL(), wantTID, "api://forge") + claims["groups"] = []string{"g1", "g2"} + tok := f.sign(claims) + + id, err := p.Verify(context.Background(), tok, auth.Headers{}) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "azure_ad" { + t.Errorf("Source = %q, want azure_ad (composition must overwrite oidc stamp)", id.Source) + } + if id.UserID != "user-1" { + t.Errorf("UserID = %q", id.UserID) + } + if len(id.Groups) != 2 { + t.Errorf("Groups = %v", id.Groups) + } +} + +func TestProvider_WrongTenant_Rejected(t *testing.T) { + f := newFakeAAD(t) + wantTID := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADSingleTenant(t, f, wantTID) + + claims := validAADClaims(f.issuerURL(), "22222222-2222-2222-2222-222222222222", "api://forge") + tok := f.sign(claims) + + _, err := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (tid mismatch)", err) + } +} + +func TestProvider_MultiTenant_AcceptsArbitraryTenant(t *testing.T) { + // Documented high-risk shape: AllowMultiTenant=true with NO + // AllowedTenants list means "accept any tenant globally." + f := newFakeAAD(t) + p := newTestProviderAADMultiTenant(t, f) + + claims := validAADClaims(f.issuerURL(), "any-tenant-uuid", "api://forge") + tok := f.sign(claims) + + if _, err := p.Verify(context.Background(), tok, auth.Headers{}); err != nil { + t.Errorf("expected multi-tenant success, got %v", err) + } +} + +// --- Review M3: AllowedTenants gate in multi-tenant mode --- + +func TestProvider_MultiTenant_AllowedTenants_AcceptsListed(t *testing.T) { + f := newFakeAAD(t) + wantTID := "33333333-3333-3333-3333-333333333333" + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{ + "22222222-2222-2222-2222-222222222222", + wantTID, + }, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), wantTID, "api://forge") + tok := f.sign(claims) + + if _, err := p.Verify(context.Background(), tok, auth.Headers{}); err != nil { + t.Errorf("expected success for listed tenant, got %v", err) + } +} + +func TestProvider_MultiTenant_AllowedTenants_RejectsUnlisted(t *testing.T) { + f := newFakeAAD(t) + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{"22222222-2222-2222-2222-222222222222"}, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), "deadbeef-dead-beef-dead-beefdeadbeef", "api://forge") + tok := f.sign(claims) + + _, vErr := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(vErr, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected (tid not in allowlist)", vErr) + } + if vErr != nil && !strings.Contains(vErr.Error(), "allowed_tenants") { + t.Errorf("err should mention allowed_tenants; got %v", vErr) + } +} + +func TestProvider_MultiTenant_AllowedTenants_CaseInsensitive(t *testing.T) { + // Entra emits lowercase tids; operators sometimes paste uppercase + // from the portal. Match should tolerate either. + f := newFakeAAD(t) + wantTID := "33333333-3333-3333-3333-333333333333" + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{"33333333-3333-3333-3333-333333333333"}, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), strings.ToUpper(wantTID), "api://forge") + tok := f.sign(claims) + + if _, err := p.Verify(context.Background(), tok, auth.Headers{}); err != nil { + t.Errorf("expected case-insensitive match, got %v", err) + } +} + +func TestProvider_MultiTenant_AllowedTenants_MissingTidRejected(t *testing.T) { + // Multi-tenant + AllowedTenants set + token has no tid → reject. + // In "any-tenant" mode (empty AllowedTenants) a missing tid would + // be allowed (we wouldn't be checking), but the moment the operator + // sets a list they want the gate enforced. + f := newFakeAAD(t) + p, err := newWithIssuer(Config{ + Audience: "api://forge", + AllowMultiTenant: true, + AllowedTenants: []string{"22222222-2222-2222-2222-222222222222"}, + }, f.issuerURL()) + if err != nil { + t.Fatalf("newWithIssuer: %v", err) + } + + claims := validAADClaims(f.issuerURL(), "ignored", "api://forge") + delete(claims, "tid") + tok := f.sign(claims) + + _, vErr := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(vErr, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (missing tid)", vErr) + } +} + +func TestProvider_SingleTenant_WithAllowedTenants_RejectedAtFactory(t *testing.T) { + // Config-level guard: AllowedTenants is meaningless in single-tenant + // mode (TenantID is THE gate). Reject the combo so a typo doesn't + // silently degrade. + _, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "tenant_id": "11111111-1111-1111-1111-111111111111", + "allowed_tenants": []any{"22222222-2222-2222-2222-222222222222"}, + }) + if err == nil { + t.Fatal("expected error: allowed_tenants requires allow_multi_tenant=true") + } +} + +func TestProvider_MissingTid_Invalid(t *testing.T) { + f := newFakeAAD(t) + p := newTestProviderAADSingleTenant(t, f, "11111111-1111-1111-1111-111111111111") + + claims := validAADClaims(f.issuerURL(), "ignored", "api://forge") + delete(claims, "tid") + tok := f.sign(claims) + + _, err := p.Verify(context.Background(), tok, auth.Headers{}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_GraphMode_EnrichesEmptyGroups(t *testing.T) { + f := newFakeAAD(t) + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"value":[{"id":"g1"},{"id":"g2"}]}`) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + // No groups claim → simulate overage. + tok := f.sign(claims) + + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if len(id.Groups) != 2 { + t.Errorf("Groups = %v, want 2 enriched ids", id.Groups) + } +} + +func TestProvider_GraphMode_SoftFailsOn5xx(t *testing.T) { + f := newFakeAAD(t) + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + tok := f.sign(claims) + + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Errorf("expected soft-fail (nil err), got %v", err) + } + if id != nil && len(id.Groups) != 0 { + t.Errorf("Groups should be empty after Graph 5xx, got %v", id.Groups) + } +} + +func TestProvider_GraphMode_401SoftFails(t *testing.T) { + // Even on Graph 401/403, the auth request itself proceeds — soft-fail + // keeps the Identity flowing with empty Groups. The graph-side error + // is surfaced separately if the operator wires audit hooks. Pinning + // this contract here so the behavior doesn't drift. + f := newFakeAAD(t) + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + tok := f.sign(claims) + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Errorf("Graph 401 should soft-fail at provider level, got err=%v", err) + } + if id == nil { + t.Fatal("Identity should be returned even on Graph 401") + } + if len(id.Groups) != 0 { + t.Errorf("Groups should be empty after Graph 401, got %v", id.Groups) + } +} + +func TestProvider_GraphMode_ClaimPresent_SkipsGraph(t *testing.T) { + // When the JWT carries a groups claim, we MUST NOT call Graph (avoid + // extra latency + unneeded Graph permission requirements). + f := newFakeAAD(t) + graphCalls := 0 + graph := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + graphCalls++ + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(graph.Close) + + tid := "11111111-1111-1111-1111-111111111111" + p := newTestProviderAADGraphMode(t, f, tid, graph.URL) + + claims := validAADClaims(f.issuerURL(), tid, "api://forge") + claims["groups"] = []string{"g-from-jwt"} + tok := f.sign(claims) + id, err := p.Verify(context.Background(), tok, auth.Headers{"Authorization": "Bearer " + tok}) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if graphCalls != 0 { + t.Errorf("Graph called %d times, want 0 (claim was present)", graphCalls) + } + if len(id.Groups) != 1 || id.Groups[0] != "g-from-jwt" { + t.Errorf("Groups = %v, want [g-from-jwt]", id.Groups) + } +} + +func TestProvider_RegisteredInRegistry(t *testing.T) { + p, err := auth.Build("azure_ad", map[string]any{ + "audience": "api://forge", + "allow_multi_tenant": true, + }) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "azure_ad" { + t.Errorf("Name = %q", p.Name()) + } +} + +func TestExtractTenantID(t *testing.T) { + cases := []struct { + in map[string]any + want string + }{ + {map[string]any{"tid": "abc"}, "abc"}, + {map[string]any{}, ""}, + {map[string]any{"tid": 123}, ""}, + {nil, ""}, + } + for _, tc := range cases { + if got := ExtractTenantID(tc.in); got != tc.want { + t.Errorf("ExtractTenantID(%v) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestNeedsEnrichment(t *testing.T) { + if !needsEnrichment(nil) { + t.Error("nil groups should need enrichment") + } + if !needsEnrichment([]string{}) { + t.Error("empty groups should need enrichment") + } + if needsEnrichment([]string{"g1"}) { + t.Error("populated groups should not need enrichment") + } +} + +func TestProvider_Name(t *testing.T) { + f := newFakeAAD(t) + p := newTestProviderAADMultiTenant(t, f) + if p.Name() != "azure_ad" { + t.Errorf("Name = %q", p.Name()) + } +} diff --git a/forge-core/auth/providers/gcp_iap/iap_jwks.go b/forge-core/auth/providers/gcp_iap/iap_jwks.go new file mode 100644 index 0000000..0038845 --- /dev/null +++ b/forge-core/auth/providers/gcp_iap/iap_jwks.go @@ -0,0 +1,380 @@ +package gcp_iap + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// IAPClaims is the projected claim set Forge needs from an IAP JWT. +// IAP-emitted JWTs include several other claims (hd, exp, iat, ...); +// callers can retrieve them from the raw token if needed. +type IAPClaims struct { + Issuer string + Audience []string + Subject string + Email string + HD string // Google Workspace domain (optional) +} + +// IAPJWKSCache fetches and caches the IAP public keys. ES256-only — +// dropping non-EC / non-P-256 / non-ES256-labeled keys during parse +// is a defense-in-depth layer on top of the algorithm whitelist in +// VerifyAndParse. +// +// Refresh model (mirrors Phase 1 OIDC review #1): +// - lastSuccessful tracks the TTL window; reuse cached keys within it. +// - lastAttempt + backoffDuration block fetch stampedes during outages. +// - Stale-grace: if backoff blocks fetch AND a key for the requested +// kid is in cache, return the stale key. IAP rotates keys on the +// order of weeks, so freshness matters less than availability. +type IAPJWKSCache struct { + url string + ttl time.Duration + http *http.Client + + mu sync.RWMutex + keys map[string]*ecdsa.PublicKey + lastSuccessful time.Time + lastAttempt time.Time + backoffDuration time.Duration +} + +// NewIAPJWKSCache builds an empty cache pointing at the given URL. +// +// CheckRedirect is pinned to ErrUseLastResponse. The IAP JWKS host is +// hardcoded (decision §9.4) precisely so we never trust any other source +// of public keys — auto-following a 3xx to a foreign URL would let a +// MITM / TLS-inspecting proxy / DNS-hijack scenario substitute attacker +// keys, after which any forged token signed by those keys would verify. +// Refuse redirects outright. +func NewIAPJWKSCache(url string, ttl, timeout time.Duration) *IAPJWKSCache { + return &IAPJWKSCache{ + url: url, + ttl: ttl, + http: &http.Client{ + Timeout: timeout, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + keys: map[string]*ecdsa.PublicKey{}, + } +} + +// VerifyAndParse validates the JWT signature against the cached IAP +// public keys and returns the projected claims. Standard claim checks +// (exp/iat/nbf) are performed by the JWT library; iss/aud are checked +// by the caller (Provider.Verify) against the operator's config. +func (j *IAPJWKSCache) VerifyAndParse(ctx context.Context, raw string) (*IAPClaims, error) { + tok, err := jwt.Parse(raw, func(t *jwt.Token) (any, error) { + // Algorithm whitelist: IAP signs with ES256 only. Any other alg + // is rejected BEFORE key lookup so algorithm-confusion attacks + // can't reach the JWKS. + if t.Method.Alg() != "ES256" { + return nil, fmt.Errorf("unexpected alg %q", t.Method.Alg()) + } + kid, _ := t.Header["kid"].(string) + if kid == "" { + return nil, errors.New("missing kid") + } + key, err := j.keyForKID(ctx, kid) + if err != nil { + return nil, err + } + return key, nil + }, jwt.WithValidMethods([]string{"ES256"})) + + if err != nil { + return nil, classifyJWTErr(err) + } + if !tok.Valid { + return nil, fmt.Errorf("%w: jwt.Valid=false", auth.ErrInvalidToken) + } + + // Extract claims via a small intermediate struct so we can handle + // the "aud as string OR array" shape (IAP currently uses string). + payload, err := json.Marshal(tok.Claims) + if err != nil { + return nil, fmt.Errorf("%w: marshal claims: %v", auth.ErrInvalidToken, err) + } + var rc struct { + Issuer string `json:"iss"` + Audience json.RawMessage `json:"aud"` + Subject string `json:"sub"` + Email string `json:"email"` + HD string `json:"hd"` + } + if err := json.Unmarshal(payload, &rc); err != nil { + return nil, fmt.Errorf("%w: unmarshal claims: %v", auth.ErrInvalidToken, err) + } + aud, err := parseAudience(rc.Audience) + if err != nil { + return nil, fmt.Errorf("%w: aud parse: %v", auth.ErrInvalidToken, err) + } + return &IAPClaims{ + Issuer: rc.Issuer, + Audience: aud, + Subject: rc.Subject, + Email: rc.Email, + HD: rc.HD, + }, nil +} + +// parseAudience handles "aud" being either a JSON string or an array. +// JWT spec allows either; IAP currently uses string. +func parseAudience(raw json.RawMessage) ([]string, error) { + if len(raw) == 0 { + return nil, errors.New("aud claim missing") + } + if raw[0] == '"' { + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return nil, err + } + return []string{s}, nil + } + var arr []string + if err := json.Unmarshal(raw, &arr); err != nil { + return nil, err + } + return arr, nil +} + +// keyForKID returns the ES256 public key for the given kid. Refreshes +// the JWKS in the foreground when the cache is stale or a kid isn't +// known, with backoff + stale-grace per the IAPJWKSCache contract. +func (j *IAPJWKSCache) keyForKID(ctx context.Context, kid string) (*ecdsa.PublicKey, error) { + j.mu.RLock() + cached, hit := j.keys[kid] + stale := time.Since(j.lastSuccessful) > j.ttl + j.mu.RUnlock() + + if hit && !stale { + return cached, nil + } + + if err := j.refresh(ctx); err != nil { + if hit { + // Stale-grace: keep using the cached key during outage. + return cached, nil + } + return nil, err + } + + j.mu.RLock() + k := j.keys[kid] + j.mu.RUnlock() + if k == nil { + return nil, fmt.Errorf("%w: kid %q not found in IAP JWKS", auth.ErrInvalidToken, kid) + } + return k, nil +} + +// refresh fetches and parses the JWKS. Backoff doubles on each failure +// (5s → 60s cap); resets on success. +func (j *IAPJWKSCache) refresh(ctx context.Context) error { + j.mu.Lock() + if !j.lastAttempt.IsZero() && time.Since(j.lastAttempt) < j.backoffDuration { + j.mu.Unlock() + return fmt.Errorf("%w: IAP JWKS in backoff", auth.ErrProviderUnavailable) + } + j.lastAttempt = time.Now() + j.mu.Unlock() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.url, nil) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS build request: %v", auth.ErrProviderUnavailable, err) + } + resp, err := j.http.Do(req) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS fetch: %v", auth.ErrProviderUnavailable, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS HTTP %d", auth.ErrProviderUnavailable, resp.StatusCode) + } + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 256<<10)) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS read: %v", auth.ErrProviderUnavailable, err) + } + + keys, err := parseECJWKSet(raw) + if err != nil { + j.bumpBackoff() + return fmt.Errorf("%w: IAP JWKS parse: %v", auth.ErrProviderUnavailable, err) + } + + // Merge-on-success rather than replace (Review NIT). A partial- + // but-valid JWKS (e.g. one freshly-rotated kid omitted by mistake) + // must not drop kids the stale-grace contract assumes we still + // have. New keys take precedence over old of the same kid; old + // keys that aren't in the new response are kept. + // + // Worst case: a stale key for a kid GCP has actually retired stays + // in our cache. JWT signature verification will fail naturally for + // any token signed with the retired private key, so this can't + // admit forged tokens — it just keeps verification working through + // JWKS-API hiccups. + j.mu.Lock() + if j.keys == nil { + j.keys = map[string]*ecdsa.PublicKey{} + } + for kid, k := range keys { + j.keys[kid] = k + } + j.lastSuccessful = time.Now() + j.backoffDuration = 0 + j.mu.Unlock() + return nil +} + +func (j *IAPJWKSCache) bumpBackoff() { + j.mu.Lock() + switch { + case j.backoffDuration == 0: + j.backoffDuration = 5 * time.Second + case j.backoffDuration < 60*time.Second: + j.backoffDuration *= 2 + } + j.mu.Unlock() +} + +// parseECJWKSet drops keys that aren't EC/P-256/ES256. Defense in depth +// against a compromised JWKS endpoint trying to slip in RSA keys for +// algorithm-confusion attacks. +func parseECJWKSet(raw []byte) (map[string]*ecdsa.PublicKey, error) { + var set struct { + Keys []struct { + Kid string `json:"kid"` + Kty string `json:"kty"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` + Alg string `json:"alg"` + } `json:"keys"` + } + if err := json.Unmarshal(raw, &set); err != nil { + return nil, err + } + out := map[string]*ecdsa.PublicKey{} + for _, k := range set.Keys { + if k.Kty != "EC" || k.Crv != "P-256" { + continue + } + if k.Alg != "" && k.Alg != "ES256" { + continue + } + x, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + continue + } + y, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + continue + } + out[k.Kid] = &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + } + } + if len(out) == 0 { + return nil, errors.New("IAP JWKS contained no usable ES256 keys") + } + return out, nil +} + +// classifyJWTErr maps golang-jwt errors to auth sentinels using v5's +// named sentinels via errors.Is rather than substring matching. The +// library's error wording can shift across patch releases; the sentinels +// are part of its public API and stable. (Review NIT.) +// +// Two-tier mapping: +// +// ErrInvalidToken (malformed / structurally wrong): +// - jwt.ErrTokenMalformed (bad base64, dot-count, etc.) +// - jwt.ErrTokenUnverifiable (no key found, alg mismatch +// detected in keyFunc) +// - jwt.ErrInvalidKey, jwt.ErrInvalidKeyType +// - keyFunc messages we emit ourselves ("unexpected alg", +// "missing kid", "not found in IAP JWKS") +// +// ErrTokenRejected (well-formed but cryptographically/temporally +// invalid — policy-denial shape): +// - jwt.ErrTokenSignatureInvalid +// - jwt.ErrTokenExpired +// - jwt.ErrTokenNotValidYet (nbf in future) +// - jwt.ErrTokenUsedBeforeIssued (iat in future) +// - jwt.ErrTokenInvalidClaims +// +// Default: ErrInvalidToken (conservative — unknown errors are +// classified as malformed, not as policy rejections). +func classifyJWTErr(err error) error { + if errors.Is(err, auth.ErrProviderUnavailable) || + errors.Is(err, auth.ErrInvalidToken) || + errors.Is(err, auth.ErrTokenRejected) { + return err + } + + // Special case: ErrTokenSignatureInvalid wraps BOTH (a) actual + // bad-signature failures (rejected) AND (b) alg-confusion errors + // where golang-jwt's WithValidMethods refused the token's alg + // before signing was even attempted. The latter is a malformed- + // shape failure (alg whitelist tripped), not a policy denial, so + // inspect the wrapped message to distinguish. + if errors.Is(err, jwt.ErrTokenSignatureInvalid) { + s := err.Error() + if strings.Contains(s, "signing method") { + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + return fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) + } + + switch { + case errors.Is(err, jwt.ErrTokenMalformed), + errors.Is(err, jwt.ErrTokenUnverifiable), + errors.Is(err, jwt.ErrInvalidKey), + errors.Is(err, jwt.ErrInvalidKeyType): + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + case errors.Is(err, jwt.ErrTokenExpired), + errors.Is(err, jwt.ErrTokenNotValidYet), + errors.Is(err, jwt.ErrTokenUsedBeforeIssued), + errors.Is(err, jwt.ErrTokenInvalidClaims): + return fmt.Errorf("%w: %v", auth.ErrTokenRejected, err) + } + + // keyFunc errors we emit ourselves get wrapped by jwt.Parse, but + // the unwrap chain doesn't preserve them as distinct sentinels. + // Fall back to substring matching for THESE messages only — they're + // strings WE control, not the library's. + s := err.Error() + switch { + case strings.Contains(s, "unexpected alg"), + strings.Contains(s, "missing kid"), + strings.Contains(s, "not found in IAP JWKS"): + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + + return fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) +} diff --git a/forge-core/auth/providers/gcp_iap/provider.go b/forge-core/auth/providers/gcp_iap/provider.go new file mode 100644 index 0000000..3fa8815 --- /dev/null +++ b/forge-core/auth/providers/gcp_iap/provider.go @@ -0,0 +1,147 @@ +// Package gcp_iap authenticates requests that come through GCP's +// Identity-Aware Proxy. IAP terminates user authentication at the +// HTTPS load balancer and forwards a signed JWT in the +// X-Goog-Iap-Jwt-Assertion header on every authenticated request. +// Forge verifies that JWT against IAP's well-known JWKS and stamps +// an Identity from the claims. +// +// Decision §9.4: the IAP issuer ("https://cloud.google.com/iap") +// and JWKS URL ("https://www.gstatic.com/iap/verify/public_key-jwk") +// are HARDCODED. They are the only stable contract GCP exposes. +// Any override knob would be a footgun (operators could be tricked +// into trusting an attacker's JWKS). +// +// Decision §9.5: reads X-Goog-Iap-Jwt-Assertion from the widened +// Headers map (PR 1). +// +// Audit reason codes (Phase 1 contract): +// +// rejected — iss/aud mismatch, expired, bad signature +// invalid — alg != ES256, missing sub/email, malformed +// provider_unavailable — JWKS fetch failed AND no prior key cached +// not_for_me — header absent → next provider +package gcp_iap + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/initializ/forge/forge-core/auth" +) + +// ProviderName is the registry name. +const ProviderName = "gcp_iap" + +// Hardcoded IAP endpoints (decision §9.4). +const ( + iapIssuer = "https://cloud.google.com/iap" + iapJWKSURL = "https://www.gstatic.com/iap/verify/public_key-jwk" +) + +const ( + defaultJWKSRefreshTTL = time.Hour + defaultHTTPTimeout = 5 * time.Second +) + +// Config controls the gcp_iap provider. +type Config struct { + // Audience is REQUIRED. It is the GCP backend service ID, + // shaped like "/projects/PROJECT_NUMBER/global/backendServices/BACKEND_ID". + // Operators find this in the GCP console under + // Security → Identity-Aware Proxy → Backend Service → "Signed Header JWT Audience". + Audience string `yaml:"audience"` + + // JWKSRefreshTTL bounds how long a cached JWKS is reused before a + // background refresh. Default 1h — IAP rotates keys slowly. + JWKSRefreshTTL time.Duration `yaml:"jwks_refresh_ttl,omitempty"` + + // HTTPTimeout caps the JWKS fetch. Default 5s. + HTTPTimeout time.Duration `yaml:"http_timeout,omitempty"` +} + +// Validate returns ErrProviderNotConfigured when required fields are missing. +func (c Config) Validate() error { + if c.Audience == "" { + return fmt.Errorf("%w: audience required (e.g. /projects/PNUM/global/backendServices/BACKEND_ID)", auth.ErrProviderNotConfigured) + } + return nil +} + +// Provider implements auth.Provider for GCP IAP-fronted callers. +type Provider struct { + cfg Config + jwks *IAPJWKSCache +} + +// New constructs a Provider after validating cfg. +func New(cfg Config) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.JWKSRefreshTTL == 0 { + cfg.JWKSRefreshTTL = defaultJWKSRefreshTTL + } + if cfg.HTTPTimeout == 0 { + cfg.HTTPTimeout = defaultHTTPTimeout + } + return &Provider{ + cfg: cfg, + jwks: NewIAPJWKSCache(iapJWKSURL, cfg.JWKSRefreshTTL, cfg.HTTPTimeout), + }, nil +} + +// Name implements auth.Provider. +func (p *Provider) Name() string { return ProviderName } + +// Verify implements auth.Provider. +// +// Token is unused — the IAP assertion lives in the +// X-Goog-Iap-Jwt-Assertion header, which the middleware delivers +// via the headers map (PR 1). +func (p *Provider) Verify(ctx context.Context, _ string, headers auth.Headers) (*auth.Identity, error) { + raw := headers.Get("X-Goog-Iap-Jwt-Assertion") + if raw == "" { + return nil, auth.ErrTokenNotForMe + } + + claims, err := p.jwks.VerifyAndParse(ctx, raw) + if err != nil { + return nil, err // already wrapped with the right sentinel + } + + if claims.Issuer != iapIssuer { + return nil, fmt.Errorf("%w: iss=%q, want %q", auth.ErrTokenRejected, claims.Issuer, iapIssuer) + } + if !slices.Contains(claims.Audience, p.cfg.Audience) { + return nil, fmt.Errorf("%w: aud mismatch", auth.ErrTokenRejected) + } + if claims.Subject == "" || claims.Email == "" { + // IAP always sets both. Absence implies a malformed/stripped + // token, not a policy denial — return ErrInvalidToken so the + // audit log distinguishes the two cases. + return nil, fmt.Errorf("%w: IAP token missing sub or email", auth.ErrInvalidToken) + } + + return &auth.Identity{ + UserID: claims.Subject, + Email: claims.Email, + Source: ProviderName, + Claims: map[string]any{ + "sub": claims.Subject, + "email": claims.Email, + "hd": claims.HD, + }, + }, nil +} + +func init() { + auth.Register(ProviderName, func(settings map[string]any) (auth.Provider, error) { + var cfg Config + if err := auth.UnmarshalSettings(settings, &cfg); err != nil { + return nil, err + } + return New(cfg) + }) +} diff --git a/forge-core/auth/providers/gcp_iap/provider_test.go b/forge-core/auth/providers/gcp_iap/provider_test.go new file mode 100644 index 0000000..eb74841 --- /dev/null +++ b/forge-core/auth/providers/gcp_iap/provider_test.go @@ -0,0 +1,477 @@ +package gcp_iap + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "io" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/initializ/forge/forge-core/auth" +) + +// --- test signer harness --- + +// es256Signer is a self-contained ES256 signer + JWKS server used by all +// provider/JWKS tests. Lives here (not in a separate test-harness file) +// to keep PR 3 self-contained. +type es256Signer struct { + priv *ecdsa.PrivateKey + pub *ecdsa.PublicKey + kid string + jwksMu *http.ServeMux + srv *httptest.Server +} + +func newES256Signer(t *testing.T) *es256Signer { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("ecdsa key: %v", err) + } + s := &es256Signer{ + priv: priv, + pub: &priv.PublicKey, + kid: "test-kid-1", + } + mux := http.NewServeMux() + mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) { + s.writeJWKS(w) + }) + s.jwksMu = mux + s.srv = httptest.NewServer(mux) + t.Cleanup(s.srv.Close) + return s +} + +func (s *es256Signer) writeJWKS(w http.ResponseWriter) { + x := base64.RawURLEncoding.EncodeToString(s.pub.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(s.pub.Y.Bytes()) + doc := map[string]any{ + "keys": []map[string]any{ + { + "kty": "EC", + "crv": "P-256", + "kid": s.kid, + "alg": "ES256", + "x": x, + "y": y, + }, + }, + } + _ = json.NewEncoder(w).Encode(doc) +} + +func (s *es256Signer) sign(claims map[string]any) string { + tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims(claims)) + tok.Header["kid"] = s.kid + str, err := tok.SignedString(s.priv) + if err != nil { + panic(err) + } + return str +} + +func (s *es256Signer) URL() string { return s.srv.URL + "/jwks" } + +// --- helpers --- + +func newProviderPointingAt(t *testing.T, signer *es256Signer, audience string) *Provider { + t.Helper() + p, err := New(Config{ + Audience: audience, + JWKSRefreshTTL: time.Hour, + HTTPTimeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + // Swap the JWKS cache for one pointed at our test signer instead + // of the real IAP URL. This is the per-package equivalent of the + // aws_sigv4 sts_endpoint override. + p.jwks = NewIAPJWKSCache(signer.URL(), time.Hour, 5*time.Second) + return p +} + +func validClaims(audience string) map[string]any { + now := time.Now() + return map[string]any{ + "iss": iapIssuer, + "aud": audience, + "sub": "1234567890", + "email": "alice@example.com", + "hd": "example.com", + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + } +} + +// --- Tests --- + +func TestProvider_NoIAPHeader_YieldsToChain(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + _, err := p.Verify(context.Background(), "", auth.Headers{"Authorization": "Bearer foo"}) + if !errors.Is(err, auth.ErrTokenNotForMe) { + t.Errorf("err = %v, want ErrTokenNotForMe", err) + } +} + +func TestProvider_HappyPath(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + tok := signer.sign(validClaims("test-aud")) + id, err := p.Verify(context.Background(), "", auth.Headers{ + "X-Goog-Iap-Jwt-Assertion": tok, + }) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if id.Source != "gcp_iap" { + t.Errorf("Source = %q", id.Source) + } + if id.UserID != "1234567890" { + t.Errorf("UserID = %q", id.UserID) + } + if id.Email != "alice@example.com" { + t.Errorf("Email = %q", id.Email) + } + if id.Claims["hd"] != "example.com" { + t.Errorf("Claims[hd] = %v", id.Claims["hd"]) + } +} + +func TestProvider_WrongIssuer_Rejected(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + claims := validClaims("test-aud") + claims["iss"] = "https://accounts.google.com" // common bug: regular Google token, not IAP + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_WrongAudience_Rejected(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "expected-aud") + + tok := signer.sign(validClaims("different-aud")) + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_AudienceAsArray(t *testing.T) { + // JWT spec allows aud as []string. Verify we handle that shape too — + // IAP currently uses string but the contract is broader. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "wanted-aud") + claims := validClaims("placeholder") + claims["aud"] = []string{"other", "wanted-aud"} + tok := signer.sign(claims) + if _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}); err != nil { + t.Errorf("expected success with aud=[]string, got %v", err) + } +} + +func TestProvider_MissingSub_Invalid(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + claims := validClaims("test-aud") + delete(claims, "sub") + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_MissingEmail_Invalid(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + claims := validClaims("test-aud") + delete(claims, "email") + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken", err) + } +} + +func TestProvider_ExpiredToken_Rejected(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + claims := validClaims("test-aud") + claims["exp"] = time.Now().Add(-time.Minute).Unix() + tok := signer.sign(claims) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}) + if !errors.Is(err, auth.ErrTokenRejected) { + t.Errorf("err = %v, want ErrTokenRejected", err) + } +} + +func TestProvider_RS256Token_Rejected(t *testing.T) { + // Algorithm-confusion defense: an RS256 token MUST be rejected + // BEFORE key lookup — gcp_iap accepts ES256 only. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + rsaPriv, _ := rsa.GenerateKey(rand.Reader, 2048) + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(validClaims("test-aud"))) + tok.Header["kid"] = signer.kid + str, _ := tok.SignedString(rsaPriv) + + _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": str}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (alg whitelist)", err) + } +} + +func TestProvider_HS256WithECPublicKeyAsSecret_Rejected(t *testing.T) { + // The most dangerous algorithm-confusion shape (Review NIT): attacker + // takes the verifier's PUBLIC key (which the JWKS endpoint publishes + // openly), uses its raw bytes as the HMAC secret to sign an HS256 + // token, and submits it. A verifier that doesn't whitelist the alg + // would happily HMAC-verify the token against the same "secret" = + // the public key bytes. Our keyFunc rejects on alg!="ES256" BEFORE + // even looking up the key — pin that explicitly. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + // Use the signer's public EC point coordinates as the HMAC "secret." + // The attacker only needs the public key, which they can fetch from + // the JWKS endpoint anonymously. + pubBytes := append(signer.pub.X.Bytes(), signer.pub.Y.Bytes()...) + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(validClaims("test-aud"))) + tok.Header["kid"] = signer.kid + str, err := tok.SignedString(pubBytes) + if err != nil { + t.Fatalf("sign HS256: %v", err) + } + + _, err = p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": str}) + if !errors.Is(err, auth.ErrInvalidToken) { + t.Errorf("err = %v, want ErrInvalidToken (HS256-with-public-key alg confusion)", err) + } +} + +func TestFactory_RejectsMissingAudience(t *testing.T) { + if _, err := auth.Build("gcp_iap", map[string]any{}); err == nil { + t.Fatal("expected error when audience is missing") + } +} + +func TestProvider_Name(t *testing.T) { + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + if p.Name() != "gcp_iap" { + t.Errorf("Name = %q", p.Name()) + } +} + +func TestProvider_RegisteredInRegistry(t *testing.T) { + p, err := auth.Build("gcp_iap", map[string]any{"audience": "test"}) + if err != nil { + t.Fatalf("Build: %v", err) + } + if p.Name() != "gcp_iap" { + t.Errorf("Name = %q", p.Name()) + } +} + +// --- JWKS cache tests (use a controllable mux, not the full signer) --- + +func TestJWKSCache_StaleGraceOnOutage(t *testing.T) { + // 1. Bring signer up, prime cache via successful verify. + // 2. Take JWKS endpoint down. + // 3. Verify same token still works — stale-grace. + signer := newES256Signer(t) + p := newProviderPointingAt(t, signer, "test-aud") + + tok := signer.sign(validClaims("test-aud")) + if _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}); err != nil { + t.Fatalf("priming Verify: %v", err) + } + + // Take JWKS down by closing the server. + signer.srv.Close() + // Mark cache stale so we go through refresh path. + p.jwks.mu.Lock() + p.jwks.lastSuccessful = time.Now().Add(-2 * time.Hour) + p.jwks.mu.Unlock() + + if _, err := p.Verify(context.Background(), "", auth.Headers{"X-Goog-Iap-Jwt-Assertion": tok}); err != nil { + t.Errorf("stale-grace failed: %v", err) + } +} + +func TestJWKSCache_DoesNotFollowRedirects(t *testing.T) { + // Review (sibling of B2/B3): the IAP JWKS URL is hardcoded (§9.4) + // precisely so we never trust an alternate key source. Auto- + // following a 302 to attacker bytes would let an attacker substitute + // their own keys — any token forged with those keys would then + // verify. Pin: any 3xx is treated as JWKS-unavailable (we never + // follow). + var hits int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + hits++ + w.Header().Set("Location", "https://attacker.example.com/") + w.WriteHeader(http.StatusFound) + })) + defer srv.Close() + + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + err := c.refresh(context.Background()) + if err == nil { + t.Fatal("expected error on 302; client must not follow") + } + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if hits != 1 { + t.Errorf("JWKS endpoint hit %d times, want 1 (redirect was followed)", hits) + } +} + +func TestJWKSCache_BackoffBumps(t *testing.T) { + // Endpoint always 500 — observe backoffDuration grow on each refresh. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + if err := c.refresh(context.Background()); err == nil { + t.Fatal("expected first refresh to fail") + } + c.mu.RLock() + first := c.backoffDuration + c.mu.RUnlock() + if first != 5*time.Second { + t.Errorf("backoff after 1st failure = %v, want 5s", first) + } + + // Force a second attempt past the backoff window. + c.mu.Lock() + c.lastAttempt = time.Now().Add(-10 * time.Second) + c.mu.Unlock() + if err := c.refresh(context.Background()); err == nil { + t.Fatal("expected second refresh to fail") + } + c.mu.RLock() + second := c.backoffDuration + c.mu.RUnlock() + if second != 10*time.Second { + t.Errorf("backoff after 2nd failure = %v, want 10s", second) + } +} + +func TestJWKSCache_BackoffBlocksRefresh(t *testing.T) { + // During backoff, refresh() returns ErrProviderUnavailable without + // attempting a network call. + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls++ + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + _ = c.refresh(context.Background()) // failure -> backoff + if calls != 1 { + t.Fatalf("calls after first refresh = %d", calls) + } + err := c.refresh(context.Background()) // should NOT call network + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } + if calls != 1 { + t.Errorf("network was hit during backoff (calls = %d)", calls) + } +} + +func TestJWKSCache_BodyCap(t *testing.T) { + // Serve 1 MiB of garbage; parse must fail without OOM. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // 256 KiB cap + 1 byte → triggers either truncated parse fail or + // LimitReader cut. + buf := make([]byte, 512<<10) + _, _ = w.Write(buf) + })) + defer srv.Close() + c := NewIAPJWKSCache(srv.URL, time.Hour, 5*time.Second) + err := c.refresh(context.Background()) + if !errors.Is(err, auth.ErrProviderUnavailable) { + t.Errorf("err = %v, want ErrProviderUnavailable", err) + } +} + +func TestParseECJWKSet_DropsNonES256(t *testing.T) { + // Build a JWKS with one valid EC key and one RSA-shaped entry. + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + x := base64.RawURLEncoding.EncodeToString(priv.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(priv.Y.Bytes()) + doc := map[string]any{ + "keys": []map[string]any{ + {"kty": "EC", "crv": "P-256", "kid": "good", "alg": "ES256", "x": x, "y": y}, + {"kty": "RSA", "kid": "bad", "alg": "RS256", "n": "ignored", "e": "ignored"}, + {"kty": "EC", "crv": "P-256", "kid": "wrong-alg", "alg": "RS256", "x": x, "y": y}, + }, + } + raw, _ := json.Marshal(doc) + keys, err := parseECJWKSet(raw) + if err != nil { + t.Fatalf("parseECJWKSet: %v", err) + } + if _, ok := keys["good"]; !ok { + t.Error("good key dropped") + } + if _, ok := keys["bad"]; ok { + t.Error("RSA key not dropped") + } + if _, ok := keys["wrong-alg"]; ok { + t.Error("EC key labeled RS256 not dropped") + } +} + +func TestParseAudience_StringAndArray(t *testing.T) { + // "aud":"x" + one, err := parseAudience(json.RawMessage(`"x"`)) + if err != nil || len(one) != 1 || one[0] != "x" { + t.Errorf("string aud: got %v, %v", one, err) + } + // "aud":["x","y"] + two, err := parseAudience(json.RawMessage(`["x","y"]`)) + if err != nil || len(two) != 2 { + t.Errorf("array aud: got %v, %v", two, err) + } + // missing + if _, err := parseAudience(nil); err == nil { + t.Error("missing aud should error") + } +} + +// ensure import not flagged +var ( + _ = io.EOF + _ = big.NewInt +) diff --git a/forge-core/auth/providers/oidc/provider.go b/forge-core/auth/providers/oidc/provider.go index 47c985d..67fa509 100644 --- a/forge-core/auth/providers/oidc/provider.go +++ b/forge-core/auth/providers/oidc/provider.go @@ -86,6 +86,19 @@ type Config struct { // HTTPClient overrides the default client. Injectable for tests. HTTPClient *http.Client `yaml:"-"` + + // SkipIssuerCheck disables the iss-claim equality check. INTERNAL — + // the yaml:"-" tag means this CANNOT be set via forge.yaml; it is + // only reachable when another Go package constructs oidc.Config + // directly (currently only azure_ad's multi-tenant mode). + // + // Reason it exists: AAD's "common" / multi-tenant issuer template + // uses a per-token tenant ID that string-equality can't satisfy. The + // caller (azure_ad) takes responsibility for tenant enforcement via + // the tid claim instead. Surfacing this in forge.yaml would let + // operators disable iss validation by accident — which is exactly + // the "open verifier" footgun this package is designed to prevent. + SkipIssuerCheck bool `yaml:"-"` } // Validate returns ErrProviderNotConfigured when required fields are @@ -255,8 +268,15 @@ func (p *Provider) Verify(ctx context.Context, tokenStr string, headers auth.Hea // IdPs that emit iss with/without a trailing slash interop with // configs that use the opposite form). See Config.normalize and // review finding #2. - if err := p.checkIssuer(claims); err != nil { - return nil, err + // + // SkipIssuerCheck is INTERNAL — only set when another Go package + // (currently azure_ad multi-tenant) takes responsibility for + // tenant/issuer enforcement via a different claim. Never reachable + // via forge.yaml — the field carries a yaml:"-" tag. + if !p.cfg.SkipIssuerCheck { + if err := p.checkIssuer(claims); err != nil { + return nil, err + } } // Audience validation (with azp fallback). diff --git a/forge-core/security/auth_domains.go b/forge-core/security/auth_domains.go index b96b1e0..a5b5fd1 100644 --- a/forge-core/security/auth_domains.go +++ b/forge-core/security/auth_domains.go @@ -55,8 +55,31 @@ func authProviderURLs(p types.AuthProvider) []string { return []string{ settingString(p.Settings, "url"), } + case "aws_sigv4": + // STS at sts..amazonaws.com is the only outbound. + // The test-only sts_endpoint override is honored so dev/test + // runs against a local fake aren't blocked by egress. + region := settingString(p.Settings, "region") + out := []string{} + if region != "" { + out = append(out, "https://sts."+region+".amazonaws.com") + } + if override := settingString(p.Settings, "sts_endpoint"); override != "" { + out = append(out, override) + } + return out + case "gcp_iap": + // Decision §9.4: IAP JWKS host is hardcoded. + return []string{"https://www.gstatic.com/iap/verify/public_key-jwk"} + case "azure_ad": + // AAD authority host is fixed (login.microsoftonline.com). + // Graph host added ONLY when groups_mode=graph. + out := []string{"https://login.microsoftonline.com"} + if mode, _ := p.Settings["groups_mode"].(string); mode == "graph" { + out = append(out, "https://graph.microsoft.com") + } + return out // static_token has no outbound; not listed - // okta (Phase 3) will be added here with issuer + api domain default: return nil } diff --git a/forge-core/security/auth_domains_test.go b/forge-core/security/auth_domains_test.go index c7a1616..4f8e40b 100644 --- a/forge-core/security/auth_domains_test.go +++ b/forge-core/security/auth_domains_test.go @@ -124,6 +124,148 @@ func TestAuthDomains_PortStripped(t *testing.T) { } } +func TestAuthDomains_AWSSigv4(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{"region": "us-east-1"}}, + }, + }) + want := []string{"sts.us-east-1.amazonaws.com"} + if len(got) != 1 || got[0] != want[0] { + t.Errorf("AuthDomains = %v, want %v", got, want) + } +} + +func TestAuthDomains_AWSSigv4_DifferentRegion(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{"region": "eu-west-2"}}, + }, + }) + if len(got) != 1 || got[0] != "sts.eu-west-2.amazonaws.com" { + t.Errorf("AuthDomains = %v, want [sts.eu-west-2.amazonaws.com]", got) + } +} + +func TestAuthDomains_AWSSigv4_TestEndpointOverride(t *testing.T) { + // The sts_endpoint override (test-only escape hatch) must surface in + // the egress allowlist too — otherwise local integration tests are + // blocked by the egress enforcer. + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{ + "region": "us-east-1", + "sts_endpoint": "http://127.0.0.1:8080", + }}, + }, + }) + have := map[string]bool{} + for _, d := range got { + have[d] = true + } + if !have["sts.us-east-1.amazonaws.com"] { + t.Errorf("AuthDomains missing real STS host: %v", got) + } + if !have["127.0.0.1"] { + t.Errorf("AuthDomains missing test override host: %v", got) + } +} + +func TestAuthDomains_AWSSigv4_MissingRegionReturnsEmpty(t *testing.T) { + // Defensive: even though Factory rejects missing region at startup, + // AuthDomains should not panic or emit a malformed host if it's ever + // called with an incomplete config. + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "aws_sigv4", Settings: map[string]any{}}, + }, + }) + if got != nil { + t.Errorf("AuthDomains with missing region = %v, want nil", got) + } +} + +func TestAuthDomains_GCPIAP(t *testing.T) { + // Decision §9.4: IAP JWKS host is hardcoded — same domain returned + // regardless of audience or any other config. + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "gcp_iap", Settings: map[string]any{ + "audience": "/projects/12345/global/backendServices/67890", + }}, + }, + }) + if len(got) != 1 || got[0] != "www.gstatic.com" { + t.Errorf("AuthDomains = %v, want [www.gstatic.com]", got) + } +} + +func TestAuthDomains_GCPIAP_MultipleEntriesDedup(t *testing.T) { + // Even with two IAP entries, the host appears once (dedup). + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "gcp_iap", Settings: map[string]any{"audience": "a"}}, + {Type: "gcp_iap", Settings: map[string]any{"audience": "b"}}, + }, + }) + if len(got) != 1 || got[0] != "www.gstatic.com" { + t.Errorf("AuthDomains = %v, want [www.gstatic.com] (dedup)", got) + } +} + +func TestAuthDomains_AzureAD_ClaimMode(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "azure_ad", Settings: map[string]any{ + "tenant_id": "00000000-...", + "audience": "api://forge", + }}, + }, + }) + if len(got) != 1 || got[0] != "login.microsoftonline.com" { + t.Errorf("AuthDomains = %v, want [login.microsoftonline.com]", got) + } +} + +func TestAuthDomains_AzureAD_GraphModeAddsGraphHost(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "azure_ad", Settings: map[string]any{ + "tenant_id": "00000000-...", + "audience": "api://forge", + "groups_mode": "graph", + }}, + }, + }) + have := map[string]bool{} + for _, d := range got { + have[d] = true + } + if !have["login.microsoftonline.com"] { + t.Errorf("missing login.microsoftonline.com: %v", got) + } + if !have["graph.microsoft.com"] { + t.Errorf("groups_mode=graph should add graph.microsoft.com: %v", got) + } +} + +func TestAuthDomains_AzureAD_ClaimMode_DoesNotAddGraph(t *testing.T) { + got := security.AuthDomains(types.AuthConfig{ + Providers: []types.AuthProvider{ + {Type: "azure_ad", Settings: map[string]any{ + "tenant_id": "00000000-...", + "audience": "api://forge", + "groups_mode": "claim", + }}, + }, + }) + for _, d := range got { + if d == "graph.microsoft.com" { + t.Errorf("claim mode should NOT add graph.microsoft.com: %v", got) + } + } +} + func TestAuthDomains_UnknownProviderTypeReturnsEmpty(t *testing.T) { got := security.AuthDomains(types.AuthConfig{ Providers: []types.AuthProvider{ diff --git a/forge-core/validate/auth.go b/forge-core/validate/auth.go index 78ead5b..4e3c4e4 100644 --- a/forge-core/validate/auth.go +++ b/forge-core/validate/auth.go @@ -17,7 +17,97 @@ var knownAuthProviderTypes = map[string]bool{ "http_verifier": true, "static_token": true, "oidc": true, - // "okta": true, // Phase 3 (v0.11.0) + // Phase 2 (v0.11.0): + "aws_sigv4": true, + "gcp_iap": true, + "azure_ad": true, +} + +// KnownAuthProviderSettings is the closed set of YAML keys each provider +// type accepts. Mirrors the `yaml:` tags on each provider's Config struct; +// must be kept in sync when new fields are added. Internal-only struct +// fields (those carrying `yaml:"-"`) are intentionally absent — they +// can only be set by another Go package, not via forge.yaml or the +// Web UI's create-payload. +// +// Two callers consume this map: +// +// 1. ValidateAuthConfig emits a *warning* per unknown key during +// `forge validate`, so a typo like `aud:` instead of `audience:` +// gets surfaced loudly. +// +// 2. The Web UI handler (forge-ui handlers_create.go) filters its +// incoming Settings map through FilterKnownSettings before +// forwarding to scaffold — closing the exploit chain where a +// malicious POST `{"settings": {"audience": "x", "evil": "y"}}` +// would otherwise drop `evil:` into forge.yaml verbatim. +var KnownAuthProviderSettings = map[string]map[string]bool{ + "http_verifier": { + "url": true, + "default_org": true, + "timeout": true, + }, + "static_token": { + "token": true, + "token_env": true, + }, + "oidc": { + "issuer": true, + "audience": true, + "client_id": true, + "jwks_url": true, + "jwks_cache_ttl": true, + "clock_skew": true, + "claim_map": true, + }, + "aws_sigv4": { + "region": true, + "audience": true, + "allowed_principals": true, + "allowed_accounts": true, + "identity_cache_ttl": true, + "sts_endpoint": true, // documented test override, intentionally YAML-reachable + "http_timeout": true, + "max_token_expires": true, + "clock_skew": true, + }, + "gcp_iap": { + "audience": true, + "jwks_refresh_ttl": true, + "http_timeout": true, + }, + "azure_ad": { + "tenant_id": true, + "audience": true, + "allow_multi_tenant": true, + "allowed_tenants": true, + "groups_mode": true, + "graph_timeout": true, + "jwks_cache_ttl": true, + // graph_endpoint intentionally omitted — yaml:"-" on the Config field + }, +} + +// FilterKnownSettings returns a copy of settings with any keys not in +// the whitelist for providerType dropped. Use this at the boundary +// between untrusted input (Web UI POST) and persistence (forge.yaml +// scaffold) so unknown keys never reach disk. +// +// For unknown providerType (returns nil from the whitelist lookup), the +// input is passed through unchanged — let the ValidateAuthConfig +// "unknown type" error catch that case instead. +func FilterKnownSettings(providerType string, settings map[string]any) map[string]any { + known, ok := KnownAuthProviderSettings[providerType] + if !ok { + return settings + } + out := make(map[string]any, len(settings)) + for k, v := range settings { + if known[k] { + out[k] = v + } + } + return out } // ValidateAuthConfig adds errors and warnings for a forge.yaml auth: block. @@ -39,7 +129,7 @@ func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { continue } if !knownAuthProviderTypes[p.Type] { - r.Errors = append(r.Errors, fmt.Sprintf("%s: unknown type %q (known: http_verifier, static_token, oidc)", prefix, p.Type)) + r.Errors = append(r.Errors, fmt.Sprintf("%s: unknown type %q (known: http_verifier, static_token, oidc, aws_sigv4, gcp_iap, azure_ad)", prefix, p.Type)) continue } @@ -59,6 +149,20 @@ func ValidateAuthConfig(cfg types.AuthConfig, r *ValidationResult) { // runs at runtime construction) so `forge validate` catches errors before // `forge run`. func validateProviderSettings(prefix string, p types.AuthProvider, r *ValidationResult) { + // Warn on any keys the provider doesn't recognize. Loose vs. error + // because some operators stash custom annotations (legacy practice + // from pre-Phase-2 configs); the Web UI handler additionally filters + // at write-time so the actual scaffold-poisoning chain is closed + // there (see forge-ui/handlers_create.go). + if known, ok := KnownAuthProviderSettings[p.Type]; ok { + for k := range p.Settings { + if !known[k] { + r.Warnings = append(r.Warnings, + fmt.Sprintf("%s (%s): unknown settings key %q — typo, or a key from a future provider version?", prefix, p.Type, k)) + } + } + } + switch p.Type { case "http_verifier": if asString(p.Settings, "url") == "" { @@ -78,6 +182,59 @@ func validateProviderSettings(prefix string, p types.AuthProvider, r *Validation if asString(p.Settings, "audience") == "" { r.Errors = append(r.Errors, prefix+" (oidc): settings.audience is required") } + case "aws_sigv4": + if asString(p.Settings, "region") == "" { + r.Errors = append(r.Errors, prefix+" (aws_sigv4): settings.region is required") + } + // allowed_accounts entries must be 12-digit AWS account IDs. + // Catches typos at validate-time so a misconfig doesn't silently + // become an unreachable pattern. + if accts, ok := p.Settings["allowed_accounts"].([]any); ok { + for i, raw := range accts { + s, _ := raw.(string) + if len(s) != 12 || !isAllDigits(s) { + r.Errors = append(r.Errors, + fmt.Sprintf("%s (aws_sigv4): allowed_accounts[%d]=%q must be a 12-digit AWS account ID", prefix, i, s)) + } + } + } + case "gcp_iap": + if asString(p.Settings, "audience") == "" { + r.Errors = append(r.Errors, prefix+" (gcp_iap): settings.audience is required (GCP backend service ID)") + } + case "azure_ad": + if asString(p.Settings, "audience") == "" { + r.Errors = append(r.Errors, prefix+" (azure_ad): settings.audience is required") + } + // tenant_id may be omitted ONLY when allow_multi_tenant is true. + multi, _ := p.Settings["allow_multi_tenant"].(bool) + if !multi && asString(p.Settings, "tenant_id") == "" { + r.Errors = append(r.Errors, prefix+" (azure_ad): settings.tenant_id is required unless allow_multi_tenant=true") + } + // allowed_tenants only makes sense with multi-tenant. + hasAllowed := false + switch v := p.Settings["allowed_tenants"].(type) { + case []any: + hasAllowed = len(v) > 0 + case []string: + hasAllowed = len(v) > 0 + } + if !multi && hasAllowed { + r.Errors = append(r.Errors, + prefix+" (azure_ad): allowed_tenants is only meaningful when allow_multi_tenant=true") + } + // "Any-tenant mode" warning: multi-tenant + empty allowed_tenants + // admits any Entra tenant globally. Documented trade-off, but + // warn so operators don't ship it by accident. + if multi && !hasAllowed { + r.Warnings = append(r.Warnings, + prefix+" (azure_ad): allow_multi_tenant=true with no allowed_tenants list "+ + "admits any Entra tenant globally — set allowed_tenants if you want to "+ + "restrict to specific partner tenants") + } + if mode := asString(p.Settings, "groups_mode"); mode != "" && mode != "claim" && mode != "graph" { + r.Errors = append(r.Errors, fmt.Sprintf("%s (azure_ad): groups_mode must be 'claim' or 'graph', got %q", prefix, mode)) + } } } @@ -91,3 +248,16 @@ func asString(m map[string]any, key string) string { s, _ := v.(string) return s } + +// isAllDigits reports whether s consists only of ASCII digits. +func isAllDigits(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if s[i] < '0' || s[i] > '9' { + return false + } + } + return true +} diff --git a/forge-core/validate/auth_test.go b/forge-core/validate/auth_test.go index a4813fb..ed02b29 100644 --- a/forge-core/validate/auth_test.go +++ b/forge-core/validate/auth_test.go @@ -216,3 +216,114 @@ func containsSubstr(ss []string, substr string) bool { } return false } + +// --- Review M5: unknown-key warning + filter helper --- + +func TestValidateAuthConfig_WarnsOnUnknownSettingsKey(t *testing.T) { + // Typo: `aud` instead of `audience`. The required-keys check would + // fire (audience missing), but we ALSO want a loud warning about + // the unknown key so operators can spot the actual typo, not just + // the symptom. + cfg := types.AuthConfig{ + Required: true, + Providers: []types.AuthProvider{{ + Type: "oidc", + Settings: map[string]any{ + "issuer": "https://login.example.com", + "aud": "api://forge", // typo: should be 'audience' + "audience": "api://forge", // keep the required-check passing + }, + }}, + } + result := &ValidationResult{} + ValidateAuthConfig(cfg, result) + if !containsSubstr(result.Warnings, `unknown settings key "aud"`) { + t.Errorf("expected warning about unknown 'aud' key, got %v", result.Warnings) + } +} + +func TestValidateAuthConfig_NoWarningForKnownKeys(t *testing.T) { + // Every documented oidc key should be silent. + cfg := types.AuthConfig{ + Required: true, + Providers: []types.AuthProvider{{ + Type: "oidc", + Settings: map[string]any{ + "issuer": "https://x", + "audience": "y", + "client_id": "c", + "jwks_url": "https://x/jwks", + "jwks_cache_ttl": "1h", + "clock_skew": "30s", + "claim_map": map[string]any{"groups": "roles"}, + }, + }}, + } + result := &ValidationResult{} + ValidateAuthConfig(cfg, result) + for _, w := range result.Warnings { + if strings.Contains(w, "unknown settings key") { + t.Errorf("unexpected unknown-key warning for known oidc field: %q", w) + } + } +} + +func TestFilterKnownSettings_DropsUnknownKeys(t *testing.T) { + // Defense-in-depth filter that forge-ui's handler runs on + // untrusted Web UI input. Unknown keys must NOT survive. + in := map[string]any{ + "audience": "api://forge", + "issuer": "https://x", + "evil_key": "attacker-value", // unknown for oidc + } + out := FilterKnownSettings("oidc", in) + if out["audience"] != "api://forge" { + t.Errorf("audience dropped: %v", out) + } + if out["issuer"] != "https://x" { + t.Errorf("issuer dropped: %v", out) + } + if _, exists := out["evil_key"]; exists { + t.Error("evil_key must be filtered out for oidc settings") + } +} + +func TestFilterKnownSettings_UnknownProviderTypePassthrough(t *testing.T) { + // If the provider type isn't in the whitelist, pass through — + // validateProviderSettings' "unknown type" error catches that case + // separately. + in := map[string]any{"x": "y"} + out := FilterKnownSettings("future_provider", in) + if out["x"] != "y" { + t.Errorf("unknown provider type should passthrough, got %v", out) + } +} + +func TestFilterKnownSettings_AllPhase2Providers(t *testing.T) { + cases := []struct { + provider string + good string // a key that SHOULD survive + }{ + {"aws_sigv4", "region"}, + {"aws_sigv4", "allowed_accounts"}, + {"aws_sigv4", "sts_endpoint"}, // test-only override, but YAML-reachable + {"gcp_iap", "audience"}, + {"azure_ad", "tenant_id"}, + {"azure_ad", "allowed_tenants"}, + {"azure_ad", "allow_multi_tenant"}, + } + for _, tc := range cases { + t.Run(tc.provider+"/"+tc.good, func(t *testing.T) { + out := FilterKnownSettings(tc.provider, map[string]any{ + tc.good: "x", + "evil_X": "bad", + }) + if _, exists := out[tc.good]; !exists { + t.Errorf("%s: known key %q was dropped", tc.provider, tc.good) + } + if _, exists := out["evil_X"]; exists { + t.Errorf("%s: unknown key 'evil_X' survived filter", tc.provider) + } + }) + } +} diff --git a/forge-ui/handlers_create.go b/forge-ui/handlers_create.go index 7cd1eed..2a21d3c 100644 --- a/forge-ui/handlers_create.go +++ b/forge-ui/handlers_create.go @@ -124,8 +124,11 @@ func (s *UIServer) handleGetWizardMeta(w http.ResponseWriter, _ *http.Request) { // frontend renders the picker from this metadata. meta.AuthProviderTypes = []AuthProviderTypeMeta{ {Type: "none", Label: "None", Description: "Anonymous access — no auth: block written"}, - {Type: "oidc", Label: "OIDC (JWT)", Description: "Auth0, Keycloak, Azure AD, Google, Okta-OIDC, …"}, + {Type: "oidc", Label: "OIDC (JWT)", Description: "Generic OIDC issuer (Keycloak, Auth0, Okta, Google …)"}, {Type: "http_verifier", Label: "HTTP Verifier", Description: "Legacy — POST tokens to your own /verify endpoint"}, + {Type: "aws_sigv4", Label: "AWS Sigv4 (IAM)", Description: "Auth AWS-IAM callers via STS GetCallerIdentity (Phase 2)"}, + {Type: "gcp_iap", Label: "GCP Identity-Aware Proxy", Description: "Forge behind a GCP HTTPS LB+IAP (Phase 2)"}, + {Type: "azure_ad", Label: "Azure AD / Entra ID", Description: "Entra tenant tokens with optional Graph enrichment (Phase 2)"}, {Type: "custom", Label: "Custom", Description: "Write a commented stub, edit forge.yaml manually"}, } @@ -204,6 +207,15 @@ func validateAuthPayload(a *AuthCreateOptions) error { return nil } + // Filter the incoming Settings to the known-keys whitelist BEFORE + // validation OR scaffolding. Closes the exploit chain (review M5): + // without this, a POST with `{"settings": {"audience": "x", + // "evil_key": "y"}}` would drop evil_key into forge.yaml verbatim. + // Today provider Config structs ignore unknown YAML fields, but a + // future field added without a `yaml:"-"` tag would suddenly become + // reachable via untrusted POST. Filtering here gives defense-in-depth. + a.Settings = validate.FilterKnownSettings(a.Mode, a.Settings) + authYAML := types.AuthConfig{ // Required is set true here because the wizard's renderAuthBlock // always emits required:true when a provider is chosen. Keeping diff --git a/forge-ui/handlers_create_test.go b/forge-ui/handlers_create_test.go index 906de6d..8491d58 100644 --- a/forge-ui/handlers_create_test.go +++ b/forge-ui/handlers_create_test.go @@ -388,11 +388,15 @@ func TestHandleGetWizardMeta(t *testing.T) { } // PR6: auth_provider_types is server-driven so the frontend doesn't - // hardcode the list. Must include the four founding types. - if len(meta.AuthProviderTypes) != 4 { - t.Errorf("auth_provider_types len = %d, want 4", len(meta.AuthProviderTypes)) + // hardcode the list. Phase 2 adds aws_sigv4, gcp_iap, azure_ad — the + // founding four still appear. + if len(meta.AuthProviderTypes) != 7 { + t.Errorf("auth_provider_types len = %d, want 7", len(meta.AuthProviderTypes)) + } + wantTypes := map[string]bool{ + "none": false, "oidc": false, "http_verifier": false, "custom": false, + "aws_sigv4": false, "gcp_iap": false, "azure_ad": false, } - wantTypes := map[string]bool{"none": false, "oidc": false, "http_verifier": false, "custom": false} for _, a := range meta.AuthProviderTypes { if _, ok := wantTypes[a.Type]; !ok { t.Errorf("unexpected auth type %q", a.Type) @@ -731,3 +735,48 @@ func TestHandleCreateAgent_WithoutAuthPayload(t *testing.T) { t.Errorf("Auth = %v, want nil for omitted field", captured.Auth) } } + +// --- Review M5: unknown-key filter at the Web UI boundary --- + +func TestHandleCreateAgent_FiltersUnknownAuthSettings(t *testing.T) { + // Defense-in-depth: a POST that carries an unknown settings key + // (typo OR malicious) must NOT survive into what the scaffold writes. + // Today the providers' Config structs ignore unknown YAML fields, so + // even an unfiltered key is harmless; this test pins the filter so + // a future config-struct field can't suddenly become reachable via + // untrusted POST. + srv, captured := setupCreateWithCapture(t) + + body := []byte(`{ + "name": "filter-test", + "model_provider": "openai", + "auth": { + "mode": "oidc", + "settings": { + "issuer": "https://login.example.com", + "audience": "api://forge", + "evil_key": "attacker-supplied-value" + } + } + }`) + req := httptest.NewRequest(http.MethodPost, "/api/agents", bytes.NewReader(body)) + w := httptest.NewRecorder() + srv.handleCreateAgent(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201", w.Code) + } + if captured.Auth == nil { + t.Fatal("expected Auth payload to reach scaffold") + } + if _, leaked := captured.Auth.Settings["evil_key"]; leaked { + t.Errorf("evil_key leaked to scaffold: %v", captured.Auth.Settings) + } + // Known keys must still be there. + if captured.Auth.Settings["issuer"] != "https://login.example.com" { + t.Errorf("issuer dropped or wrong: %v", captured.Auth.Settings) + } + if captured.Auth.Settings["audience"] != "api://forge" { + t.Errorf("audience dropped or wrong: %v", captured.Auth.Settings) + } +} diff --git a/scripts/forge-aws-sign.py b/scripts/forge-aws-sign.py new file mode 100755 index 0000000..8cd9deb --- /dev/null +++ b/scripts/forge-aws-sign.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +forge-aws-sign — reference client for Forge's aws_sigv4 auth provider. + +The aws_sigv4 provider uses the pre-signed URL pattern (same approach as +aws-iam-authenticator for EKS). The client mints a pre-signed STS +GetCallerIdentity URL using its own AWS SDK, then sends it to Forge as +a Bearer token of the form: + + Authorization: Bearer forge-aws-v1. + +Forge invokes the pre-signed URL on STS, which validates the signature +against its own host (because that's what was signed), and returns the +caller's canonical ARN. + +Usage +===== + # Print just the token (use it however you want): + python3 forge-aws-sign.py --token-only --region us-east-1 + + # Make a one-shot call to Forge: + python3 forge-aws-sign.py --region us-east-1 \ + --url http://localhost:9999/tasks/send \ + --body '{"task":"hello"}' + +Reads AWS credentials the same way boto3 does: env vars, profile, SSO, +IRSA, instance profile, etc. + +Exits 0 on HTTP 2xx (or when --token-only succeeds); 1 otherwise. +""" +from __future__ import annotations + +import argparse +import base64 +import sys + +try: + import boto3 + import requests + from botocore.auth import SigV4QueryAuth + from botocore.awsrequest import AWSRequest +except ImportError as e: + print(f"missing dependency: {e}", file=sys.stderr) + print("install with: pip3 install --user boto3 requests", file=sys.stderr) + sys.exit(2) + + +def mint_token(region: str, profile: str | None, expires: int = 900) -> str: + """Mint a forge-aws-v1 token from the current AWS credentials. + + Builds the pre-signed URL via SigV4QueryAuth directly, NOT via + boto3.client('sts').generate_presigned_url('get_caller_identity', ...) + — the latter signs as if the request were a POST to STS and STS + rejects the resulting GET URL with "SignatureDoesNotMatch." Same + quirk aws-iam-authenticator works around by signing the request + explicitly. + + `expires` (seconds) is the TTL baked into the URL; max 900. + """ + session = boto3.Session(profile_name=profile) if profile else boto3.Session() + creds = session.get_credentials().get_frozen_credentials() + + req = AWSRequest( + method="GET", + url=f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + headers={}, + ) + SigV4QueryAuth(creds, "sts", region, expires=expires).add_auth(req) + + encoded = base64.urlsafe_b64encode(req.url.encode()).rstrip(b"=").decode() + return "forge-aws-v1." + encoded + + +def main() -> int: + parser = argparse.ArgumentParser(description="Forge aws_sigv4 reference client") + parser.add_argument("--region", default="us-east-1", help="AWS region used in the Sigv4 scope") + parser.add_argument("--url", default="http://localhost:9999/tasks/send", help="Forge endpoint to POST to") + parser.add_argument("--body", default='{"task":"hello"}', help="JSON body to send to Forge") + parser.add_argument("--profile", default=None, help="AWS profile (default: boto3's default chain)") + parser.add_argument("--expires", type=int, default=900, help="Pre-signed URL TTL in seconds (max 900)") + parser.add_argument("--token-only", action="store_true", help="Print only the token, don't make a request") + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + args = parser.parse_args() + + try: + token = mint_token(args.region, args.profile, args.expires) + except Exception as e: + print(f"failed to mint token: {e}", file=sys.stderr) + return 1 + + if args.token_only: + print(token) + return 0 + + if args.verbose: + print(f"POST {args.url}", file=sys.stderr) + print(f" Authorization: Bearer {token[:60]}...", file=sys.stderr) + + resp = requests.post( + args.url, + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + data=args.body, + ) + print(f"HTTP {resp.status_code}") + print(resp.text) + return 0 if 200 <= resp.status_code < 300 else 1 + + +if __name__ == "__main__": + sys.exit(main())