Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ env:
value: 0
- name: MAX_CONCURRENT_SESSIONS
value: "5"
# Override PAT_ROTATION_INTERVAL / PAT_TOKEN_LIFETIME here for e2e testing.
# Defaults (10 min / 15 min) live in pat_rotator.py.
24 changes: 20 additions & 4 deletions content_filter_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@
# OpenCode (and this proxy) are separate processes with frozen env snapshots,
# so we read the file on-demand instead of trusting os.environ.

_TOKEN_CACHE: dict = {"token": None, "read_at": 0.0}
_TOKEN_CACHE_TTL = 30 # seconds — short enough to pick up rotations quickly
_TOKEN_CACHE: dict = {"token": None, "read_at": 0.0, "mtime": 0.0}
# Hard ceiling on cache age. With mtime invalidation below, the cache normally
# refreshes the instant the rotator rewrites the file, so this is just a
# defence against an mtime that stops advancing (e.g. clock skew, watched fs
# tools that touch the file without updating contents).
_TOKEN_CACHE_TTL = 30

_HOME = os.environ.get("HOME", "/app/python/source_code")
if not _HOME or _HOME == "/":
Expand All @@ -50,10 +54,21 @@
def _get_fresh_token() -> str | None:
"""Read current token from ~/.databrickscfg (updated by PAT rotator).

Returns cached value if read within the last _TOKEN_CACHE_TTL seconds.
Cache invalidates on file mtime change so a rotation produces a near-zero
window of stale tokens. The TTL is a backstop; mtime is authoritative.
"""
now = time.time()
if _TOKEN_CACHE["token"] and (now - _TOKEN_CACHE["read_at"]) < _TOKEN_CACHE_TTL:
try:
mtime = os.stat(_DATABRICKSCFG_PATH).st_mtime
except OSError:
mtime = 0.0

cache_hot = (
_TOKEN_CACHE["token"]
and mtime <= _TOKEN_CACHE["mtime"]
and (now - _TOKEN_CACHE["read_at"]) < _TOKEN_CACHE_TTL
)
if cache_hot:
return _TOKEN_CACHE["token"]

try:
Expand All @@ -63,6 +78,7 @@ def _get_fresh_token() -> str | None:
if token:
_TOKEN_CACHE["token"] = token
_TOKEN_CACHE["read_at"] = now
_TOKEN_CACHE["mtime"] = mtime
return token
except Exception as e:
log.warning(f"Could not read fresh token from {_DATABRICKSCFG_PATH}: {e}")
Expand Down
29 changes: 25 additions & 4 deletions pat_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

logger = logging.getLogger(__name__)

DEFAULT_TOKEN_LIFETIME = 900 # 15 minutes
DEFAULT_ROTATION_INTERVAL = 600 # 10 minutes
# Env overrides exist so e2e tests can compress the cycle to seconds without
# a code change. Production defaults: 15-min tokens rotated every 10 min.
DEFAULT_TOKEN_LIFETIME = int(os.environ.get("PAT_TOKEN_LIFETIME", "900"))
DEFAULT_ROTATION_INTERVAL = int(os.environ.get("PAT_ROTATION_INTERVAL", "600"))


class PATRotator:
Expand Down Expand Up @@ -80,16 +82,35 @@ def stop(self):
self._stop_event.set()

