Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions src/paperscout/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
"""Environment-backed runtime configuration (see ``settings`` singleton)."""
"""Environment-backed runtime configuration (see ``settings`` singleton).

The module-level :data:`settings` object is the process-wide singleton loaded at
import time. For temporary field overrides in tests or integration scenarios,
use :func:`override_settings` — it mutates that instance in place so all
``from paperscout.config import settings`` importers observe the change.

``_PAPERSCOUT_TESTING=1`` is an import-time bypass for Slack credential
validation during pytest collection; it is not a substitute for per-test field
overrides (use :func:`override_settings` for those).
"""

from __future__ import annotations

import asyncio
import os
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple
from typing import Any, NamedTuple

from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
Expand Down Expand Up @@ -209,4 +224,78 @@ def _require_slack_credentials_unless_testing(self) -> Settings:
}


@dataclass(frozen=True, slots=True)
class _OverrideFrame:
snapshot: dict[str, Any]
owner: tuple[int, int | None] # (thread_id, asyncio task id or None)


_override_stack: list[_OverrideFrame] = []


def _execution_context() -> tuple[int, int | None]:
"""Identify the current thread and asyncio task (if any)."""
thread_id = threading.get_ident()
try:
task = asyncio.current_task()
except RuntimeError:
task_id = None
else:
task_id = id(task) if task is not None else None
return (thread_id, task_id)


def _restore_snapshot(snapshot: dict[str, Any]) -> None:
for key, value in snapshot.items():
setattr(settings, key, value)


def _apply_validated_overrides(snapshot: dict[str, Any], **kwargs: Any) -> None:
"""Merge *kwargs* into *snapshot*, validate, and apply to ``settings`` in place."""
validated = Settings.model_validate({**snapshot, **kwargs})
for key in kwargs:
setattr(settings, key, getattr(validated, key))


@contextmanager
def override_settings(**kwargs: Any) -> Iterator[Settings]:
"""Temporarily override fields on the module ``settings`` singleton.

Supported mechanism for tests and integration scenarios. Mutates the
existing instance in place so ``from paperscout.config import settings``
importers observe changes. Nested calls on the **same thread/task** are
supported (LIFO); concurrent use from other threads or asyncio tasks raises
:exc:`RuntimeError` on exit rather than restoring out of order.

Objects that copied setting values at construction time (e.g. queue size
captured in ``__init__``) are not updated — only live reads of ``settings``
or references to the singleton itself see the override.
"""
unknown = set(kwargs) - set(Settings.model_fields)
if unknown:
raise TypeError(f"Unknown settings field(s): {sorted(unknown)}")

frame = _OverrideFrame(snapshot=settings.model_dump(), owner=_execution_context())
_override_stack.append(frame)
try:
_apply_validated_overrides(frame.snapshot, **kwargs)
yield settings
Comment thread
coderabbitai[bot] marked this conversation as resolved.
finally:
current = _execution_context()
if frame.owner != current:
raise RuntimeError(
"override_settings: owner context changed during override "
f"(entered as {frame.owner!r}, exiting as {current!r})"
)
if not _override_stack or _override_stack[-1] is not frame:
if frame in _override_stack:
_override_stack.remove(frame)
raise RuntimeError(
"override_settings: non-LIFO or concurrent override detected; "
"nested overrides must exit in reverse order on the same thread/task"
)
_override_stack.pop()
_restore_snapshot(frame.snapshot)


settings = Settings()
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import os

# Tests import ``paperscout.config`` which instantiates ``Settings()`` at module load; keep Slack placeholders.
# Import-time only: ``Settings()`` loads at module import and skips Slack validation
# when ``_PAPERSCOUT_TESTING=1``. Per-test field overrides use
# ``paperscout.config.override_settings``; ``make_test_settings()`` is for explicit
# ``cfg=`` injection into constructors.
os.environ.setdefault("_PAPERSCOUT_TESTING", "1")
os.environ.setdefault("SLACK_BOT_TOKEN", "xoxb-test")
os.environ.setdefault("SLACK_SIGNING_SECRET", "test-secret")
Expand Down
105 changes: 103 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
from __future__ import annotations

import pytest

from paperscout.config import ENV_VAR_MAP, Settings, legacy_env_name, prefixed_env_name
from pydantic import ValidationError

from paperscout.config import (
ENV_VAR_MAP,
Settings,
legacy_env_name,
override_settings,
prefixed_env_name,
settings,
)
from paperscout.errors import ConfigurationError


