diff --git a/app/demo_adapter.py b/app/demo_adapter.py index 2d30c0eb..d8a912d2 100644 --- a/app/demo_adapter.py +++ b/app/demo_adapter.py @@ -15,7 +15,6 @@ import uuid from fastapi import HTTPException -from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from .routers.account import facility_adapter as account_adapter @@ -368,8 +367,8 @@ async def list_sites( sites = [s for s in sites if s.last_modified > ms] o = offset or 0 - l = limit or len(sites) - return sites[o : o + l] + limit_count = limit or len(sites) + return sites[o : o + limit_count] async def get_site(self: "DemoAdapter", site_id: str, modified_since: str | None = None) -> facility_models.Site: site = next((s for s in self.sites if s.id == site_id), None) @@ -512,11 +511,25 @@ async def get_current_user_globus( """ return "gtorok" + async def get_current_user_oidc( + self: "DemoAdapter", + api_key: str, + client_ip: str | None, + token_info: dict | None, + ) -> str: + """ + Decode the api_key and return the authenticated user's id from information returned by an OIDC token. + This method is not called directly, rather authorized endpoints "depend" on it. + (https://fastapi.tiangolo.com/tutorial/dependencies/) + """ + return token_info.get("sub", "gtorok") if token_info else "gtorok" + async def get_user( self: "DemoAdapter", user_id: str, api_key: str, client_ip: str | None, + token_info: dict | None, globus_introspect: dict | None, ) -> User: if user_id != self.user.id: diff --git a/app/routers/iri_router.py b/app/routers/iri_router.py index 8abcc4b5..2c4ee942 100644 --- a/app/routers/iri_router.py +++ b/app/routers/iri_router.py @@ -1,15 +1,27 @@ from abc import ABC, abstractmethod +import asyncio import os import logging import importlib +import threading import time +from typing import Any import globus_sdk +import httpx +from authlib.jose import JsonWebKey, JsonWebToken, KeySet +from authlib.jose.errors import JoseError +from cachetools import TTLCache from fastapi import Request, Depends, HTTPException, APIRouter from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from ..types.user import User bearer_scheme = HTTPBearer() +_DISCOVERY_TIMEOUT_SECONDS = float(os.environ.get("OIDC_DISCOVERY_TIMEOUT_SECONDS", "10")) +_DISCOVERY_CACHE_TTL_SECONDS = float(os.environ.get("OIDC_DISCOVERY_CACHE_TTL_SECONDS", "300")) +_oidc_remote_cache_lock = threading.Lock() +_oidc_remote_cache: TTLCache[str, tuple[dict[str, Any], KeySet]] = TTLCache(maxsize=128, ttl=_DISCOVERY_CACHE_TTL_SECONDS) +_oidc_remote_stale_cache: dict[str, tuple[dict[str, Any], KeySet]] = {} GLOBUS_RS_ID = os.environ.get("GLOBUS_RS_ID") @@ -17,6 +29,153 @@ GLOBUS_RS_SCOPE_SUFFIX = os.environ.get("GLOBUS_RS_SCOPE_SUFFIX") +def _env_true(name: str, default: bool = False) -> bool: + """Boolean env var checker.""" + raw = os.environ.get(name) + if raw is None or raw == "": + return default + return raw.strip().lower() not in {"0", "false", "off", "no"} + + +def _amsc_oidc_enabled() -> bool: + """AmSC PingAM OIDC: on if IRI_AUTH_AMSC != off AND OIDC_DISCOVERY_URI/CLIENT_ID configured.""" + return _env_true("IRI_AUTH_AMSC", False) and _oidc_auth_config() is not None + + +def _globus_enabled() -> bool: + """Globus introspection: on if IRI_AUTH_GLOBUS != off AND GLOBUS_RS_ID/SECRET/SCOPE_SUFFIX configured.""" + return bool(_env_true("IRI_AUTH_GLOBUS", False) and GLOBUS_RS_ID + and GLOBUS_RS_SECRET and GLOBUS_RS_SCOPE_SUFFIX) + + +def _oidc_auth_config() -> dict[str, str] | None: + discovery_uri = os.environ.get("OIDC_DISCOVERY_URI") + client_id = os.environ.get("OIDC_CLIENT_ID") + + if not discovery_uri or not client_id: + return None + + required_scopes = tuple( + scope + for scope in ( + os.environ.get("OIDC_REQUIRED_SCOPES") + or os.environ.get("OIDC_REQUIRED_SCOPE") + or "" + ).replace(",", " ").split() + if scope + ) + + return { + "discovery_uri": discovery_uri, + "client_id": client_id, + "required_scopes": required_scopes, + "required_audience": os.environ.get("OIDC_REQUIRED_AUDIENCE") or client_id, + } + + +async def _fetch_oidc_remote_state(discovery_uri: str) -> tuple[dict[str, Any], KeySet]: + """Fetch the OIDC discovery.""" + async with httpx.AsyncClient(timeout=_DISCOVERY_TIMEOUT_SECONDS) as client: + metadata_resp = await client.get(discovery_uri, headers={"Accept": "application/json"}) + metadata_resp.raise_for_status() + metadata = metadata_resp.json() + jwks_uri = metadata.get("jwks_uri") + if not jwks_uri: + raise RuntimeError("OIDC discovery document is missing jwks_uri") + jwks_resp = await client.get(jwks_uri, headers={"Accept": "application/json"}) + jwks_resp.raise_for_status() + return metadata, JsonWebKey.import_key_set(jwks_resp.json()) + + +async def _load_oidc_remote_state(discovery_uri: str) -> tuple[dict[str, Any], KeySet]: + """TTL-cached wrapper around fetching oidc remote state. + On refresh failure we fall back to the last cached state so a transient + IdP outage doesn't take the whole IRI service down. + """ + _log = logging.getLogger(__name__) + cached: tuple[dict[str, Any], KeySet] | None = None + stale: tuple[dict[str, Any], KeySet] | None = None + with _oidc_remote_cache_lock: + cached = _oidc_remote_cache.get(discovery_uri) + stale = _oidc_remote_stale_cache.get(discovery_uri) + if cached: + _log.info("OIDC JWKS cache HIT for %s (TTL %.0fs)", discovery_uri, _DISCOVERY_CACHE_TTL_SECONDS) + return cached + + _log.info("OIDC JWKS cache MISS for %s; fetching discovery + JWKS", discovery_uri) + try: + metadata, key_set = await _fetch_oidc_remote_state(discovery_uri) + except Exception: + if stale: + logging.getLogger(__name__).warning( + "OIDC discovery refresh failed for %s; reusing cached metadata + JWKS", + discovery_uri, + exc_info=True, + ) + return stale + raise + + with _oidc_remote_cache_lock: + _oidc_remote_cache[discovery_uri] = (metadata, key_set) + _oidc_remote_stale_cache[discovery_uri] = (metadata, key_set) + _log.info("OIDC JWKS cache STORED for %s (TTL %.0fs)", discovery_uri, _DISCOVERY_CACHE_TTL_SECONDS) + return metadata, key_set + + +async def _decode_oidc_jwt(api_key: str, discovery_uri: str, required_audience: str) -> dict[str, Any]: + """Verify the JWT signature against the IdP's JWKS and enforce required claims.""" + metadata, key_set = await _load_oidc_remote_state(discovery_uri) + algs_advertised = metadata.get("id_token_signing_alg_values_supported") or [] + algorithms = [alg for alg in algs_advertised if not alg.startswith("HS")] + if not algorithms: + raise RuntimeError("OIDC discovery document advertises no asymmetric signing algorithms") + claims_options = { + "iss": {"essential": True, "value": metadata["issuer"]}, + "aud": {"essential": True, "value": required_audience}, + "exp": {"essential": True}, + "nbf": {"essential": True}, + "iat": {"essential": True}, + } + + def decode_and_validate() -> dict[str, Any]: + claims = JsonWebToken(algorithms).decode(api_key, key_set, claims_options=claims_options) + claims.validate() + return dict(claims) + + return await asyncio.to_thread(decode_and_validate) + + +async def _get_userinfo(bearer_token: str, discovery_uri: str, token_info: dict[str, Any]) -> dict[str, Any]: + """Fetch profile claims from the OIDC UserInfo endpoint when they are not embedded.""" + _log = logging.getLogger(__name__) + + if token_info.get("name") or token_info.get("email"): + return token_info + + metadata, _ = await _load_oidc_remote_state(discovery_uri) + userinfo_endpoint = metadata.get("userinfo_endpoint") + if not userinfo_endpoint: + _log.warning("OIDC discovery document missing userinfo_endpoint; profile claims unavailable") + return token_info + + try: + async with httpx.AsyncClient(timeout=_DISCOVERY_TIMEOUT_SECONDS) as client: + resp = await client.get( + userinfo_endpoint, + headers={"Authorization": f"Bearer {bearer_token}", "Accept": "application/json"}, + ) + resp.raise_for_status() + userinfo = resp.json() + _log.info("OIDC UserInfo returned claims: %s", list(userinfo.keys())) + for key, value in userinfo.items(): + if key not in token_info: + token_info[key] = value + except Exception: + _log.warning("Failed to fetch OIDC UserInfo; proceeding without profile claims", exc_info=True) + + return token_info + + def get_client_ip(request: Request) -> str | None: forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: @@ -80,6 +239,44 @@ def create_adapter(router_name, router_adapter): return AdapterClass() + async def get_oidc_token_info(self, api_key: str) -> dict[str, Any]: + """Validate a bearer JWT against the configured OIDC provider.""" + config = _oidc_auth_config() + if not config: + raise RuntimeError("OIDC auth is not configured") + + try: + token_info = await _decode_oidc_jwt( + api_key, + config["discovery_uri"], + config["required_audience"], + ) + except httpx.HTTPError as exc: + raise RuntimeError(f"OIDC discovery/JWKS request failed: {exc}") from exc + except JoseError as exc: + raise RuntimeError(f"OIDC JWT validation failed: {exc}") from exc + + logging.getLogger().info("PING OIDC JWT VALIDATION CLAIMS:") + logging.getLogger().info(token_info) + + token_info = await _get_userinfo(api_key, config["discovery_uri"], token_info) + + required_scopes = config["required_scopes"] + if required_scopes: + raw_scope = token_info.get("scope") + if isinstance(raw_scope, str): + token_scope = {s for s in raw_scope.split() if s} + elif isinstance(raw_scope, list): + token_scope = {str(s) for s in raw_scope if str(s)} + else: + token_scope = set() + missing_scopes = [s for s in required_scopes if s not in token_scope] + if missing_scopes: + raise Exception(f"Token missing required scopes: {', '.join(missing_scopes)}") + + return token_info + + async def get_globus_info(self, api_key: str) -> dict: """Returns the linked identities and the session info objects""" # Introspect the IRI API token using resource server credentials @@ -129,22 +326,36 @@ async def current_user( token = credentials.credentials ip_address = get_client_ip(request) user_id = None + token_info = None globus_introspect = None exc_msg = "" - try: - if GLOBUS_RS_ID and GLOBUS_RS_SECRET and GLOBUS_RS_SCOPE_SUFFIX: - try: - globus_introspect = await self.get_globus_info(token) - user_id = await self.adapter.get_current_user_globus(token, ip_address, globus_introspect) - except Exception as globus_exc: - logging.getLogger().exception("Globus error:", exc_info=globus_exc) - exc_msg = f"Globus authentication failed: {str(globus_exc)}. || " - if not user_id: + + if _amsc_oidc_enabled(): + try: + token_info = await self.get_oidc_token_info(token) + user_id = await self.adapter.get_current_user_oidc(token, ip_address, token_info) + except Exception as oidc_exc: + logging.getLogger().exception("AmSC OIDC auth error:", exc_info=oidc_exc) + exc_msg += f"AmSC OIDC authentication failed: {str(oidc_exc)}. || " + token_info = None + + if not user_id and _globus_enabled(): + try: + globus_introspect = await self.get_globus_info(token) + user_id = await self.adapter.get_current_user_globus(token, ip_address, globus_introspect) + except Exception as globus_exc: + logging.getLogger().exception("Globus auth error:", exc_info=globus_exc) + exc_msg += f"Globus authentication failed: {str(globus_exc)}. || " + globus_introspect = None + + if not user_id: + try: user_id = await self.adapter.get_current_user(token, ip_address) - except Exception as exc: - logging.getLogger().exception("Facility Specific auth failed: ", exc_info=exc) - exc_msg += f"Facility Specific authentication failed: {str(exc)}" - raise HTTPException(status_code=401, detail=exc_msg) from exc + except Exception as exc: + logging.getLogger().exception("Facility Specific auth failed: ", exc_info=exc) + exc_msg += f"Facility Specific authentication failed: {str(exc)}" + raise HTTPException(status_code=401, detail=exc_msg) from exc + if not user_id: raise HTTPException(status_code=403, detail="Authentication succeeded but no user ID was identified. Contact Facility Admin.") @@ -152,6 +363,7 @@ async def current_user( user_id=user_id, api_key=token, client_ip=ip_address, + token_info=token_info, globus_introspect=globus_introspect, ) @@ -170,6 +382,15 @@ async def get_current_user(self: "AuthenticatedAdapter", api_key: str, client_ip """ pass + @abstractmethod + async def get_current_user_oidc(self: "AuthenticatedAdapter", api_key: str, client_ip: str | None, token_info: dict | None) -> str: + """ + Decode the api_key and return the authenticated user's id from information returned by an OIDC token. + This method is not called directly, rather authorized endpoints "depend" on it. + (https://fastapi.tiangolo.com/tutorial/dependencies/) + """ + pass + @abstractmethod async def get_current_user_globus(self: "AuthenticatedAdapter", api_key: str, client_ip: str | None, globus_introspect: dict | None) -> str: """ @@ -180,8 +401,10 @@ async def get_current_user_globus(self: "AuthenticatedAdapter", api_key: str, cl pass @abstractmethod - async def get_user(self: "AuthenticatedAdapter", user_id: str, api_key: str, client_ip: str | None, globus_introspect: dict | None) -> User: + async def get_user(self: "AuthenticatedAdapter", user_id: str, api_key: str, client_ip: str | None, token_info: dict | None, globus_introspect: dict | None) -> User: """ Retrieve additional user information (name, email, etc.) for the given user_id. + ``token_info`` is populated when OIDC validation produced it; + ``globus_introspect`` is populated when Globus introspection produced it. """ pass diff --git a/pyproject.toml b/pyproject.toml index 4e34d7e4..d2f2d684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,9 @@ dependencies = [ "opentelemetry-instrumentation-fastapi>=0.60b1,<0.61b0", "opentelemetry-exporter-otlp>=1.39.1,<1.40.0", "globus-sdk>=4.3.1", + "authlib>=1.3.0", + "httpx>=0.27.0", + "cachetools>=5.3.0", "typer>=0.24.1", ] [tool.ruff] diff --git a/test/test_oidc_auth.py b/test/test_oidc_auth.py new file mode 100644 index 00000000..10ec36c7 --- /dev/null +++ b/test/test_oidc_auth.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +"""Focused tests for OIDC remote state caching in the IRI router.""" + +import os +import unittest +from unittest.mock import patch + +os.environ.setdefault("IRI_SHOW_MISSING_ROUTES", "true") + +from app.routers import iri_router + + +class _FakeHttpxResponse: + def __init__(self, payload: dict): + self._payload = payload + + def json(self) -> dict: + return self._payload + + def raise_for_status(self) -> None: + return None + + +class _FakeAsyncClient: + def __init__(self, responses: dict[str, dict], requests_seen: list[str], *args, **kwargs): + self._responses = responses + self._requests_seen = requests_seen + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url: str, headers: dict | None = None): + self._requests_seen.append(url) + if url not in self._responses: + raise AssertionError(f"unexpected URL opened in test: {url}") + return _FakeHttpxResponse(self._responses[url]) + + +class OidcAuthTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + iri_router._oidc_remote_cache.clear() + iri_router._oidc_remote_stale_cache.clear() + + async def test_load_oidc_remote_state_fetches_and_caches_with_async_httpx(self): + discovery_uri = "https://identity.example.test/.well-known/openid-configuration" + jwks_uri = "https://identity.example.test/oauth2/jwks" + requests_seen = [] + responses = { + discovery_uri: { + "issuer": "https://identity.example.test/oauth2", + "jwks_uri": jwks_uri, + }, + jwks_uri: {"keys": []}, + } + + def fake_async_client(*args, **kwargs): + return _FakeAsyncClient(responses, requests_seen, *args, **kwargs) + + with patch("app.routers.iri_router.httpx.AsyncClient", side_effect=fake_async_client), \ + patch("app.routers.iri_router.JsonWebKey.import_key_set", return_value="fake-key-set"): + metadata, key_set = await iri_router._load_oidc_remote_state(discovery_uri) + cached_metadata, cached_key_set = await iri_router._load_oidc_remote_state(discovery_uri) + + self.assertEqual(metadata["jwks_uri"], jwks_uri) + self.assertEqual(key_set, "fake-key-set") + self.assertEqual(cached_metadata, metadata) + self.assertEqual(cached_key_set, key_set) + self.assertEqual(requests_seen, [discovery_uri, jwks_uri]) + + async def test_load_oidc_remote_state_reuses_stale_cache_on_refresh_failure(self): + discovery_uri = "https://identity.example.test/.well-known/openid-configuration" + cached_metadata = {"issuer": "https://identity.example.test/oauth2", "jwks_uri": "cached"} + cached_key_set = object() + iri_router._oidc_remote_stale_cache[discovery_uri] = (cached_metadata, cached_key_set) + + async def fail_fetch(uri: str): + raise RuntimeError("temporary IdP outage") + + with patch("app.routers.iri_router._fetch_oidc_remote_state", side_effect=fail_fetch): + metadata, key_set = await iri_router._load_oidc_remote_state(discovery_uri) + + self.assertIs(metadata, cached_metadata) + self.assertIs(key_set, cached_key_set) + + +if __name__ == "__main__": + unittest.main()