def _rotation_loop(self):
"""Background loop: sleep, rotate if sessions exist, repeat."""
"""Background loop: sleep, then rotate if sessions exist OR if the
in-process token is about to expire. Always-rotating-near-expiry
prevents the rotator from deadlocking when an idle skip outruns the
token's lifetime — at that point our own auth would be dead and we
could never mint a replacement.
"""
# Force a refresh once we're inside one rotation interval of expiry.
# That window is the maximum time we can afford to skip a rotation and
# still be sure the next attempt can authenticate.
expiry_grace = max(self._rotation_interval, 60)
while not self._stop_event.is_set():
self._stop_event.wait(timeout=self._rotation_interval)
if self._stop_event.is_set():
break
try:
session_count = self._session_count_fn()
if session_count == 0:
token_age = (
time.time() - self._last_rotation_time
if self._last_rotation_time else float("inf")
)
token_near_expiry = token_age > (self._token_lifetime - expiry_grace)
if session_count == 0 and not token_near_expiry:
logger.info("PAT rotation: no active sessions — skipping rotation")
continue
if session_count == 0 and token_near_expiry:
logger.info(
"PAT rotation: no active sessions, but token approaching "
f"expiry (age={int(token_age)}s, lifetime={self._token_lifetime}s) — rotating anyway"
)
self._rotate_once()
except Exception as e:
logger.error(f"PAT rotation failed unexpectedly: {e}")
Expand Down
22 changes: 17 additions & 5 deletions setup_hermes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,26 @@ def _run(cmd, **kwargs):
print("Warning: AI Gateway resolved but DATABRICKS_TOKEN missing, falling back to DATABRICKS_HOST")
gateway_host = ""

# Route Hermes through the local content-filter proxy (127.0.0.1:4000) so
# PAT rotation is transparent: the proxy reads ~/.databrickscfg on every
# request and injects a fresh Bearer token, overriding whatever literal
# api_key is cached in Hermes's in-memory config. Without this, the
# long-running `hermes chat` process holds the startup token and gets
# revoked-token 403s once the rotator swaps PATs (~10 min cadence).
# OpenCode uses the same trick.
#
# upstream_base is recorded for the diagnostic banner below; the proxy
# itself decides the actual upstream from PROXY_UPSTREAM_BASE.
if gateway_host:
base_url = f"{gateway_host}/mlflow/v1"
upstream_base = f"{gateway_host}/mlflow/v1"
auth_token = gateway_token
print(f"Using Databricks AI Gateway: {gateway_host}")
print(f"Hermes will route via content-filter proxy -> AI Gateway: {gateway_host}")
else:
base_url = f"{host}/serving-endpoints"
upstream_base = f"{host}/serving-endpoints"
auth_token = token
print(f"Using Databricks Host: {host}")
print(f"Hermes will route via content-filter proxy -> {host}/serving-endpoints")

base_url = "http://127.0.0.1:4000"

# 4. Write ~/.hermes/config.yaml
config_path = hermes_home / "config.yaml"
Expand Down Expand Up @@ -226,7 +238,7 @@ def _run(cmd, **kwargs):
print(" hermes model # Select default model")
print(" hermes setup # Reconfigure wizard")
print(" hermes mcp add <name> <url> # Add MCP server")
print(f"\nEndpoint: {base_url}")
print(f"\nEndpoint: {base_url} (forwards to {upstream_base})")
print(f"Primary model: {hermes_model}")
print(f"Fallback model: {hermes_fallback_model} (auto-activates on 429/529/503)")
print(f"Install: minimal (add extras: uv pip install \"hermes-agent[mcp,messaging,...]\")")
Expand Down
71 changes: 70 additions & 1 deletion tests/test_cli_token_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,69 @@ def test_skips_missing_file(self, isolated_home):
update_cli_tokens("new-token")


class TestUpdateHermes:
def test_updates_both_api_key_lines(self, isolated_home):
"""Hermes config has two api_key lines (primary + fallback). Both must rotate."""
from cli_auth import update_cli_tokens
hermes_dir = isolated_home / ".hermes"
hermes_dir.mkdir()
config = (
"model:\n"
" default: databricks-claude-opus-4-7\n"
" provider: custom\n"
" base_url: http://127.0.0.1:4000\n"
" api_key: old-token\n"
"\n"
"fallback_providers:\n"
"- provider: custom\n"
" model: databricks-claude-opus-4-6\n"
" base_url: http://127.0.0.1:4000\n"
" api_key: old-token\n"
)
(hermes_dir / "config.yaml").write_text(config)

update_cli_tokens("new-token")