Expand Down Expand Up @@ -91,3 +99,96 @@ def test_legacy_keys_in_dotenv_file_load(tmp_path, monkeypatch):
env_file.write_text("POLL_INTERVAL_MINUTES=66\n")
s = Settings(_env_file=env_file)
assert s.poll_interval_minutes == 66


class TestOverrideSettings:
def test_override_settings_applies_within_block(self):
original = settings.poll_interval_minutes
with override_settings(poll_interval_minutes=99):
assert settings.poll_interval_minutes == 99
assert settings.poll_interval_minutes == original

def test_override_settings_restores_after_block(self):
original = settings.wg21_index_timeout_s
with override_settings(wg21_index_timeout_s=42.0):
assert settings.wg21_index_timeout_s == 42.0
assert settings.wg21_index_timeout_s == original

def test_override_settings_restores_on_exception(self):
original = settings.poll_interval_minutes
with pytest.raises(RuntimeError, match="boom"):
with override_settings(poll_interval_minutes=77):
assert settings.poll_interval_minutes == 77
raise RuntimeError("boom")
assert settings.poll_interval_minutes == original

def test_override_settings_nested(self):
original = settings.poll_interval_minutes
with override_settings(poll_interval_minutes=10):
assert settings.poll_interval_minutes == 10
with override_settings(poll_interval_minutes=20):
assert settings.poll_interval_minutes == 20
assert settings.poll_interval_minutes == 10
assert settings.poll_interval_minutes == original

def test_override_settings_rejects_unknown_field(self):
with pytest.raises(TypeError, match="Unknown settings field"):
with override_settings(not_a_real_field=1):
pass

def test_override_settings_rejects_invalid_value(self):
with pytest.raises(ValidationError):
with override_settings(wg21_index_timeout_s=-1):
pass

def test_importers_see_override(self):
from paperscout.config import settings as imported_settings

assert imported_settings is settings
original = settings.enable_iso_probe
with override_settings(enable_iso_probe=not original):
assert imported_settings.enable_iso_probe is not original
assert imported_settings.enable_iso_probe == original

def test_override_settings_rejects_concurrent_threads(self):
import threading

import paperscout.config as config_module

a_inside = threading.Event()
b_inside = threading.Event()
a_may_exit = threading.Event()
errors: list[RuntimeError] = []
original = settings.poll_interval_minutes

def thread_a() -> None:
try:
with override_settings(poll_interval_minutes=111):
a_inside.set()
b_inside.wait(timeout=5)
except RuntimeError as exc:
errors.append(exc)
finally:
a_may_exit.set()

def thread_b() -> None:
with override_settings(poll_interval_minutes=222):
b_inside.set()
a_may_exit.wait(timeout=5)

try:
t_a = threading.Thread(target=thread_a)
t_b = threading.Thread(target=thread_b)
t_a.start()
assert a_inside.wait(timeout=5)
t_b.start()
assert b_inside.wait(timeout=5)
t_a.join(timeout=5)
t_b.join(timeout=5)
assert not t_a.is_alive()
assert not t_b.is_alive()
assert errors, "expected concurrent override exit to raise RuntimeError"
assert any("non-LIFO or concurrent" in str(exc) for exc in errors)
finally:
config_module._override_stack.clear()
settings.poll_interval_minutes = original
Comment thread
coderabbitai[bot] marked this conversation as resolved.
17 changes: 9 additions & 8 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import httpx
import pytest

from paperscout.config import override_settings
from paperscout.models import CycleStatus, Paper
from paperscout.sources import (
ISOProber,
Expand Down Expand Up @@ -176,16 +177,16 @@ async def test_download_http_status_500_emits_network(self, fake_pool, caplog):
assert "failure_category=NETWORK" in caplog.text

async def test_download_uses_wg21_index_timeout_from_settings(self, fake_pool):
cfg = make_test_settings(wg21_index_timeout_s=42.0)
index = WG21Index(fake_pool, cfg=cfg)
index = WG21Index(fake_pool)
mock_resp = _make_response(200, json_data=SAMPLE_INDEX_DATA)
mock_client = _make_async_client(get_resp=mock_resp)
with patch("paperscout.sources.httpx.AsyncClient") as mock_cls:
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await index._download()
mock_cls.assert_called_once()
arg_timeout = mock_cls.call_args.kwargs["timeout"]
with override_settings(wg21_index_timeout_s=42.0):
with patch("paperscout.sources.httpx.AsyncClient") as mock_cls:
mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await index._download()
mock_cls.assert_called_once()
arg_timeout = mock_cls.call_args.kwargs["timeout"]
assert isinstance(arg_timeout, httpx.Timeout)
assert arg_timeout == httpx.Timeout(42.0)

Expand Down