diff --git a/src/paperscout/config.py b/src/paperscout/config.py index eb45f61..3b825df 100644 --- a/src/paperscout/config.py +++ b/src/paperscout/config.py @@ -1,10 +1,25 @@ -"""Environment-backed runtime configuration (see ``settings`` singleton).""" +"""Environment-backed runtime configuration (see ``settings`` singleton). + +The module-level :data:`settings` object is the process-wide singleton loaded at +import time. For temporary field overrides in tests or integration scenarios, +use :func:`override_settings` — it mutates that instance in place so all +``from paperscout.config import settings`` importers observe the change. + +``_PAPERSCOUT_TESTING=1`` is an import-time bypass for Slack credential +validation during pytest collection; it is not a substitute for per-test field +overrides (use :func:`override_settings` for those). +""" from __future__ import annotations +import asyncio import os +import threading +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path -from typing import NamedTuple +from typing import Any, NamedTuple from pydantic import Field, model_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict @@ -209,4 +224,78 @@ def _require_slack_credentials_unless_testing(self) -> Settings: } +@dataclass(frozen=True, slots=True) +class _OverrideFrame: + snapshot: dict[str, Any] + owner: tuple[int, int | None] # (thread_id, asyncio task id or None) + + +_override_stack: list[_OverrideFrame] = [] + + +def _execution_context() -> tuple[int, int | None]: + """Identify the current thread and asyncio task (if any).""" + thread_id = threading.get_ident() + try: + task = asyncio.current_task() + except RuntimeError: + task_id = None + else: + task_id = id(task) if task is not None else None + return (thread_id, task_id) + + +def _restore_snapshot(snapshot: dict[str, Any]) -> None: + for key, value in snapshot.items(): + setattr(settings, key, value) + + +def _apply_validated_overrides(snapshot: dict[str, Any], **kwargs: Any) -> None: + """Merge *kwargs* into *snapshot*, validate, and apply to ``settings`` in place.""" + validated = Settings.model_validate({**snapshot, **kwargs}) + for key in kwargs: + setattr(settings, key, getattr(validated, key)) + + +@contextmanager +def override_settings(**kwargs: Any) -> Iterator[Settings]: + """Temporarily override fields on the module ``settings`` singleton. + + Supported mechanism for tests and integration scenarios. Mutates the + existing instance in place so ``from paperscout.config import settings`` + importers observe changes. Nested calls on the **same thread/task** are + supported (LIFO); concurrent use from other threads or asyncio tasks raises + :exc:`RuntimeError` on exit rather than restoring out of order. + + Objects that copied setting values at construction time (e.g. queue size + captured in ``__init__``) are not updated — only live reads of ``settings`` + or references to the singleton itself see the override. + """ + unknown = set(kwargs) - set(Settings.model_fields) + if unknown: + raise TypeError(f"Unknown settings field(s): {sorted(unknown)}") + + frame = _OverrideFrame(snapshot=settings.model_dump(), owner=_execution_context()) + _override_stack.append(frame) + try: + _apply_validated_overrides(frame.snapshot, **kwargs) + yield settings + finally: + current = _execution_context() + if frame.owner != current: + raise RuntimeError( + "override_settings: owner context changed during override " + f"(entered as {frame.owner!r}, exiting as {current!r})" + ) + if not _override_stack or _override_stack[-1] is not frame: + if frame in _override_stack: + _override_stack.remove(frame) + raise RuntimeError( + "override_settings: non-LIFO or concurrent override detected; " + "nested overrides must exit in reverse order on the same thread/task" + ) + _override_stack.pop() + _restore_snapshot(frame.snapshot) + + settings = Settings() diff --git a/tests/conftest.py b/tests/conftest.py index dca2179..13249fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,10 @@ import os -# Tests import ``paperscout.config`` which instantiates ``Settings()`` at module load; keep Slack placeholders. +# Import-time only: ``Settings()`` loads at module import and skips Slack validation +# when ``_PAPERSCOUT_TESTING=1``. Per-test field overrides use +# ``paperscout.config.override_settings``; ``make_test_settings()`` is for explicit +# ``cfg=`` injection into constructors. os.environ.setdefault("_PAPERSCOUT_TESTING", "1") os.environ.setdefault("SLACK_BOT_TOKEN", "xoxb-test") os.environ.setdefault("SLACK_SIGNING_SECRET", "test-secret") diff --git a/tests/test_config.py b/tests/test_config.py index 7d5e277..67ceacd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,8 +3,16 @@ from __future__ import annotations import pytest - -from paperscout.config import ENV_VAR_MAP, Settings, legacy_env_name, prefixed_env_name +from pydantic import ValidationError + +from paperscout.config import ( + ENV_VAR_MAP, + Settings, + legacy_env_name, + override_settings, + prefixed_env_name, + settings, +) from paperscout.errors import ConfigurationError @@ -91,3 +99,96 @@ def test_legacy_keys_in_dotenv_file_load(tmp_path, monkeypatch): env_file.write_text("POLL_INTERVAL_MINUTES=66\n") s = Settings(_env_file=env_file) assert s.poll_interval_minutes == 66 + + +class TestOverrideSettings: + def test_override_settings_applies_within_block(self): + original = settings.poll_interval_minutes + with override_settings(poll_interval_minutes=99): + assert settings.poll_interval_minutes == 99 + assert settings.poll_interval_minutes == original + + def test_override_settings_restores_after_block(self): + original = settings.wg21_index_timeout_s + with override_settings(wg21_index_timeout_s=42.0): + assert settings.wg21_index_timeout_s == 42.0 + assert settings.wg21_index_timeout_s == original + + def test_override_settings_restores_on_exception(self): + original = settings.poll_interval_minutes + with pytest.raises(RuntimeError, match="boom"): + with override_settings(poll_interval_minutes=77): + assert settings.poll_interval_minutes == 77 + raise RuntimeError("boom") + assert settings.poll_interval_minutes == original + + def test_override_settings_nested(self): + original = settings.poll_interval_minutes + with override_settings(poll_interval_minutes=10): + assert settings.poll_interval_minutes == 10 + with override_settings(poll_interval_minutes=20): + assert settings.poll_interval_minutes == 20 + assert settings.poll_interval_minutes == 10 + assert settings.poll_interval_minutes == original + + def test_override_settings_rejects_unknown_field(self): + with pytest.raises(TypeError, match="Unknown settings field"): + with override_settings(not_a_real_field=1): + pass + + def test_override_settings_rejects_invalid_value(self): + with pytest.raises(ValidationError): + with override_settings(wg21_index_timeout_s=-1): + pass + + def test_importers_see_override(self): + from paperscout.config import settings as imported_settings + + assert imported_settings is settings + original = settings.enable_iso_probe + with override_settings(enable_iso_probe=not original): + assert imported_settings.enable_iso_probe is not original + assert imported_settings.enable_iso_probe == original + + def test_override_settings_rejects_concurrent_threads(self): + import threading + + import paperscout.config as config_module + + a_inside = threading.Event() + b_inside = threading.Event() + a_may_exit = threading.Event() + errors: list[RuntimeError] = [] + original = settings.poll_interval_minutes + + def thread_a() -> None: + try: + with override_settings(poll_interval_minutes=111): + a_inside.set() + b_inside.wait(timeout=5) + except RuntimeError as exc: + errors.append(exc) + finally: + a_may_exit.set() + + def thread_b() -> None: + with override_settings(poll_interval_minutes=222): + b_inside.set() + a_may_exit.wait(timeout=5) + + try: + t_a = threading.Thread(target=thread_a) + t_b = threading.Thread(target=thread_b) + t_a.start() + assert a_inside.wait(timeout=5) + t_b.start() + assert b_inside.wait(timeout=5) + t_a.join(timeout=5) + t_b.join(timeout=5) + assert not t_a.is_alive() + assert not t_b.is_alive() + assert errors, "expected concurrent override exit to raise RuntimeError" + assert any("non-LIFO or concurrent" in str(exc) for exc in errors) + finally: + config_module._override_stack.clear() + settings.poll_interval_minutes = original diff --git a/tests/test_sources.py b/tests/test_sources.py index 6c1fe09..b8a714e 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -11,6 +11,7 @@ import httpx import pytest +from paperscout.config import override_settings from paperscout.models import CycleStatus, Paper from paperscout.sources import ( ISOProber, @@ -176,16 +177,16 @@ async def test_download_http_status_500_emits_network(self, fake_pool, caplog): assert "failure_category=NETWORK" in caplog.text async def test_download_uses_wg21_index_timeout_from_settings(self, fake_pool): - cfg = make_test_settings(wg21_index_timeout_s=42.0) - index = WG21Index(fake_pool, cfg=cfg) + index = WG21Index(fake_pool) mock_resp = _make_response(200, json_data=SAMPLE_INDEX_DATA) mock_client = _make_async_client(get_resp=mock_resp) - with patch("paperscout.sources.httpx.AsyncClient") as mock_cls: - mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) - mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) - await index._download() - mock_cls.assert_called_once() - arg_timeout = mock_cls.call_args.kwargs["timeout"] + with override_settings(wg21_index_timeout_s=42.0): + with patch("paperscout.sources.httpx.AsyncClient") as mock_cls: + mock_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_cls.return_value.__aexit__ = AsyncMock(return_value=False) + await index._download() + mock_cls.assert_called_once() + arg_timeout = mock_cls.call_args.kwargs["timeout"] assert isinstance(arg_timeout, httpx.Timeout) assert arg_timeout == httpx.Timeout(42.0)