content = (hermes_dir / "config.yaml").read_text()
assert content.count("api_key: new-token") == 2, (
"Both primary and fallback api_key lines must be rotated. "
f"Content was:\n{content}"
)
assert "old-token" not in content
# Unrelated lines preserved
assert "default: databricks-claude-opus-4-7" in content
assert "model: databricks-claude-opus-4-6" in content

def test_preserves_other_indentation(self, isolated_home):
"""Regex must match only ` api_key:` with two-space indent, not arbitrary text."""
from cli_auth import update_cli_tokens
hermes_dir = isolated_home / ".hermes"
hermes_dir.mkdir()
# Decoy: a comment that mentions api_key, plus a 4-space-indented api_key
# that should NOT be touched.
config = (
"# api_key: this-is-a-comment-not-a-value\n"
"model:\n"
" api_key: old-token\n"
"deep:\n"
" api_key: should-not-change\n"
)
(hermes_dir / "config.yaml").write_text(config)

update_cli_tokens("new-token")

content = (hermes_dir / "config.yaml").read_text()
assert " api_key: new-token" in content
assert "# api_key: this-is-a-comment-not-a-value" in content
assert " api_key: should-not-change" in content

def test_skips_missing_file(self, isolated_home):
from cli_auth import update_cli_tokens
update_cli_tokens("new-token") # must not raise


class TestAllCLIsUpdated:
def test_all_four_updated_in_one_call(self, isolated_home):
def test_all_five_updated_in_one_call(self, isolated_home):
from cli_auth import update_cli_tokens

# Set up all config files
Expand All @@ -111,10 +172,18 @@ def test_all_four_updated_in_one_call(self, isolated_home):
gemini_dir.mkdir()
(gemini_dir / ".env").write_text("GEMINI_API_KEY=old\n")

hermes_dir = isolated_home / ".hermes"
hermes_dir.mkdir()
(hermes_dir / "config.yaml").write_text(
"model:\n api_key: old\nfallback_providers:\n- provider: custom\n api_key: old\n"
)

# One call updates all
update_cli_tokens("rotated-token")

assert json.loads((claude_dir / "settings.json").read_text())["env"]["ANTHROPIC_AUTH_TOKEN"] == "rotated-token"
assert "OPENAI_API_KEY=rotated-token" in (codex_dir / ".env").read_text()
assert json.loads((oc_dir / "auth.json").read_text())["databricks"]["api_key"] == "rotated-token"
assert "GEMINI_API_KEY=rotated-token" in (gemini_dir / ".env").read_text()
hermes_content = (hermes_dir / "config.yaml").read_text()
assert hermes_content.count("api_key: rotated-token") == 2
77 changes: 77 additions & 0 deletions tests/test_content_filter_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tests for content_filter_proxy._get_fresh_token cache invalidation.

