diff --git a/app.yaml b/app.yaml index 1a2fbc0..6348fe2 100644 --- a/app.yaml +++ b/app.yaml @@ -21,3 +21,5 @@ env: value: 0 - name: MAX_CONCURRENT_SESSIONS value: "5" + # Override PAT_ROTATION_INTERVAL / PAT_TOKEN_LIFETIME here for e2e testing. + # Defaults (10 min / 15 min) live in pat_rotator.py. diff --git a/content_filter_proxy.py b/content_filter_proxy.py index c20d8d2..39f3391 100644 --- a/content_filter_proxy.py +++ b/content_filter_proxy.py @@ -38,8 +38,12 @@ # OpenCode (and this proxy) are separate processes with frozen env snapshots, # so we read the file on-demand instead of trusting os.environ. -_TOKEN_CACHE: dict = {"token": None, "read_at": 0.0} -_TOKEN_CACHE_TTL = 30 # seconds — short enough to pick up rotations quickly +_TOKEN_CACHE: dict = {"token": None, "read_at": 0.0, "mtime": 0.0} +# Hard ceiling on cache age. With mtime invalidation below, the cache normally +# refreshes the instant the rotator rewrites the file, so this is just a +# defence against an mtime that stops advancing (e.g. clock skew, watched fs +# tools that touch the file without updating contents). +_TOKEN_CACHE_TTL = 30 _HOME = os.environ.get("HOME", "/app/python/source_code") if not _HOME or _HOME == "/": @@ -50,10 +54,21 @@ def _get_fresh_token() -> str | None: """Read current token from ~/.databrickscfg (updated by PAT rotator). - Returns cached value if read within the last _TOKEN_CACHE_TTL seconds. + Cache invalidates on file mtime change so a rotation produces a near-zero + window of stale tokens. The TTL is a backstop; mtime is authoritative. """ now = time.time() - if _TOKEN_CACHE["token"] and (now - _TOKEN_CACHE["read_at"]) < _TOKEN_CACHE_TTL: + try: + mtime = os.stat(_DATABRICKSCFG_PATH).st_mtime + except OSError: + mtime = 0.0 + + cache_hot = ( + _TOKEN_CACHE["token"] + and mtime <= _TOKEN_CACHE["mtime"] + and (now - _TOKEN_CACHE["read_at"]) < _TOKEN_CACHE_TTL + ) + if cache_hot: return _TOKEN_CACHE["token"] try: @@ -63,6 +78,7 @@ def _get_fresh_token() -> str | None: if token: _TOKEN_CACHE["token"] = token _TOKEN_CACHE["read_at"] = now + _TOKEN_CACHE["mtime"] = mtime return token except Exception as e: log.warning(f"Could not read fresh token from {_DATABRICKSCFG_PATH}: {e}") diff --git a/pat_rotator.py b/pat_rotator.py index cc8734f..8b3e42e 100644 --- a/pat_rotator.py +++ b/pat_rotator.py @@ -18,8 +18,10 @@ logger = logging.getLogger(__name__) -DEFAULT_TOKEN_LIFETIME = 900 # 15 minutes -DEFAULT_ROTATION_INTERVAL = 600 # 10 minutes +# Env overrides exist so e2e tests can compress the cycle to seconds without +# a code change. Production defaults: 15-min tokens rotated every 10 min. +DEFAULT_TOKEN_LIFETIME = int(os.environ.get("PAT_TOKEN_LIFETIME", "900")) +DEFAULT_ROTATION_INTERVAL = int(os.environ.get("PAT_ROTATION_INTERVAL", "600")) class PATRotator: @@ -80,16 +82,35 @@ def stop(self): self._stop_event.set() def _rotation_loop(self): - """Background loop: sleep, rotate if sessions exist, repeat.""" + """Background loop: sleep, then rotate if sessions exist OR if the + in-process token is about to expire. Always-rotating-near-expiry + prevents the rotator from deadlocking when an idle skip outruns the + token's lifetime — at that point our own auth would be dead and we + could never mint a replacement. + """ + # Force a refresh once we're inside one rotation interval of expiry. + # That window is the maximum time we can afford to skip a rotation and + # still be sure the next attempt can authenticate. + expiry_grace = max(self._rotation_interval, 60) while not self._stop_event.is_set(): self._stop_event.wait(timeout=self._rotation_interval) if self._stop_event.is_set(): break try: session_count = self._session_count_fn() - if session_count == 0: + token_age = ( + time.time() - self._last_rotation_time + if self._last_rotation_time else float("inf") + ) + token_near_expiry = token_age > (self._token_lifetime - expiry_grace) + if session_count == 0 and not token_near_expiry: logger.info("PAT rotation: no active sessions — skipping rotation") continue + if session_count == 0 and token_near_expiry: + logger.info( + "PAT rotation: no active sessions, but token approaching " + f"expiry (age={int(token_age)}s, lifetime={self._token_lifetime}s) — rotating anyway" + ) self._rotate_once() except Exception as e: logger.error(f"PAT rotation failed unexpectedly: {e}") diff --git a/setup_hermes.py b/setup_hermes.py index 07bb030..37f4309 100644 --- a/setup_hermes.py +++ b/setup_hermes.py @@ -107,14 +107,26 @@ def _run(cmd, **kwargs): print("Warning: AI Gateway resolved but DATABRICKS_TOKEN missing, falling back to DATABRICKS_HOST") gateway_host = "" +# Route Hermes through the local content-filter proxy (127.0.0.1:4000) so +# PAT rotation is transparent: the proxy reads ~/.databrickscfg on every +# request and injects a fresh Bearer token, overriding whatever literal +# api_key is cached in Hermes's in-memory config. Without this, the +# long-running `hermes chat` process holds the startup token and gets +# revoked-token 403s once the rotator swaps PATs (~10 min cadence). +# OpenCode uses the same trick. +# +# upstream_base is recorded for the diagnostic banner below; the proxy +# itself decides the actual upstream from PROXY_UPSTREAM_BASE. if gateway_host: - base_url = f"{gateway_host}/mlflow/v1" + upstream_base = f"{gateway_host}/mlflow/v1" auth_token = gateway_token - print(f"Using Databricks AI Gateway: {gateway_host}") + print(f"Hermes will route via content-filter proxy -> AI Gateway: {gateway_host}") else: - base_url = f"{host}/serving-endpoints" + upstream_base = f"{host}/serving-endpoints" auth_token = token - print(f"Using Databricks Host: {host}") + print(f"Hermes will route via content-filter proxy -> {host}/serving-endpoints") + +base_url = "http://127.0.0.1:4000" # 4. Write ~/.hermes/config.yaml config_path = hermes_home / "config.yaml" @@ -226,7 +238,7 @@ def _run(cmd, **kwargs): print(" hermes model # Select default model") print(" hermes setup # Reconfigure wizard") print(" hermes mcp add # Add MCP server") -print(f"\nEndpoint: {base_url}") +print(f"\nEndpoint: {base_url} (forwards to {upstream_base})") print(f"Primary model: {hermes_model}") print(f"Fallback model: {hermes_fallback_model} (auto-activates on 429/529/503)") print(f"Install: minimal (add extras: uv pip install \"hermes-agent[mcp,messaging,...]\")") diff --git a/tests/test_cli_token_rotation.py b/tests/test_cli_token_rotation.py index 7393299..cfa3655 100644 --- a/tests/test_cli_token_rotation.py +++ b/tests/test_cli_token_rotation.py @@ -88,8 +88,69 @@ def test_skips_missing_file(self, isolated_home): update_cli_tokens("new-token") +class TestUpdateHermes: + def test_updates_both_api_key_lines(self, isolated_home): + """Hermes config has two api_key lines (primary + fallback). Both must rotate.""" + from cli_auth import update_cli_tokens + hermes_dir = isolated_home / ".hermes" + hermes_dir.mkdir() + config = ( + "model:\n" + " default: databricks-claude-opus-4-7\n" + " provider: custom\n" + " base_url: http://127.0.0.1:4000\n" + " api_key: old-token\n" + "\n" + "fallback_providers:\n" + "- provider: custom\n" + " model: databricks-claude-opus-4-6\n" + " base_url: http://127.0.0.1:4000\n" + " api_key: old-token\n" + ) + (hermes_dir / "config.yaml").write_text(config) + + update_cli_tokens("new-token") + + content = (hermes_dir / "config.yaml").read_text() + assert content.count("api_key: new-token") == 2, ( + "Both primary and fallback api_key lines must be rotated. " + f"Content was:\n{content}" + ) + assert "old-token" not in content + # Unrelated lines preserved + assert "default: databricks-claude-opus-4-7" in content + assert "model: databricks-claude-opus-4-6" in content + + def test_preserves_other_indentation(self, isolated_home): + """Regex must match only ` api_key:` with two-space indent, not arbitrary text.""" + from cli_auth import update_cli_tokens + hermes_dir = isolated_home / ".hermes" + hermes_dir.mkdir() + # Decoy: a comment that mentions api_key, plus a 4-space-indented api_key + # that should NOT be touched. + config = ( + "# api_key: this-is-a-comment-not-a-value\n" + "model:\n" + " api_key: old-token\n" + "deep:\n" + " api_key: should-not-change\n" + ) + (hermes_dir / "config.yaml").write_text(config) + + update_cli_tokens("new-token") + + content = (hermes_dir / "config.yaml").read_text() + assert " api_key: new-token" in content + assert "# api_key: this-is-a-comment-not-a-value" in content + assert " api_key: should-not-change" in content + + def test_skips_missing_file(self, isolated_home): + from cli_auth import update_cli_tokens + update_cli_tokens("new-token") # must not raise + + class TestAllCLIsUpdated: - def test_all_four_updated_in_one_call(self, isolated_home): + def test_all_five_updated_in_one_call(self, isolated_home): from cli_auth import update_cli_tokens # Set up all config files @@ -111,6 +172,12 @@ def test_all_four_updated_in_one_call(self, isolated_home): gemini_dir.mkdir() (gemini_dir / ".env").write_text("GEMINI_API_KEY=old\n") + hermes_dir = isolated_home / ".hermes" + hermes_dir.mkdir() + (hermes_dir / "config.yaml").write_text( + "model:\n api_key: old\nfallback_providers:\n- provider: custom\n api_key: old\n" + ) + # One call updates all update_cli_tokens("rotated-token") @@ -118,3 +185,5 @@ def test_all_four_updated_in_one_call(self, isolated_home): assert "OPENAI_API_KEY=rotated-token" in (codex_dir / ".env").read_text() assert json.loads((oc_dir / "auth.json").read_text())["databricks"]["api_key"] == "rotated-token" assert "GEMINI_API_KEY=rotated-token" in (gemini_dir / ".env").read_text() + hermes_content = (hermes_dir / "config.yaml").read_text() + assert hermes_content.count("api_key: rotated-token") == 2 diff --git a/tests/test_content_filter_proxy.py b/tests/test_content_filter_proxy.py new file mode 100644 index 0000000..8cc0d55 --- /dev/null +++ b/tests/test_content_filter_proxy.py @@ -0,0 +1,77 @@ +"""Tests for content_filter_proxy._get_fresh_token cache invalidation. + +The proxy reads ~/.databrickscfg on every forwarded request, with a cache to +avoid filesystem hits in tight request bursts. The cache must invalidate the +moment the rotator rewrites the file, otherwise the proxy serves revoked +tokens to upstream for up to TTL seconds after each rotation. +""" + +import time +from unittest import mock + +import pytest + + +@pytest.fixture +def tmp_cfg(tmp_path, monkeypatch): + """Point the proxy at a temp .databrickscfg, with a clean cache.""" + cfg = tmp_path / ".databrickscfg" + import content_filter_proxy as cfp + monkeypatch.setattr(cfp, "_DATABRICKSCFG_PATH", str(cfg)) + monkeypatch.setattr(cfp, "_TOKEN_CACHE", {"token": None, "read_at": 0.0, "mtime": 0.0}) + return cfg + + +def _write_cfg(path, token): + path.write_text(f"[DEFAULT]\nhost = https://example.databricks.com\ntoken = {token}\n") + + +class TestFreshTokenCacheInvalidation: + def test_cache_invalidates_on_mtime_change(self, tmp_cfg): + from content_filter_proxy import _get_fresh_token + _write_cfg(tmp_cfg, "dapi-old") + assert _get_fresh_token() == "dapi-old" + + # Simulate rotator rewriting the file. utime to a guaranteed-newer mtime + # so the test isn't sensitive to filesystem mtime granularity. + _write_cfg(tmp_cfg, "dapi-new") + import os + st = os.stat(tmp_cfg) + os.utime(tmp_cfg, (st.st_atime, st.st_mtime + 10)) + + assert _get_fresh_token() == "dapi-new", "must re-read after mtime change" + + def test_cache_hits_when_mtime_unchanged(self, tmp_cfg): + from content_filter_proxy import _get_fresh_token + _write_cfg(tmp_cfg, "dapi-stable") + assert _get_fresh_token() == "dapi-stable" + + # Mutate the file contents WITHOUT advancing mtime (force mtime backwards). + # If the cache ignored mtime, it'd happily keep serving "dapi-stable"; + # if it consulted mtime, it'd still serve "dapi-stable" because mtime + # didn't advance. Either way we expect the cached value back, which + # asserts the cache is doing its de-dup job within the TTL. + import os + st = os.stat(tmp_cfg) + _write_cfg(tmp_cfg, "dapi-tampered") + os.utime(tmp_cfg, (st.st_atime, st.st_mtime)) # restore old mtime + + assert _get_fresh_token() == "dapi-stable" + + def test_falls_back_to_cache_on_stat_error(self, tmp_cfg, monkeypatch): + from content_filter_proxy import _get_fresh_token + _write_cfg(tmp_cfg, "dapi-cached") + assert _get_fresh_token() == "dapi-cached" + + # Now make os.stat fail. The cache should still return the last known token. + def boom(_): + raise OSError("stat broken") + monkeypatch.setattr("content_filter_proxy.os.stat", boom) + assert _get_fresh_token() == "dapi-cached" + + def test_returns_none_when_file_missing_and_cache_empty(self, tmp_path, monkeypatch): + import content_filter_proxy as cfp + missing = tmp_path / "does-not-exist" + monkeypatch.setattr(cfp, "_DATABRICKSCFG_PATH", str(missing)) + monkeypatch.setattr(cfp, "_TOKEN_CACHE", {"token": None, "read_at": 0.0, "mtime": 0.0}) + assert cfp._get_fresh_token() is None diff --git a/tests/test_pat_rotator.py b/tests/test_pat_rotator.py index 40b412b..9dc797c 100644 --- a/tests/test_pat_rotator.py +++ b/tests/test_pat_rotator.py @@ -398,3 +398,50 @@ def test_log_pat_rotated_label(self, mock_post, caplog, tmp_path): combined = " ".join(caplog.messages) assert "PAT rotation complete" in combined + + +class TestRotationOnNearExpiry: + """When sessions are idle but the in-process token is near expiry, the loop + must rotate anyway. Otherwise our own auth dies during idle periods and the + rotator deadlocks on subsequent token/create attempts.""" + + def test_skip_when_token_fresh_and_no_sessions(self, caplog): + import time as _time + rotator = _make_rotator( + session_count_fn=lambda: 0, + rotation_interval=1, + token_lifetime=3600, + ) + rotator._current_token = "dapi-fresh" + rotator._current_token_id = "tid-fresh" + rotator._last_rotation_time = _time.time() # just minted + + with mock.patch.object(rotator, "_rotate_once") as mr: + t = threading.Thread(target=rotator._rotation_loop, daemon=True) + t.start() + _time.sleep(1.5) + rotator.stop() + t.join(timeout=2) + + assert mr.call_count == 0 + + def test_rotate_when_token_near_expiry_even_with_no_sessions(self): + import time as _time + rotator = _make_rotator( + session_count_fn=lambda: 0, + rotation_interval=1, + token_lifetime=10, + ) + rotator._current_token = "dapi-near-expiry" + rotator._current_token_id = "tid-near-expiry" + # Token is 9s old of a 10s lifetime — within the rotation interval of expiry. + rotator._last_rotation_time = _time.time() - 9 + + with mock.patch.object(rotator, "_rotate_once", return_value=True) as mr: + t = threading.Thread(target=rotator._rotation_loop, daemon=True) + t.start() + _time.sleep(1.5) + rotator.stop() + t.join(timeout=2) + + assert mr.call_count >= 1