The proxy reads ~/.databrickscfg on every forwarded request, with a cache to
avoid filesystem hits in tight request bursts. The cache must invalidate the
moment the rotator rewrites the file, otherwise the proxy serves revoked
tokens to upstream for up to TTL seconds after each rotation.
"""

import time
from unittest import mock

import pytest


@pytest.fixture
def tmp_cfg(tmp_path, monkeypatch):
"""Point the proxy at a temp .databrickscfg, with a clean cache."""
cfg = tmp_path / ".databrickscfg"
import content_filter_proxy as cfp
monkeypatch.setattr(cfp, "_DATABRICKSCFG_PATH", str(cfg))
monkeypatch.setattr(cfp, "_TOKEN_CACHE", {"token": None, "read_at": 0.0, "mtime": 0.0})
return cfg


def _write_cfg(path, token):
path.write_text(f"[DEFAULT]\nhost = https://example.databricks.com\ntoken = {token}\n")


class TestFreshTokenCacheInvalidation:
def test_cache_invalidates_on_mtime_change(self, tmp_cfg):
from content_filter_proxy import _get_fresh_token
_write_cfg(tmp_cfg, "dapi-old")
assert _get_fresh_token() == "dapi-old"

# Simulate rotator rewriting the file. utime to a guaranteed-newer mtime
# so the test isn't sensitive to filesystem mtime granularity.
_write_cfg(tmp_cfg, "dapi-new")
import os
st = os.stat(tmp_cfg)
os.utime(tmp_cfg, (st.st_atime, st.st_mtime + 10))

assert _get_fresh_token() == "dapi-new", "must re-read after mtime change"

def test_cache_hits_when_mtime_unchanged(self, tmp_cfg):
from content_filter_proxy import _get_fresh_token
_write_cfg(tmp_cfg, "dapi-stable")
assert _get_fresh_token() == "dapi-stable"

# Mutate the file contents WITHOUT advancing mtime (force mtime backwards).
# If the cache ignored mtime, it'd happily keep serving "dapi-stable";
# if it consulted mtime, it'd still serve "dapi-stable" because mtime
# didn't advance. Either way we expect the cached value back, which
# asserts the cache is doing its de-dup job within the TTL.
import os
st = os.stat(tmp_cfg)
_write_cfg(tmp_cfg, "dapi-tampered")
os.utime(tmp_cfg, (st.st_atime, st.st_mtime)) # restore old mtime

assert _get_fresh_token() == "dapi-stable"

def test_falls_back_to_cache_on_stat_error(self, tmp_cfg, monkeypatch):
from content_filter_proxy import _get_fresh_token
_write_cfg(tmp_cfg, "dapi-cached")
assert _get_fresh_token() == "dapi-cached"

# Now make os.stat fail. The cache should still return the last known token.
def boom(_):
raise OSError("stat broken")
monkeypatch.setattr("content_filter_proxy.os.stat", boom)
assert _get_fresh_token() == "dapi-cached"

def test_returns_none_when_file_missing_and_cache_empty(self, tmp_path, monkeypatch):
import content_filter_proxy as cfp
missing = tmp_path / "does-not-exist"
monkeypatch.setattr(cfp, "_DATABRICKSCFG_PATH", str(missing))
monkeypatch.setattr(cfp, "_TOKEN_CACHE", {"token": None, "read_at": 0.0, "mtime": 0.0})
assert cfp._get_fresh_token() is None
47 changes: 47 additions & 0 deletions tests/test_pat_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,50 @@ def test_log_pat_rotated_label(self, mock_post, caplog, tmp_path):

combined = " ".join(caplog.messages)
assert "PAT rotation complete" in combined


class TestRotationOnNearExpiry:
"""When sessions are idle but the in-process token is near expiry, the loop
must rotate anyway. Otherwise our own auth dies during idle periods and the
rotator deadlocks on subsequent token/create attempts."""

def test_skip_when_token_fresh_and_no_sessions(self, caplog):
import time as _time
rotator = _make_rotator(
session_count_fn=lambda: 0,
rotation_interval=1,
token_lifetime=3600,
)
rotator._current_token = "dapi-fresh"
rotator._current_token_id = "tid-fresh"
rotator._last_rotation_time = _time.time() # just minted

with mock.patch.object(rotator, "_rotate_once") as mr:
t = threading.Thread(target=rotator._rotation_loop, daemon=True)
t.start()
_time.sleep(1.5)
rotator.stop()
t.join(timeout=2)

assert mr.call_count == 0

def test_rotate_when_token_near_expiry_even_with_no_sessions(self):
import time as _time
rotator = _make_rotator(
session_count_fn=lambda: 0,
rotation_interval=1,
token_lifetime=10,
)
rotator._current_token = "dapi-near-expiry"
rotator._current_token_id = "tid-near-expiry"
# Token is 9s old of a 10s lifetime — within the rotation interval of expiry.
rotator._last_rotation_time = _time.time() - 9

with mock.patch.object(rotator, "_rotate_once", return_value=True) as mr:
t = threading.Thread(target=rotator._rotation_loop, daemon=True)
t.start()
_time.sleep(1.5)
rotator.stop()
t.join(timeout=2)

assert mr.call_count >= 1