diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 04a689e4..0e03f2a3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,8 @@ on: - 'src/backend-api/**/*.py' - 'src/backend-api/pyproject.toml' - 'src/backend-api/pytest.ini' + - 'src/processor/**/*.py' + - 'src/processor/pyproject.toml' - '.github/workflows/test.yml' pull_request: types: @@ -23,6 +25,8 @@ on: - 'src/backend-api/**/*.py' - 'src/backend-api/pyproject.toml' - 'src/backend-api/pytest.ini' + - 'src/processor/**/*.py' + - 'src/processor/pyproject.toml' - '.github/workflows/test.yml' permissions: @@ -48,7 +52,7 @@ jobs: python -m pip install --upgrade pip cd src/backend-api pip install -e . - pip install pytest pytest-cov + pip install pytest pytest-cov pytest-asyncio - name: Check if Backend Test Files Exist id: check_backend_tests @@ -71,9 +75,26 @@ jobs: --cov=src/app \ --cov-report=term-missing \ --cov-report=xml:reports/coverage.xml \ + --cov-fail-under=82 \ --junitxml=pytest.xml \ -v + - name: Prefix coverage XML filenames with repo-root path + if: env.skip_backend_tests == 'false' + run: | + python <<'PY' + import xml.etree.ElementTree as ET + path = "src/backend-api/reports/coverage.xml" + prefix = "src/backend-api/src/app/" + tree = ET.parse(path) + root = tree.getroot() + for cls in root.iter("class"): + fname = cls.attrib.get("filename", "") + if fname and not fname.startswith(prefix): + cls.attrib["filename"] = prefix + fname + tree.write(path, xml_declaration=True, encoding="utf-8") + PY + - name: Pytest Coverage Comment if: | always() && @@ -90,3 +111,80 @@ jobs: if: env.skip_backend_tests == 'true' run: | echo "Skipping backend tests because no test files were found." + + processor_tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install Processor Dependencies + run: | + python -m pip install --upgrade pip + cd src/processor + pip install -e . + pip install pytest pytest-cov pytest-asyncio + + - name: Check if Processor Test Files Exist + id: check_processor_tests + run: | + if [ -z "$(find src/processor/src/tests -type f -name 'test_*.py' 2>/dev/null)" ]; then + echo "No processor test files found, skipping processor tests." + echo "skip_processor_tests=true" >> $GITHUB_ENV + else + echo "Processor test files found, running tests." + echo "skip_processor_tests=false" >> $GITHUB_ENV + fi + + - name: Run Processor Tests with Coverage + if: env.skip_processor_tests == 'false' + run: | + cd src/processor + pytest src/tests \ + --cov=src \ + --cov-report=term-missing \ + --cov-report=xml:reports/coverage.xml \ + --cov-fail-under=82 \ + --junitxml=pytest.xml \ + -v + + - name: Prefix coverage XML filenames with repo-root path + if: env.skip_processor_tests == 'false' + run: | + python <<'PY' + import xml.etree.ElementTree as ET + path = "src/processor/reports/coverage.xml" + prefix = "src/processor/src/" + tree = ET.parse(path) + root = tree.getroot() + for cls in root.iter("class"): + fname = cls.attrib.get("filename", "") + if fname and not fname.startswith(prefix): + cls.attrib["filename"] = prefix + fname + tree.write(path, xml_declaration=True, encoding="utf-8") + PY + + - name: Pytest Coverage Comment (Processor) + if: | + always() && + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.fork == false && + env.skip_processor_tests == 'false' + uses: MishaKav/pytest-coverage-comment@26f986d2599c288bb62f623d29c2da98609e9cd4 # v1.6.0 + with: + pytest-xml-coverage-path: src/processor/reports/coverage.xml + junitxml-path: src/processor/pytest.xml + title: Processor Coverage Report + unique-id-for-comment: processor + report-only-changed-files: true + + - name: Skip Processor Tests + if: env.skip_processor_tests == 'true' + run: | + echo "Skipping processor tests because no test files were found." diff --git a/src/backend-api/pyproject.toml b/src/backend-api/pyproject.toml index 6da77728..ba23b7eb 100644 --- a/src/backend-api/pyproject.toml +++ b/src/backend-api/pyproject.toml @@ -25,7 +25,10 @@ dependencies = [ ] [dependency-groups] -dev = ["pytest>=9.0.3", "pytest-cov>=6.2.1"] +dev = ["pytest>=9.0.3", "pytest-cov>=6.2.1", "pytest-asyncio>=0.23.0"] + +[tool.coverage.run] +omit = ["src/tests/*"] [tool.uv] override-dependencies = [ diff --git a/src/backend-api/src/tests/application/test_application.py b/src/backend-api/src/tests/application/test_application.py new file mode 100644 index 00000000..24c85106 --- /dev/null +++ b/src/backend-api/src/tests/application/test_application.py @@ -0,0 +1,38 @@ +"""Tests for application.Application bootstrap.""" + +from application import Application +from libs.base.typed_fastapi import TypedFastAPI +from libs.services.interfaces import IDataService, IHttpService, ILoggerService +from libs.services.process_services import ProcessService + + +def test_application_initializes_typed_fastapi(): + app = Application() + assert isinstance(app.app, TypedFastAPI) + assert app.app.title == "FastAPI Application" + assert app.app.version == "1.0.0" + + +def test_application_sets_app_context_on_app(): + app = Application() + assert app.app.app_context is app.application_context + + +def test_application_registers_core_services(): + app = Application() + ctx = app.application_context + assert ctx.get_service(ILoggerService) is not None + assert ctx.get_service(IHttpService) is not None + assert ctx.get_service(IDataService) is not None + assert ctx.get_service(ProcessService) is not None + + +def test_application_includes_routers(): + app = Application() + paths = {route.path for route in app.app.routes} + # router_files + assert "/api/file/upload" in paths + # router_process + assert "/api/process/create" in paths + # http_probes + assert "/health" in paths diff --git a/src/backend-api/src/tests/application/test_application_context_extra.py b/src/backend-api/src/tests/application/test_application_context_extra.py new file mode 100644 index 00000000..a540acec --- /dev/null +++ b/src/backend-api/src/tests/application/test_application_context_extra.py @@ -0,0 +1,233 @@ +import asyncio + +import pytest + +from libs.application.application_context import ( + AppContext, + ServiceDescriptor, + ServiceLifetime, + ServiceScope, +) + + +class _DummyService: + def __init__(self): + self.created = True + + +class _AsyncResource: + def __init__(self): + self.entered = False + self.exited = False + self.closed = False + + async def __aenter__(self): + self.entered = True + return self + + async def __aexit__(self, exc_type, exc, tb): + self.exited = True + + async def close(self): + self.closed = True + + +class TestServiceLifetimeConstants: + def test_constants_exist(self): + assert ServiceLifetime.SINGLETON == "singleton" + assert ServiceLifetime.TRANSIENT == "transient" + assert ServiceLifetime.SCOPED == "scoped" + assert ServiceLifetime.ASYNC_SINGLETON == "async_singleton" + assert ServiceLifetime.ASYNC_SCOPED == "async_scoped" + + +class TestServiceDescriptor: + def test_defaults(self): + d = ServiceDescriptor( + service_type=_DummyService, + implementation=_DummyService, + lifetime=ServiceLifetime.SINGLETON, + ) + assert d.is_async is False + assert d.cleanup_method == "close" + assert d.instance is None + + def test_custom_cleanup_method(self): + d = ServiceDescriptor( + service_type=_DummyService, + implementation=_DummyService, + lifetime=ServiceLifetime.ASYNC_SINGLETON, + is_async=True, + cleanup_method="dispose", + ) + assert d.cleanup_method == "dispose" + assert d.is_async is True + + +class TestScopedServices: + def test_scoped_service_requires_active_scope(self): + ctx = AppContext() + ctx.add_scoped(_DummyService, _DummyService) + with pytest.raises(ValueError, match="requires an active scope"): + ctx.get_service(_DummyService) + + def test_scoped_service_returns_same_instance_within_scope(self): + ctx = AppContext() + ctx.add_scoped(_DummyService, _DummyService) + + async def run(): + async with ctx.create_scope() as scope: + a = scope.get_service(_DummyService) + b = scope.get_service(_DummyService) + assert a is b + + asyncio.run(run()) + + def test_scoped_service_returns_different_instances_in_separate_scopes(self): + ctx = AppContext() + ctx.add_scoped(_DummyService, _DummyService) + + async def run(): + async with ctx.create_scope() as scope1: + a = scope1.get_service(_DummyService) + async with ctx.create_scope() as scope2: + b = scope2.get_service(_DummyService) + assert a is not b + + asyncio.run(run()) + + +class TestAsyncSingleton: + def test_async_singleton_returns_same_instance(self): + ctx = AppContext() + ctx.add_async_singleton(_AsyncResource, _AsyncResource) + + async def run(): + a = await ctx.get_service_async(_AsyncResource) + b = await ctx.get_service_async(_AsyncResource) + assert a is b + assert a.entered is True + + asyncio.run(run()) + + def test_get_service_async_raises_for_unregistered(self): + ctx = AppContext() + + async def run(): + with pytest.raises(KeyError): + await ctx.get_service_async(_DummyService) + + asyncio.run(run()) + + def test_get_service_async_raises_for_non_async_service(self): + ctx = AppContext() + ctx.add_singleton(_DummyService, _DummyService) + + async def run(): + with pytest.raises(ValueError, match="not registered as an async service"): + await ctx.get_service_async(_DummyService) + + asyncio.run(run()) + + +class TestAsyncScoped: + def test_async_scoped_requires_active_scope(self): + ctx = AppContext() + ctx.add_async_scoped(_AsyncResource, _AsyncResource) + + async def run(): + with pytest.raises(ValueError, match="requires an active scope"): + await ctx.get_service_async(_AsyncResource) + + asyncio.run(run()) + + def test_async_scoped_same_instance_in_scope(self): + ctx = AppContext() + ctx.add_async_scoped(_AsyncResource, _AsyncResource) + + async def run(): + async with ctx.create_scope() as scope: + a = await scope.get_service_async(_AsyncResource) + b = await scope.get_service_async(_AsyncResource) + assert a is b + # After scope exit, __aexit__ should be called + assert a.exited is True + + asyncio.run(run()) + + +class TestCreateInstance: + def test_create_instance_supports_pre_created_instance(self): + ctx = AppContext() + existing = _DummyService() + ctx.add_singleton(_DummyService, existing) + assert ctx.get_service(_DummyService) is existing + + def test_create_instance_supports_callable(self): + ctx = AppContext() + ctx.add_transient(_DummyService, _DummyService) + a = ctx.get_service(_DummyService) + b = ctx.get_service(_DummyService) + assert a is not b + assert isinstance(a, _DummyService) + + +class TestShutdownAsync: + def test_shutdown_calls_cleanup_method(self): + ctx = AppContext() + ctx.add_async_singleton(_AsyncResource, _AsyncResource) + + async def run(): + instance = await ctx.get_service_async(_AsyncResource) + await ctx.shutdown_async() + return instance + + instance = asyncio.run(run()) + # After shutdown, internal caches are cleared + assert ctx._instances == {} + assert ctx._scoped_instances == {} + assert instance.closed is True + + def test_shutdown_with_no_services_is_noop(self): + ctx = AppContext() + asyncio.run(ctx.shutdown_async()) # should not raise + + +class TestCreateAsyncInstance: + def test_async_factory_returning_coroutine(self): + ctx = AppContext() + + async def factory(): + return _AsyncResource() + + ctx.add_async_singleton(_AsyncResource, factory) + + async def run(): + instance = await ctx.get_service_async(_AsyncResource) + assert isinstance(instance, _AsyncResource) + assert instance.entered is True + + asyncio.run(run()) + + def test_async_instance_passthrough(self): + ctx = AppContext() + existing = _AsyncResource() + # Pre-created instance (not callable, not a class) + ctx.add_async_singleton(_AsyncResource, existing) + + async def run(): + instance = await ctx.get_service_async(_AsyncResource) + assert instance is existing + + asyncio.run(run()) + + +class TestServiceScope: + def test_scope_restores_previous_scope_id(self): + ctx = AppContext() + ctx._current_scope_id = "outer" + + ctx.add_singleton(_DummyService, _DummyService) + scope = ServiceScope(ctx, "inner") + scope.get_service(_DummyService) + assert ctx._current_scope_id == "outer" diff --git a/src/backend-api/src/tests/azure/test_app_configuration_helper.py b/src/backend-api/src/tests/azure/test_app_configuration_helper.py new file mode 100644 index 00000000..d8696ec9 --- /dev/null +++ b/src/backend-api/src/tests/azure/test_app_configuration_helper.py @@ -0,0 +1,75 @@ +import os +from unittest.mock import MagicMock, patch + +import pytest + +from libs.azure.app_configuration import AppConfigurationHelper + + +def _patch_client(): + return patch("libs.azure.app_configuration.AzureAppConfigurationClient") + + +class TestInitialization: + def test_uses_provided_credential(self): + cred = MagicMock() + with _patch_client() as MockClient: + helper = AppConfigurationHelper( + "https://example.azconfig.io", credential=cred + ) + MockClient.assert_called_once_with("https://example.azconfig.io", cred) + assert helper.credential is cred + assert helper.app_config_endpoint == "https://example.azconfig.io" + assert helper.app_config_client is MockClient.return_value + + def test_raises_value_error_when_endpoint_is_none(self): + with pytest.raises(ValueError, match="App Configuration Endpoint is not set"): + AppConfigurationHelper(None, credential=MagicMock()) + + def test_creates_default_credential_when_none_provided(self): + with patch( + "libs.azure.app_configuration.DefaultAzureCredential" + ) as MockCred, _patch_client(): + MockCred.return_value = MagicMock() + helper = AppConfigurationHelper("https://example.azconfig.io") + MockCred.assert_called_once() + assert helper.credential is MockCred.return_value + + +class TestReadAndSetEnvironmentalVariables: + def test_sets_environment_variables_from_settings(self): + with _patch_client(): + helper = AppConfigurationHelper( + "https://example.azconfig.io", credential=MagicMock() + ) + + item1 = MagicMock() + item1.key = "TEST_KEY_ONE" + item1.value = "value-one" + item2 = MagicMock() + item2.key = "TEST_KEY_TWO" + item2.value = "value-two" + + helper.app_config_client = MagicMock() + helper.app_config_client.list_configuration_settings.return_value = iter( + [item1, item2] + ) + + try: + result = helper.read_and_set_environmental_variables() + assert os.environ["TEST_KEY_ONE"] == "value-one" + assert os.environ["TEST_KEY_TWO"] == "value-two" + assert result is os.environ + finally: + os.environ.pop("TEST_KEY_ONE", None) + os.environ.pop("TEST_KEY_TWO", None) + + def test_read_configuration_delegates_to_client(self): + with _patch_client(): + helper = AppConfigurationHelper( + "https://example.azconfig.io", credential=MagicMock() + ) + helper.app_config_client = MagicMock() + helper.app_config_client.list_configuration_settings.return_value = "settings" + assert helper.read_configuration() == "settings" + helper.app_config_client.list_configuration_settings.assert_called_once() diff --git a/src/backend-api/src/tests/base/test_kernel_agent.py b/src/backend-api/src/tests/base/test_kernel_agent.py new file mode 100644 index 00000000..b5edf2f9 --- /dev/null +++ b/src/backend-api/src/tests/base/test_kernel_agent.py @@ -0,0 +1,260 @@ +"""Tests for libs/base/kernel_agent.py.""" + +import importlib +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import BaseModel, ValidationError +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError + +# The source module imports SKBaseModel from libs.base.SKBase, but that module +# is empty in the repository. Inject a minimal stand-in before importing +# kernel_agent so tests can exercise the file without touching source. +import libs.base.SKBase as _skbase_mod # noqa: E402 + +if not hasattr(_skbase_mod, "SKBaseModel"): + + class _SKBaseModelStub(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + _skbase_mod.SKBaseModel = _SKBaseModelStub # type: ignore[attr-defined] + +importlib.import_module("libs.base.kernel_agent") # noqa: E402 + + +@pytest.fixture +def patched(): + with ( + patch("libs.base.kernel_agent.Kernel") as kernel_cls, + patch("libs.base.kernel_agent.Configuration") as cfg_cls, + patch("libs.base.kernel_agent.AzureChatCompletion") as chat_cls, + patch("libs.base.kernel_agent.AzureTextCompletion") as text_cls, + ): + kernel = MagicMock() + kernel.plugins = {} + kernel.services = {} + kernel_cls.return_value = kernel + cfg = SimpleNamespace(global_llm_service="AzureOpenAI", env_file_path=None) + cfg_cls.return_value = cfg + yield { + "kernel_cls": kernel_cls, + "kernel": kernel, + "cfg_cls": cfg_cls, + "cfg": cfg, + "chat_cls": chat_cls, + "text_cls": text_cls, + } + + +class TestInit: + def test_init_sets_kernel_and_settings(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + assert a.kernel is patched["kernel"] + assert a._settings is patched["cfg"] + + def test_init_with_env_file_path(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + semantic_kernel_agent(env_file_path="some/path.env") + patched["cfg_cls"].assert_called_with(env_file_path="some/path.env") + + def test_init_default_global_llm_service(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + patched["cfg"].global_llm_service = None + a = semantic_kernel_agent() + assert a._settings.global_llm_service == "AzureOpenAI" + + def test_init_validation_error_wraps(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + # Construct a real ValidationError via a Pydantic model + try: + from pydantic import BaseModel + + class _M(BaseModel): + x: int + + _M(x="not-int") + except ValidationError as ve: + patched["cfg_cls"].side_effect = ve + with pytest.raises(ServiceInitializationError): + semantic_kernel_agent() + + +class TestPlugins: + def test_get_plugin_present(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {"p": MagicMock(name="plug")} + a.kernel.get_plugin = MagicMock(return_value="plug-obj") + assert a.get_plugin("p") == "plug-obj" + + def test_get_plugin_missing(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {} + assert a.get_plugin("nope") is None + + def test_add_plugin_when_present_returns_existing(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {"p": "x"} + a.kernel.get_plugin = MagicMock(return_value="existing") + result = a.add_plugin(plugin=MagicMock(), plugin_name="p") + assert result == "existing" + a.kernel.add_plugin.assert_not_called() + + def test_add_plugin_when_absent_adds(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {} + a.kernel.get_plugin = MagicMock(return_value="newly") + plug = MagicMock() + result = a.add_plugin(plugin=plug, plugin_name="p") + a.kernel.add_plugin.assert_called_once_with(plugin=plug, plugin_name="p") + assert result == "newly" + + def test_add_plugin_from_directory_present(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {"p": MagicMock()} + a.kernel.get_plugin = MagicMock(return_value="existing") + assert a.add_plugin_from_directory("/dir", "p") == "existing" + a.kernel.add_plugin.assert_not_called() + + def test_add_plugin_from_directory_absent(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {} + a.kernel.get_plugin = MagicMock(return_value="newly") + result = a.add_plugin_from_directory("/dir", "p") + a.kernel.add_plugin.assert_called_once_with( + parent_directory="/dir", plugin_name="p" + ) + assert result == "newly" + + +class TestFunctions: + def test_get_function_no_plugin(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.plugins = {} + assert a.get_function("p", "f") is None + + def test_get_function_present(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + plug = MagicMock() + plug.functions = {"f": "func-obj"} + a.kernel.plugins = {"p": plug} + a.kernel.get_plugin = MagicMock(return_value=plug) + assert a.get_function("p", "f") == "func-obj" + + def test_get_function_function_missing(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + plug = MagicMock() + plug.functions = {} + a.kernel.plugins = {"p": plug} + a.kernel.get_plugin = MagicMock(return_value=plug) + assert a.get_function("p", "f") is None + + def test_add_function_existing(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + plug = MagicMock() + plug.functions = {"f": "existing-func"} + a.kernel.plugins = {"p": plug} + a.kernel.get_plugin = MagicMock(return_value=plug) + result = a.add_function(plugin_name="p", function_name="f", function=MagicMock()) + assert result == "existing-func" + a.kernel.add_function.assert_not_called() + + def test_add_function_new(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + # First call: no plugin -> add_plugin path; then get_function returns None; then add + a.kernel.plugins = {} + + # Track plugin presence dynamically + state = {"plugin": None} + + def get_plugin_side(name): + return state["plugin"] + + a.kernel.get_plugin = MagicMock(side_effect=get_plugin_side) + + def add_plugin_side(plugin, plugin_name): + plug = MagicMock() + plug.functions = {} + state["plugin"] = plug + a.kernel.plugins[plugin_name] = plug + + a.kernel.add_plugin.side_effect = add_plugin_side + a.kernel.get_function = MagicMock(return_value="new-func") + + fn = MagicMock() + fn.name = "f" + result = a.add_function(plugin_name="p", function=fn) + a.kernel.add_function.assert_called_once() + assert result == "new-func" + + +class TestGetKernel: + def test_get_kernel_chat_adds_service(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent, service_type + + a = semantic_kernel_agent() + a.kernel.services = {} + result = a.get_kernel(service_id="default", service_type=service_type.Chat_Completion) + a.kernel.add_service.assert_called_once() + assert result is a.kernel + + def test_get_kernel_already_present(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.services = {"default": object()} + result = a.get_kernel(service_id="default") + a.kernel.add_service.assert_not_called() + assert result is a.kernel + + def test_get_kernel_non_azure_raises(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a._settings.global_llm_service = "OpenAI" + with pytest.raises(ServiceInitializationError): + a.get_kernel() + + def test_get_prompt_execution_settings(self, patched): + from libs.base.kernel_agent import semantic_kernel_agent + + a = semantic_kernel_agent() + a.kernel.get_prompt_execution_settings_from_service_id = MagicMock( + return_value="settings" + ) + assert a.get_prompt_execution_settings_from_service_id("svc") == "settings" + + +class TestServiceTypeEnum: + def test_enum_values(self, patched): + from libs.base.kernel_agent import service_type + + assert service_type.Chat_Completion.value == "ChatCompletion" + assert service_type.Text_Completion.value == "TextCompletion" diff --git a/src/backend-api/src/tests/base/test_sk_logic_base.py b/src/backend-api/src/tests/base/test_sk_logic_base.py new file mode 100644 index 00000000..e12e6395 --- /dev/null +++ b/src/backend-api/src/tests/base/test_sk_logic_base.py @@ -0,0 +1,197 @@ +"""Tests for libs/base/SKLogicBase.py. + +The production module imports ``SKBaseModel`` from ``libs.base.SKBase`` (which +is empty in this repo) and ``semantic_kernel_agent`` from +``libs.base.KernelAgent`` (a module that does not exist on disk; the actual +file is ``kernel_agent.py``). Stub both before importing so the module is +exercised without modifying production source. +""" + +import importlib +import sys +import types +from typing import Type +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + + +# --------------------------------------------------------------------------- +# Module-load helpers (stub SKBaseModel and the misspelled KernelAgent module) +# --------------------------------------------------------------------------- +import libs.base.SKBase as _skbase_mod # noqa: E402 + + +class _SKBaseModelStub(BaseModel): + # ``extra="allow"`` is required because SKLogicBase.__init__ sets + # attributes (``response_format``, ``system_prompt``) that are not + # declared as Pydantic fields. + model_config = { + "arbitrary_types_allowed": True, + "extra": "allow", + } + + +# Always force-set our stub so it overrides whatever a previously-loaded +# test module installed (e.g. test_kernel_agent.py uses a stricter stub +# without ``extra="allow"`` which prevents SKLogicBase from being constructed). +_skbase_mod.SKBaseModel = _SKBaseModelStub # type: ignore[attr-defined] + +# Ensure libs.base.kernel_agent has been imported (creates real +# semantic_kernel_agent symbol used by the stub below). +import libs.base.kernel_agent as _kernel_agent_mod # noqa: E402 + +if "libs.base.KernelAgent" not in sys.modules: + _stub_kernel_agent_module = types.ModuleType("libs.base.KernelAgent") + _stub_kernel_agent_module.semantic_kernel_agent = ( + _kernel_agent_mod.semantic_kernel_agent + ) + sys.modules["libs.base.KernelAgent"] = _stub_kernel_agent_module + +# Now safe to import the SUT. +sk_logic_base = importlib.import_module("libs.base.SKLogicBase") +SKLogicBase = sk_logic_base.SKLogicBase + + +class _Resp(BaseModel): + name: str = "x" + + +class _ConcreteLogic(SKLogicBase): + """A concrete subclass that satisfies abstract methods without doing real work.""" + + def _init_agent(self): # override: skip real agent setup + return None + + async def _init_agent_async(self): + return None + + async def execute_thread( # type: ignore[override] + self, user_input, response_format=None, thread=None + ): + return ("answer", thread) + + +class _BareLogic(SKLogicBase): + """Subclass that delegates _init_agent / _init_agent_async to base. + + Used to exercise the NotImplementedError branches without reaching + the abstract execute_thread. + """ + + def _init_agent(self): # call up to base to hit raise + return super()._init_agent() + + async def _init_agent_async(self): + return await super()._init_agent_async() + + async def execute_thread(self, *a, **kw): # satisfy abstractmethod + return None + + +def _make_kernel_agent_stub(): + """Build a MagicMock that behaves enough like a semantic_kernel_agent.""" + return MagicMock() + + +class TestValidateResponseFormat: + def test_returns_true_when_none(self): + assert SKLogicBase._validate_response_format(None) is True + + def test_returns_true_for_basemodel_subclass(self): + assert SKLogicBase._validate_response_format(_Resp) is True + + def test_raises_typeerror_when_not_a_class(self): + with pytest.raises(TypeError): + SKLogicBase._validate_response_format("not-a-class") + + def test_raises_typeerror_when_not_basemodel(self): + class _Plain: + pass + + with pytest.raises(TypeError): + SKLogicBase._validate_response_format(_Plain) + + +class TestConstruction: + def test_concrete_subclass_instantiates(self): + ka = _make_kernel_agent_stub() + instance = _ConcreteLogic(kernel_agent=ka) + assert instance.kernel_agent is ka + + def test_constructor_passes_through_response_format_and_prompt(self): + ka = _make_kernel_agent_stub() + instance = _ConcreteLogic( + kernel_agent=ka, + system_prompt="be helpful", + response_format=_Resp, + ) + assert instance.kernel_agent is ka + # response_format / system_prompt are set as instance attrs (not declared + # Pydantic fields) — confirm they exist via the underlying dict. + assert getattr(instance, "response_format") is _Resp + assert getattr(instance, "system_prompt") == "be helpful" + + +class TestNotImplementedBranches: + def test_init_agent_raises_in_base(self): + ka = _make_kernel_agent_stub() + with pytest.raises(NotImplementedError): + _BareLogic(kernel_agent=ka) + + @pytest.mark.asyncio + async def test_init_agent_async_raises_in_base(self): + # Build with no-op init to skip _init_agent failure, then invoke + # _init_agent_async via super() to hit the base raise. + ka = _make_kernel_agent_stub() + + class _OnlyAsyncRaises(SKLogicBase): + def _init_agent(self): # no-op so __init__ works + return None + + async def _init_agent_async(self): + return await super()._init_agent_async() + + async def execute_thread(self, *a, **kw): + return None + + instance = _OnlyAsyncRaises(kernel_agent=ka) + with pytest.raises(NotImplementedError): + await instance._init_agent_async() + + @pytest.mark.asyncio + async def test_execute_raises_not_implemented(self): + ka = _make_kernel_agent_stub() + instance = _ConcreteLogic(kernel_agent=ka) + with pytest.raises(NotImplementedError): + await instance.execute({"x": 1}) + + +class TestCreateClassMethod: + @pytest.mark.asyncio + async def test_create_calls_init_agent_async(self): + ka = _make_kernel_agent_stub() + + captured = {"called": False} + + class _Tracking(SKLogicBase): + def _init_agent(self): + return None + + async def _init_agent_async(self): + captured["called"] = True + + async def execute_thread(self, *a, **kw): + return None + + instance = await _Tracking.create(kernel_agent=ka) + assert isinstance(instance, _Tracking) + assert captured["called"] is True + + +class TestAbstractContract: + def test_cannot_instantiate_abstract_directly(self): + ka = _make_kernel_agent_stub() + with pytest.raises(TypeError): + SKLogicBase(kernel_agent=ka) diff --git a/src/backend-api/src/tests/base/test_typed_fastapi.py b/src/backend-api/src/tests/base/test_typed_fastapi.py new file mode 100644 index 00000000..75289899 --- /dev/null +++ b/src/backend-api/src/tests/base/test_typed_fastapi.py @@ -0,0 +1,46 @@ +from fastapi import FastAPI + +from libs.application.application_context import AppContext +from libs.base.fastapi_protocol import ( + FastAPIWithContext, + add_app_context_to_fastapi, +) +from libs.base.typed_fastapi import TypedFastAPI + + +class TestTypedFastAPI: + def test_initial_app_context_is_none(self): + app = TypedFastAPI() + assert app.app_context is None + + def test_set_app_context_assigns_value(self): + app = TypedFastAPI() + ctx = AppContext() + app.set_app_context(ctx) + assert app.app_context is ctx + + def test_inherits_from_fastapi(self): + assert isinstance(TypedFastAPI(), FastAPI) + + +class TestAddAppContextToFastAPI: + def test_adds_app_context_attribute(self): + app = FastAPI() + ctx = AppContext() + result = add_app_context_to_fastapi(app, ctx) + assert result is app + assert app.app_context is ctx + + def test_replaces_existing_app_context(self): + app = FastAPI() + first = AppContext() + second = AppContext() + add_app_context_to_fastapi(app, first) + add_app_context_to_fastapi(app, second) + assert app.app_context is second + + +class TestFastAPIWithContextProtocol: + def test_typed_fastapi_satisfies_protocol(self): + # Runtime-checkable not required; just confirm symbols exist. + assert hasattr(FastAPIWithContext, "include_router") diff --git a/src/backend-api/src/tests/model/test_entities.py b/src/backend-api/src/tests/model/test_entities.py new file mode 100644 index 00000000..78610cf0 --- /dev/null +++ b/src/backend-api/src/tests/model/test_entities.py @@ -0,0 +1,64 @@ +from datetime import datetime + +from libs.models.entities import AgentStatus, File, Process, ProcessStatus +from libs.models.messages import ProcessStartQueueMessage + + +class TestProcess: + def test_default_field_values(self): + process = Process(id="p1", user_id="u1") + assert process.id == "p1" + assert process.user_id == "u1" + assert process.source_file_count == 0 + assert process.result_file_count == 0 + assert process.status == "initialized" + assert isinstance(process.created_at, datetime) + assert isinstance(process.updated_at, datetime) + + def test_overrides(self): + process = Process( + id="p1", + user_id="u1", + source_file_count=2, + result_file_count=3, + status="ready_to_process", + ) + assert process.source_file_count == 2 + assert process.result_file_count == 3 + assert process.status == "ready_to_process" + + +class TestFile: + def test_default_counts_zero(self): + file_ = File(id="f1", process_id="p1", name="n", blob_path="b") + assert file_.error_count == 0 + assert file_.syntax_count == 0 + assert isinstance(file_.created_at, datetime) + + +class TestAgentStatus: + def test_time_stamp_default_is_iso_string(self): + status = AgentStatus(name="agent", role="role", status="ok") + # datetime.fromisoformat will raise if not ISO + assert isinstance(status.time_stamp, str) + datetime.fromisoformat(status.time_stamp) + + +class TestProcessStatus: + def test_status_list_assignment(self): + agents = [AgentStatus(name="a", role="r", status="s")] + ps = ProcessStatus(id="ps", process_id="p", phase="ph", status=agents) + assert ps.process_id == "p" + assert ps.phase == "ph" + assert len(ps.status) == 1 + + +class TestProcessStartQueueMessage: + def test_to_base64_roundtrip(self): + import base64 + import json + + msg = ProcessStartQueueMessage(process_id="p", user_id="u") + encoded = msg.to_base64() + decoded = json.loads(base64.b64decode(encoded).decode()) + assert decoded == {"process_id": "p", "user_id": "u"} diff --git a/src/backend-api/src/tests/model/test_router_models.py b/src/backend-api/src/tests/model/test_router_models.py new file mode 100644 index 00000000..1e82bc9f --- /dev/null +++ b/src/backend-api/src/tests/model/test_router_models.py @@ -0,0 +1,106 @@ +import base64 +import json + +import pytest +from pydantic import ValidationError + +from routers.models.files import Batch, File, FileInfo, FileUploadResult +from routers.models.processes import ( + FileContentResponse, + FileInfo as ProcessFileInfo, + ProcessCreateResponse, + ProcessInfo, + ProcessSummaryFileInfo, + ProcessSummaryResponse, + enlist_process_queue_response, +) + + +class TestFileEntity: + def test_attributes_assigned(self): + f = File(file_id="fid", original_name="orig.txt") + assert f.file_id == "fid" + assert f.original_name == "orig.txt" + + +class TestBatch: + def test_batch_id_assigned(self): + b = Batch(batch_id="bid") + assert b.batch_id == "bid" + + +class TestFileUploadResult: + def test_composes_batch_and_file(self): + result = FileUploadResult(batch_id="b1", file_id="f1", file_name="x.yaml") + assert isinstance(result.batch, Batch) + assert isinstance(result.file, File) + assert result.batch.batch_id == "b1" + assert result.file.file_id == "f1" + assert result.file.original_name == "x.yaml" + + +class TestFileInfo: + def test_excludes_content_from_serialization(self): + info = FileInfo( + filename="a.txt", + content=b"secret", + content_type="text/plain", + size=6, + ) + dumped = info.model_dump() + assert "content" not in dumped + assert dumped["filename"] == "a.txt" + + def test_validation_requires_filename(self): + with pytest.raises(ValidationError): + FileInfo(content_type="text/plain", size=0) + + +class TestProcessSchemas: + def test_process_create_response(self): + assert ProcessCreateResponse(process_id="x").process_id == "x" + + def test_process_info_requires_fields(self): + with pytest.raises(ValidationError): + ProcessInfo(process_id="x") # missing created_at/file_count + + def test_process_summary_response_round_trip(self): + from datetime import datetime, timezone + + info = ProcessInfo( + process_id="p", + created_at=datetime.now(timezone.utc), + file_count=2, + ) + summary = ProcessSummaryResponse( + Process=info, + files=[ProcessSummaryFileInfo(filename="a")], + ) + assert summary.Process.process_id == "p" + assert summary.files[0].filename == "a" + + def test_file_content_response(self): + assert FileContentResponse(content="hello").content == "hello" + + def test_process_file_info_validation(self): + info = ProcessFileInfo(filename="f", content_type="text/plain", size=1) + assert info.filename == "f" + + +class TestEnlistProcessQueueResponse: + def test_to_base64_round_trip(self): + resp = enlist_process_queue_response( + user_id="u", + process_id="p", + message="m", + files=[ProcessFileInfo(filename="a", content_type="text/plain", size=1)], + ) + decoded = json.loads(base64.b64decode(resp.to_base64()).decode()) + assert decoded["user_id"] == "u" + assert decoded["process_id"] == "p" + assert decoded["files"][0]["filename"] == "a" + + def test_optional_fields_default_to_none(self): + resp = enlist_process_queue_response(user_id="u", process_id="p") + assert resp.message is None + assert resp.files is None diff --git a/src/backend-api/src/tests/repositories/__init__.py b/src/backend-api/src/tests/repositories/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/repositories/test_process_status_repository.py b/src/backend-api/src/tests/repositories/test_process_status_repository.py new file mode 100644 index 00000000..6dd4b2f4 --- /dev/null +++ b/src/backend-api/src/tests/repositories/test_process_status_repository.py @@ -0,0 +1,312 @@ +"""Tests for libs/repositories/process_status_repository.py.""" + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from libs.repositories import process_status_repository as psr_module +from libs.repositories.process_status_repository import ( + ProcessStatusRepository, + analyze_agent_velocity, + calculate_activity_duration, + get_agent_relationship_status, +) + + +class TestCalculateActivityDuration: + def test_returns_zero_for_empty_input(self): + assert calculate_activity_duration("") == (0, "0s") + assert calculate_activity_duration(None) == (0, "0s") + + def test_returns_seconds_for_under_minute(self): + ts = (datetime.now(UTC) - timedelta(seconds=5)).isoformat() + secs, formatted = calculate_activity_duration(ts) + assert 4 <= secs <= 7 + assert formatted.endswith("s") + + def test_returns_minutes_for_under_hour(self): + ts = (datetime.now(UTC) - timedelta(minutes=5)).isoformat() + secs, formatted = calculate_activity_duration(ts) + assert 290 <= secs <= 320 + assert "m" in formatted and "s" in formatted + + def test_returns_hours_for_long_durations(self): + ts = (datetime.now(UTC) - timedelta(hours=2, minutes=15)).isoformat() + secs, formatted = calculate_activity_duration(ts) + assert secs >= 2 * 3600 + assert "h" in formatted and "m" in formatted + + def test_handles_utc_suffix(self): + ts = (datetime.now(UTC) - timedelta(seconds=10)).strftime( + "%Y-%m-%dT%H:%M:%S UTC" + ) + secs, _ = calculate_activity_duration(ts) + assert secs >= 9 + + def test_returns_zero_on_parse_error(self): + assert calculate_activity_duration("not-a-date") == (0, "0s") + + +class TestAnalyzeAgentVelocity: + def test_idle_when_no_history(self): + assert analyze_agent_velocity([]) == "idle" + + def test_slow_when_no_recent_activity(self): + old = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + history = [{"timestamp": old}] + assert analyze_agent_velocity(history) == "slow" + + @pytest.mark.parametrize( + "count,expected", + [(1, "normal"), (3, "fast"), (5, "very_fast"), (7, "very_fast")], + ) + def test_velocity_thresholds(self, count, expected): + ts = datetime.now(UTC).isoformat() + history = [{"timestamp": ts} for _ in range(count)] + assert analyze_agent_velocity(history) == expected + + def test_skips_invalid_timestamps(self): + history = [{"timestamp": "broken"}, {"timestamp": "broken"}] + assert analyze_agent_velocity(history) == "slow" + + +class TestGetAgentRelationshipStatus: + def test_empty_relationships_for_unknown(self): + rels = get_agent_relationship_status({"name": "x"}, {}) + assert rels == { + "waiting_for": [], + "blocking": [], + "collaborating_with": [], + "dependency_chain": [], + } + + def test_standby_agent_waits_for_active_ready(self): + agent = {"name": "me", "participation_status": "standby"} + all_agents = { + "other": { + "name": "other", + "is_active": True, + "participation_status": "ready", + } + } + rels = get_agent_relationship_status(agent, all_agents) + assert "other" in rels["waiting_for"] + + def test_active_agent_blocks_standby(self): + agent = {"name": "me", "is_active": True} + all_agents = { + "other": {"name": "other", "participation_status": "standby"}, + "me": {"name": "me", "is_active": True}, + } + rels = get_agent_relationship_status(agent, all_agents) + assert "other" in rels["blocking"] + assert "me" not in rels["blocking"] + + +class _NoOpRepoBase: + def __init__(self, *a, **kw): + pass + + +@pytest.fixture +def repo(): + with patch.object( + psr_module.RepositoryBase, "__init__", _NoOpRepoBase.__init__ + ): + r = ProcessStatusRepository("u", "db", "c") + r.get_async = AsyncMock() + return r + + +def _agent_obj(**overrides): + base = { + "name": "alpha", + "is_currently_speaking": False, + "is_active": True, + "current_action": "thinking", + "current_speaking_content": "", + "last_message_preview": "hello", + "participation_status": "ready", + "current_reasoning": "", + "last_reasoning": "", + "thinking_about": "", + "reasoning_steps": [], + "last_activity_summary": "", + "is_currently_thinking": False, + "last_update_time": "", + "last_full_message": "", + "activity_history": [], + "message_word_count": 0, + } + base.update(overrides) + return SimpleNamespace(**base) + + +def _process_status(agents=None, **overrides): + base = { + "id": "proc-1", + "step": "Analysis", + "phase": "Analysis", + "status": "running", + "last_update_time": "", + "started_at_time": "", + "failure_agent": "", + "failure_reason": "", + "failure_details": "", + "failure_step": "", + "failure_timestamp": "", + "stack_trace": "", + "agents": agents if agents is not None else {}, + "step_timings": {}, + "step_results": {}, + "generated_files": [], + "conversion_metrics": {}, + } + base.update(overrides) + return SimpleNamespace(**base) + + +class TestRepositoryGetters: + @pytest.mark.asyncio + async def test_get_process_agent_activities_returns_status(self, repo): + ps = _process_status() + repo.get_async.return_value = ps + result = await repo.get_process_agent_activities_by_process_id("proc-1") + assert result is ps + + @pytest.mark.asyncio + async def test_get_process_agent_activities_returns_none(self, repo): + repo.get_async.return_value = None + assert ( + await repo.get_process_agent_activities_by_process_id("proc-1") is None + ) + + @pytest.mark.asyncio + async def test_get_process_status_by_process_id_returns_none(self, repo): + repo.get_async.return_value = None + assert await repo.get_process_status_by_process_id("p") is None + + @pytest.mark.asyncio + async def test_get_process_status_by_process_id_builds_snapshot(self, repo): + agent = _agent_obj() + ps = _process_status(agents={"alpha": agent}) + repo.get_async.return_value = ps + snap = await repo.get_process_status_by_process_id("proc-1") + assert snap is not None + assert snap.process_id == "proc-1" + assert len(snap.agents) == 1 + assert snap.agents[0].name == "alpha" + + +class TestRenderAgentStatus: + @pytest.mark.asyncio + async def test_returns_not_found_when_no_process(self, repo): + repo.get_async.return_value = None + result = await repo.render_agent_status("missing") + assert result["status"] == "not_found" + assert result["agents"] == [] + + @pytest.mark.asyncio + async def test_renders_with_full_data(self, repo): + agent = _agent_obj( + name="Chief_Architect", + participation_status="ready", + current_action="reviewing", + ) + ps = _process_status(agents={"Chief_Architect": agent}) + repo.get_async.return_value = ps + result = await repo.render_agent_status("proc-1") + assert result["process_id"] == "proc-1" + assert result["total_agents"] == 1 + assert "agents" in result + assert isinstance(result["agents"], list) + assert result["health_status"].startswith("🟢") or result[ + "health_status" + ].startswith("🟡") + + @pytest.mark.asyncio + async def test_renders_failed_process(self, repo): + agent = _agent_obj(name="system", is_active=True) + ps = _process_status(agents={"system": agent}, status="failed") + repo.get_async.return_value = ps + result = await repo.render_agent_status("proc-1") + assert "system" in result["failed_agents"] + assert result["health_status"] == "🔴 CRITICAL" + + @pytest.mark.asyncio + async def test_renders_speaking_agent(self, repo): + agent = _agent_obj( + is_currently_speaking=True, + current_speaking_content="hello world", + message_word_count=2, + ) + ps = _process_status(agents={"alpha": agent}) + repo.get_async.return_value = ps + result = await repo.render_agent_status("proc-1") + assert any("hello world" in line for line in result["agents"]) + + @pytest.mark.asyncio + async def test_returns_empty_when_no_agents_data(self, repo): + ps = _process_status(agents={}) + repo.get_async.return_value = ps + result = await repo.render_agent_status("proc-1") + assert result["agents"] == [] + + +class TestRenderAgentStatusOld: + @pytest.mark.asyncio + async def test_returns_not_found_when_no_snapshot(self, repo): + repo.get_async.return_value = None + result = await repo.render_agent_status_old("nope") + assert result["status"] == "not_found" + + @pytest.mark.asyncio + async def test_renders_old_with_snapshot(self, repo): + agent = _agent_obj(name="system") + ps = _process_status(agents={"system": agent}) + repo.get_async.return_value = ps + result = await repo.render_agent_status_old("proc-1") + assert result["process_id"] == "proc-1" + assert isinstance(result["agents"], list) + + +class TestReadyStatusMessage: + @pytest.fixture + def r(self, repo): + return repo + + @pytest.mark.parametrize( + "agent,step,expected_substring", + [ + ("Chief_Architect", "Analysis", "analyze architecture"), + ("EKS_Expert", "Design", "EKS"), + ("GKS_Expert", "YAML", "AKS"), + ("Azure_Expert", "Documentation", "document Azure"), + ("Technical_Writer", "Analysis", "document"), + ("QA_Engineer", "YAML", "validate YAML"), + ], + ) + def test_known_agent_messages(self, r, agent, step, expected_substring): + msg = r._get_ready_status_message(agent, step, "Analysis", "ready") + assert expected_substring.lower() in msg.lower() + + def test_unknown_agent_default(self, r): + msg = r._get_ready_status_message( + "Chief_Architect", "UnknownStep", "Analysis", "ready" + ) + assert "Ready" in msg + + @pytest.mark.parametrize( + "status,expected", + [ + ("standby", "Standing by"), + ("waiting", "Waiting"), + ("completed", "Completed"), + ("other", "Ready for"), + ], + ) + def test_unknown_agent_status_messages(self, r, status, expected): + msg = r._get_ready_status_message("Unknown", "Analysis", "Analysis", status) + assert expected in msg diff --git a/src/backend-api/src/tests/repositories/test_repositories_extra.py b/src/backend-api/src/tests/repositories/test_repositories_extra.py new file mode 100644 index 00000000..0c5588ce --- /dev/null +++ b/src/backend-api/src/tests/repositories/test_repositories_extra.py @@ -0,0 +1,108 @@ +"""Tests for libs/repositories/file_repository.py and process_repository.py.""" + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from libs.repositories import file_repository as fr_module +from libs.repositories import process_repository as pr_module +from libs.repositories.file_repository import FileRepository +from libs.repositories.process_repository import ProcessRepository + + +def _no_init(self, *args, **kwargs): # pragma: no cover - helper + return None + + +class TestFileRepository: + @pytest.mark.asyncio + async def test_update_async_sets_updated_at_and_calls_super(self): + with patch.object(fr_module.RepositoryBase, "__init__", _no_init): + repo = FileRepository( + account_url="https://x", database_name="db", container_name="c" + ) + + entity = SimpleNamespace(id="f1", updated_at=None) + before = datetime.now(UTC) - timedelta(seconds=1) + + with patch.object( + fr_module.RepositoryBase, + "update_async", + new=AsyncMock(return_value=entity), + ) as mock_super: + result = await repo.update_async(entity) + + after = datetime.now(UTC) + timedelta(seconds=1) + assert result is entity + assert isinstance(entity.updated_at, datetime) + assert before <= entity.updated_at <= after + mock_super.assert_awaited_once_with(entity) + + def test_init_calls_super_with_proper_args(self): + captured = {} + + def fake_init(self, account_url, database_name, container_name): + captured["account_url"] = account_url + captured["database_name"] = database_name + captured["container_name"] = container_name + + with patch.object(fr_module.RepositoryBase, "__init__", fake_init): + FileRepository( + account_url="https://acct.documents.azure.com", + database_name="mydb", + container_name="files", + ) + + assert captured == { + "account_url": "https://acct.documents.azure.com", + "database_name": "mydb", + "container_name": "files", + } + + +class TestProcessRepository: + @pytest.mark.asyncio + async def test_update_async_sets_updated_at_and_calls_super(self): + with patch.object(pr_module.RepositoryBase, "__init__", _no_init): + repo = ProcessRepository( + account_url="https://x", database_name="db", container_name="c" + ) + + entity = SimpleNamespace(id="p1", updated_at=None) + before = datetime.now(UTC) - timedelta(seconds=1) + + with patch.object( + pr_module.RepositoryBase, + "update_async", + new=AsyncMock(return_value=entity), + ) as mock_super: + result = await repo.update_async(entity) + + after = datetime.now(UTC) + timedelta(seconds=1) + assert result is entity + assert isinstance(entity.updated_at, datetime) + assert before <= entity.updated_at <= after + mock_super.assert_awaited_once_with(entity) + + def test_init_calls_super_with_proper_args(self): + captured = {} + + def fake_init(self, account_url, database_name, container_name): + captured["account_url"] = account_url + captured["database_name"] = database_name + captured["container_name"] = container_name + + with patch.object(pr_module.RepositoryBase, "__init__", fake_init): + ProcessRepository( + account_url="https://acct.documents.azure.com", + database_name="mydb", + container_name="processes", + ) + + assert captured == { + "account_url": "https://acct.documents.azure.com", + "database_name": "mydb", + "container_name": "processes", + } diff --git a/src/backend-api/src/tests/routers/__init__.py b/src/backend-api/src/tests/routers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/routers/test_http_probes.py b/src/backend-api/src/tests/routers/test_http_probes.py new file mode 100644 index 00000000..129a67c5 --- /dev/null +++ b/src/backend-api/src/tests/routers/test_http_probes.py @@ -0,0 +1,51 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from routers.http_probes import router + + +def _build_client(): + app = FastAPI() + app.include_router(router) + return TestClient(app) + + +class TestRoot: + def test_returns_200(self): + client = _build_client() + res = client.get("/") + assert res.status_code == 200 + + def test_response_body_contains_expected_fields(self): + client = _build_client() + body = client.get("/").json() + assert body["message"] == "Code Migration Code converting process API" + assert body["version"] == "1.0.0" + assert body["status"] == "running" + assert "timestamp" in body + assert "uptime_seconds" in body + assert isinstance(body["uptime_seconds"], (int, float)) + + +class TestHealth: + def test_returns_200(self): + client = _build_client() + res = client.get("/health") + assert res.status_code == 200 + + def test_response_includes_message(self): + client = _build_client() + body = client.get("/health").json() + assert body == {"message": "I'm alive!"} + + +class TestStartup: + def test_returns_200(self): + client = _build_client() + res = client.get("/startup") + assert res.status_code == 200 + + def test_response_includes_running_message(self): + client = _build_client() + body = client.get("/startup").json() + assert body["message"].startswith("Running for") diff --git a/src/backend-api/src/tests/routers/test_router_debug.py b/src/backend-api/src/tests/routers/test_router_debug.py new file mode 100644 index 00000000..5eb34163 --- /dev/null +++ b/src/backend-api/src/tests/routers/test_router_debug.py @@ -0,0 +1,58 @@ +from types import SimpleNamespace + +from fastapi.testclient import TestClient + +from libs.application.application_context import AppContext +from libs.base.typed_fastapi import TypedFastAPI +from routers.router_debug import router + + +def _make_configuration(**overrides): + base = { + "app_logging_enable": True, + "app_logging_level": "INFO", + "azure_package_logging_level": "WARNING", + "azure_logging_packages": None, + "cosmos_db_account_url": "https://cosmos.example.com", + "cosmos_db_database_name": "db", + "cosmos_db_process_container": "processes", + "cosmos_db_process_log_container": "logs", + "storage_account_name": "stg", + "storage_account_blob_url": "https://blob.example.com", + "storage_account_queue_url": "https://queue.example.com", + "storage_account_process_container": "container", + "storage_account_process_queue": "queue", + } + base.update(overrides) + return SimpleNamespace(**base) + + +def _build_client(config=None): + app = TypedFastAPI() + ctx = AppContext() + ctx.set_configuration(config or _make_configuration()) + app.set_app_context(ctx) + app.include_router(router) + return TestClient(app) + + +class TestGetConfigDebug: + def test_returns_200(self): + client = _build_client() + res = client.get("/debug/config") + assert res.status_code == 200 + + def test_returns_configuration_payload(self): + client = _build_client() + body = client.get("/debug/config").json() + assert "configuration" in body + cfg = body["configuration"] + assert cfg["cosmos_db_database_name"] == "db" + assert cfg["storage_account_name"] == "stg" + assert cfg["app_logging_level"] == "INFO" + + def test_reflects_overridden_values(self): + config = _make_configuration(storage_account_name="custom-storage") + client = _build_client(config=config) + body = client.get("/debug/config").json() + assert body["configuration"]["storage_account_name"] == "custom-storage" diff --git a/src/backend-api/src/tests/routers/test_router_files.py b/src/backend-api/src/tests/routers/test_router_files.py new file mode 100644 index 00000000..8ef44ca5 --- /dev/null +++ b/src/backend-api/src/tests/routers/test_router_files.py @@ -0,0 +1,161 @@ +"""Tests for routers/router_files.py.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +from fastapi.testclient import TestClient + +from libs.base.typed_fastapi import TypedFastAPI +from libs.services.interfaces import ILoggerService +from routers.router_files import router + + +def _make_async_cm(yielded): + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=yielded) + cm.__aexit__ = AsyncMock(return_value=False) + return cm + + +def _make_app(*, process_record=None, file_count=1, blob_helper=None): + app = TypedFastAPI() + + logger = MagicMock(spec=ILoggerService) + process_repo = MagicMock() + process_repo.get_async = AsyncMock( + return_value=process_record + or SimpleNamespace(id="p-1", source_file_count=0, status="initialized") + ) + process_repo.update_async = AsyncMock(return_value=None) + + file_repo = MagicMock() + file_repo.add_async = AsyncMock(return_value=None) + file_repo.count_async = AsyncMock(return_value=file_count) + + if blob_helper is None: + blob_helper = MagicMock() + blob_helper.upload_blob = AsyncMock(return_value=None) + + blob_cm = _make_async_cm(blob_helper) + + def scope_get_service(t): + from libs.repositories.file_repository import FileRepository + from libs.repositories.process_repository import ProcessRepository + from libs.sas.storage import AsyncStorageBlobHelper + + if t is ProcessRepository: + return process_repo + if t is FileRepository: + return file_repo + if t is AsyncStorageBlobHelper: + return blob_cm + return MagicMock() + + scope = MagicMock() + scope.get_service.side_effect = scope_get_service + scope_cm = _make_async_cm(scope) + + ctx = MagicMock() + ctx.configuration = SimpleNamespace(storage_account_process_container="container") + ctx.create_scope = MagicMock(return_value=scope_cm) + + def app_get_service(t): + if t is ILoggerService: + return logger + return MagicMock() + + ctx.get_service.side_effect = app_get_service + app.app_context = ctx + app.include_router(router) + return app, { + "logger": logger, + "process_repo": process_repo, + "file_repo": file_repo, + "blob_helper": blob_helper, + } + + +VALID_PROCESS_ID = "123e4567-e89b-42d3-a456-426614174000" +AUTH_HEADERS = {"x-ms-client-principal-id": "user-1"} + + +class TestUploadOptions: + def test_returns_200(self): + app, _ = _make_app() + client = TestClient(app) + res = client.options("/api/file/upload") + assert res.status_code == 200 + + def test_returns_cors_headers(self): + app, _ = _make_app() + client = TestClient(app) + res = client.options("/api/file/upload") + assert res.headers["Access-Control-Allow-Origin"] == "*" + assert "POST" in res.headers["Access-Control-Allow-Methods"] + + +class TestUploadFile: + def test_uploads_file_successfully(self): + app, mocks = _make_app() + client = TestClient(app) + res = client.post( + "/api/file/upload", + files={"file": ("hello.txt", b"hi", "text/plain")}, + data={"process_id": VALID_PROCESS_ID}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 200 + body = res.json() + assert body["file"]["original_name"] == "hello.txt" + assert body["batch"]["batch_id"] == "p-1" + mocks["blob_helper"].upload_blob.assert_awaited() + mocks["process_repo"].update_async.assert_awaited() + + def test_returns_400_on_invalid_process_id(self): + app, _ = _make_app() + client = TestClient(app) + res = client.post( + "/api/file/upload", + files={"file": ("x.txt", b"x", "text/plain")}, + data={"process_id": "not-a-uuid"}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 400 + + def test_sanitizes_filename_to_blob_path(self): + app, mocks = _make_app() + client = TestClient(app) + client.post( + "/api/file/upload", + files={"file": ("a b!c.txt", b"x", "text/plain")}, + data={"process_id": VALID_PROCESS_ID}, + headers=AUTH_HEADERS, + ) + kwargs = mocks["blob_helper"].upload_blob.await_args.kwargs + assert kwargs["blob_name"].endswith("/source/a_b_c.txt") + + def test_marks_status_ready_to_process_when_files_exist(self): + app, mocks = _make_app(file_count=3) + client = TestClient(app) + client.post( + "/api/file/upload", + files={"file": ("a.txt", b"x", "text/plain")}, + data={"process_id": VALID_PROCESS_ID}, + headers=AUTH_HEADERS, + ) + updated = mocks["process_repo"].update_async.await_args.args[0] + assert updated.source_file_count == 3 + assert updated.status == "ready_to_process" + + def test_returns_500_when_blob_upload_fails(self): + bad_blob = MagicMock() + bad_blob.upload_blob = AsyncMock(side_effect=RuntimeError("boom")) + app, _ = _make_app(blob_helper=bad_blob) + client = TestClient(app) + res = client.post( + "/api/file/upload", + files={"file": ("a.txt", b"x", "text/plain")}, + data={"process_id": VALID_PROCESS_ID}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 500 diff --git a/src/backend-api/src/tests/routers/test_router_process.py b/src/backend-api/src/tests/routers/test_router_process.py new file mode 100644 index 00000000..55b16341 --- /dev/null +++ b/src/backend-api/src/tests/routers/test_router_process.py @@ -0,0 +1,394 @@ +"""Tests for routers/router_process.py.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi.testclient import TestClient + +from libs.base.typed_fastapi import TypedFastAPI +from libs.services.interfaces import ILoggerService +from libs.services.process_services import ProcessService +from libs.repositories.process_repository import ProcessRepository +from routers.router_process import router + + +AUTH_HEADERS = {"x-ms-client-principal-id": "user-1"} + + +def _make_async_cm(yielded): + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=yielded) + cm.__aexit__ = AsyncMock(return_value=False) + return cm + + +def _build(process_service=None, process_repo=None, configuration=None): + app = TypedFastAPI() + logger = MagicMock(spec=ILoggerService) + + process_service = process_service or MagicMock(spec=ProcessService) + process_repo = process_repo or MagicMock() + if not hasattr(process_repo, "add_async") or not isinstance( + process_repo.add_async, AsyncMock + ): + process_repo.add_async = AsyncMock(return_value=None) + + scope = MagicMock() + scope.get_service.side_effect = lambda t: ( + process_repo if t is ProcessRepository else MagicMock() + ) + scope_cm = _make_async_cm(scope) + + ctx = MagicMock() + ctx.configuration = configuration or SimpleNamespace( + processor_control_url="http://proc:8080", + processor_control_token="tok", + ) + ctx.create_scope = MagicMock(return_value=scope_cm) + + def app_get(t): + if t is ILoggerService: + return logger + if t is ProcessService: + return process_service + return MagicMock() + + ctx.get_service.side_effect = app_get + app.app_context = ctx + app.include_router(router) + return app, process_service, process_repo + + +class TestCreateProcess: + def test_returns_process_id(self): + app, _svc, repo = _build() + client = TestClient(app) + res = client.post("/api/process/create", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert "process_id" in res.json() + repo.add_async.assert_awaited() + + def test_returns_500_on_repo_error(self): + repo = MagicMock() + repo.add_async = AsyncMock(side_effect=RuntimeError("db down")) + app, _, _ = _build(process_repo=repo) + client = TestClient(app) + res = client.post("/api/process/create", headers=AUTH_HEADERS) + assert res.status_code == 500 + + +class TestStatus: + def test_returns_service_payload(self): + svc = MagicMock(spec=ProcessService) + svc.get_current_process = AsyncMock(return_value={"phase": "x"}) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/status/abc/", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert res.json() == {"phase": "x"} + + def test_render_status(self): + svc = MagicMock(spec=ProcessService) + svc.render_current_process = AsyncMock(return_value=["a", "b"]) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/status/abc/render/", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert res.json() == ["a", "b"] + + +class TestUploadFiles: + def test_uploads_and_returns_files(self): + svc = MagicMock(spec=ProcessService) + svc.save_files_to_blob = AsyncMock(return_value=None) + svc.get_all_uploaded_files = AsyncMock(return_value=[]) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.post( + "/api/process/upload", + data={"process_id": "p-1"}, + files={"files": ("a.txt", b"hi", "text/plain")}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 200 + svc.save_files_to_blob.assert_awaited() + + def test_returns_500_on_service_error(self): + svc = MagicMock(spec=ProcessService) + svc.save_files_to_blob = AsyncMock(side_effect=RuntimeError("fail")) + svc.get_all_uploaded_files = AsyncMock(return_value=[]) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.post( + "/api/process/upload", + data={"process_id": "p-1"}, + files={"files": ("a.txt", b"x", "text/plain")}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 500 + + +class TestDeleteFile: + def test_returns_200(self): + svc = MagicMock(spec=ProcessService) + svc.delete_file_from_blob = AsyncMock(return_value=None) + svc.get_all_uploaded_files = AsyncMock(return_value=[]) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.request( + "DELETE", + "/api/process/delete-file/foo.txt", + data={"process_id": "p-1"}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 200 + + def test_returns_404_when_file_missing(self): + svc = MagicMock(spec=ProcessService) + svc.delete_file_from_blob = AsyncMock(side_effect=FileNotFoundError("x")) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.request( + "DELETE", + "/api/process/delete-file/foo.txt", + data={"process_id": "p-1"}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 404 + + +class TestDeleteProcess: + def test_returns_200_with_deleted_count_message(self): + svc = MagicMock(spec=ProcessService) + svc.delete_all_files_from_blob = AsyncMock(return_value=3) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.delete("/api/process/delete-process/p-1", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert "3 files removed" in res.json()["message"] + + def test_returns_500_on_error(self): + svc = MagicMock(spec=ProcessService) + svc.delete_all_files_from_blob = AsyncMock(side_effect=RuntimeError("boom")) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.delete("/api/process/delete-process/p-1", headers=AUTH_HEADERS) + assert res.status_code == 500 + + +class TestStartProcessing: + def test_returns_202(self): + svc = MagicMock(spec=ProcessService) + svc.process_enqueue = AsyncMock(return_value=None) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.post( + "/api/process/start-processing", + data={"process_id": "p-1"}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 202 + assert res.json()["status"] == "queued" + + def test_returns_500_on_service_error(self): + svc = MagicMock(spec=ProcessService) + svc.process_enqueue = AsyncMock(side_effect=RuntimeError("nope")) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.post( + "/api/process/start-processing", + data={"process_id": "p-1"}, + headers=AUTH_HEADERS, + ) + assert res.status_code == 500 + + +class TestDownload: + def test_returns_zip(self): + from routers.models.files import FileInfo + + files = [FileInfo(filename="a.txt", content=b"hello", content_type="text/plain", size=5)] + svc = MagicMock(spec=ProcessService) + svc.get_converted_files = AsyncMock(return_value=files) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/p-1/download", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert res.headers["content-type"] == "application/zip" + + def test_returns_404_when_no_files(self): + svc = MagicMock(spec=ProcessService) + svc.get_converted_files = AsyncMock(return_value=[]) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/p-1/download", headers=AUTH_HEADERS) + assert res.status_code == 404 + + +class TestProcessSummary: + def test_returns_summary(self): + from datetime import datetime, timezone + + entity = SimpleNamespace(id="p-1", created_at=datetime.now(timezone.utc)) + svc = MagicMock(spec=ProcessService) + svc.get_process_summary = AsyncMock(return_value=(entity, ["a.txt", "b.txt"])) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/process-summary/p-1", headers=AUTH_HEADERS) + assert res.status_code == 200 + body = res.json() + assert body["Process"]["file_count"] == 2 + assert len(body["files"]) == 2 + + def test_returns_500_on_error(self): + svc = MagicMock(spec=ProcessService) + svc.get_process_summary = AsyncMock(side_effect=RuntimeError("x")) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/process-summary/p-1", headers=AUTH_HEADERS) + assert res.status_code == 500 + + +class TestGetFileContent: + def test_returns_content(self): + svc = MagicMock(spec=ProcessService) + svc.get_converted_file_content = AsyncMock(return_value="hello") + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/p-1/file/a.txt", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert res.json() == {"content": "hello"} + + def test_returns_404_when_missing(self): + svc = MagicMock(spec=ProcessService) + svc.get_converted_file_content = AsyncMock(side_effect=FileNotFoundError()) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/p-1/file/a.txt", headers=AUTH_HEADERS) + assert res.status_code == 404 + + def test_returns_400_on_unicode_error(self): + svc = MagicMock(spec=ProcessService) + svc.get_converted_file_content = AsyncMock( + side_effect=UnicodeDecodeError("utf-8", b"", 0, 1, "x") + ) + app, *_ = _build(process_service=svc) + client = TestClient(app) + res = client.get("/api/process/p-1/file/a.bin", headers=AUTH_HEADERS) + assert res.status_code == 400 + + +def _make_httpx_response(status_code=200, json_data=None, text=""): + resp = MagicMock() + resp.status_code = status_code + resp.json = MagicMock(return_value=json_data or {}) + resp.text = text + return resp + + +def _patch_httpx_async_client(method, response): + """Return a patcher that replaces httpx.AsyncClient with a context manager + whose `.` AsyncMock yields the given response.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + setattr(mock_client, method, AsyncMock(return_value=response)) + return patch("routers.router_process.httpx.AsyncClient", return_value=mock_client) + + +class TestCancelProcess: + def test_returns_202_on_success(self): + app, *_ = _build() + client = TestClient(app) + resp = _make_httpx_response( + 200, + json_data={ + "kill_requested": True, + "kill_state": "pending", + "kill_requested_at": "2025-01-01", + }, + ) + with _patch_httpx_async_client("post", resp): + res = client.post("/api/process/cancel/p-1", headers=AUTH_HEADERS) + assert res.status_code == 202 + assert res.json()["kill_state"] == "pending" + + def test_returns_502_on_processor_401(self): + app, *_ = _build() + client = TestClient(app) + resp = _make_httpx_response(401, text="unauth") + with _patch_httpx_async_client("post", resp): + res = client.post("/api/process/cancel/p-1", headers=AUTH_HEADERS) + assert res.status_code == 502 + + def test_returns_502_on_processor_500(self): + app, *_ = _build() + client = TestClient(app) + resp = _make_httpx_response(500, text="boom") + with _patch_httpx_async_client("post", resp): + res = client.post("/api/process/cancel/p-1", headers=AUTH_HEADERS) + assert res.status_code == 502 + + def test_returns_504_on_timeout(self): + import httpx + + app, *_ = _build() + client = TestClient(app) + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("t")) + with patch( + "routers.router_process.httpx.AsyncClient", return_value=mock_client + ): + res = client.post("/api/process/cancel/p-1", headers=AUTH_HEADERS) + assert res.status_code == 504 + + def test_returns_503_on_connect_error(self): + import httpx + + app, *_ = _build() + client = TestClient(app) + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("c")) + with patch( + "routers.router_process.httpx.AsyncClient", return_value=mock_client + ): + res = client.post("/api/process/cancel/p-1", headers=AUTH_HEADERS) + assert res.status_code == 503 + + +class TestCancelStatus: + def test_returns_200_on_success(self): + app, *_ = _build() + client = TestClient(app) + resp = _make_httpx_response(200, json_data={"kill_state": "pending"}) + with _patch_httpx_async_client("get", resp): + res = client.get("/api/process/cancel/p-1/status", headers=AUTH_HEADERS) + assert res.status_code == 200 + assert res.json() == {"kill_state": "pending"} + + def test_returns_502_on_processor_401(self): + app, *_ = _build() + client = TestClient(app) + resp = _make_httpx_response(401, text="nope") + with _patch_httpx_async_client("get", resp): + res = client.get("/api/process/cancel/p-1/status", headers=AUTH_HEADERS) + assert res.status_code == 502 + + def test_returns_504_on_timeout(self): + import httpx + + app, *_ = _build() + client = TestClient(app) + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(side_effect=httpx.TimeoutException("t")) + with patch( + "routers.router_process.httpx.AsyncClient", return_value=mock_client + ): + res = client.get("/api/process/cancel/p-1/status", headers=AUTH_HEADERS) + assert res.status_code == 504 diff --git a/src/backend-api/src/tests/sas/__init__.py b/src/backend-api/src/tests/sas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/sas/storage/__init__.py b/src/backend-api/src/tests/sas/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/sas/storage/blob/__init__.py b/src/backend-api/src/tests/sas/storage/blob/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/sas/storage/blob/test_async_helper.py b/src/backend-api/src/tests/sas/storage/blob/test_async_helper.py new file mode 100644 index 00000000..374a30e4 --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/blob/test_async_helper.py @@ -0,0 +1,618 @@ +"""Tests for libs/sas/storage/blob/async_helper.py.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError + + +class _AsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +@pytest.fixture +def blob_service_mock(): + with patch( + "libs.sas.storage.blob.async_helper.BlobServiceClient" + ) as svc_cls: + yield svc_cls + + +def _wire(svc_cls): + """Wire BlobServiceClient -> container_client -> blob_client.""" + svc_instance = MagicMock() + svc_cls.from_connection_string.return_value = svc_instance + svc_cls.return_value = svc_instance + svc_instance.close = AsyncMock() + + container_client = MagicMock() + container_client.create_container = AsyncMock() + container_client.delete_container = AsyncMock() + container_client.get_container_properties = AsyncMock() + blob_client = MagicMock() + blob_client.upload_blob = AsyncMock(return_value={"etag": "e"}) + blob_client.download_blob = AsyncMock() + blob_client.delete_blob = AsyncMock() + blob_client.get_blob_properties = AsyncMock() + blob_client.set_blob_metadata = AsyncMock() + container_client.get_blob_client.return_value = blob_client + svc_instance.get_container_client.return_value = container_client + return svc_instance, container_client, blob_client + + +def _blob_obj(name="f.txt", metadata=None): + b = MagicMock() + b.name = name + b.size = 5 + b.last_modified = None + b.etag = "e" + b.content_settings = None + b.blob_tier = None + b.blob_type = None + b.metadata = metadata or {} + return b + + +class TestInit: + @pytest.mark.asyncio + async def test_async_with_init_and_close(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc, _, _ = _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert h._blob_service_client is svc + svc.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_init_with_account_and_credential(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper( + account_name="acct", credential=MagicMock() + ): + pass + blob_service_mock.assert_called() + + @pytest.mark.asyncio + async def test_init_with_account_only(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + with patch( + "libs.sas.storage.blob.async_helper.DefaultAzureCredential" + ) as cred: + async with AsyncStorageBlobHelper(account_name="acct"): + pass + cred.assert_called_once() + + @pytest.mark.asyncio + async def test_init_no_args_raises(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + h = AsyncStorageBlobHelper() + with pytest.raises(ValueError): + await h._initialize_client() + + @pytest.mark.asyncio + async def test_init_failure_propagates(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + blob_service_mock.from_connection_string.side_effect = RuntimeError("x") + h = AsyncStorageBlobHelper(connection_string="c") + with pytest.raises(RuntimeError): + await h._initialize_client() + + @pytest.mark.asyncio + async def test_property_raises_when_uninitialized(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + h = AsyncStorageBlobHelper(connection_string="c") + with pytest.raises(RuntimeError): + _ = h.blob_service_client + + def test_init_with_dict_config(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + with patch( + "libs.sas.storage.blob.async_helper.create_config" + ) as cc: + cc.return_value = {"logging_level": "INFO"} + AsyncStorageBlobHelper(connection_string="c", config={"x": 1}) + cc.assert_called_once() + + @pytest.mark.asyncio + async def test_close_no_client_noop(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + h = AsyncStorageBlobHelper(connection_string="c") + await h.close() + + +class TestContainerOps: + @pytest.mark.asyncio + async def test_create_container(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.create_container("c") is True + + @pytest.mark.asyncio + async def test_create_container_exists(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.create_container.side_effect = ResourceExistsError("e") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.create_container("c") is False + + @pytest.mark.asyncio + async def test_create_container_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.create_container.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.create_container("c") + + @pytest.mark.asyncio + async def test_delete_container_empty(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.list_blobs = MagicMock(return_value=_AsyncIter([])) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.delete_container("c") is True + + @pytest.mark.asyncio + async def test_delete_container_nonempty_no_force(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.list_blobs = MagicMock(return_value=_AsyncIter([_blob_obj("x")])) + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(ValueError): + await h.delete_container("c") + + @pytest.mark.asyncio + async def test_delete_container_force(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + # call 1: check empty (force=True; first list_blobs) + # call 2: iterate blobs to delete + cc.list_blobs = MagicMock( + side_effect=[_AsyncIter([_blob_obj("x")]), _AsyncIter([_blob_obj("x")])] + ) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.delete_container("c", force_delete=True) is True + + @pytest.mark.asyncio + async def test_delete_container_not_found(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.list_blobs = MagicMock(return_value=_AsyncIter([])) + cc.delete_container.side_effect = ResourceNotFoundError("nf") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.delete_container("c") is False + + @pytest.mark.asyncio + async def test_container_exists_true(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.container_exists("c") is True + + @pytest.mark.asyncio + async def test_container_exists_false(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.get_container_properties.side_effect = ResourceNotFoundError("nf") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.container_exists("c") is False + + @pytest.mark.asyncio + async def test_container_exists_other_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.get_container_properties.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.container_exists("c") + + @pytest.mark.asyncio + async def test_list_containers(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc, _, _ = _wire(blob_service_mock) + c = MagicMock() + c.name = "x" + c.last_modified = None + c.metadata = {"k": "v"} + c.lease = None + c.public_access = None + svc.list_containers = MagicMock(return_value=_AsyncIter([c])) + async with AsyncStorageBlobHelper(connection_string="c") as h: + result = await h.list_containers() + assert result[0]["name"] == "x" + + @pytest.mark.asyncio + async def test_list_containers_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc, _, _ = _wire(blob_service_mock) + svc.list_containers = MagicMock(side_effect=RuntimeError("x")) + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.list_containers() + + +class TestBlobOps: + @pytest.mark.asyncio + async def test_upload_blob_bytes(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + await h.upload_blob("c", "f.txt", b"x") + bc.upload_blob.assert_awaited() + + @pytest.mark.asyncio + async def test_upload_blob_string_converts(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, _ = _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + await h.upload_blob("c", "f.txt", "hi", content_type="text/plain") + + @pytest.mark.asyncio + async def test_upload_blob_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.upload_blob.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.upload_blob("c", "f.txt", b"x") + + @pytest.mark.asyncio + async def test_download_blob(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + stream = MagicMock() + stream.readall = AsyncMock(return_value=b"data") + bc.download_blob.return_value = stream + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.download_blob("c", "f") == b"data" + + @pytest.mark.asyncio + async def test_download_blob_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.download_blob.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.download_blob("c", "f") + + @pytest.mark.asyncio + async def test_download_blob_to_file(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + stream = MagicMock() + stream.readall = AsyncMock(return_value=b"abc") + bc.download_blob.return_value = stream + out = tmp_path / "x.bin" + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.download_blob_to_file("c", "f", str(out)) is True + assert out.read_bytes() == b"abc" + + @pytest.mark.asyncio + async def test_download_blob_to_file_error(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.download_blob.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.download_blob_to_file("c", "f", str(tmp_path / "x")) + + @pytest.mark.asyncio + async def test_upload_blob_from_text(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + await h.upload_blob_from_text("c", "f", "hello") + + @pytest.mark.asyncio + async def test_upload_blob_from_text_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.upload_blob.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.upload_blob_from_text("c", "f", "hi") + + @pytest.mark.asyncio + async def test_upload_file(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + f = tmp_path / "src.txt" + f.write_text("hello") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.upload_file("c", "f.txt", str(f)) is True + + @pytest.mark.asyncio + async def test_upload_file_error(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.upload_blob.side_effect = RuntimeError("x") + f = tmp_path / "src.txt" + f.write_text("a") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.upload_file("c", "f.txt", str(f)) + + @pytest.mark.asyncio + async def test_download_file(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + + class _Stream: + def chunks(self): + async def _gen(): + for c in (b"a", b"b"): + yield c + + return _gen() + + bc.download_blob.return_value = _Stream() + out = tmp_path / "sub" / "x.bin" + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.download_file("c", "f", str(out)) is True + assert out.read_bytes() == b"ab" + + @pytest.mark.asyncio + async def test_download_file_error(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.download_blob.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.download_file("c", "f", str(tmp_path / "x")) + + @pytest.mark.asyncio + async def test_blob_exists_true(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.blob_exists("c", "f") is True + + @pytest.mark.asyncio + async def test_blob_exists_false(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.get_blob_properties.side_effect = ResourceNotFoundError("nf") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.blob_exists("c", "f") is False + + @pytest.mark.asyncio + async def test_blob_exists_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.get_blob_properties.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.blob_exists("c", "f") + + @pytest.mark.asyncio + async def test_delete_blob(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.delete_blob("c", "f") is True + + @pytest.mark.asyncio + async def test_delete_blob_not_found(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.delete_blob.side_effect = ResourceNotFoundError("nf") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.delete_blob("c", "f") is False + + @pytest.mark.asyncio + async def test_delete_blob_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.delete_blob.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.delete_blob("c", "f") + + @pytest.mark.asyncio + async def test_list_blobs(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.list_blobs = MagicMock( + return_value=_AsyncIter([_blob_obj("a.txt", {"k": "v"})]) + ) + async with AsyncStorageBlobHelper(connection_string="c") as h: + result = await h.list_blobs("c", include_metadata=True) + assert result[0]["name"] == "a.txt" + + @pytest.mark.asyncio + async def test_list_blobs_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, cc, _ = _wire(blob_service_mock) + cc.list_blobs = MagicMock(side_effect=RuntimeError("x")) + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.list_blobs("c") + + +class TestPropsAndSearch: + @pytest.mark.asyncio + async def test_get_blob_properties(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + props = MagicMock() + props.size = 1 + props.last_modified = None + props.etag = "e" + props.content_settings = None + props.metadata = {} + props.blob_tier = None + props.blob_type = "BlockBlob" + props.lease = None + props.creation_time = None + bc.get_blob_properties.return_value = props + async with AsyncStorageBlobHelper(connection_string="c") as h: + result = await h.get_blob_properties("c", "f") + assert result["size"] == 1 + + @pytest.mark.asyncio + async def test_get_blob_properties_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.get_blob_properties.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.get_blob_properties("c", "f") + + @pytest.mark.asyncio + async def test_set_blob_metadata(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.set_blob_metadata("c", "f", {"k": "v"}) is True + + @pytest.mark.asyncio + async def test_set_blob_metadata_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _, _, bc = _wire(blob_service_mock) + bc.set_blob_metadata.side_effect = RuntimeError("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.set_blob_metadata("c", "f", {}) + + @pytest.mark.asyncio + async def test_search_blobs_by_name(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.list_blobs = AsyncMock( + return_value=[ + {"name": "alpha.txt", "metadata": {"tag": "x"}}, + {"name": "beta.txt", "metadata": {"tag": "alpha"}}, + ] + ) + result = await h.search_blobs("c", "alpha", search_in_metadata=True) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_search_blobs_case_sensitive_no_match(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.list_blobs = AsyncMock( + return_value=[{"name": "ALPHA.txt", "metadata": {}}] + ) + result = await h.search_blobs("c", "alpha", case_sensitive=True) + assert result == [] + + @pytest.mark.asyncio + async def test_search_blobs_error(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.list_blobs = AsyncMock(side_effect=RuntimeError("x")) + with pytest.raises(RuntimeError): + await h.search_blobs("c", "x") + + +class TestBatch: + @pytest.mark.asyncio + async def test_upload_multiple_files(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + f = tmp_path / "a.txt" + f.write_text("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.upload_file = AsyncMock(return_value=True) + results = await h.upload_multiple_files("c", [str(f)]) + assert results[str(f)] is True + + @pytest.mark.asyncio + async def test_upload_multiple_files_with_failure(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + f = tmp_path / "a.txt" + f.write_text("x") + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.upload_file = AsyncMock(side_effect=RuntimeError("x")) + results = await h.upload_multiple_files("c", [str(f)]) + assert results[str(f)] is False + + @pytest.mark.asyncio + async def test_download_multiple_blobs(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.download_file = AsyncMock(return_value=True) + results = await h.download_multiple_blobs( + "c", ["a.txt"], str(tmp_path) + ) + assert results["a.txt"] is True + + @pytest.mark.asyncio + async def test_download_multiple_blobs_failure(self, blob_service_mock, tmp_path): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h.download_file = AsyncMock(side_effect=RuntimeError("x")) + results = await h.download_multiple_blobs( + "c", ["a.txt"], str(tmp_path) + ) + assert results["a.txt"] is False diff --git a/src/backend-api/src/tests/sas/storage/blob/test_async_helper_extra.py b/src/backend-api/src/tests/sas/storage/blob/test_async_helper_extra.py new file mode 100644 index 00000000..9d686158 --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/blob/test_async_helper_extra.py @@ -0,0 +1,334 @@ +"""Additional tests for libs/sas/storage/blob/async_helper.py. + +Targets the previously uncovered SAS URL generators, credential / account +helpers, and the inner failure branches of delete_container. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class _AsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +@pytest.fixture +def blob_service_mock(): + with patch( + "libs.sas.storage.blob.async_helper.BlobServiceClient" + ) as svc_cls: + yield svc_cls + + +def _wire(svc_cls): + svc_instance = MagicMock() + svc_cls.from_connection_string.return_value = svc_instance + svc_cls.return_value = svc_instance + svc_instance.close = AsyncMock() + return svc_instance + + +def _blob_obj(name="f.txt"): + b = MagicMock() + b.name = name + return b + + +class TestInitWithObjectConfig: + def test_object_config_kept_as_is(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + cfg = MagicMock() + cfg.get = MagicMock(return_value="INFO") + h = AsyncStorageBlobHelper(connection_string="c", config=cfg) + assert h.config is cfg + + +class TestDeleteContainerInnerFailures: + @pytest.mark.asyncio + async def test_force_delete_inner_blob_error_continues(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire(blob_service_mock) + cc = MagicMock() + svc.get_container_client.return_value = cc + # Two iterations of list_blobs: existence + delete pass + b1, b2 = _blob_obj("x"), _blob_obj("y") + cc.list_blobs = MagicMock( + side_effect=[_AsyncIter([b1, b2]), _AsyncIter([b1, b2])] + ) + ok = MagicMock() + ok.delete_blob = AsyncMock(return_value=None) + bad = MagicMock() + bad.delete_blob = AsyncMock(side_effect=RuntimeError("boom")) + cc.get_blob_client.side_effect = [bad, ok] + cc.delete_container = AsyncMock(return_value=None) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h.delete_container("c", force_delete=True) is True + + +def _wire_credential(svc_cls, *, account_key=None, account_name="myacct", + credential_cls_name="DefaultAzureCredential"): + svc = _wire(svc_cls) + svc.account_name = account_name + if credential_cls_name == "AccountKey": + cred = MagicMock() + cred.account_key = account_key + type(cred).__name__ = "StorageSharedKeyCredential" + else: + cred = MagicMock(spec=[]) + type(cred).__name__ = credential_cls_name + svc.credential = cred + return svc + + +class TestAccountAndCredentialHelpers: + @pytest.mark.asyncio + async def test_get_account_name_returns_value(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire_credential(blob_service_mock, account_name="abc") + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h._get_account_name() == "abc" + + @pytest.mark.asyncio + async def test_get_account_key_from_credential(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire_credential( + blob_service_mock, + account_key="key123", + credential_cls_name="AccountKey", + ) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h._get_account_key() == "key123" + + @pytest.mark.asyncio + async def test_get_account_key_from_connection_string(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire(blob_service_mock) + svc.credential = object() + conn = ( + "DefaultEndpointsProtocol=https;AccountName=x;AccountKey=k=v;" + "EndpointSuffix=core.windows.net" + ) + async with AsyncStorageBlobHelper(connection_string=conn) as h: + assert await h._get_account_key() == "k=v" + + @pytest.mark.asyncio + async def test_get_account_key_returns_none_when_no_match(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire(blob_service_mock) + svc.credential = object() + async with AsyncStorageBlobHelper(connection_string="AccountName=x") as h: + assert await h._get_account_key() is None + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "name,expected", + [ + ("StorageSharedKeyCredential", "Storage Account Key"), + ("DefaultAzureCredential", "DefaultAzureCredential"), + ("ManagedIdentityCredential", "Managed Identity"), + ("AzureCliCredential", "Azure CLI"), + ("EnvironmentCredential", "Environment Variables"), + ("WorkloadIdentityCredential", "Workload Identity"), + ("ChainedTokenCredential", "Chained Token Credential"), + ("WeirdCustomCredential", "Azure AD (WeirdCustomCredential)"), + ], + ) + async def test_credential_type_mappings( + self, blob_service_mock, name, expected + ): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire_credential(blob_service_mock, credential_cls_name=name) + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h._get_credential_type() == expected + + @pytest.mark.asyncio + async def test_credential_type_unknown_when_no_credential_attr( + self, blob_service_mock + ): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = MagicMock(spec=["close", "get_container_client"]) + svc.close = AsyncMock() + blob_service_mock.from_connection_string.return_value = svc + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h._get_credential_type() == "unknown" + + @pytest.mark.asyncio + async def test_credential_type_unknown_when_credential_is_none( + self, blob_service_mock + ): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire(blob_service_mock) + svc.credential = None + async with AsyncStorageBlobHelper(connection_string="c") as h: + assert await h._get_credential_type() == "unknown" + + +class TestGenerateBlobSasUrlAsync: + @pytest.mark.asyncio + async def test_account_key_path(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire_credential( + blob_service_mock, + account_name="acct", + account_key="abc", + credential_cls_name="AccountKey", + ) + with patch( + "azure.storage.blob.generate_blob_sas", return_value="sig=token" + ): + async with AsyncStorageBlobHelper(connection_string="c") as h: + url = await h.generate_blob_sas_url("ctn", "blob") + assert "sig=token" in url + + @pytest.mark.asyncio + async def test_user_delegation_path(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + svc.get_user_delegation_key = AsyncMock(return_value="udkey") + with patch( + "azure.storage.blob.generate_blob_sas", return_value="sig=ud" + ): + async with AsyncStorageBlobHelper(connection_string="c") as h: + url = await h.generate_blob_sas_url("ctn", "blob") + assert "sig=ud" in url + + @pytest.mark.asyncio + async def test_unknown_credential_raises(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire(blob_service_mock) + svc.account_name = "acct" + svc.credential = None + async with AsyncStorageBlobHelper(connection_string="AccountName=acct") as h: + with pytest.raises(ValueError): + await h.generate_blob_sas_url("c", "b") + + @pytest.mark.asyncio + async def test_no_account_name_raises(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h._get_account_name = AsyncMock(return_value=None) + with pytest.raises(ValueError): + await h.generate_blob_sas_url("c", "b") + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "msg", + ["403 Forbidden", "401 Unauthorized", "network down"], + ) + async def test_delegation_key_errors_wrapped(self, blob_service_mock, msg): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire_credential( + blob_service_mock, + account_name="acct", + credential_cls_name="DefaultAzureCredential", + ) + svc.get_user_delegation_key = AsyncMock(side_effect=RuntimeError(msg)) + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(ValueError): + await h.generate_blob_sas_url("c", "b") + + +class TestGenerateContainerSasUrlAsync: + @pytest.mark.asyncio + async def test_account_key_path(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire_credential( + blob_service_mock, + account_name="acct", + account_key="abc", + credential_cls_name="AccountKey", + ) + with patch( + "azure.storage.blob.generate_container_sas", return_value="sig=ctk" + ): + async with AsyncStorageBlobHelper(connection_string="c") as h: + url = await h.generate_container_sas_url("ctn") + assert "sig=ctk" in url + + @pytest.mark.asyncio + async def test_user_delegation_path(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire_credential( + blob_service_mock, + account_name="acct", + credential_cls_name="DefaultAzureCredential", + ) + svc.get_user_delegation_key = AsyncMock(return_value="udkey") + with patch( + "azure.storage.blob.generate_container_sas", + return_value="sig=udc", + ): + async with AsyncStorageBlobHelper(connection_string="c") as h: + url = await h.generate_container_sas_url("ctn") + assert "sig=udc" in url + + @pytest.mark.asyncio + async def test_unknown_credential_raises(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire(blob_service_mock) + svc.account_name = "acct" + svc.credential = None + async with AsyncStorageBlobHelper(connection_string="AccountName=acct") as h: + with pytest.raises(ValueError): + await h.generate_container_sas_url("c") + + @pytest.mark.asyncio + async def test_no_account_name_raises(self, blob_service_mock): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + _wire(blob_service_mock) + async with AsyncStorageBlobHelper(connection_string="c") as h: + h._get_account_name = AsyncMock(return_value=None) + with pytest.raises(ValueError): + await h.generate_container_sas_url("c") + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "msg", + ["403 Forbidden", "401 Unauthorized", "transient error"], + ) + async def test_delegation_key_errors_wrapped(self, blob_service_mock, msg): + from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper + + svc = _wire_credential( + blob_service_mock, + account_name="acct", + credential_cls_name="DefaultAzureCredential", + ) + svc.get_user_delegation_key = AsyncMock(side_effect=RuntimeError(msg)) + async with AsyncStorageBlobHelper(connection_string="c") as h: + with pytest.raises(ValueError): + await h.generate_container_sas_url("c") diff --git a/src/backend-api/src/tests/sas/storage/blob/test_config.py b/src/backend-api/src/tests/sas/storage/blob/test_config.py new file mode 100644 index 00000000..49f25ea5 --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/blob/test_config.py @@ -0,0 +1,123 @@ +"""Tests for libs/sas/storage/blob/config.py.""" + +import pytest + +from libs.sas.storage.blob import config as blob_config_module +from libs.sas.storage.blob.config import ( + BlobHelperConfig, + create_config, + default_config, + get_config, + set_config, +) + + +class TestBlobHelperConfigDefaults: + def test_inherits_shared_defaults(self): + cfg = BlobHelperConfig() + assert cfg.get("retry_attempts") == 3 + assert cfg.get("timeout_seconds") == 30 + assert cfg.get("logging_level") == "INFO" + + def test_blob_specific_defaults(self): + cfg = BlobHelperConfig() + assert cfg.get("max_single_upload_size") == 64 * 1024 * 1024 + assert cfg.get("max_block_size") == 4 * 1024 * 1024 + assert cfg.get("default_blob_tier") == "Hot" + assert "*.tmp" in cfg.get("sync_exclude_patterns") + + def test_init_with_overrides(self): + cfg = BlobHelperConfig({"max_block_size": 999, "custom_key": "v"}) + assert cfg.get("max_block_size") == 999 + assert cfg.get("custom_key") == "v" + # Other defaults preserved + assert cfg.get("default_blob_tier") == "Hot" + + +class TestBlobHelperConfigEnvironment: + def test_loads_env_vars_with_correct_types(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_MAX_UPLOAD_SIZE", "12345") + monkeypatch.setenv("AZURE_STORAGE_MAX_BLOCK_SIZE", "678") + monkeypatch.setenv("AZURE_STORAGE_DEFAULT_TIER", "Cool") + cfg = BlobHelperConfig() + assert cfg.get("max_single_upload_size") == 12345 + assert cfg.get("max_block_size") == 678 + assert cfg.get("default_blob_tier") == "Cool" + + def test_skips_invalid_int_env_var(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_MAX_UPLOAD_SIZE", "not-a-number") + cfg = BlobHelperConfig() + # Falls back to default when conversion fails + assert cfg.get("max_single_upload_size") == 64 * 1024 * 1024 + + def test_inherits_shared_env_loading(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_RETRY_ATTEMPTS", "9") + cfg = BlobHelperConfig() + assert cfg.get("retry_attempts") == 9 + + +class TestBlobHelperConfigGetSet: + def test_get_returns_default_for_unknown_key(self): + cfg = BlobHelperConfig() + assert cfg.get("missing_key") is None + assert cfg.get("missing_key", "fallback") == "fallback" + + def test_set_then_get(self): + cfg = BlobHelperConfig() + cfg.set("new_key", "new_val") + assert cfg.get("new_key") == "new_val" + + def test_get_all_returns_copy(self): + cfg = BlobHelperConfig() + all_cfg = cfg.get_all() + assert isinstance(all_cfg, dict) + all_cfg["mutate"] = "x" + assert cfg.get("mutate") is None # original untouched + + def test_update_multiple_keys(self): + cfg = BlobHelperConfig() + cfg.update({"a": 1, "b": 2}) + assert cfg.get("a") == 1 + assert cfg.get("b") == 2 + + def test_reset_to_defaults_restores_defaults(self): + cfg = BlobHelperConfig() + cfg.set("max_block_size", 99) + cfg.reset_to_defaults() + assert cfg.get("max_block_size") == 4 * 1024 * 1024 + + +class TestGetContentType: + def test_known_extension(self): + cfg = BlobHelperConfig() + assert cfg.get_content_type(".txt") == "text/plain" + assert cfg.get_content_type(".PDF") == "application/pdf" # case-insensitive + + def test_unknown_extension_returns_octet_stream(self): + cfg = BlobHelperConfig() + assert cfg.get_content_type(".xyz123") == "application/octet-stream" + + +class TestModuleLevelHelpers: + def test_get_config_returns_default(self): + assert get_config() is blob_config_module.default_config + + def test_set_config_replaces_default(self): + original = get_config() + try: + new_cfg = BlobHelperConfig({"flag": True}) + set_config(new_cfg) + assert get_config() is new_cfg + assert get_config().get("flag") is True + finally: + set_config(original) + + def test_create_config_returns_new_instance(self): + cfg = create_config({"x": 1}) + assert isinstance(cfg, BlobHelperConfig) + assert cfg.get("x") == 1 + + def test_create_config_no_overrides(self): + cfg = create_config() + assert isinstance(cfg, BlobHelperConfig) + assert cfg.get("default_blob_tier") == "Hot" diff --git a/src/backend-api/src/tests/sas/storage/blob/test_helper.py b/src/backend-api/src/tests/sas/storage/blob/test_helper.py new file mode 100644 index 00000000..0a9ec897 --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/blob/test_helper.py @@ -0,0 +1,515 @@ +"""Tests for libs/sas/storage/blob/helper.py.""" + +from unittest.mock import MagicMock, patch + +import pytest +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError + + +@pytest.fixture +def blob_service_mock(): + with patch("libs.sas.storage.blob.helper.BlobServiceClient") as svc_cls: + yield svc_cls + + +def _make_helper(blob_service_mock, **kwargs): + from libs.sas.storage.blob.helper import StorageBlobHelper + + return StorageBlobHelper(connection_string="conn", **kwargs) + + +def _container_client(helper): + cc = MagicMock() + helper.blob_service_client.get_container_client.return_value = cc + return cc + + +def _blob_client(container_client): + bc = MagicMock() + container_client.get_blob_client.return_value = bc + return bc + + +class TestInit: + def test_init_with_connection_string(self, blob_service_mock): + h = _make_helper(blob_service_mock) + blob_service_mock.from_connection_string.assert_called_once() + assert h._connection_string == "conn" + + def test_init_with_account_and_credential(self, blob_service_mock): + from libs.sas.storage.blob.helper import StorageBlobHelper + + StorageBlobHelper(account_name="acct", credential=MagicMock()) + blob_service_mock.assert_called() + + def test_init_with_account_only_uses_default_credential(self, blob_service_mock): + from libs.sas.storage.blob.helper import StorageBlobHelper + + with patch("libs.sas.storage.blob.helper.DefaultAzureCredential") as cred: + StorageBlobHelper(account_name="acct") + cred.assert_called_once() + + def test_init_no_args_raises(self, blob_service_mock): + from libs.sas.storage.blob.helper import StorageBlobHelper + + with pytest.raises(ValueError): + StorageBlobHelper() + + def test_init_with_dict_config(self, blob_service_mock): + from libs.sas.storage.blob.helper import StorageBlobHelper + + with patch("libs.sas.storage.blob.config.create_config") as cc: + cc.return_value = {"logging_level": "INFO"} + StorageBlobHelper(connection_string="c", config={"x": 1}) + cc.assert_called_once() + + def test_init_failure_propagates(self, blob_service_mock): + from libs.sas.storage.blob.helper import StorageBlobHelper + + blob_service_mock.from_connection_string.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError): + StorageBlobHelper(connection_string="c") + + +class TestContainerOps: + def test_create_container_success(self, blob_service_mock): + h = _make_helper(blob_service_mock) + _container_client(h) + assert h.create_container("c") is True + + def test_create_container_exists(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.create_container.side_effect = ResourceExistsError("e") + assert h.create_container("c") is False + + def test_create_container_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.create_container.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.create_container("c") + + def test_delete_container_empty(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.list_blobs.return_value = iter([]) + assert h.delete_container("c") is True + + def test_delete_container_non_empty_without_force(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.list_blobs.return_value = iter([MagicMock(name="b1")]) + with pytest.raises(ValueError): + h.delete_container("c", force_delete=False) + + def test_delete_container_force_with_blobs(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + b1 = MagicMock() + b1.name = "x" + cc.list_blobs.side_effect = [iter([b1]), iter([b1])] + bc = MagicMock() + cc.get_blob_client.return_value = bc + assert h.delete_container("c", force_delete=True) is True + + def test_delete_container_not_found(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.list_blobs.return_value = iter([]) + cc.delete_container.side_effect = ResourceNotFoundError("nf") + assert h.delete_container("c") is False + + def test_list_containers(self, blob_service_mock): + h = _make_helper(blob_service_mock) + c = MagicMock() + c.name = "x" + c.last_modified = None + c.etag = "e" + c.public_access = None + c.metadata = {"k": "v"} + h.blob_service_client.list_containers.return_value = iter([c]) + result = h.list_containers(include_metadata=True) + assert result[0]["name"] == "x" + assert result[0]["metadata"] == {"k": "v"} + + def test_list_containers_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.blob_service_client.list_containers.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.list_containers() + + def test_container_exists_true(self, blob_service_mock): + h = _make_helper(blob_service_mock) + _container_client(h) + assert h.container_exists("c") is True + + def test_container_exists_false(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.get_container_properties.side_effect = ResourceNotFoundError("nf") + assert h.container_exists("c") is False + + def test_container_exists_other_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.get_container_properties.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.container_exists("c") + + +class TestBlobUpload: + def test_upload_blob_success(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + _blob_client(cc) + assert h.upload_blob("c", "b", b"data") is True + + def test_upload_blob_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.upload_blob.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.upload_blob("c", "b", b"data") + + def test_upload_file_not_found(self, blob_service_mock): + h = _make_helper(blob_service_mock) + with pytest.raises(FileNotFoundError): + h.upload_file("c", "b", "Z:/nope/missing.txt") + + def test_upload_file_success(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + _blob_client(cc) + f = tmp_path / "x.txt" + f.write_text("hi") + assert h.upload_file("c", "b", str(f)) is True + + +class TestBlobDownload: + def test_download_blob(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + stream = MagicMock() + stream.readall.return_value = b"data" + bc.download_blob.return_value = stream + assert h.download_blob("c", "b") == b"data" + + def test_download_blob_not_found(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.download_blob.side_effect = ResourceNotFoundError("nf") + with pytest.raises(ResourceNotFoundError): + h.download_blob("c", "b") + + def test_download_blob_other_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.download_blob.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.download_blob("c", "b") + + def test_download_blob_to_file(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + stream = MagicMock() + stream.readall.return_value = b"data" + bc.download_blob.return_value = stream + out = tmp_path / "sub" / "f.bin" + assert h.download_blob_to_file("c", "b", str(out)) is True + assert out.read_bytes() == b"data" + + def test_download_blob_to_file_error(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.download_blob.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.download_blob_to_file("c", "b", str(tmp_path / "x.bin")) + + +class TestBlobMgmt: + def test_delete_blob(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + _blob_client(cc) + assert h.delete_blob("c", "b") is True + + def test_delete_blob_not_found(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.delete_blob.side_effect = ResourceNotFoundError("nf") + assert h.delete_blob("c", "b") is False + + def test_delete_blob_other_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.delete_blob.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.delete_blob("c", "b") + + def test_copy_blob(self, blob_service_mock): + h = _make_helper(blob_service_mock) + src_bc = MagicMock() + src_bc.url = "src" + dest_bc = MagicMock() + dest_bc.start_copy_from_url.return_value = {"copy_status": "success"} + h.blob_service_client.get_blob_client.side_effect = [src_bc, dest_bc] + assert h.copy_blob("a", "b", "c", "d", metadata={"k": "v"}) is True + + def test_copy_blob_pending(self, blob_service_mock): + h = _make_helper(blob_service_mock) + src_bc = MagicMock() + dest_bc = MagicMock() + dest_bc.start_copy_from_url.return_value = {"copy_status": "pending"} + h.blob_service_client.get_blob_client.side_effect = [src_bc, dest_bc] + assert h.copy_blob("a", "b", "c", "d") is True + + def test_copy_blob_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.blob_service_client.get_blob_client.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.copy_blob("a", "b", "c", "d") + + def test_move_blob_success(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.copy_blob = MagicMock(return_value=True) + h.delete_blob = MagicMock(return_value=True) + assert h.move_blob("a", "b", "c", "d") is True + + def test_move_blob_copy_failed(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.copy_blob = MagicMock(return_value=False) + assert h.move_blob("a", "b", "c", "d") is False + + def test_move_blob_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.copy_blob = MagicMock(side_effect=RuntimeError("x")) + with pytest.raises(RuntimeError): + h.move_blob("a", "b", "c", "d") + + def test_blob_exists_true(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + _blob_client(cc) + assert h.blob_exists("c", "b") is True + + def test_blob_exists_false(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.get_blob_properties.side_effect = ResourceNotFoundError("nf") + assert h.blob_exists("c", "b") is False + + def test_blob_exists_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.get_blob_properties.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.blob_exists("c", "b") + + +def _blob_obj(name="f.txt"): + b = MagicMock() + b.name = name + b.size = 10 + b.last_modified = None + b.etag = "e" + b.content_settings = None + b.blob_tier = None + b.blob_type = None + b.metadata = {"k": "v"} + b.snapshot = None + return b + + +class TestListAndProps: + def test_list_blobs(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.list_blobs.return_value = iter([_blob_obj()]) + result = h.list_blobs("c", include_metadata=True) + assert result[0]["name"] == "f.txt" + assert result[0]["metadata"] == {"k": "v"} + + def test_list_blobs_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.list_blobs.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.list_blobs("c") + + def test_list_blobs_hierarchical(self, blob_service_mock): + from azure.storage.blob import BlobPrefix + + h = _make_helper(blob_service_mock) + cc = _container_client(h) + prefix = MagicMock(spec=BlobPrefix) + prefix.name = "dir/" + cc.walk_blobs.return_value = iter([prefix, _blob_obj("dir/file.txt")]) + result = h.list_blobs_hierarchical("c", prefix="dir/") + assert len(result["prefixes"]) == 1 + assert len(result["blobs"]) == 1 + + def test_list_blobs_hierarchical_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.walk_blobs.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.list_blobs_hierarchical("c") + + def test_get_blob_properties(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + props = MagicMock() + props.size = 7 + props.last_modified = None + props.etag = "e" + props.content_settings = None + props.blob_tier = None + props.blob_type = None + props.metadata = {} + props.creation_time = None + props.lease = None + bc.get_blob_properties.return_value = props + result = h.get_blob_properties("c", "b") + assert result["size"] == 7 + assert result["lease_status"] is None + + def test_get_blob_properties_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.get_blob_properties.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.get_blob_properties("c", "b") + + def test_set_blob_metadata(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + _blob_client(cc) + assert h.set_blob_metadata("c", "b", {"k": "v"}) is True + + def test_set_blob_metadata_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.set_blob_metadata.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.set_blob_metadata("c", "b", {}) + + +class TestBatch: + def test_upload_multiple_files_mixed(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + good = tmp_path / "g.txt" + good.write_text("x") + h.upload_file = MagicMock(return_value=True) + results = h.upload_multiple_files("c", [str(good), "Z:/nope/missing.txt"]) + assert results[str(good)] is True + assert results["Z:/nope/missing.txt"] is False + + def test_upload_multiple_files_upload_error(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + f = tmp_path / "g.txt" + f.write_text("x") + h.upload_file = MagicMock(side_effect=RuntimeError("x")) + results = h.upload_multiple_files("c", [str(f)]) + assert results[str(f)] is False + + def test_download_multiple_blobs(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + h.download_blob_to_file = MagicMock(return_value=True) + results = h.download_multiple_blobs("c", ["a.txt", "b.txt"], str(tmp_path)) + assert all(results.values()) + + def test_download_multiple_blobs_error(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + h.download_blob_to_file = MagicMock(side_effect=RuntimeError("x")) + results = h.download_multiple_blobs("c", ["a.txt"], str(tmp_path)) + assert results["a.txt"] is False + + def test_delete_multiple_blobs(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.delete_blob = MagicMock(return_value=True) + results = h.delete_multiple_blobs("c", ["a", "b"]) + assert all(results.values()) + + def test_delete_multiple_blobs_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.delete_blob = MagicMock(side_effect=RuntimeError("x")) + results = h.delete_multiple_blobs("c", ["a"]) + assert results["a"] is False + + +class TestAdvanced: + def test_set_blob_tier(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + _blob_client(cc) + assert h.set_blob_tier("c", "b", "Cool") is True + + def test_set_blob_tier_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.set_standard_blob_tier.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.set_blob_tier("c", "b", "Cool") + + def test_create_snapshot(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.create_snapshot.return_value = {"snapshot": "2024-01-01"} + assert h.create_snapshot("c", "b") == "2024-01-01" + + def test_create_snapshot_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + bc = _blob_client(cc) + bc.create_snapshot.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.create_snapshot("c", "b") + + def test_list_blob_snapshots(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + snap = _blob_obj("b") + snap.snapshot = "ts" + cc.list_blobs.return_value = iter([snap]) + result = h.list_blob_snapshots("c", "b") + assert result[0]["snapshot"] == "ts" + + def test_list_blob_snapshots_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = _container_client(h) + cc.list_blobs.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.list_blob_snapshots("c", "b") + + def test_search_blobs(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.list_blobs = MagicMock( + return_value=[ + {"name": "alpha.txt", "metadata": {"tag": "x"}}, + {"name": "beta.txt", "metadata": {"tag": "alpha-tag"}}, + ] + ) + result = h.search_blobs("c", "alpha", search_in_metadata=True) + assert len(result) == 2 + + def test_search_blobs_error(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.list_blobs = MagicMock(side_effect=RuntimeError("x")) + with pytest.raises(RuntimeError): + h.search_blobs("c", "alpha") diff --git a/src/backend-api/src/tests/sas/storage/blob/test_helper_extra.py b/src/backend-api/src/tests/sas/storage/blob/test_helper_extra.py new file mode 100644 index 00000000..e4a672c5 --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/blob/test_helper_extra.py @@ -0,0 +1,395 @@ +"""Additional tests for libs/sas/storage/blob/helper.py. + +Targets the previously uncovered branches: SAS URL generation (account-key +and user-delegation paths), sync_directory, credential / account name +helpers, and miscellaneous URL builders. +""" + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def blob_service_mock(): + with patch("libs.sas.storage.blob.helper.BlobServiceClient") as svc_cls: + yield svc_cls + + +def _make_helper(blob_service_mock=None, **kwargs): + from libs.sas.storage.blob.helper import StorageBlobHelper + + return StorageBlobHelper(connection_string="conn", **kwargs) + + +class TestInitWithConfigObject: + def test_init_with_object_config(self, blob_service_mock): + from libs.sas.storage.blob.helper import StorageBlobHelper + + cfg = MagicMock() + cfg.get = MagicMock(return_value="INFO") + h = StorageBlobHelper(connection_string="c", config=cfg) + # Object configs are kept as-is (line 57 branch). + assert h.config is cfg + + +class TestDeleteContainerForceErrorBranches: + def test_force_delete_inner_blob_error_continues(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = MagicMock() + h.blob_service_client.get_container_client.return_value = cc + # Two list_blobs calls: existence check + iteration for deletion + b1 = MagicMock() + b1.name = "x" + b2 = MagicMock() + b2.name = "y" + cc.list_blobs.side_effect = [iter([b1, b2]), iter([b1, b2])] + bc_ok = MagicMock() + bc_fail = MagicMock() + bc_fail.delete_blob.side_effect = RuntimeError("blob-err") + # First blob fails, second succeeds + cc.get_blob_client.side_effect = [bc_fail, bc_ok] + # delete_container final call still succeeds + assert h.delete_container("c", force_delete=True) is True + + def test_delete_container_blobs_present_message_no_force(self, blob_service_mock): + h = _make_helper(blob_service_mock) + cc = MagicMock() + h.blob_service_client.get_container_client.return_value = cc + cc.list_blobs.return_value = iter([]) # Initially empty for first check + cc.delete_container.side_effect = RuntimeError( + "Container has blobs and cannot be deleted" + ) + with pytest.raises(ValueError): + h.delete_container("c", force_delete=False) + + +def _wire_credential(blob_service_mock, *, account_key=None, account_name="myacct", + credential_cls_name="DefaultAzureCredential"): + """Make the helper's blob_service_client respond like an Azure SDK client.""" + h = _make_helper(blob_service_mock) + h.blob_service_client.account_name = account_name + if credential_cls_name == "AccountKey": + cred = MagicMock() + cred.account_key = account_key + type(cred).__name__ = "StorageSharedKeyCredential" + else: + cred = MagicMock(spec=[]) + type(cred).__name__ = credential_cls_name + h.blob_service_client.credential = cred + return h + + +class TestAccountAndCredentialHelpers: + def test_get_account_name_returns_value(self, blob_service_mock): + h = _wire_credential(blob_service_mock, account_name="abc") + assert h._get_account_name() == "abc" + + def test_get_account_name_handles_exception(self, blob_service_mock): + h = _make_helper(blob_service_mock) + # Make property access raise via PropertyMock + type(h.blob_service_client).account_name = property( + lambda self: (_ for _ in ()).throw(RuntimeError("boom")) + ) + assert h._get_account_name() is None + + def test_get_account_key_from_credential(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, account_key="key123", credential_cls_name="AccountKey" + ) + assert h._get_account_key() == "key123" + + def test_get_account_key_from_connection_string(self, blob_service_mock): + h = _make_helper( + blob_service_mock, + ) + # Replace credential with object that lacks account_key + h.blob_service_client.credential = object() + h._connection_string = ( + "DefaultEndpointsProtocol=https;AccountName=x;AccountKey=k=y;EndpointSuffix=core" + ) + assert h._get_account_key() == "k=y" + + def test_get_account_key_returns_none_when_missing(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.blob_service_client.credential = object() + # Force a connection string without an account key + h._connection_string = "DefaultEndpointsProtocol=https;AccountName=x" + assert h._get_account_key() is None + + @pytest.mark.parametrize( + "name,expected", + [ + ("StorageSharedKeyCredential", "Storage Account Key"), + ("DefaultAzureCredential", "DefaultAzureCredential"), + ("ManagedIdentityCredential", "Managed Identity"), + ("AzureCliCredential", "Azure CLI"), + ("EnvironmentCredential", "Environment Variables"), + ("WorkloadIdentityCredential", "Workload Identity"), + ("ChainedTokenCredential", "Chained Token Credential"), + ("SomeOtherCredential", "Azure AD (SomeOtherCredential)"), + ], + ) + def test_get_credential_type_mappings(self, blob_service_mock, name, expected): + h = _wire_credential(blob_service_mock, credential_cls_name=name) + assert h._get_credential_type() == expected + + def test_get_credential_type_unknown_when_no_credential_attr(self, blob_service_mock): + h = _make_helper(blob_service_mock) + # Replace client with one that has no `credential` attribute + bsc = MagicMock(spec=[]) # no attributes + h.blob_service_client = bsc + assert h._get_credential_type() == "unknown" + + def test_get_credential_type_unknown_when_credential_is_none(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.blob_service_client.credential = None + assert h._get_credential_type() == "unknown" + + +class TestUrlBuilders: + def test_get_blob_url(self, blob_service_mock): + h = _wire_credential(blob_service_mock, account_name="acc") + url = h.get_blob_url("c", "b") + assert url == "https://acc.blob.core.windows.net/c/b" + + def test_get_container_url(self, blob_service_mock): + h = _wire_credential(blob_service_mock, account_name="acc") + assert ( + h.get_container_url("ctn") == "https://acc.blob.core.windows.net/ctn" + ) + + def test_get_content_type_uses_config(self, blob_service_mock): + h = _make_helper(blob_service_mock) + # config is a real BlobHelperConfig instance — exercise the lookup + assert h._get_content_type("README.txt") == "text/plain" + + +class TestGenerateBlobSasUrl: + def test_account_key_path(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key="abc", + credential_cls_name="AccountKey", + ) + with patch( + "azure.storage.blob.generate_blob_sas", return_value="sig=token" + ) as gen: + url = h.generate_blob_sas_url("ctn", "blob", expiry_hours=1) + gen.assert_called_once() + assert url.startswith("https://acct.blob.core.windows.net/ctn/blob?") + assert "sig=token" in url + + def test_user_delegation_path(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock(return_value="udkey") + with patch( + "azure.storage.blob.generate_blob_sas", return_value="sig=ud" + ) as gen: + url = h.generate_blob_sas_url("ctn", "blob") + gen.assert_called_once() + assert "sig=ud" in url + + def test_unknown_credential_raises(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.blob_service_client.account_name = "acct" + h.blob_service_client.credential = None # -> credential_type 'unknown' + h._connection_string = "DefaultEndpointsProtocol=https;AccountName=acct" + with pytest.raises(ValueError): + h.generate_blob_sas_url("c", "b") + + def test_no_account_name_raises(self, blob_service_mock): + h = _make_helper(blob_service_mock) + # Force account_name extraction to return None + h._get_account_name = MagicMock(return_value=None) + with pytest.raises(ValueError): + h.generate_blob_sas_url("c", "b") + + def test_user_delegation_key_403_raises_value_error(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock( + side_effect=RuntimeError("403 Forbidden") + ) + with pytest.raises(ValueError): + h.generate_blob_sas_url("c", "b") + + def test_user_delegation_key_401_raises_value_error(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock( + side_effect=RuntimeError("401 Unauthorized") + ) + with pytest.raises(ValueError): + h.generate_blob_sas_url("c", "b") + + def test_user_delegation_key_other_error_wrapped(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock( + side_effect=RuntimeError("network down") + ) + with pytest.raises(ValueError): + h.generate_blob_sas_url("c", "b") + + +class TestGenerateContainerSasUrl: + def test_account_key_path(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key="abc", + credential_cls_name="AccountKey", + ) + with patch( + "azure.storage.blob.generate_container_sas", return_value="sig=ctk" + ): + url = h.generate_container_sas_url("ctn", expiry_hours=2) + assert url.startswith("https://acct.blob.core.windows.net/ctn?") + assert "sig=ctk" in url + + def test_user_delegation_path(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock(return_value="udkey") + with patch( + "azure.storage.blob.generate_container_sas", return_value="sig=udc" + ): + url = h.generate_container_sas_url("ctn") + assert "sig=udc" in url + + def test_unknown_credential_raises(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h.blob_service_client.account_name = "acct" + h.blob_service_client.credential = None + h._connection_string = "DefaultEndpointsProtocol=https;AccountName=acct" + with pytest.raises(ValueError): + h.generate_container_sas_url("c") + + def test_no_account_name_raises(self, blob_service_mock): + h = _make_helper(blob_service_mock) + h._get_account_name = MagicMock(return_value=None) + with pytest.raises(ValueError): + h.generate_container_sas_url("c") + + def test_user_delegation_key_403_raises_value_error(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock( + side_effect=RuntimeError("403 Forbidden") + ) + with pytest.raises(ValueError): + h.generate_container_sas_url("c") + + def test_user_delegation_key_401_raises_value_error(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock( + side_effect=RuntimeError("401 Unauthorized") + ) + with pytest.raises(ValueError): + h.generate_container_sas_url("c") + + def test_user_delegation_key_other_error_wrapped(self, blob_service_mock): + h = _wire_credential( + blob_service_mock, + account_name="acct", + account_key=None, + credential_cls_name="DefaultAzureCredential", + ) + h.blob_service_client.get_user_delegation_key = MagicMock( + side_effect=RuntimeError("oops") + ) + with pytest.raises(ValueError): + h.generate_container_sas_url("c") + + +class TestSyncDirectory: + def test_missing_local_directory_raises(self, blob_service_mock): + h = _make_helper(blob_service_mock) + with pytest.raises(FileNotFoundError): + h.sync_directory("Z:/no/such/dir", "c") + + def test_uploads_new_files(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + f1 = tmp_path / "a.txt" + f1.write_text("hello") + h.blob_exists = MagicMock(return_value=False) + h.upload_file = MagicMock(return_value=True) + result = h.sync_directory(str(tmp_path), "c", blob_prefix="pre/") + assert result["total_files"] == 1 + assert "a.txt" in result["uploaded"] + h.upload_file.assert_called_once() + + def test_skips_excluded_patterns(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + (tmp_path / "x.tmp").write_text("a") + (tmp_path / "y.txt").write_text("b") + h.blob_exists = MagicMock(return_value=False) + h.upload_file = MagicMock(return_value=True) + result = h.sync_directory( + str(tmp_path), "c", exclude_patterns=["*.tmp"] + ) + assert "x.tmp" in result["skipped"] + assert "y.txt" in result["uploaded"] + + def test_skips_when_blob_newer(self, blob_service_mock, tmp_path): + from datetime import datetime, timedelta + + h = _make_helper(blob_service_mock) + f = tmp_path / "a.txt" + f.write_text("x") + h.blob_exists = MagicMock(return_value=True) + h.get_blob_properties = MagicMock( + return_value={"last_modified": datetime.now() + timedelta(hours=1)} + ) + result = h.sync_directory(str(tmp_path), "c") + assert "a.txt" in result["skipped"] + + def test_collects_errors(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + f = tmp_path / "a.txt" + f.write_text("x") + h.blob_exists = MagicMock(side_effect=RuntimeError("network")) + h.upload_file = MagicMock() + result = h.sync_directory(str(tmp_path), "c") + assert any("a.txt" in err for err in result["errors"]) + + def test_records_failed_upload(self, blob_service_mock, tmp_path): + h = _make_helper(blob_service_mock) + f = tmp_path / "a.txt" + f.write_text("x") + h.blob_exists = MagicMock(return_value=False) + h.upload_file = MagicMock(return_value=False) + result = h.sync_directory(str(tmp_path), "c") + assert any("a.txt" in err for err in result["errors"]) diff --git a/src/backend-api/src/tests/sas/storage/queue/__init__.py b/src/backend-api/src/tests/sas/storage/queue/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/sas/storage/queue/test_async_helper.py b/src/backend-api/src/tests/sas/storage/queue/test_async_helper.py new file mode 100644 index 00000000..f2906471 --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/queue/test_async_helper.py @@ -0,0 +1,544 @@ +"""Tests for libs/sas/storage/queue/async_helper.py.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError + + +class _AsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +@pytest.fixture +def queue_service_mock(): + with patch("libs.sas.storage.queue.async_helper.QueueServiceClient") as svc_cls: + yield svc_cls + + +def _make_message( + id="m1", pop_receipt="pr", content="c", inserted_on=None, expires_on=None, + next_visible_on=None, dequeue_count=0, +): + m = MagicMock() + m.id = id + m.pop_receipt = pop_receipt + m.content = content + m.inserted_on = inserted_on + m.expires_on = expires_on + m.next_visible_on = next_visible_on + m.dequeue_count = dequeue_count + return m + + +def _wire_qc(svc_cls): + """Wire QueueServiceClient -> queue_client mock with AsyncMock'd methods.""" + svc_instance = MagicMock() + svc_cls.from_connection_string.return_value = svc_instance + svc_cls.return_value = svc_instance + svc_instance.close = AsyncMock() + + qc = MagicMock() + qc.create_queue = AsyncMock() + qc.delete_queue = AsyncMock() + qc.get_queue_properties = AsyncMock() + qc.send_message = AsyncMock() + qc.delete_message = AsyncMock() + qc.update_message = AsyncMock() + qc.set_queue_metadata = AsyncMock() + qc.clear_messages = AsyncMock() + qc.peek_messages = AsyncMock() + svc_instance.get_queue_client.return_value = qc + return svc_instance, qc + + +async def _make_helper(queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + return AsyncStorageQueueHelper(connection_string="conn-str") + + +class TestInitAndContext: + @pytest.mark.asyncio + async def test_async_with_initializes_and_closes(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + svc_instance, _ = _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert h._queue_service_client is svc_instance + svc_instance.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_init_with_account_and_credential(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + svc_instance, _ = _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper( + account_name="acct", credential=MagicMock() + ): + pass + queue_service_mock.assert_called() + + @pytest.mark.asyncio + async def test_init_with_account_only_uses_default_credential( + self, queue_service_mock + ): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + svc_instance, _ = _wire_qc(queue_service_mock) + with patch( + "libs.sas.storage.queue.async_helper.DefaultAzureCredential" + ) as cred: + async with AsyncStorageQueueHelper(account_name="acct"): + pass + cred.assert_called_once() + + @pytest.mark.asyncio + async def test_init_no_args_raises(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + h = AsyncStorageQueueHelper() + with pytest.raises(ValueError): + await h._initialize_client() + + @pytest.mark.asyncio + async def test_init_failure_propagates(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + queue_service_mock.from_connection_string.side_effect = RuntimeError("boom") + h = AsyncStorageQueueHelper(connection_string="c") + with pytest.raises(RuntimeError): + await h._initialize_client() + + @pytest.mark.asyncio + async def test_property_raises_when_uninitialized(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + h = AsyncStorageQueueHelper(connection_string="c") + with pytest.raises(RuntimeError): + _ = h.queue_service_client + + def test_init_with_dict_config(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + h = AsyncStorageQueueHelper(connection_string="c", config={"logging_level": "INFO"}) + assert h.config == {"logging_level": "INFO"} + + +class TestQueueOps: + @pytest.mark.asyncio + async def test_create_queue_success(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.create_queue("q") is True + + @pytest.mark.asyncio + async def test_create_queue_already_exists(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.create_queue.side_effect = ResourceExistsError("e") + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.create_queue("q") is False + + @pytest.mark.asyncio + async def test_create_queue_other_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.create_queue.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.create_queue("q") + + @pytest.mark.asyncio + async def test_delete_queue_success(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.delete_queue("q") is True + + @pytest.mark.asyncio + async def test_delete_queue_not_found(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.delete_queue.side_effect = ResourceNotFoundError("nf") + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.delete_queue("q") is False + + @pytest.mark.asyncio + async def test_delete_queue_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.delete_queue.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.delete_queue("q") + + @pytest.mark.asyncio + async def test_queue_exists_true(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.queue_exists("q") is True + + @pytest.mark.asyncio + async def test_queue_exists_false(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.get_queue_properties.side_effect = ResourceNotFoundError("nf") + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.queue_exists("q") is False + + @pytest.mark.asyncio + async def test_queue_exists_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.get_queue_properties.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.queue_exists("q") + + @pytest.mark.asyncio + async def test_list_queues(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + svc_instance, _ = _wire_qc(queue_service_mock) + q = MagicMock() + q.name = "x" + q.metadata = {"k": "v"} + svc_instance.list_queues = MagicMock(return_value=_AsyncIter([q])) + async with AsyncStorageQueueHelper(connection_string="c") as h: + result = await h.list_queues() + assert result[0]["name"] == "x" + + @pytest.mark.asyncio + async def test_list_queues_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + svc_instance, _ = _wire_qc(queue_service_mock) + + def boom(**kw): + raise RuntimeError("x") + + svc_instance.list_queues = MagicMock(side_effect=boom) + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.list_queues() + + +class TestMessageOps: + @pytest.mark.asyncio + async def test_send_message_dict(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.send_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + info = await h.send_message("q", {"k": "v"}) + assert info["message_id"] == "m1" + + @pytest.mark.asyncio + async def test_send_message_string(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.send_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + await h.send_message("q", "hi") + + @pytest.mark.asyncio + async def test_send_message_other_type(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.send_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + await h.send_message("q", 123) + + @pytest.mark.asyncio + async def test_send_message_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.send_message.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.send_message("q", "hi") + + @pytest.mark.asyncio + async def test_receive_message_returns_one(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(return_value=_AsyncIter([_make_message()])) + async with AsyncStorageQueueHelper(connection_string="c") as h: + msg = await h.receive_message("q") + assert msg["id"] == "m1" + + @pytest.mark.asyncio + async def test_receive_message_returns_none_when_empty(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(return_value=_AsyncIter([])) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.receive_message("q") is None + + @pytest.mark.asyncio + async def test_receive_message_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(side_effect=RuntimeError("x")) + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.receive_message("q") + + @pytest.mark.asyncio + async def test_receive_messages(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock( + return_value=_AsyncIter([_make_message(), _make_message("m2")]) + ) + async with AsyncStorageQueueHelper(connection_string="c") as h: + msgs = await h.receive_messages("q", max_messages=2) + assert len(msgs) == 2 + + @pytest.mark.asyncio + async def test_receive_messages_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(side_effect=RuntimeError("x")) + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.receive_messages("q") + + @pytest.mark.asyncio + async def test_delete_message(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.delete_message("q", "id", "pr") is True + + @pytest.mark.asyncio + async def test_delete_message_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.delete_message.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.delete_message("q", "id", "pr") + + @pytest.mark.asyncio + async def test_update_message_dict(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.update_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + info = await h.update_message("q", "id", "pr", {"k": "v"}) + assert info["pop_receipt"] == "pr" + + @pytest.mark.asyncio + async def test_update_message_string(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.update_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + await h.update_message("q", "id", "pr", "hi") + + @pytest.mark.asyncio + async def test_update_message_other_type(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.update_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + await h.update_message("q", "id", "pr", 99) + + @pytest.mark.asyncio + async def test_update_message_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.update_message.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.update_message("q", "id", "pr", "hi") + + +class TestBatch: + @pytest.mark.asyncio + async def test_send_messages_batch(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.send_message.return_value = _make_message() + async with AsyncStorageQueueHelper(connection_string="c") as h: + results = await h.send_messages_batch("q", ["a", "b"]) + assert len(results) == 2 + + @pytest.mark.asyncio + async def test_send_messages_batch_filters_failures(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.send_message.side_effect = [_make_message(), RuntimeError("x")] + async with AsyncStorageQueueHelper(connection_string="c") as h: + results = await h.send_messages_batch("q", ["a", "b"]) + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_process_messages_batch_success(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(return_value=_AsyncIter([_make_message()])) + + async def proc(_msg): + return "ok" + + async with AsyncStorageQueueHelper(connection_string="c") as h: + results = await h.process_messages_batch("q", proc) + assert results[0]["success"] is True + + @pytest.mark.asyncio + async def test_process_messages_batch_no_messages(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(return_value=_AsyncIter([])) + + async def proc(_msg): + return "ok" + + async with AsyncStorageQueueHelper(connection_string="c") as h: + results = await h.process_messages_batch("q", proc) + assert results == [] + + @pytest.mark.asyncio + async def test_process_messages_batch_processor_fails(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.receive_messages = MagicMock(return_value=_AsyncIter([_make_message()])) + + async def proc(_msg): + raise RuntimeError("nope") + + async with AsyncStorageQueueHelper(connection_string="c") as h: + results = await h.process_messages_batch("q", proc) + assert results[0]["success"] is False + + +class TestPropsAndMisc: + @pytest.mark.asyncio + async def test_get_queue_properties(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + props = MagicMock() + props.metadata = {"k": "v"} + props.approximate_message_count = 5 + qc.get_queue_properties.return_value = props + async with AsyncStorageQueueHelper(connection_string="c") as h: + result = await h.get_queue_properties("q") + assert result["approximate_message_count"] == 5 + + @pytest.mark.asyncio + async def test_get_queue_properties_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.get_queue_properties.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.get_queue_properties("q") + + @pytest.mark.asyncio + async def test_set_queue_metadata(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.set_queue_metadata("q", {"k": "v"}) is True + + @pytest.mark.asyncio + async def test_set_queue_metadata_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.set_queue_metadata.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.set_queue_metadata("q", {}) + + @pytest.mark.asyncio + async def test_clear_queue(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _wire_qc(queue_service_mock) + async with AsyncStorageQueueHelper(connection_string="c") as h: + assert await h.clear_queue("q") is True + + @pytest.mark.asyncio + async def test_clear_queue_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.clear_messages.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.clear_queue("q") + + @pytest.mark.asyncio + async def test_peek_messages(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.peek_messages.return_value = [_make_message()] + async with AsyncStorageQueueHelper(connection_string="c") as h: + result = await h.peek_messages("q") + assert result[0]["id"] == "m1" + + @pytest.mark.asyncio + async def test_peek_messages_error(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + _, qc = _wire_qc(queue_service_mock) + qc.peek_messages.side_effect = RuntimeError("x") + async with AsyncStorageQueueHelper(connection_string="c") as h: + with pytest.raises(RuntimeError): + await h.peek_messages("q") + + @pytest.mark.asyncio + async def test_close_no_client(self, queue_service_mock): + from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper + + h = AsyncStorageQueueHelper(connection_string="c") + # never initialized; close should be a no-op + await h.close() diff --git a/src/backend-api/src/tests/sas/storage/queue/test_helper.py b/src/backend-api/src/tests/sas/storage/queue/test_helper.py new file mode 100644 index 00000000..563467ef --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/queue/test_helper.py @@ -0,0 +1,406 @@ +"""Tests for libs/sas/storage/queue/helper.py.""" + +from unittest.mock import MagicMock, patch + +import pytest +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError + + +@pytest.fixture +def queue_service_mock(): + with patch( + "libs.sas.storage.queue.helper.QueueServiceClient" + ) as svc_cls: + yield svc_cls + + +def _make_helper(queue_service_mock, **kwargs): + from libs.sas.storage.queue.helper import StorageQueueHelper + + helper = StorageQueueHelper(connection_string="conn-str", **kwargs) + return helper + + +def _make_message( + id="m1", pop_receipt="pr", content="c", inserted_on=None, expires_on=None, + next_visible_on=None, dequeue_count=0, +): + m = MagicMock() + m.id = id + m.pop_receipt = pop_receipt + m.content = content + m.inserted_on = inserted_on + m.expires_on = expires_on + m.next_visible_on = next_visible_on + m.dequeue_count = dequeue_count + return m + + +class TestInit: + def test_init_with_connection_string(self, queue_service_mock): + h = _make_helper(queue_service_mock) + queue_service_mock.from_connection_string.assert_called_once() + assert h._connection_string == "conn-str" + + def test_init_with_account_name_and_credential(self, queue_service_mock): + from libs.sas.storage.queue.helper import StorageQueueHelper + + cred = MagicMock() + StorageQueueHelper(account_name="acct", credential=cred) + queue_service_mock.assert_called() + + def test_init_with_account_name_only_uses_default_credential( + self, queue_service_mock + ): + from libs.sas.storage.queue.helper import StorageQueueHelper + + with patch( + "libs.sas.storage.queue.helper.DefaultAzureCredential" + ) as cred_cls: + StorageQueueHelper(account_name="acct") + cred_cls.assert_called_once() + + def test_init_no_args_raises(self, queue_service_mock): + from libs.sas.storage.queue.helper import StorageQueueHelper + + with pytest.raises(ValueError): + StorageQueueHelper() + + def test_init_with_dict_config(self, queue_service_mock): + from libs.sas.storage.queue.helper import StorageQueueHelper + + with patch("libs.sas.storage.shared_config.create_config") as cc: + cc.return_value = {"logging_level": "INFO"} + StorageQueueHelper(connection_string="c", config={"x": 1}) + cc.assert_called_once() + + def test_init_failure_propagates(self, queue_service_mock): + from libs.sas.storage.queue.helper import StorageQueueHelper + + queue_service_mock.from_connection_string.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError): + StorageQueueHelper(connection_string="c") + + +class TestQueueOperations: + def test_create_queue_success(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + h.queue_service_client.get_queue_client.return_value = qc + assert h.create_queue("q") is True + + def test_create_queue_already_exists(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.create_queue.side_effect = ResourceExistsError("exists") + h.queue_service_client.get_queue_client.return_value = qc + assert h.create_queue("q") is False + + def test_create_queue_other_error_raises(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.create_queue.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.create_queue("q") + + def test_delete_queue_success(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + h.queue_service_client.get_queue_client.return_value = qc + assert h.delete_queue("q") is True + + def test_delete_queue_not_found(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.delete_queue.side_effect = ResourceNotFoundError("nf") + h.queue_service_client.get_queue_client.return_value = qc + assert h.delete_queue("q") is False + + def test_delete_queue_error_raises(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.delete_queue.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.delete_queue("q") + + def test_list_queues(self, queue_service_mock): + h = _make_helper(queue_service_mock) + q1 = MagicMock() + q1.name = "a" + q1.metadata = {"k": "v"} + h.queue_service_client.list_queues.return_value = iter([q1]) + result = h.list_queues(include_metadata=True) + assert result[0]["name"] == "a" + assert result[0]["metadata"] == {"k": "v"} + + def test_list_queues_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + h.queue_service_client.list_queues.side_effect = RuntimeError("x") + with pytest.raises(RuntimeError): + h.list_queues() + + def test_queue_exists_true(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + h.queue_service_client.get_queue_client.return_value = qc + assert h.queue_exists("q") is True + + def test_queue_exists_false(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.get_queue_properties.side_effect = ResourceNotFoundError("nf") + h.queue_service_client.get_queue_client.return_value = qc + assert h.queue_exists("q") is False + + def test_queue_exists_other_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.get_queue_properties.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.queue_exists("q") + + def test_clear_queue(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + h.queue_service_client.get_queue_client.return_value = qc + assert h.clear_queue("q") is True + + def test_clear_queue_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.clear_messages.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.clear_queue("q") + + +class TestMessageOperations: + def test_send_message_dict(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.send_message.return_value = _make_message() + h.queue_service_client.get_queue_client.return_value = qc + info = h.send_message("q", {"key": "val"}) + assert info["message_id"] == "m1" + + def test_send_message_bytes(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.send_message.return_value = _make_message() + h.queue_service_client.get_queue_client.return_value = qc + info = h.send_message("q", b"bytes-data") + assert info["message_id"] == "m1" + + def test_send_message_string(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.send_message.return_value = _make_message() + h.queue_service_client.get_queue_client.return_value = qc + h.send_message("q", "hello") + qc.send_message.assert_called_once() + + def test_send_message_failure(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.send_message.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.send_message("q", "msg") + + def test_receive_messages(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.receive_messages.return_value = iter([_make_message(), _make_message("m2")]) + h.queue_service_client.get_queue_client.return_value = qc + msgs = h.receive_messages("q", max_messages=2) + assert len(msgs) == 2 + assert msgs[0]["message_id"] == "m1" + + def test_receive_messages_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.receive_messages.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.receive_messages("q") + + def test_peek_messages(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.peek_messages.return_value = iter([_make_message()]) + h.queue_service_client.get_queue_client.return_value = qc + msgs = h.peek_messages("q") + assert msgs[0]["message_id"] == "m1" + + def test_peek_messages_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.peek_messages.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.peek_messages("q") + + def test_delete_message(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + h.queue_service_client.get_queue_client.return_value = qc + assert h.delete_message("q", "id", "pr") is True + + def test_delete_message_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.delete_message.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.delete_message("q", "id", "pr") + + def test_update_message_dict(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.update_message.return_value = _make_message() + h.queue_service_client.get_queue_client.return_value = qc + info = h.update_message("q", "id", "pr", {"a": 1}) + assert info["pop_receipt"] == "pr" + + def test_update_message_bytes(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.update_message.return_value = _make_message() + h.queue_service_client.get_queue_client.return_value = qc + h.update_message("q", "id", "pr", b"bytes") + + def test_update_message_no_content(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.update_message.return_value = _make_message() + h.queue_service_client.get_queue_client.return_value = qc + h.update_message("q", "id", "pr", content=None) + + def test_update_message_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.update_message.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.update_message("q", "id", "pr", "msg") + + +class TestBatchAndProcessing: + def test_send_multiple_messages_mixed_results(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.send_message.side_effect = [_make_message(), RuntimeError("x")] + h.queue_service_client.get_queue_client.return_value = qc + results = h.send_multiple_messages("q", ["a", "b"]) + assert results[0]["success"] is True + assert results[1]["success"] is False + + def test_process_messages_success_with_delete(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.receive_messages.return_value = iter([_make_message()]) + h.queue_service_client.get_queue_client.return_value = qc + results = h.process_messages("q", lambda m: {"success": True}) + assert results[0]["deleted"] is True + + def test_process_messages_processor_raises(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.receive_messages.return_value = iter([_make_message()]) + h.queue_service_client.get_queue_client.return_value = qc + + def boom(_): + raise RuntimeError("nope") + + results = h.process_messages("q", boom) + assert results[0]["processing_result"]["success"] is False + + def test_process_messages_top_level_failure(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.receive_messages.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.process_messages("q", lambda m: {"success": True}) + + +class TestProperties: + def test_get_queue_properties(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + props = MagicMock() + props.metadata = {"k": "v"} + props.approximate_message_count = 7 + qc.get_queue_properties.return_value = props + h.queue_service_client.get_queue_client.return_value = qc + result = h.get_queue_properties("q") + assert result["approximate_message_count"] == 7 + + def test_get_queue_properties_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.get_queue_properties.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.get_queue_properties("q") + + def test_set_queue_metadata(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + h.queue_service_client.get_queue_client.return_value = qc + assert h.set_queue_metadata("q", {"k": "v"}) is True + + def test_set_queue_metadata_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.set_queue_metadata.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.set_queue_metadata("q", {}) + + def test_get_queue_statistics(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + props = MagicMock() + props.metadata = {} + props.approximate_message_count = 3 + qc.get_queue_properties.return_value = props + h.queue_service_client.get_queue_client.return_value = qc + stats = h.get_queue_statistics("q") + assert stats["approximate_message_count"] == 3 + assert "last_updated" in stats + + def test_get_queue_statistics_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + qc = MagicMock() + qc.get_queue_properties.side_effect = RuntimeError("x") + h.queue_service_client.get_queue_client.return_value = qc + with pytest.raises(RuntimeError): + h.get_queue_statistics("q") + + +class TestUtilities: + def test_get_queue_url(self, queue_service_mock): + h = _make_helper(queue_service_mock) + h.queue_service_client.account_name = "acct" + url = h.get_queue_url("q") + assert "acct" in url and "q" in url + + def test_get_account_name_returns_none_on_error(self, queue_service_mock): + h = _make_helper(queue_service_mock) + type(h.queue_service_client).account_name = property( + lambda s: (_ for _ in ()).throw(RuntimeError("x")) + ) + assert h._get_account_name() is None + + def test_encode_message_dict(self, queue_service_mock): + h = _make_helper(queue_service_mock) + out = h.encode_message({"a": 1}) + assert "a" in out + + def test_encode_message_string(self, queue_service_mock): + h = _make_helper(queue_service_mock) + assert h.encode_message("plain") == "plain" diff --git a/src/backend-api/src/tests/sas/storage/test_shared_config.py b/src/backend-api/src/tests/sas/storage/test_shared_config.py new file mode 100644 index 00000000..e4f2d9af --- /dev/null +++ b/src/backend-api/src/tests/sas/storage/test_shared_config.py @@ -0,0 +1,108 @@ +"""Tests for libs/sas/storage/shared_config.py.""" + +import pytest + +from libs.sas.storage import shared_config as shared_config_module +from libs.sas.storage.shared_config import ( + StorageConfig, + create_config, + get_config, + set_config, +) + + +class TestStorageConfigDefaults: + def test_default_values(self): + cfg = StorageConfig() + assert cfg.get("retry_attempts") == 3 + assert cfg.get("timeout_seconds") == 30 + assert cfg.get("logging_level") == "INFO" + + def test_init_with_overrides(self): + cfg = StorageConfig({"retry_attempts": 7, "extra": "v"}) + assert cfg.get("retry_attempts") == 7 + assert cfg.get("extra") == "v" + # Other defaults preserved + assert cfg.get("logging_level") == "INFO" + + def test_init_none_overrides(self): + cfg = StorageConfig(None) + assert cfg.get("retry_attempts") == 3 + + +class TestStorageConfigEnvironment: + def test_loads_env_vars_with_correct_types(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_RETRY_ATTEMPTS", "10") + monkeypatch.setenv("AZURE_STORAGE_TIMEOUT_SECONDS", "60") + monkeypatch.setenv("AZURE_STORAGE_LOGGING_LEVEL", "DEBUG") + cfg = StorageConfig() + assert cfg.get("retry_attempts") == 10 + assert cfg.get("timeout_seconds") == 60 + assert cfg.get("logging_level") == "DEBUG" + + def test_skips_invalid_int_env_var(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_RETRY_ATTEMPTS", "garbage") + cfg = StorageConfig() + # Falls back to default + assert cfg.get("retry_attempts") == 3 + + def test_skips_invalid_timeout_env_var(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_TIMEOUT_SECONDS", "not-an-int") + cfg = StorageConfig() + assert cfg.get("timeout_seconds") == 30 + + +class TestStorageConfigGetSet: + def test_get_with_default(self): + cfg = StorageConfig() + assert cfg.get("missing") is None + assert cfg.get("missing", "fallback") == "fallback" + + def test_set_then_get(self): + cfg = StorageConfig() + cfg.set("foo", "bar") + assert cfg.get("foo") == "bar" + + def test_get_all_returns_copy(self): + cfg = StorageConfig() + snapshot = cfg.get_all() + assert isinstance(snapshot, dict) + snapshot["new_key"] = "x" + assert cfg.get("new_key") is None # underlying dict not affected + + def test_update_multiple(self): + cfg = StorageConfig() + cfg.update({"a": 1, "b": 2}) + assert cfg.get("a") == 1 + assert cfg.get("b") == 2 + + def test_reset_to_defaults_restores(self): + cfg = StorageConfig() + cfg.set("retry_attempts", 99) + cfg.reset_to_defaults() + assert cfg.get("retry_attempts") == 3 + + +class TestModuleLevelHelpers: + def test_get_config_returns_default(self): + assert get_config() is shared_config_module.default_config + + def test_set_config_replaces_default(self): + original = get_config() + try: + new_cfg = StorageConfig({"x": 1}) + set_config(new_cfg) + assert get_config() is new_cfg + assert get_config().get("x") == 1 + finally: + set_config(original) + + def test_create_config_returns_new_instance(self): + cfg = create_config({"y": "z"}) + assert isinstance(cfg, StorageConfig) + assert cfg.get("y") == "z" + + def test_create_config_no_overrides(self): + cfg = create_config() + assert isinstance(cfg, StorageConfig) + assert cfg.get("retry_attempts") == 3 diff --git a/src/backend-api/src/tests/services/__init__.py b/src/backend-api/src/tests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backend-api/src/tests/services/test_auth.py b/src/backend-api/src/tests/services/test_auth.py new file mode 100644 index 00000000..31dec71e --- /dev/null +++ b/src/backend-api/src/tests/services/test_auth.py @@ -0,0 +1,112 @@ +import base64 +import json +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException + +from libs.services.auth import ( + UserDetails, + get_authenticated_user, + get_tenant_id, + sample_user, +) + + +def _make_request(headers: dict): + request = MagicMock() + request.headers = headers + return request + + +class TestUserDetails: + def test_basic_fields_assigned(self): + details = UserDetails( + { + "user_principal_id": "pid-1", + "user_name": "alice@example.com", + "auth_provider": "aad", + "auth_token": "tok", + } + ) + assert details.user_principal_id == "pid-1" + assert details.user_name == "alice@example.com" + assert details.auth_provider == "aad" + assert details.auth_token == "tok" + assert details.tenant_id is None + + def test_missing_keys_default_to_none(self): + details = UserDetails({}) + assert details.user_principal_id is None + assert details.user_name is None + assert details.auth_provider is None + assert details.auth_token is None + assert details.tenant_id is None + + def test_tenant_id_extracted_from_client_principal(self): + principal = {"tid": "tenant-xyz", "oid": "obj"} + encoded = base64.b64encode(json.dumps(principal).encode()).decode() + details = UserDetails( + {"user_principal_id": "pid", "client_principal_b64": encoded} + ) + assert details.tenant_id == "tenant-xyz" + + def test_placeholder_principal_value_does_not_decode(self): + details = UserDetails( + { + "user_principal_id": "pid", + "client_principal_b64": "your_base_64_encoded_token", + } + ) + assert details.tenant_id is None + + def test_invalid_client_principal_returns_empty_tenant(self): + details = UserDetails( + {"user_principal_id": "pid", "client_principal_b64": "@@@not-base64@@@"} + ) + assert details.tenant_id == "" + + +class TestGetTenantId: + def test_returns_tid_when_present(self): + principal = {"tid": "abc-123"} + encoded = base64.b64encode(json.dumps(principal).encode()).decode() + assert get_tenant_id(encoded) == "abc-123" + + def test_returns_empty_string_when_tid_missing(self): + encoded = base64.b64encode(json.dumps({}).encode()).decode() + assert get_tenant_id(encoded) == "" + + def test_returns_empty_string_on_decode_failure(self): + assert get_tenant_id("not-valid-base64-!!!") == "" + + def test_returns_empty_string_on_non_json_payload(self): + encoded = base64.b64encode(b"not-json-content").decode() + assert get_tenant_id(encoded) == "" + + +class TestGetAuthenticatedUser: + def test_uses_sample_user_when_no_principal_header(self): + request = _make_request({"some-other-header": "value"}) + user = get_authenticated_user(request) + assert ( + user.user_principal_id + == sample_user["x-ms-client-principal-id"] + ) + + def test_uses_request_headers_when_principal_present(self): + request = _make_request( + { + "x-ms-client-principal-id": "real-user-id", + "x-ms-client-principal-name": "real@example.com", + } + ) + user = get_authenticated_user(request) + assert user.user_principal_id == "real-user-id" + + def test_raises_401_when_principal_id_empty(self): + request = _make_request({"x-ms-client-principal-id": ""}) + with pytest.raises(HTTPException) as exc_info: + get_authenticated_user(request) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "User not authenticated" diff --git a/src/backend-api/src/tests/services/test_implementations.py b/src/backend-api/src/tests/services/test_implementations.py new file mode 100644 index 00000000..71e1623b --- /dev/null +++ b/src/backend-api/src/tests/services/test_implementations.py @@ -0,0 +1,134 @@ +import asyncio +import logging +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from libs.services.implementations import ( + ConsoleLoggerService, + HttpClientService, + InMemoryDataService, +) +from libs.services.interfaces import IDataService, IHttpService, ILoggerService + + +class TestInMemoryDataService: + def test_get_returns_empty_dict_for_unknown_key(self): + service = InMemoryDataService() + assert service.get_data("missing") == {} + + def test_save_then_get_roundtrip(self): + service = InMemoryDataService() + assert service.save_data("k", {"a": 1}) is True + assert service.get_data("k") == {"a": 1} + + def test_save_overwrites_existing_value(self): + service = InMemoryDataService() + service.save_data("k", {"a": 1}) + service.save_data("k", {"b": 2}) + assert service.get_data("k") == {"b": 2} + + def test_implements_interface(self): + assert isinstance(InMemoryDataService(), IDataService) + + +class TestConsoleLoggerService: + def test_log_info_writes_to_underlying_logger(self, caplog): + service = ConsoleLoggerService() + with caplog.at_level(logging.INFO, logger="ConsoleLoggerService"): + service.log_info("hello") + assert any("hello" in r.message for r in caplog.records) + + def test_log_error_with_exception_includes_exception(self, caplog): + service = ConsoleLoggerService() + with caplog.at_level(logging.ERROR, logger="ConsoleLoggerService"): + service.log_error("boom", ValueError("bad")) + assert any("boom" in r.message and "bad" in r.message for r in caplog.records) + + def test_log_error_without_exception(self, caplog): + service = ConsoleLoggerService() + with caplog.at_level(logging.ERROR, logger="ConsoleLoggerService"): + service.log_error("only message") + assert any("only message" in r.message for r in caplog.records) + + def test_implements_interface(self): + assert isinstance(ConsoleLoggerService(), ILoggerService) + + +def _build_response(json_data=None, text="ok", content_type="application/json"): + response = MagicMock() + response.headers = {"content-type": content_type} + response.json.return_value = json_data or {} + response.text = text + response.raise_for_status = MagicMock() + return response + + +class TestHttpClientService: + def test_implements_interface(self): + assert isinstance(HttpClientService(), IHttpService) + + def test_get_returns_json_when_content_type_is_json(self): + service = HttpClientService() + response = _build_response(json_data={"ok": True}) + service._client.get = AsyncMock(return_value=response) + result = asyncio.run(service.get("http://x")) + assert result == {"ok": True} + + def test_get_returns_text_when_content_type_not_json(self): + service = HttpClientService() + response = _build_response(text="plain", content_type="text/plain") + service._client.get = AsyncMock(return_value=response) + result = asyncio.run(service.get("http://x")) + assert result == {"text": "plain"} + + def test_get_returns_error_dict_on_exception(self): + service = HttpClientService() + service._client.get = AsyncMock(side_effect=RuntimeError("boom")) + result = asyncio.run(service.get("http://x")) + assert result == {"error": "boom"} + + def test_post_returns_json_when_content_type_is_json(self): + service = HttpClientService() + response = _build_response(json_data={"created": 1}) + service._client.post = AsyncMock(return_value=response) + result = asyncio.run(service.post("http://x", {"a": 1})) + assert result == {"created": 1} + + def test_post_returns_text_when_content_type_not_json(self): + service = HttpClientService() + response = _build_response(text="done", content_type="text/plain") + service._client.post = AsyncMock(return_value=response) + result = asyncio.run(service.post("http://x", {"a": 1})) + assert result == {"text": "done"} + + def test_post_returns_error_dict_on_exception(self): + service = HttpClientService() + service._client.post = AsyncMock(side_effect=RuntimeError("nope")) + result = asyncio.run(service.post("http://x", {})) + assert result == {"error": "nope"} + + def test_async_context_manager_closes_client(self): + service = HttpClientService() + service._client.aclose = AsyncMock() + + async def run(): + async with service as s: + assert s is service + + asyncio.run(run()) + service._client.aclose.assert_awaited_once() + + +class TestInterfacesAreAbstract: + def test_idata_service_cannot_be_instantiated(self): + with pytest.raises(TypeError): + IDataService() + + def test_ilogger_service_cannot_be_instantiated(self): + with pytest.raises(TypeError): + ILoggerService() + + def test_ihttp_service_cannot_be_instantiated(self): + with pytest.raises(TypeError): + IHttpService() diff --git a/src/backend-api/src/tests/services/test_input_validation.py b/src/backend-api/src/tests/services/test_input_validation.py new file mode 100644 index 00000000..418e8032 --- /dev/null +++ b/src/backend-api/src/tests/services/test_input_validation.py @@ -0,0 +1,33 @@ +from uuid import uuid4 + +import pytest + +from libs.services.input_validation import is_valid_uuid + + +class TestIsValidUuid: + """Test cases for is_valid_uuid""" + + def test_returns_true_for_valid_uuid4(self): + assert is_valid_uuid(str(uuid4())) is True + + def test_returns_true_for_known_valid_uuid4(self): + assert is_valid_uuid("123e4567-e89b-42d3-a456-426614174000") is True + + @pytest.mark.parametrize( + "value", + [ + "not-a-uuid", + "", + "123", + "00000000-0000-0000-0000-00000000000Z", + "g23e4567-e89b-42d3-a456-426614174000", + ], + ) + def test_returns_false_for_invalid_strings(self, value): + assert is_valid_uuid(value) is False + + def test_returns_false_for_none(self): + # is_valid_uuid only catches ValueError; non-string input raises TypeError + with pytest.raises(TypeError): + is_valid_uuid(None) diff --git a/src/backend-api/src/tests/services/test_interfaces.py b/src/backend-api/src/tests/services/test_interfaces.py new file mode 100644 index 00000000..fccc2177 --- /dev/null +++ b/src/backend-api/src/tests/services/test_interfaces.py @@ -0,0 +1,80 @@ +"""Tests for libs/services/interfaces.py. + +These cover the abstract `pass` bodies of the interfaces by subclassing the +ABCs and calling the parent abstract methods via super(). The bodies are no-op +``pass`` statements, so the parent calls return ``None``; the goal is purely +to exercise those lines for coverage. +""" + +from typing import Any, Dict + +import pytest + +from libs.services.interfaces import IDataService, IHttpService, ILoggerService + + +class _DataImpl(IDataService): + def get_data(self, key: str) -> Dict[str, Any]: + return super().get_data(key) + + def save_data(self, key: str, data: Dict[str, Any]) -> bool: + return super().save_data(key, data) + + +class _LoggerImpl(ILoggerService): + def log_info(self, message: str) -> None: + return super().log_info(message) + + def log_error(self, message: str, exception: Exception = None) -> None: + return super().log_error(message, exception) + + +class _HttpImpl(IHttpService): + async def get(self, url: str) -> Dict[str, Any]: + return await self._get_super(url) + + async def post(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: + return await self._post_super(url, data) + + async def _get_super(self, url): + return await IHttpService.get(self, url) + + async def _post_super(self, url, data): + return await IHttpService.post(self, url, data) + + +class TestIDataService: + def test_subclass_super_calls_return_none(self): + impl = _DataImpl() + assert impl.get_data("k") is None + assert impl.save_data("k", {"a": 1}) is None + + +class TestILoggerService: + def test_subclass_super_calls_return_none(self): + impl = _LoggerImpl() + assert impl.log_info("hello") is None + assert impl.log_error("oops", ValueError("x")) is None + assert impl.log_error("oops") is None + + +class TestIHttpService: + @pytest.mark.asyncio + async def test_subclass_super_calls_return_none(self): + impl = _HttpImpl() + assert await impl.get("https://x") is None + assert await impl.post("https://x", {"a": 1}) is None + + +class TestAbstractInstantiation: + def test_idataservice_cannot_be_instantiated_directly(self): + with pytest.raises(TypeError): + IDataService() # type: ignore[abstract] + + def test_iloggerservice_cannot_be_instantiated_directly(self): + with pytest.raises(TypeError): + ILoggerService() # type: ignore[abstract] + + def test_ihttpservice_cannot_be_instantiated_directly(self): + with pytest.raises(TypeError): + IHttpService() # type: ignore[abstract] diff --git a/src/backend-api/src/tests/services/test_process_services.py b/src/backend-api/src/tests/services/test_process_services.py new file mode 100644 index 00000000..0c83f3b5 --- /dev/null +++ b/src/backend-api/src/tests/services/test_process_services.py @@ -0,0 +1,280 @@ +"""Tests for libs/services/process_services.py.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from libs.base.typed_fastapi import TypedFastAPI +from libs.services.interfaces import ILoggerService +from libs.services.process_services import ProcessService +from libs.repositories.process_repository import ProcessRepository +from libs.repositories.process_status_repository import ProcessStatusRepository +from libs.sas.storage.blob.async_helper import AsyncStorageBlobHelper +from libs.sas.storage.queue.async_helper import AsyncStorageQueueHelper +from routers.models.files import FileInfo +from routers.models.processes import enlist_process_queue_response + + +def _make_async_cm(yielded): + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=yielded) + cm.__aexit__ = AsyncMock(return_value=False) + return cm + + +def _make_service(*, blob_helper=None, queue_helper=None, scope_services=None): + app = TypedFastAPI() + logger = MagicMock(spec=ILoggerService) + + blob_helper = blob_helper or MagicMock() + queue_helper = queue_helper or MagicMock() + blob_cm = _make_async_cm(blob_helper) + queue_cm = _make_async_cm(queue_helper) + + scope = MagicMock() + scope.get_service.side_effect = lambda t: ( + (scope_services or {}).get(t, MagicMock()) + ) + scope_cm = _make_async_cm(scope) + + ctx = MagicMock() + ctx.configuration = SimpleNamespace( + storage_account_process_container="container", + storage_account_process_queue="queue", + ) + ctx.create_scope = MagicMock(return_value=scope_cm) + + def app_get(t): + if t is ILoggerService: + return logger + if t is AsyncStorageBlobHelper: + return blob_cm + if t is AsyncStorageQueueHelper: + return queue_cm + return MagicMock() + + ctx.get_service.side_effect = app_get + app.app_context = ctx + return ProcessService(app), { + "blob": blob_helper, + "queue": queue_helper, + "logger": logger, + "scope": scope, + } + + +@pytest.mark.asyncio +class TestSaveFilesToBlob: + async def test_creates_container_when_missing(self): + blob = MagicMock() + blob.container_exists = AsyncMock(return_value=False) + blob.create_container = AsyncMock(return_value=None) + blob.upload_blob = AsyncMock(return_value=None) + svc, _ = _make_service(blob_helper=blob) + await svc.save_files_to_blob( + "p1", [FileInfo(filename="a.txt", content=b"x", content_type="t", size=1)] + ) + blob.create_container.assert_awaited_once() + blob.upload_blob.assert_awaited_once() + + async def test_skips_create_when_container_exists(self): + blob = MagicMock() + blob.container_exists = AsyncMock(return_value=True) + blob.create_container = AsyncMock(return_value=None) + blob.upload_blob = AsyncMock(return_value=None) + svc, _ = _make_service(blob_helper=blob) + await svc.save_files_to_blob( + "p1", [FileInfo(filename="a.txt", content=b"x", content_type="t", size=1)] + ) + blob.create_container.assert_not_awaited() + + +@pytest.mark.asyncio +class TestGetAllUploadedFiles: + async def test_returns_files(self): + blob = MagicMock() + blob.list_blobs = AsyncMock( + return_value=[{"name": "p1/source/a.txt"}, {"name": "p1/source/"}] + ) + blob.get_blob_properties = AsyncMock( + return_value={"content_type": "text/plain", "size": 7} + ) + svc, _ = _make_service(blob_helper=blob) + files = await svc.get_all_uploaded_files("p1") + assert len(files) == 1 + assert files[0].filename == "a.txt" + assert files[0].size == 7 + + async def test_propagates_error(self): + blob = MagicMock() + blob.list_blobs = AsyncMock(side_effect=RuntimeError("x")) + svc, _ = _make_service(blob_helper=blob) + with pytest.raises(RuntimeError): + await svc.get_all_uploaded_files("p1") + + +@pytest.mark.asyncio +class TestDeleteFileFromBlob: + async def test_deletes_existing(self): + blob = MagicMock() + blob.blob_exists = AsyncMock(return_value=True) + blob.delete_blob = AsyncMock(return_value=None) + svc, _ = _make_service(blob_helper=blob) + await svc.delete_file_from_blob("p1", "a.txt") + blob.delete_blob.assert_awaited_once() + + async def test_raises_filenotfound_when_missing(self): + blob = MagicMock() + blob.blob_exists = AsyncMock(return_value=False) + svc, _ = _make_service(blob_helper=blob) + with pytest.raises(FileNotFoundError): + await svc.delete_file_from_blob("p1", "missing.txt") + + async def test_propagates_other_errors(self): + blob = MagicMock() + blob.blob_exists = AsyncMock(return_value=True) + blob.delete_blob = AsyncMock(side_effect=RuntimeError("x")) + svc, _ = _make_service(blob_helper=blob) + with pytest.raises(RuntimeError): + await svc.delete_file_from_blob("p1", "a.txt") + + +@pytest.mark.asyncio +class TestDeleteAllFilesFromBlob: + async def test_deletes_each_and_returns_count(self): + blob = MagicMock() + blob.list_blobs = AsyncMock( + return_value=[ + {"name": "p1/source/a.txt"}, + {"name": "p1/source/b.txt"}, + {"name": "p1/source/"}, + ] + ) + blob.delete_blob = AsyncMock(return_value=None) + svc, _ = _make_service(blob_helper=blob) + count = await svc.delete_all_files_from_blob("p1") + assert count == 2 + + async def test_continues_when_one_fails(self): + blob = MagicMock() + blob.list_blobs = AsyncMock( + return_value=[{"name": "p1/source/a.txt"}, {"name": "p1/source/b.txt"}] + ) + blob.delete_blob = AsyncMock(side_effect=[RuntimeError("x"), None]) + svc, _ = _make_service(blob_helper=blob) + count = await svc.delete_all_files_from_blob("p1") + assert count == 1 + + +@pytest.mark.asyncio +class TestProcessEnqueue: + async def test_creates_queue_and_sends_message(self): + queue = MagicMock() + queue.queue_exists = AsyncMock(return_value=False) + queue.create_queue = AsyncMock(return_value=None) + queue.send_message = AsyncMock(return_value=None) + svc, _ = _make_service(queue_helper=queue) + msg = enlist_process_queue_response(user_id="u", process_id="p", message="hi") + await svc.process_enqueue(msg) + queue.create_queue.assert_awaited_once() + queue.send_message.assert_awaited_once() + + async def test_skips_create_when_queue_exists(self): + queue = MagicMock() + queue.queue_exists = AsyncMock(return_value=True) + queue.create_queue = AsyncMock(return_value=None) + queue.send_message = AsyncMock(return_value=None) + svc, _ = _make_service(queue_helper=queue) + msg = enlist_process_queue_response(user_id="u", process_id="p") + await svc.process_enqueue(msg) + queue.create_queue.assert_not_awaited() + queue.send_message.assert_awaited_once() + + +@pytest.mark.asyncio +class TestGetCurrentProcess: + async def test_returns_repo_value(self): + repo = MagicMock() + repo.get_process_status_by_process_id = AsyncMock(return_value="snapshot") + svc, _ = _make_service(scope_services={ProcessStatusRepository: repo}) + assert await svc.get_current_process("p1") == "snapshot" + + +@pytest.mark.asyncio +class TestRenderCurrentProcess: + async def test_returns_repo_value(self): + repo = MagicMock() + repo.render_agent_status = AsyncMock(return_value=["a", "b"]) + svc, _ = _make_service(scope_services={ProcessStatusRepository: repo}) + assert await svc.render_current_process("p1") == ["a", "b"] + + +@pytest.mark.asyncio +class TestGetConvertedFiles: + async def test_downloads_and_returns_files(self): + blob = MagicMock() + blob.list_blobs = AsyncMock(return_value=[{"name": "p1/converted/a.txt"}]) + blob.download_blob = AsyncMock(return_value=b"hello") + svc, _ = _make_service(blob_helper=blob) + files = await svc.get_converted_files("p1") + assert files[0].filename == "a.txt" + assert files[0].content == b"hello" + assert files[0].size == 5 + + async def test_propagates_error(self): + blob = MagicMock() + blob.list_blobs = AsyncMock(side_effect=RuntimeError("x")) + svc, _ = _make_service(blob_helper=blob) + with pytest.raises(RuntimeError): + await svc.get_converted_files("p1") + + +@pytest.mark.asyncio +class TestGetProcessSummary: + async def test_returns_entity_and_filenames(self): + repo = MagicMock() + entity = SimpleNamespace(id="p1") + repo.get_async = AsyncMock(return_value=entity) + blob = MagicMock() + blob.list_blobs = AsyncMock( + return_value=[ + {"name": "p1/converted/a.txt"}, + {"name": "p1/converted/"}, + ] + ) + svc, _ = _make_service( + blob_helper=blob, scope_services={ProcessRepository: repo} + ) + result_entity, names = await svc.get_process_summary("p1") + assert result_entity is entity + assert names == ["a.txt"] + + async def test_raises_when_process_missing(self): + repo = MagicMock() + repo.get_async = AsyncMock(return_value=None) + svc, _ = _make_service(scope_services={ProcessRepository: repo}) + with pytest.raises(ValueError): + await svc.get_process_summary("p1") + + +@pytest.mark.asyncio +class TestGetConvertedFileContent: + async def test_returns_decoded_content(self): + blob = MagicMock() + blob.download_blob = AsyncMock(return_value="hello".encode()) + svc, _ = _make_service(blob_helper=blob) + assert await svc.get_converted_file_content("p1", "a.txt") == "hello" + + async def test_returns_empty_string_when_blob_empty(self): + blob = MagicMock() + blob.download_blob = AsyncMock(return_value=None) + svc, _ = _make_service(blob_helper=blob) + assert await svc.get_converted_file_content("p1", "a.txt") == "" + + async def test_propagates_error(self): + blob = MagicMock() + blob.download_blob = AsyncMock(side_effect=RuntimeError("x")) + svc, _ = _make_service(blob_helper=blob) + with pytest.raises(RuntimeError): + await svc.get_converted_file_content("p1", "a.txt") diff --git a/src/backend-api/src/tests/test_app_init.py b/src/backend-api/src/tests/test_app_init.py new file mode 100644 index 00000000..8635896b --- /dev/null +++ b/src/backend-api/src/tests/test_app_init.py @@ -0,0 +1,16 @@ +"""Trivial coverage for `app/__init__.py` (sys.path bootstrap).""" + +import importlib +import os +import sys + + +def test_importing_app_package_inserts_source_root_into_syspath(): + # Import (or re-import) the app package + if "app" in sys.modules: + importlib.reload(sys.modules["app"]) + else: + importlib.import_module("app") + + expected = os.path.dirname(os.path.abspath(sys.modules["app"].__file__)) + assert expected in sys.path diff --git a/src/backend-api/src/tests/test_main.py b/src/backend-api/src/tests/test_main.py new file mode 100644 index 00000000..0c673fe1 --- /dev/null +++ b/src/backend-api/src/tests/test_main.py @@ -0,0 +1,21 @@ +"""Tests for main.get_app() factory.""" + +from unittest.mock import MagicMock, patch + + +def test_get_app_returns_app_and_caches_singleton(): + """get_app should call Application() once and reuse the cached instance.""" + fake_app = MagicMock(name="FastAPIApp") + fake_application = MagicMock() + fake_application.app = fake_app + + # Reset module-level singleton, then patch Application before reload + import main as main_module + + main_module._app_instance = None + with patch("main.Application", return_value=fake_application) as MockApp: + first = main_module.get_app() + second = main_module.get_app() + assert first is fake_app + assert second is fake_app + assert MockApp.call_count == 1 diff --git a/src/backend-api/uv.lock b/src/backend-api/uv.lock index 9e38fe25..efb8b59f 100644 --- a/src/backend-api/uv.lock +++ b/src/backend-api/uv.lock @@ -193,6 +193,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, ] @@ -220,6 +221,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=9.0.3" }, + { name = "pytest-asyncio", specifier = ">=0.23.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, ] @@ -2488,6 +2490,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0" diff --git a/src/processor/pyproject.toml b/src/processor/pyproject.toml index 2a97e23e..0c84d7d0 100644 --- a/src/processor/pyproject.toml +++ b/src/processor/pyproject.toml @@ -54,5 +54,8 @@ indent-style = "space" testpaths = ["src/tests"] pythonpath = ["src"] +[tool.coverage.run] +omit = ["src/tests/*"] + [tool.uv] prerelease = "allow" diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py b/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py new file mode 100644 index 00000000..26fcbfe5 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from libs.agent_framework.agent_builder import AgentBuilder + + +def _builder(): + return AgentBuilder(chat_client=MagicMock()) + + +class TestFluentSetters: + def test_default_state(self): + b = _builder() + assert b._instructions is None + assert b._tools is None + assert b._tool_choice == "auto" + assert b._kwargs == {} + + def test_with_instructions(self): + b = _builder().with_instructions("hello") + assert b._instructions == "hello" + + def test_with_id(self): + b = _builder().with_id("agent-1") + assert b._id == "agent-1" + + def test_with_name(self): + b = _builder().with_name("MyAgent") + assert b._name == "MyAgent" + + def test_with_description(self): + b = _builder().with_description("desc") + assert b._description == "desc" + + def test_with_temperature(self): + b = _builder().with_temperature(0.7) + assert b._temperature == 0.7 + + def test_with_max_tokens(self): + b = _builder().with_max_tokens(123) + assert b._max_tokens == 123 + + def test_with_tools(self): + tools = [lambda: None] + b = _builder().with_tools(tools) + assert b._tools is tools + + def test_with_tool_choice(self): + b = _builder().with_tool_choice("required") + assert b._tool_choice == "required" + + def test_with_middleware(self): + m = [MagicMock()] + b = _builder().with_middleware(m) + assert b._middleware is m + + def test_with_context_providers(self): + cp = MagicMock() + b = _builder().with_context_providers(cp) + assert b._context_providers is cp + + def test_with_conversation_id(self): + b = _builder().with_conversation_id("conv-1") + assert b._conversation_id == "conv-1" + + def test_with_model_id(self): + b = _builder().with_model_id("gpt-4") + assert b._model_id == "gpt-4" + + def test_with_top_p(self): + b = _builder().with_top_p(0.9) + assert b._top_p == 0.9 + + def test_with_frequency_penalty(self): + b = _builder().with_frequency_penalty(-0.2) + assert b._frequency_penalty == -0.2 + + def test_with_presence_penalty(self): + b = _builder().with_presence_penalty(0.5) + assert b._presence_penalty == 0.5 + + def test_with_seed(self): + b = _builder().with_seed(42) + assert b._seed == 42 + + def test_with_stop(self): + b = _builder().with_stop(["X", "Y"]) + assert b._stop == ["X", "Y"] + + def test_with_response_format(self): + class Resp: + pass + + b = _builder().with_response_format(Resp) + assert b._response_format is Resp + + def test_with_metadata(self): + b = _builder().with_metadata({"k": "v"}) + assert b._metadata == {"k": "v"} + + def test_with_user(self): + b = _builder().with_user("alice") + assert b._user == "alice" + + def test_with_additional_chat_options(self): + b = _builder().with_additional_chat_options({"x": 1}) + assert b._additional_chat_options == {"x": 1} + + def test_with_store(self): + b = _builder().with_store(True) + assert b._store is True + + def test_with_message_store_factory(self): + def f(): + return MagicMock() + b = _builder().with_message_store_factory(f) + assert b._chat_message_store_factory is f + + def test_with_logit_bias(self): + b = _builder().with_logit_bias({"1": 0.5}) + assert b._logit_bias == {"1": 0.5} + + def test_with_kwargs_merges(self): + b = _builder().with_kwargs(a=1).with_kwargs(b=2) + assert b._kwargs == {"a": 1, "b": 2} + + def test_chaining_returns_self_each_step(self): + b = _builder() + out = ( + b.with_name("n") + .with_id("i") + .with_temperature(0.1) + .with_max_tokens(10) + .with_top_p(0.5) + ) + assert out is b + + +class TestBuild: + def test_build_passes_all_state_to_chat_agent(self): + chat_client = MagicMock() + with patch("libs.agent_framework.agent_builder.ChatAgent") as mock_chat: + agent = ( + AgentBuilder(chat_client) + .with_instructions("inst") + .with_id("id1") + .with_name("name1") + .with_description("desc1") + .with_temperature(0.3) + .with_max_tokens(100) + .with_kwargs(extra=42) + .build() + ) + assert agent is mock_chat.return_value + kwargs = mock_chat.call_args.kwargs + assert kwargs["chat_client"] is chat_client + assert kwargs["instructions"] == "inst" + assert kwargs["id"] == "id1" + assert kwargs["name"] == "name1" + assert kwargs["description"] == "desc1" + assert kwargs["temperature"] == 0.3 + assert kwargs["max_tokens"] == 100 + assert kwargs["tool_choice"] == "auto" + assert kwargs["extra"] == 42 + + +class TestStaticFactories: + def test_create_agent_invokes_chat_agent(self): + chat_client = MagicMock() + with patch("libs.agent_framework.agent_builder.ChatAgent") as mock_chat: + agent = AgentBuilder.create_agent( + chat_client=chat_client, + instructions="i", + name="n", + temperature=0.4, + ) + assert agent is mock_chat.return_value + kwargs = mock_chat.call_args.kwargs + assert kwargs["chat_client"] is chat_client + assert kwargs["instructions"] == "i" + assert kwargs["name"] == "n" + assert kwargs["temperature"] == 0.4 + + def test_create_agent_by_agentinfo_uses_helper_and_creates_client(self): + # Build a fake AgentInfo with the minimum surface used by the method + helper = MagicMock() + helper.settings.get_service_config.return_value = SimpleNamespace( + endpoint="https://x", + chat_deployment_name="gpt", + api_version="2024-02-01", + ) + helper.create_client.return_value = "client-instance" + agent_info = SimpleNamespace( + agent_framework_helper=helper, + agent_type="azure_openai", + agent_instruction="instr", + agent_system_prompt=None, + agent_name="A", + agent_description="D", + ) + with patch( + "libs.agent_framework.agent_builder.get_bearer_token_provider", + return_value="token-provider", + ), patch("libs.agent_framework.agent_builder.ChatAgent") as mock_chat: + agent = AgentBuilder.create_agent_by_agentinfo( + service_id="default", + agent_info=agent_info, + temperature=0.2, + ) + assert agent is mock_chat.return_value + helper.settings.get_service_config.assert_called_once_with("default") + helper.create_client.assert_called_once() + ck = mock_chat.call_args.kwargs + assert ck["chat_client"] == "client-instance" + assert ck["instructions"] == "instr" + assert ck["name"] == "A" + assert ck["description"] == "D" + assert ck["temperature"] == 0.2 + + def test_create_agent_by_agentinfo_falls_back_to_system_prompt(self): + helper = MagicMock() + helper.settings.get_service_config.return_value = SimpleNamespace( + endpoint="https://x", + chat_deployment_name="gpt", + api_version="2024-02-01", + ) + helper.create_client.return_value = "client" + agent_info = SimpleNamespace( + agent_framework_helper=helper, + agent_type="azure_openai", + agent_instruction=None, + agent_system_prompt="fallback", + agent_name="A", + agent_description="D", + ) + with patch( + "libs.agent_framework.agent_builder.get_bearer_token_provider", + return_value="tp", + ), patch("libs.agent_framework.agent_builder.ChatAgent") as mock_chat: + AgentBuilder.create_agent_by_agentinfo( + service_id="default", agent_info=agent_info + ) + assert mock_chat.call_args.kwargs["instructions"] == "fallback" + + def test_create_agent_by_agentinfo_raises_when_service_config_missing(self): + helper = MagicMock() + helper.settings.get_service_config.return_value = None + agent_info = SimpleNamespace( + agent_framework_helper=helper, + agent_type="azure_openai", + agent_instruction="x", + agent_system_prompt=None, + agent_name="A", + agent_description="D", + ) + import pytest + + with pytest.raises(ValueError, match="Service config"): + AgentBuilder.create_agent_by_agentinfo( + service_id="missing", agent_info=agent_info + ) diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_agent_framework_helper.py b/src/processor/src/tests/unit/libs/agent_framework/test_agent_framework_helper.py new file mode 100644 index 00000000..64a8d415 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_agent_framework_helper.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from libs.agent_framework.agent_framework_helper import ( + AgentFrameworkHelper, + ClientType, +) + + +def _run(coro): + return asyncio.run(coro) + + +class TestInitialization: + def test_init_creates_empty_registry(self): + h = AgentFrameworkHelper() + assert h.ai_clients == {} + + def test_initialize_requires_settings(self): + h = AgentFrameworkHelper() + with pytest.raises(ValueError): + h.initialize(None) + + def test_initialize_all_clients_skips_invalid(self): + h = AgentFrameworkHelper() + settings = MagicMock() + settings.get_available_services.return_value = ["default", "broken"] + cfg_default = MagicMock( + endpoint="https://x", chat_deployment_name="gpt-4", api_version="v1" + ) + # broken returns None to exercise the warning path + settings.get_service_config.side_effect = lambda sid: cfg_default if sid == "default" else None + + with patch( + "libs.agent_framework.agent_framework_helper.get_bearer_token_provider", + return_value="token", + ), patch.object( + AgentFrameworkHelper, "create_client", return_value="client_obj" + ) as mock_create: + h.initialize(settings) + assert h.ai_clients == {"default": "client_obj"} + assert mock_create.call_count == 1 + + def test_get_client_async_returns_cached(self): + h = AgentFrameworkHelper() + h.ai_clients["default"] = "cached_client" + result = _run(h.get_client_async("default")) + assert result == "cached_client" + + def test_get_client_async_returns_none_for_missing(self): + h = AgentFrameworkHelper() + result = _run(h.get_client_async("nope")) + assert result is None + + +class TestCreateClient: + def test_not_implemented_openai_chat(self): + with pytest.raises(NotImplementedError): + AgentFrameworkHelper.create_client(ClientType.OpenAIChatCompletion) + + def test_not_implemented_openai_assistant(self): + with pytest.raises(NotImplementedError): + AgentFrameworkHelper.create_client(ClientType.OpenAIAssistant) + + def test_not_implemented_openai_response(self): + with pytest.raises(NotImplementedError): + AgentFrameworkHelper.create_client(ClientType.OpenAIResponse) + + def test_unsupported_client_type_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + AgentFrameworkHelper.create_client("garbage") # type: ignore[arg-type] + + def test_azure_openai_response_with_retry(self): + with patch( + "libs.agent_framework.agent_framework_helper.AzureOpenAIResponseClientWithRetry" + ) as mock_cls: + client = AgentFrameworkHelper.create_client( + ClientType.AzureOpenAIResponseWithRetry, + endpoint="https://x", + deployment_name="gpt-4", + ad_token_provider="token", + ) + assert client is mock_cls.return_value + kwargs = mock_cls.call_args.kwargs + assert kwargs["endpoint"] == "https://x" + assert kwargs["deployment_name"] == "gpt-4" + assert kwargs["ad_token_provider"] == "token" + + def test_default_token_provider_when_no_credential(self): + with patch( + "libs.agent_framework.agent_framework_helper.AzureOpenAIResponseClientWithRetry" + ) as mock_cls, patch( + "libs.agent_framework.agent_framework_helper.get_bearer_token_provider", + return_value="default-token", + ): + AgentFrameworkHelper.create_client( + ClientType.AzureOpenAIResponseWithRetry, + endpoint="https://x", + deployment_name="gpt-4", + ) + assert mock_cls.call_args.kwargs["ad_token_provider"] == "default-token" + + def test_azure_openai_chat_completion(self): + # Patch the lazily imported module + fake_module = types.ModuleType("agent_framework.azure") + fake_module.AzureOpenAIChatClient = MagicMock(return_value="chat_client") + with patch.dict(sys.modules, {"agent_framework.azure": fake_module}): + client = AgentFrameworkHelper.create_client( + ClientType.AzureOpenAIChatCompletion, + endpoint="https://x", + deployment_name="gpt-4", + ad_token_provider="t", + ) + assert client == "chat_client" + + def test_azure_openai_assistant(self): + fake_module = types.ModuleType("agent_framework.azure") + fake_module.AzureOpenAIAssistantsClient = MagicMock(return_value="asst_client") + with patch.dict(sys.modules, {"agent_framework.azure": fake_module}): + client = AgentFrameworkHelper.create_client( + ClientType.AzureOpenAIAssistant, + endpoint="https://x", + deployment_name="gpt-4", + ad_token_provider="t", + ) + assert client == "asst_client" + + def test_azure_openai_response(self): + fake_module = types.ModuleType("agent_framework.azure") + fake_module.AzureOpenAIResponsesClient = MagicMock(return_value="resp_client") + with patch.dict(sys.modules, {"agent_framework.azure": fake_module}): + client = AgentFrameworkHelper.create_client( + ClientType.AzureOpenAIResponse, + endpoint="https://x", + deployment_name="gpt-4", + ad_token_provider="t", + ) + assert client == "resp_client" + + def test_azure_openai_agent(self): + fake_module = types.ModuleType("agent_framework.azure") + fake_module.AzureAIAgentClient = MagicMock(return_value="agent_client") + with patch.dict(sys.modules, {"agent_framework.azure": fake_module}): + client = AgentFrameworkHelper.create_client( + ClientType.AzureOpenAIAgent, + project_endpoint="https://proj", + model_deployment_name="gpt-4", + ad_token_provider="t", + ) + assert client == "agent_client" diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_agent_framework_settings.py b/src/processor/src/tests/unit/libs/agent_framework/test_agent_framework_settings.py new file mode 100644 index 00000000..8e732547 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_agent_framework_settings.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import pytest + +from libs.agent_framework.agent_framework_settings import AgentFrameworkSettings + + +@pytest.fixture +def clear_azure_env(monkeypatch): + """Wipe AZURE_OPENAI_* env vars before each test so service discovery is deterministic.""" + for key in [ + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", + "AZURE_OPENAI_API_VERSION", + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_BASE_URL", + "AZURE_OPENAI_TEXT_DEPLOYMENT_NAME", + "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME", + ]: + monkeypatch.delenv(key, raising=False) + + +def test_init_with_no_custom_prefixes(clear_azure_env): + s = AgentFrameworkSettings() + assert s.use_entra_id is True + assert s.custom_service_prefixes == {} + assert s.service_configs == {} + + +def test_init_with_custom_prefixes(monkeypatch, clear_azure_env): + monkeypatch.setenv("CUSTOM_ENDPOINT", "https://x.openai.azure.com/") + monkeypatch.setenv("CUSTOM_CHAT_DEPLOYMENT_NAME", "gpt-4") + s = AgentFrameworkSettings(custom_service_prefixes={"alt": "CUSTOM"}) + assert "alt" in s.service_configs + assert s.has_service("alt") is True + + +def test_get_service_config_returns_none_for_unknown(clear_azure_env): + s = AgentFrameworkSettings() + assert s.get_service_config("unknown") is None + + +def test_discovers_default_when_env_present(monkeypatch, clear_azure_env): + monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://x.openai.azure.com/") + monkeypatch.setenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4") + s = AgentFrameworkSettings() + assert "default" in s.service_configs + assert s.get_available_services() == ["default"] + cfg = s.get_service_config("default") + assert cfg is not None + assert cfg.endpoint == "https://x.openai.azure.com/" + + +def test_refresh_services(monkeypatch, clear_azure_env): + s = AgentFrameworkSettings() + monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://x.openai.azure.com/") + monkeypatch.setenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4") + s.refresh_services() + assert s.has_service("default") is True + + +def test_load_env_file_loads_values(tmp_path, monkeypatch, clear_azure_env): + f = tmp_path / "test.env" + f.write_text( + '# comment line\n' + 'AZURE_OPENAI_ENDPOINT="https://from-file.openai.azure.com/"\n' + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME='gpt-4'\n" + "EMPTY_LINE_BELOW=\n" + "\n" + ) + monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False) + monkeypatch.delenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", raising=False) + s = AgentFrameworkSettings(env_file_path=str(f)) + # env loaded → discover_services should pick it up + assert s.has_service("default") is True + + +def test_load_env_file_missing_path_is_ignored(clear_azure_env): + # Non-existent file path fails the os.path.exists check and is silently ignored + s = AgentFrameworkSettings(env_file_path="/nope/does/not/exist.env") + assert s.use_entra_id is True + + +def test_load_env_file_unreadable_raises(tmp_path, monkeypatch, clear_azure_env): + f = tmp_path / "bad.env" + f.write_text("ok\n") + s = AgentFrameworkSettings() + # Open _load_env_file directly to test error wrapping + monkeypatch.setattr( + "builtins.open", + lambda *a, **k: (_ for _ in ()).throw(OSError("permission denied")), + ) + with pytest.raises(ValueError, match="Error loading environment file"): + s._load_env_file(str(f)) + + +def test_load_env_file_not_found_raises(clear_azure_env): + s = AgentFrameworkSettings() + with pytest.raises(ValueError, match="Environment file not found"): + # Bypass os.path.exists check by calling private method directly + s._load_env_file("/definitely/does/not/exist.env") diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_agent_speaking_capture.py b/src/processor/src/tests/unit/libs/agent_framework/test_agent_speaking_capture.py new file mode 100644 index 00000000..2def3012 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_agent_speaking_capture.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import asyncio +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +from libs.agent_framework.agent_speaking_capture import AgentSpeakingCaptureMiddleware + + +def _ctx(agent_name="A1", is_streaming=False, result=None, messages=None): + agent = SimpleNamespace(name=agent_name) + return SimpleNamespace( + agent=agent, + is_streaming=is_streaming, + result=result, + messages=messages or [], + ) + + +def _result_with_messages(*texts): + msgs = [SimpleNamespace(text=t) for t in texts] + return SimpleNamespace(messages=msgs) + + +def _run(coro): + return asyncio.get_event_loop().run_until_complete(coro) if False else asyncio.run(coro) + + +class TestAgentSpeakingCaptureMiddleware: + def test_captures_non_streaming_response_with_messages(self): + mw = AgentSpeakingCaptureMiddleware() + + async def _next(_ctx_): + return None + + ctx = _ctx(result=_result_with_messages("hello", "world")) + _run(mw.process(ctx, _next)) + + all_responses = mw.get_all_responses() + assert len(all_responses) == 1 + assert all_responses[0]["agent_name"] == "A1" + assert "hello" in all_responses[0]["response"] + assert "world" in all_responses[0]["response"] + assert all_responses[0]["is_streaming"] is False + + def test_captures_response_with_text_attr(self): + mw = AgentSpeakingCaptureMiddleware() + result = SimpleNamespace(text="just text") + + async def _next(_): + return None + + ctx = _ctx(result=result) + _run(mw.process(ctx, _next)) + assert mw.get_all_responses()[0]["response"] == "just text" + + def test_captures_response_falls_back_to_str(self): + mw = AgentSpeakingCaptureMiddleware() + # No messages, no text -> str(result) + result = "raw-string-value" + + async def _next(_): + return None + + ctx = _ctx(result=result) + _run(mw.process(ctx, _next)) + assert mw.get_all_responses()[0]["response"] == "raw-string-value" + + def test_streaming_records_placeholder(self): + mw = AgentSpeakingCaptureMiddleware() + + async def _next(c): + c.result = None # generator already consumed + return None + + ctx = _ctx(is_streaming=True, result=None) + _run(mw.process(ctx, _next)) + responses = mw.get_all_responses() + assert responses[0]["is_streaming"] is True + assert "Streaming response" in responses[0]["response"] + + def test_no_storage_returns_empty(self): + mw = AgentSpeakingCaptureMiddleware(store_responses=False) + + async def _next(_): + return None + + ctx = _ctx(result=_result_with_messages("x")) + _run(mw.process(ctx, _next)) + assert mw.get_all_responses() == [] + assert mw.get_responses_by_agent("A1") == [] + + def test_clear_resets_storage(self): + mw = AgentSpeakingCaptureMiddleware() + + async def _next(_): + return None + + ctx = _ctx(result=_result_with_messages("hi")) + _run(mw.process(ctx, _next)) + assert mw.get_all_responses() + mw.clear() + assert mw.get_all_responses() == [] + + def test_get_responses_by_agent_filters(self): + mw = AgentSpeakingCaptureMiddleware() + + async def _next(_): + return None + + for name in ("A", "B", "A"): + _run(mw.process(_ctx(agent_name=name, result=_result_with_messages("x")), _next)) + + assert len(mw.get_responses_by_agent("A")) == 2 + assert len(mw.get_responses_by_agent("B")) == 1 + + def test_async_callback_invoked(self): + cb = AsyncMock() + mw = AgentSpeakingCaptureMiddleware(callback=cb) + + async def _next(_): + return None + + _run(mw.process(_ctx(result=_result_with_messages("hi")), _next)) + cb.assert_awaited_once() + + def test_sync_callback_invoked(self): + seen = [] + + def cb(data): + seen.append(data["agent_name"]) + + mw = AgentSpeakingCaptureMiddleware(callback=cb) + + async def _next(_): + return None + + _run(mw.process(_ctx(agent_name="X", result=_result_with_messages("h")), _next)) + assert seen == ["X"] + + def test_callback_exception_swallowed(self, capsys): + def cb(_): + raise RuntimeError("boom") + + mw = AgentSpeakingCaptureMiddleware(callback=cb) + + async def _next(_): + return None + + _run(mw.process(_ctx(result=_result_with_messages("h")), _next)) + captured = capsys.readouterr() + assert "WARNING" in captured.out + + def test_stream_complete_callback_invoked(self): + cb = AsyncMock() + mw = AgentSpeakingCaptureMiddleware(on_stream_response_complete=cb) + + async def _next(_): + return None + + _run(mw.process(_ctx(is_streaming=True), _next)) + cb.assert_awaited_once() + + def test_agent_without_name_uses_str(self): + mw = AgentSpeakingCaptureMiddleware() + + async def _next(_): + return None + + # Use an object that does not have 'name' + class A: + def __str__(self): + return "AGENT_STR" + + ctx = SimpleNamespace( + agent=A(), is_streaming=False, result=_result_with_messages("x"), messages=[] + ) + _run(mw.process(ctx, _next)) + assert mw.get_all_responses()[0]["agent_name"] == "AGENT_STR" diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_extras.py b/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_extras.py new file mode 100644 index 00000000..211fef00 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_extras.py @@ -0,0 +1,345 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from libs.agent_framework.azure_openai_response_retry import ( + ContextTrimConfig, + RateLimitRetryConfig, + _estimate_message_text, + _format_exc_brief, + _get_message_role, + _looks_like_context_length, + _looks_like_rate_limit, + _looks_like_save_blob_call, + _looks_like_tool_result, + _safe_str, + _set_message_text, + _summarize_save_blob, + _trim_messages, + _truncate_text, + _try_get_retry_after_seconds, +) + + +class TestFormatExcBrief: + def test_with_message(self): + assert _format_exc_brief(ValueError("boom")) == "ValueError: boom" + + def test_no_message(self): + assert _format_exc_brief(ValueError("")) == "ValueError" + + +class TestRateLimitRetryConfig: + def test_from_env_defaults(self, monkeypatch): + for k in ("AOAI_429_MAX_RETRIES", "AOAI_429_BASE_DELAY_SECONDS", "AOAI_429_MAX_DELAY_SECONDS"): + monkeypatch.delenv(k, raising=False) + cfg = RateLimitRetryConfig.from_env() + assert cfg.max_retries == 8 + assert cfg.base_delay_seconds == 5.0 + + def test_from_env_with_values(self, monkeypatch): + monkeypatch.setenv("AOAI_429_MAX_RETRIES", "3") + monkeypatch.setenv("AOAI_429_BASE_DELAY_SECONDS", "2.5") + monkeypatch.setenv("AOAI_429_MAX_DELAY_SECONDS", "60.0") + cfg = RateLimitRetryConfig.from_env() + assert cfg.max_retries == 3 + assert cfg.base_delay_seconds == 2.5 + assert cfg.max_delay_seconds == 60.0 + + def test_from_env_with_invalid_int(self, monkeypatch): + monkeypatch.setenv("AOAI_429_MAX_RETRIES", "abc") + cfg = RateLimitRetryConfig.from_env() + assert cfg.max_retries == 8 + + def test_from_env_negative_clamped(self, monkeypatch): + monkeypatch.setenv("AOAI_429_MAX_RETRIES", "-3") + monkeypatch.setenv("AOAI_429_BASE_DELAY_SECONDS", "-1") + cfg = RateLimitRetryConfig.from_env() + assert cfg.max_retries == 0 + assert cfg.base_delay_seconds == 0.0 + + +class TestLooksLikeRateLimit: + def test_text_indicator(self): + assert _looks_like_rate_limit(Exception("Too Many Requests")) is True + assert _looks_like_rate_limit(Exception("rate limit exceeded")) is True + + def test_status_429(self): + e = Exception("anything") + e.status_code = 429 + assert _looks_like_rate_limit(e) is True + + def test_status_500(self): + e = Exception("server") + e.status_code = 503 + assert _looks_like_rate_limit(e) is True + + def test_empty_message_treated_transient(self): + assert _looks_like_rate_limit(Exception("")) is True + + def test_chained_cause(self): + inner = Exception("rate limit") + outer = Exception("wrapper") + outer.__cause__ = inner + assert _looks_like_rate_limit(outer) is True + + def test_returns_false_for_unrelated(self): + e = Exception("validation failed: bad input") + e.status_code = 400 + assert _looks_like_rate_limit(e) is False + + +class TestLooksLikeContextLength: + def test_text_indicator(self): + assert _looks_like_context_length(Exception("maximum context length exceeded")) + + def test_400_with_context_keyword(self): + e = Exception("token limit exceeded") + e.status_code = 400 + assert _looks_like_context_length(e) is True + + def test_400_without_context_keyword(self): + e = Exception("invalid argument") + e.status_code = 400 + assert _looks_like_context_length(e) is False + + def test_chained_cause(self): + inner = Exception("maximum context length exceeded") + outer = Exception("oops") + outer.__cause__ = inner + assert _looks_like_context_length(outer) is True + + +class TestSafeStr: + def test_none(self): + assert _safe_str(None) == "" + + def test_str_passthrough(self): + assert _safe_str("hi") == "hi" + + def test_int_converted(self): + assert _safe_str(123) == "123" + + +class TestToolResultDetection: + def test_short_text_returns_false(self): + assert _looks_like_tool_result("short") is False + + def test_blob_indicator_returns_true(self): + text = '{"blob_name": "x.txt", ' + "x" * 100 + '}' + assert _looks_like_tool_result(text) is True + + def test_no_indicators(self): + assert _looks_like_tool_result("a" * 100) is False + + +class TestSaveBlobCallDetection: + def test_empty_returns_false(self): + assert _looks_like_save_blob_call("") is False + + def test_short_returns_false(self): + assert _looks_like_save_blob_call("save_content_to_blob(short)") is False + + def test_long_call_returns_true(self): + text = "save_content_to_blob(" + "x" * 1500 + ")" + assert _looks_like_save_blob_call(text) is True + + +class TestSummarizeSaveBlob: + def test_extracts_blob_name(self): + text = '{"blob_name": "report.pdf", "data": "x"}' + result = _summarize_save_blob(text, max_chars=200) + assert "report.pdf" in result + + def test_unknown_when_no_blob_name(self): + text = '{"other": "data"}' + result = _summarize_save_blob(text, max_chars=200) + assert "unknown" in result + + +class TestTruncateText: + def test_zero_max(self): + assert _truncate_text("x" * 100, max_chars=0, keep_head_chars=0, keep_tail_chars=0) == "" + + def test_empty(self): + assert _truncate_text("", max_chars=10, keep_head_chars=5, keep_tail_chars=5) == "" + + def test_short_passthrough(self): + assert _truncate_text("hi", max_chars=100, keep_head_chars=5, keep_tail_chars=5) == "hi" + + def test_truncates_with_marker(self): + text = "A" * 500 + "B" * 500 + result = _truncate_text(text, max_chars=200, keep_head_chars=50, keep_tail_chars=50) + assert "TRUNCATED" in result + + def test_no_tail_when_remaining_zero(self): + text = "X" * 100 + result = _truncate_text(text, max_chars=20, keep_head_chars=20, keep_tail_chars=10) + assert len(result) <= 20 + + +class TestEstimateMessageText: + def test_none(self): + assert _estimate_message_text(None) == "" + + def test_dict_with_content(self): + assert _estimate_message_text({"content": "hello"}) == "hello" + + def test_dict_with_text(self): + assert _estimate_message_text({"text": "hi"}) == "hi" + + def test_object_with_content(self): + class M: + content = "msg" + + assert _estimate_message_text(M()) == "msg" + + def test_dict_fallback(self): + result = _estimate_message_text({"role": "user"}) + assert "user" in result + + +class TestMessageRole: + def test_dict(self): + assert _get_message_role({"role": "user"}) == "user" + + def test_dict_no_role(self): + assert _get_message_role({}) is None + + def test_object(self): + class M: + role = "system" + + assert _get_message_role(M()) == "system" + + def test_none(self): + assert _get_message_role(None) is None + + +class TestSetMessageText: + def test_dict_with_content(self): + result = _set_message_text({"content": "old"}, "new") + assert result["content"] == "new" + + def test_dict_with_text(self): + result = _set_message_text({"text": "old"}, "new") + assert result["text"] == "new" + + def test_dict_with_no_known_keys(self): + result = _set_message_text({"role": "user"}, "new") + assert result["content"] == "new" + + def test_object_with_content(self): + class M: + content = "old" + + m = M() + result = _set_message_text(m, "new") + assert result.content == "new" + + +class TestContextTrimConfigFromEnv: + def test_defaults_when_unset(self, monkeypatch): + for k in [ + "AOAI_CTX_TRIM_ENABLED", + "AOAI_CTX_MAX_TOTAL_CHARS", + "AOAI_CTX_MAX_MESSAGE_CHARS", + "AOAI_CTX_KEEP_LAST_MESSAGES", + "AOAI_CTX_KEEP_HEAD_CHARS", + "AOAI_CTX_KEEP_TAIL_CHARS", + "AOAI_CTX_KEEP_SYSTEM_MESSAGES", + "AOAI_CTX_RETRY_ON_CONTEXT_ERROR", + ]: + monkeypatch.delenv(k, raising=False) + cfg = ContextTrimConfig.from_env() + assert cfg.enabled is True + + def test_disabled_via_env(self, monkeypatch): + monkeypatch.setenv("AOAI_CTX_TRIM_ENABLED", "0") + cfg = ContextTrimConfig.from_env() + assert cfg.enabled is False + + def test_invalid_int_falls_back(self, monkeypatch): + monkeypatch.setenv("AOAI_CTX_MAX_TOTAL_CHARS", "abc") + cfg = ContextTrimConfig.from_env() + assert cfg.max_total_chars == 240_000 + + +class TestTrimMessages: + def test_disabled_returns_copy(self): + cfg = ContextTrimConfig(enabled=False) + msgs = [{"role": "user", "content": "hi"}] + out = _trim_messages(list(msgs), cfg=cfg) + assert out == msgs + + def test_keeps_last_n(self): + cfg = ContextTrimConfig( + enabled=True, + max_total_chars=10_000, + max_message_chars=0, + keep_last_messages=2, + keep_system_messages=False, + ) + msgs = [ + {"role": "user", "content": f"msg {i}"} for i in range(10) + ] + out = _trim_messages(list(msgs), cfg=cfg) + assert len(out) == 2 + assert "msg 9" in out[-1]["content"] + + def test_summarizes_save_blob_call(self): + cfg = ContextTrimConfig(enabled=True, max_total_chars=100_000, keep_last_messages=10) + big = ( + 'save_content_to_blob {"blob_name": "report.json", "content": "' + + "x" * 2000 + + '"}' + ) + msgs = [{"role": "user", "content": big}] + out = _trim_messages(list(msgs), cfg=cfg) + assert "report.json" in out[-1]["content"] + + def test_drops_old_when_over_budget(self): + cfg = ContextTrimConfig( + enabled=True, + max_total_chars=200, + max_message_chars=0, + keep_last_messages=20, + keep_system_messages=False, + ) + msgs = [{"role": "user", "content": "y" * 100} for _ in range(10)] + out = _trim_messages(list(msgs), cfg=cfg) + assert sum(len(m["content"]) for m in out) <= 200 + + +class TestTryGetRetryAfter: + def test_int_attribute(self): + e = Exception("x") + e.retry_after = 7 + assert _try_get_retry_after_seconds(e) == 7.0 + + def test_string_attribute(self): + e = Exception("x") + e.retry_after = "12.5" + assert _try_get_retry_after_seconds(e) == 12.5 + + def test_invalid_string_returns_none(self): + e = Exception("x") + e.retry_after = "not-a-number" + assert _try_get_retry_after_seconds(e) is None + + def test_headers_dict(self): + e = Exception("x") + e.retry_after = None + e.headers = {"retry-after": "42"} + assert _try_get_retry_after_seconds(e) == 42.0 + + def test_no_attributes_returns_none(self): + assert _try_get_retry_after_seconds(Exception("x")) is None + + def test_inner_exception(self): + inner = Exception("inner") + inner.retry_after = 5 + outer = Exception("outer") + outer.inner_exception = inner + assert _try_get_retry_after_seconds(outer) == 5.0 diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_utils.py b/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_utils.py index 95125db6..aba664fa 100644 --- a/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_utils.py +++ b/src/processor/src/tests/unit/libs/agent_framework/test_azure_openai_response_retry_utils.py @@ -20,8 +20,8 @@ def test_rate_limit_retry_config_from_env_clamps_invalid_values(monkeypatch) -> cfg = RateLimitRetryConfig.from_env() assert cfg.max_retries == 0 assert cfg.base_delay_seconds == 0.0 - # Falls back to default (30.0) on parse failure, then clamped. - assert cfg.max_delay_seconds == 30.0 + # Falls back to default (120.0) on parse failure, then clamped (max(0, 120.0)). + assert cfg.max_delay_seconds == 120.0 def test_looks_like_rate_limit_detects_common_signals() -> None: @@ -42,7 +42,7 @@ def test_looks_like_context_length_detects_common_signals() -> None: class E(Exception): pass - e = E("something") + e = E("prompt is too long") e.status = 413 assert _looks_like_context_length(e) @@ -81,6 +81,7 @@ def test_trim_messages_keeps_system_and_tails_and_truncates_long_messages() -> N assert trimmed[0]["role"] == "system" assert len(trimmed) == 3 - # Each long message should be truncated to <= max_message_chars. + # Non-last long messages are truncated to <= max_message_chars. + # The last message is intentionally never truncated (agent needs full context). assert len(trimmed[1]["content"]) <= 50 - assert len(trimmed[2]["content"]) <= 50 + assert len(trimmed[2]["content"]) == 100 diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_cosmos_checkpoint_storage.py b/src/processor/src/tests/unit/libs/agent_framework/test_cosmos_checkpoint_storage.py new file mode 100644 index 00000000..241b1471 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_cosmos_checkpoint_storage.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from libs.agent_framework import cosmos_checkpoint_storage as ccs +from libs.agent_framework.cosmos_checkpoint_storage import ( + CosmosCheckpointStorage, + CosmosWorkflowCheckpoint, +) + + +class TestCosmosWorkflowCheckpoint: + def test_init_uses_checkpoint_id_as_id(self): + cp = CosmosWorkflowCheckpoint(checkpoint_id="abc-123", workflow_id="wf-1") + assert cp.checkpoint_id == "abc-123" + assert cp.workflow_id == "wf-1" + assert getattr(cp, "id", None) == "abc-123" + + def test_init_keeps_explicit_id(self): + cp = CosmosWorkflowCheckpoint(checkpoint_id="x", id="custom-id") + assert cp.id == "custom-id" + + def test_defaults_populated(self): + cp = CosmosWorkflowCheckpoint(checkpoint_id="x") + assert cp.iteration_count == 0 + assert cp.version == "1.0" + assert cp.messages == {} + + +class TestCosmosCheckpointStorage: + def _make(self): + repo = MagicMock() + repo.save_checkpoint = AsyncMock() + repo.load_checkpoint = AsyncMock() + repo.list_checkpoint_ids = AsyncMock() + repo.list_checkpoints = AsyncMock() + repo.delete_checkpoint = AsyncMock() + return CosmosCheckpointStorage(repository=repo), repo + + def test_save_converts_workflow_to_cosmos(self): + storage, repo = self._make() + wf = MagicMock() + wf.to_dict.return_value = {"checkpoint_id": "cp1", "workflow_id": "wf1"} + asyncio.run(storage.save_checkpoint(wf)) + repo.save_checkpoint.assert_awaited_once() + passed = repo.save_checkpoint.await_args.args[0] + assert isinstance(passed, CosmosWorkflowCheckpoint) + assert passed.checkpoint_id == "cp1" + + def test_load_delegates_to_repository(self): + storage, repo = self._make() + repo.load_checkpoint.return_value = "loaded" + result = asyncio.run(storage.load_checkpoint("id-1")) + assert result == "loaded" + repo.load_checkpoint.assert_awaited_once_with("id-1") + + def test_list_ids_delegates(self): + storage, repo = self._make() + repo.list_checkpoint_ids.return_value = ["a", "b"] + result = asyncio.run(storage.list_checkpoint_ids("wf")) + assert result == ["a", "b"] + repo.list_checkpoint_ids.assert_awaited_once_with("wf") + + def test_list_checkpoints_delegates(self): + storage, repo = self._make() + repo.list_checkpoints.return_value = [] + asyncio.run(storage.list_checkpoints()) + repo.list_checkpoints.assert_awaited_once_with(None) + + def test_delete_delegates(self): + storage, repo = self._make() + asyncio.run(storage.delete_checkpoint("id-9")) + repo.delete_checkpoint.assert_awaited_once_with("id-9") diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py b/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py new file mode 100644 index 00000000..a95d9623 --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py @@ -0,0 +1,977 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Coverage for GroupChatOrchestrator helpers, dataclass model_dump/to_json, +loop detection, tool-call processing, conversation truncation and final-result +building. Avoids running the full async workflow (which requires the real +agent_framework GroupChat runtime).""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from libs.agent_framework.groupchat_orchestrator import ( + AgentResponse, + AgentResponseStream, + GroupChatOrchestrator, + OrchestrationResult, +) + + +def _run(coro): + return asyncio.run(coro) + + +@dataclass +class _Msg: + """Lightweight stand-in for a ChatMessage.""" + + source: str = "" + content: str = "" + text: str = "" + role: object = None + author_name: str | None = None + contents: object = None + + +def _make_orch(participants=None, result_format=None): + return GroupChatOrchestrator( + name="t", + process_id="p1", + participants=participants or {"Coordinator": object()}, + memory_client=None, + coordinator_name="Coordinator", + result_output_format=result_format, + ) + + +# ----------------------------------------------------------------------------- +# AgentResponse / OrchestrationResult dataclasses +# ----------------------------------------------------------------------------- + + +class TestAgentResponseDump: + def test_model_dump_with_datetime(self): + ts = datetime(2024, 1, 1, 12, 0, 0) + r = AgentResponse(agent_id="a", agent_name="A", message="m", timestamp=ts) + d = r.model_dump() + assert d["timestamp"] == ts.isoformat() + assert d["agent_id"] == "a" + + def test_model_dump_with_string_timestamp(self): + r = AgentResponse( + agent_id="a", agent_name="A", message="m", timestamp="not a datetime" + ) + d = r.model_dump() + assert d["timestamp"] == "not a datetime" + + +class TestOrchestrationResultJsonable: + def test_to_jsonable_primitives(self): + assert OrchestrationResult._to_jsonable(None) is None + assert OrchestrationResult._to_jsonable("hi") == "hi" + assert OrchestrationResult._to_jsonable(1) == 1 + assert OrchestrationResult._to_jsonable(1.5) == 1.5 + assert OrchestrationResult._to_jsonable(True) is True + + def test_to_jsonable_datetime(self): + ts = datetime(2024, 1, 1) + assert OrchestrationResult._to_jsonable(ts) == ts.isoformat() + + def test_to_jsonable_dict_and_list(self): + out = OrchestrationResult._to_jsonable({"a": [1, 2], "b": (3, 4)}) + assert out == {"a": [1, 2], "b": [3, 4]} + + def test_to_jsonable_pydantic_v2(self): + m = MagicMock() + m.model_dump = MagicMock(return_value={"x": 1}) + m.dict = MagicMock(return_value={"y": 2}) + out = OrchestrationResult._to_jsonable(m) + assert out == {"x": 1} + + def test_to_jsonable_pydantic_v1_fallback(self): + class Obj: + def dict(self): + return {"y": 2} + + out = OrchestrationResult._to_jsonable(Obj()) + assert out == {"y": 2} + + def test_to_jsonable_dataclass(self): + @dataclass + class D: + x: int = 5 + + out = OrchestrationResult._to_jsonable(D()) + assert out == {"x": 5} + + def test_to_jsonable_vars_fallback(self): + class Anon: + def __init__(self): + self.k = "v" + + out = OrchestrationResult._to_jsonable(Anon()) + assert out == {"k": "v"} + + def test_to_jsonable_str_fallback(self): + # Object with no __dict__ falls back to str() + out = OrchestrationResult._to_jsonable(object.__new__(object)) + # Either a dict or str; must be a string for slot-only objects + assert isinstance(out, (dict, str)) + + def test_model_dump_and_to_json(self): + r = OrchestrationResult( + success=True, + conversation=[], + agent_responses=[ + AgentResponse(agent_id="a", agent_name="A", message="m", timestamp=datetime(2024, 1, 1)) + ], + tool_usage={}, + result=None, + error=None, + execution_time_seconds=1.5, + ) + d = r.model_dump() + assert d["success"] is True + assert d["execution_time_seconds"] == 1.5 + s = r.to_json(indent=0) + assert isinstance(s, str) + assert '"success"' in s + + +# ----------------------------------------------------------------------------- +# Forced termination + try_build_forced_result +# ----------------------------------------------------------------------------- + + +class TestForcedTermination: + def test_request_forced_termination_sets_state(self): + orch = _make_orch() + orch._request_forced_termination(reason="r", termination_type="hard_timeout") + assert orch._forced_termination_requested is True + assert orch._forced_termination_reason == "r" + + def test_request_forced_termination_noop_when_already_set(self): + orch = _make_orch() + orch._termination_requested = True + orch._request_forced_termination(reason="r", termination_type="t") + assert orch._forced_termination_requested is False + + def test_try_build_forced_result_no_format_returns_none(self): + orch = _make_orch(result_format=None) + assert orch._try_build_forced_result(reason="r", termination_type="t") is None + + def test_try_build_forced_result_populates_known_fields(self): + from pydantic import BaseModel + + class Model(BaseModel): + result: bool = False + reason: str = "" + is_hard_terminated: bool = False + termination_type: str = "" + blocking_issues: list[str] = [] + process_id: str = "" + + orch = _make_orch(result_format=Model) + m = orch._try_build_forced_result(reason="boom", termination_type="hard_timeout") + assert m.is_hard_terminated is True + assert m.reason == "boom" + assert m.termination_type == "hard_timeout" + assert m.blocking_issues == ["boom"] + assert m.process_id == "p1" + + def test_try_build_forced_result_handles_optional_fields(self): + from pydantic import BaseModel + + class Model(BaseModel): + output: str | None = None + termination_output: str | None = None + reason: str = "" + + orch = _make_orch(result_format=Model) + m = orch._try_build_forced_result(reason="r", termination_type="hard_blocked") + assert m.output is None + assert m.termination_output is None + + +# ----------------------------------------------------------------------------- +# get_result_generator_name +# ----------------------------------------------------------------------------- + + +class TestGetResultGeneratorName: + def test_default(self): + assert _make_orch().get_result_generator_name() == "ResultGenerator" + + +# ----------------------------------------------------------------------------- +# _validate_sign_offs +# ----------------------------------------------------------------------------- + + +class TestValidateSignOffs: + def test_all_pass(self): + orch = _make_orch() + msgs = [ + _Msg(source="A", content="SIGN-OFF: PASS"), + _Msg(source="B", content="SIGN-OFF:PASS"), + ] + ok, reason = orch._validate_sign_offs(msgs) + assert ok is True + + def test_pending_blocks(self): + orch = _make_orch() + msgs = [_Msg(source="A", content="SIGN-OFF: PENDING")] + ok, reason = orch._validate_sign_offs(msgs) + assert ok is False + assert "PENDING" in reason + + def test_fail_blocks(self): + orch = _make_orch() + msgs = [_Msg(source="A", content="SIGN-OFF: FAIL")] + ok, reason = orch._validate_sign_offs(msgs) + assert ok is False + assert "FAIL" in reason + + def test_missing_blocks(self): + orch = _make_orch() + msgs = [_Msg(source="A", content="some text without signoff")] + ok, reason = orch._validate_sign_offs(msgs) + assert ok is False + assert "missing" in reason + + def test_excludes_coordinator_and_resultgenerator(self): + orch = _make_orch() + msgs = [ + _Msg(source="Coordinator", content="ignored"), + _Msg(source="ResultGenerator", content="ignored"), + ] + ok, _ = orch._validate_sign_offs(msgs) + assert ok is True + + +# ----------------------------------------------------------------------------- +# _extract_first_json_payload +# ----------------------------------------------------------------------------- + + +class TestExtractFirstJsonPayload: + def test_pure_json_object(self): + out = GroupChatOrchestrator._extract_first_json_payload('{"a":1}') + assert out == '{"a":1}' + + def test_json_with_trailing_text(self): + out = GroupChatOrchestrator._extract_first_json_payload('{"a":1} SIGN-OFF: PASS') + assert out == '{"a":1}' + + def test_json_with_leading_text(self): + out = GroupChatOrchestrator._extract_first_json_payload('prefix {"a":1}') + assert '{"a":1}' in out + + def test_empty_returns_empty(self): + assert GroupChatOrchestrator._extract_first_json_payload("") == "" + assert GroupChatOrchestrator._extract_first_json_payload(" ") == "" + + def test_no_json_returns_input(self): + out = GroupChatOrchestrator._extract_first_json_payload("plain text") + assert out == "plain text" + + def test_unparsable_after_position_returns_input(self): + out = GroupChatOrchestrator._extract_first_json_payload("text {not json") + assert "text {not json" in out + + def test_non_string_raises(self): + with pytest.raises(TypeError): + GroupChatOrchestrator._extract_first_json_payload(123) # type: ignore[arg-type] + + +# ----------------------------------------------------------------------------- +# initialize +# ----------------------------------------------------------------------------- + + +class TestInitialize: + def test_initialize_sets_initialized(self): + orch = _make_orch() + _run(orch.initialize()) + assert orch._initialized is True + + def test_initialize_skipped_if_already_done(self): + orch = _make_orch() + orch._initialized = True + _run(orch.initialize()) # no error + + +# ----------------------------------------------------------------------------- +# _normalize_executor_id +# ----------------------------------------------------------------------------- + + +class TestNormalizeExecutorId: + def test_strips_prefix(self): + orch = _make_orch() + assert orch._normalize_executor_id("groupchat_agent:Coordinator") == "Coordinator" + + def test_no_prefix(self): + orch = _make_orch() + assert orch._normalize_executor_id("Bare") == "Bare" + + +# ----------------------------------------------------------------------------- +# _append_text_chunk +# ----------------------------------------------------------------------------- + + +class TestAppendTextChunk: + def test_no_text_attr(self): + orch = _make_orch() + ev = SimpleNamespace(data=SimpleNamespace()) # no `text` attr + orch._current_agent_response = [] + orch._append_text_chunk(ev) # noop + assert orch._current_agent_response == [] + + def test_falsy_text(self): + orch = _make_orch() + ev = SimpleNamespace(data=SimpleNamespace(text="")) + orch._current_agent_response = [] + orch._append_text_chunk(ev) + assert orch._current_agent_response == [] + + def test_text_object_with_text_attr(self): + orch = _make_orch() + text_obj = SimpleNamespace(text="hello") + ev = SimpleNamespace(data=SimpleNamespace(text=text_obj)) + orch._current_agent_response = [] + orch._append_text_chunk(ev) + assert orch._current_agent_response == ["hello"] + + def test_text_string(self): + orch = _make_orch() + ev = SimpleNamespace(data=SimpleNamespace(text="raw")) + orch._current_agent_response = [] + orch._append_text_chunk(ev) + assert orch._current_agent_response == ["raw"] + + +# ----------------------------------------------------------------------------- +# _start_agent_if_needed +# ----------------------------------------------------------------------------- + + +class TestStartAgentIfNeeded: + def test_same_executor_noop(self): + orch = _make_orch() + orch._last_executor_id = "A" + orch._current_agent_response = ["x"] + _run(orch._start_agent_if_needed("A", None, None)) + # no change + assert orch._current_agent_response == ["x"] + + def test_switch_completes_previous(self): + orch = _make_orch() + orch._last_executor_id = "A" + orch._current_agent_response = ["msg"] + completed = [] + + async def _cb(resp): + completed.append(resp) + + _run(orch._start_agent_if_needed("B", None, _cb)) + assert orch._last_executor_id == "B" + assert orch._current_agent_response == [] + assert len(completed) == 1 + + def test_stream_callback_invoked_on_switch(self): + orch = _make_orch() + orch._last_executor_id = None + captured = [] + + async def _stream_cb(s): + captured.append(s) + + _run(orch._start_agent_if_needed("X", _stream_cb, None)) + assert len(captured) == 1 + assert captured[0].response_type == "message" + + def test_stream_callback_failure_is_swallowed(self): + orch = _make_orch() + orch._last_executor_id = None + + async def _bad_stream(_): + raise RuntimeError("boom") + + _run(orch._start_agent_if_needed("X", _bad_stream, None)) + + +# ----------------------------------------------------------------------------- +# _process_tool_calls + helpers +# ----------------------------------------------------------------------------- + + +class TestProcessToolCalls: + def test_no_tool_calls_returns_immediately(self): + orch = _make_orch() + ev = SimpleNamespace(data=SimpleNamespace(contents=None)) + _run(orch._process_tool_calls(ev, "A", None)) + + def test_records_complete_dict_args(self): + orch = _make_orch() + item = SimpleNamespace(name="search", call_id="c1", arguments={"q": "x"}) + ev = SimpleNamespace(data=SimpleNamespace(contents=[item])) + _run(orch._process_tool_calls(ev, "A", None)) + assert "search" in {tc["tool_name"] for tc in orch.agent_tool_usage["A"]} + + def test_skips_when_already_recorded(self): + orch = _make_orch() + item = SimpleNamespace(name="search", call_id="c1", arguments={"q": "x"}) + ev = SimpleNamespace(data=SimpleNamespace(contents=[item])) + _run(orch._process_tool_calls(ev, "A", None)) + # second pass should be skipped + _run(orch._process_tool_calls(ev, "A", None)) + assert len(orch.agent_tool_usage["A"]) == 1 + + def test_skips_invalid_calls(self): + orch = _make_orch() + item = SimpleNamespace(name=None, call_id=None, arguments=None) + ev = SimpleNamespace(data=SimpleNamespace(contents=[item])) + _run(orch._process_tool_calls(ev, "A", None)) + assert orch.agent_tool_usage == {} + + def test_streamed_string_args_buffer_until_complete(self): + orch = _make_orch() + + # Send incomplete JSON args, then complete + item1 = SimpleNamespace(name="t", call_id="c", arguments='{"q":"hel') + ev1 = SimpleNamespace(data=SimpleNamespace(contents=[item1])) + _run(orch._process_tool_calls(ev1, "A", None)) + # not yet recorded + assert "A" not in orch.agent_tool_usage or not orch.agent_tool_usage["A"] + + item2 = SimpleNamespace(name="t", call_id="c", arguments='{"q":"hello"}') + ev2 = SimpleNamespace(data=SimpleNamespace(contents=[item2])) + _run(orch._process_tool_calls(ev2, "A", None)) + assert orch.agent_tool_usage["A"][0]["arguments"] == {"q": "hello"} + + +class TestParseOrBufferToolArgs: + def test_dict_passthrough(self): + orch = _make_orch() + parsed, raw = orch._parse_or_buffer_tool_args(("A", "c"), {"k": 1}) + assert parsed == {"k": 1} + assert raw == {"k": 1} + + def test_string_buffered(self): + orch = _make_orch() + parsed, raw = orch._parse_or_buffer_tool_args(("A", "c"), '{"k":1}') + assert parsed == {"k": 1} + + def test_string_invalid_returns_none(self): + orch = _make_orch() + parsed, raw = orch._parse_or_buffer_tool_args(("A", "c"), '{"k":') + assert parsed is None + + def test_other_returns_none(self): + orch = _make_orch() + parsed, raw = orch._parse_or_buffer_tool_args(("A", "c"), 123) + assert parsed is None and raw == 123 + + +class TestMergeStreamedArgs: + def test_existing_none(self): + orch = _make_orch() + assert orch._merge_streamed_args(None, "abc") == "abc" + + def test_incoming_starts_with_existing(self): + orch = _make_orch() + assert orch._merge_streamed_args("ab", "abcde") == "abcde" + + def test_existing_starts_with_incoming(self): + orch = _make_orch() + assert orch._merge_streamed_args("abcde", "ab") == "abcde" + + def test_concatenates(self): + orch = _make_orch() + assert orch._merge_streamed_args("abc", "xyz") == "abcxyz" + + +class TestArgsComplete: + def test_dict_args(self): + assert _make_orch()._args_complete({}, None) is True + + def test_string_with_parsed(self): + assert _make_orch()._args_complete("x", {"k": 1}) is True + + def test_string_no_parsed(self): + assert _make_orch()._args_complete("x", None) is False + + def test_none(self): + assert _make_orch()._args_complete(None, None) is True + + +class TestRecordToolCall: + def test_appends_when_new(self): + orch = _make_orch() + info = {"tool_name": "t", "call_id": "c", "arguments": {}, "timestamp": "x"} + orch._record_tool_call("A", ("A", "c"), info) + assert orch.agent_tool_usage["A"] == [info] + assert ("A", "c") in orch._tool_call_recorded + + def test_updates_existing_index(self): + orch = _make_orch() + info1 = {"tool_name": "t", "call_id": "c", "arguments": {}, "timestamp": "1"} + info2 = {"tool_name": "t", "call_id": "c", "arguments": {"x": 1}, "timestamp": "2"} + orch._record_tool_call("A", ("A", "c"), info1) + orch._record_tool_call("A", ("A", "c"), info2) + assert orch.agent_tool_usage["A"][0]["timestamp"] == "2" + + +class TestEmitToolCallOnce: + def test_no_callback_noop(self): + orch = _make_orch() + _run( + orch._emit_tool_call_once( + agent_name="A", call_key=("A", "c"), tool_name="t", + parsed_args={"x": 1}, stream_callback=None, + ) + ) + assert ("A", "c") not in orch._tool_call_emitted + + def test_only_emits_once(self): + orch = _make_orch() + captured = [] + + async def _cb(s): + captured.append(s) + + _run(orch._emit_tool_call_once("A", ("A", "c"), "t", {"x": 1}, _cb)) + _run(orch._emit_tool_call_once("A", ("A", "c"), "t", {"x": 1}, _cb)) + assert len(captured) == 1 + + def test_swallows_callback_exception(self): + orch = _make_orch() + + async def _bad(_): + raise RuntimeError("nope") + + _run(orch._emit_tool_call_once("A", ("A", "c"), "t", {"x": 1}, _bad)) + + +# ----------------------------------------------------------------------------- +# _extract_function_calls +# ----------------------------------------------------------------------------- + + +class TestExtractFunctionCalls: + def test_empty_returns_empty(self): + orch = _make_orch() + assert orch._extract_function_calls(None) == [] + assert orch._extract_function_calls([]) == [] + + def test_object_path(self): + orch = _make_orch() + items = [SimpleNamespace(name="t", call_id="c", arguments={"x": 1})] + out = orch._extract_function_calls(items) + assert out == [{"name": "t", "call_id": "c", "arguments": {"x": 1}}] + + def test_dict_path(self): + orch = _make_orch() + items = [{"type": "function_call", "name": "t", "call_id": "c", "arguments": {}}] + out = orch._extract_function_calls(items) + assert out == [{"name": "t", "call_id": "c", "arguments": {}}] + + def test_skips_unrelated(self): + orch = _make_orch() + items = [{"type": "text", "name": "t", "call_id": "c"}] + # name+call_id present on dict but matched as object first; falls through to dict path with non-tool-call type → skipped + out = orch._extract_function_calls(items) + # dict path only matches when type ∈ {function_call, tool_call}; here type='text' so skipped + assert out == [] + + +# ----------------------------------------------------------------------------- +# _backfill_tool_usage_from_conversation +# ----------------------------------------------------------------------------- + + +class TestBackfillToolUsage: + def test_skips_non_assistant(self): + from agent_framework import Role + orch = _make_orch() + msg = SimpleNamespace(role=Role.USER, contents=[]) + orch._backfill_tool_usage_from_conversation([msg]) + assert orch.agent_tool_usage == {} + + def test_records_calls_from_assistant(self): + from agent_framework import Role + orch = _make_orch() + item = SimpleNamespace(name="t", call_id="c", arguments={"x": 1}) + msg = SimpleNamespace( + role=Role.ASSISTANT, author_name="A", contents=[item] + ) + orch._backfill_tool_usage_from_conversation([msg]) + assert orch.agent_tool_usage["A"][0]["tool_name"] == "t" + + def test_dedup_already_recorded(self): + from agent_framework import Role + orch = _make_orch() + # Pre-mark this call as already recorded + orch._tool_call_recorded.add(("A", "c")) + item = SimpleNamespace(name="t", call_id="c", arguments={}) + msg = SimpleNamespace( + role=Role.ASSISTANT, author_name="A", contents=[item] + ) + orch._backfill_tool_usage_from_conversation([msg]) + assert "A" in orch.agent_tool_usage + assert orch.agent_tool_usage["A"] == [] + + def test_swallows_exceptions(self): + orch = _make_orch() + # Invalid msg causes attribute access to raise — swallowed by `except Exception` + broken = MagicMock() + broken.role = MagicMock(side_effect=RuntimeError("x")) + orch._backfill_tool_usage_from_conversation([broken]) # no raise + + +# ----------------------------------------------------------------------------- +# _complete_agent_response (additional paths) +# ----------------------------------------------------------------------------- + + +class TestCompleteAgentResponse: + def test_no_pending_response_returns_early(self): + orch = _make_orch() + orch._current_agent_response = [] + _run(orch._complete_agent_response("A", None)) + + def test_callback_swallows_exception(self): + orch = _make_orch() + orch._current_agent_response = ["msg"] + orch._current_agent_start_time = datetime.now() + + async def _bad(_): + raise RuntimeError("cb err") + + _run(orch._complete_agent_response("A", _bad)) + # response was still recorded + assert orch.agent_responses[-1].agent_name == "A" + + def test_records_invocation_for_non_termination_selection(self): + orch = _make_orch() + orch._current_agent_response = [ + json.dumps( + { + "selected_participant": "Architect", + "instruction": "do", + "finish": False, + "final_message": "", + } + ) + ] + orch._current_agent_start_time = datetime.now() + orch._conversation = [] + _run(orch._complete_agent_response("Coordinator", None)) + assert "Architect" in orch._agent_invoked_at + + def test_loop_breaker_triggered_after_3_repeats_without_progress(self): + orch = _make_orch() + orch._conversation = [] + + def _select(participant: str, instruction: str = "do"): + orch._current_agent_response = [ + json.dumps( + { + "selected_participant": participant, + "instruction": instruction, + "finish": False, + "final_message": "", + } + ) + ] + orch._current_agent_start_time = datetime.now() + + _select("A") + _run(orch._complete_agent_response("Coordinator", None)) + _select("A") + _run(orch._complete_agent_response("Coordinator", None)) + _select("A") + _run(orch._complete_agent_response("Coordinator", None)) + + assert orch._forced_termination_requested is True + + +# ----------------------------------------------------------------------------- +# _build_groupchat +# ----------------------------------------------------------------------------- + + +class TestBuildGroupchat: + def test_build_groupchat_invokes_builder(self): + orch = _make_orch(participants={ + "Coordinator": "coord", + "Architect": "arch", + "ResultGenerator": "rg", + }) + with patch("libs.agent_framework.groupchat_orchestrator.GroupChatBuilder") as MockBuilder: + built = MagicMock() + built.set_manager.return_value = built + built.participants.return_value = built + built.build.return_value = "wf" + MockBuilder.return_value = built + wf = _run(orch._build_groupchat()) + assert wf == "wf" + # ResultGenerator excluded from participants + kwargs = built.participants.call_args.args[0] + assert "arch" in kwargs + assert "rg" not in kwargs + + +# ----------------------------------------------------------------------------- +# _truncate_text + _build_result_generator_conversation +# ----------------------------------------------------------------------------- + + +class TestTruncateText: + def test_zero_max_returns_empty(self): + out = GroupChatOrchestrator._truncate_text( + "x" * 100, max_chars=0, keep_head_chars=10, keep_tail_chars=10 + ) + assert out == "" + + def test_empty(self): + assert GroupChatOrchestrator._truncate_text("", max_chars=10, keep_head_chars=5, keep_tail_chars=5) == "" + + def test_short_passthrough(self): + out = GroupChatOrchestrator._truncate_text( + "hi", max_chars=10, keep_head_chars=5, keep_tail_chars=5 + ) + assert out == "hi" + + def test_long_truncated_with_marker(self): + text = "A" * 200 + "B" * 200 + out = GroupChatOrchestrator._truncate_text( + text, max_chars=100, keep_head_chars=20, keep_tail_chars=20 + ) + assert "TRUNCATED" in out + + def test_remaining_zero_returns_head(self): + text = "X" * 100 + out = GroupChatOrchestrator._truncate_text( + text, max_chars=20, keep_head_chars=20, keep_tail_chars=10 + ) + assert len(out) <= 20 + + def test_tail_zero_returns_head(self): + text = "Y" * 100 + out = GroupChatOrchestrator._truncate_text( + text, max_chars=15, keep_head_chars=15, keep_tail_chars=0 + ) + assert out == "Y" * 15 + + +class TestBuildResultGeneratorConversation: + def test_excludes_named_authors(self): + from agent_framework import Role + from agent_framework import ChatMessage + + orch = _make_orch() + msgs = [ + ChatMessage(role=Role.ASSISTANT, text="from coord", author_name="Coordinator"), + ChatMessage(role=Role.ASSISTANT, text="from architect", author_name="Architect"), + ] + out = orch._build_result_generator_conversation( + msgs, + exclude_authors={"Coordinator"}, + max_messages=10, + max_total_chars=10_000, + max_chars_per_message=10_000, + keep_head_chars=100, + keep_tail_chars=50, + ) + assert any("Architect" == m.author_name for m in out) + assert all("Coordinator" != m.author_name for m in out) + + def test_dedupes_identical_payloads(self): + from agent_framework import Role + from agent_framework import ChatMessage + + orch = _make_orch() + big = "X" * 1000 + msgs = [ + ChatMessage(role=Role.ASSISTANT, text=big, author_name="A"), + ChatMessage(role=Role.ASSISTANT, text=big, author_name="A"), + ] + out = orch._build_result_generator_conversation( + msgs, + exclude_authors=None, + max_messages=10, + max_total_chars=100_000, + max_chars_per_message=10_000, + keep_head_chars=100, + keep_tail_chars=50, + ) + assert len(out) == 1 + + def test_truncates_messages_to_per_message_budget(self): + from agent_framework import Role + from agent_framework import ChatMessage + + orch = _make_orch() + msgs = [ + ChatMessage(role=Role.ASSISTANT, text="A" * 500, author_name="X"), + ] + out = orch._build_result_generator_conversation( + msgs, + exclude_authors=None, + max_messages=10, + max_total_chars=10_000, + max_chars_per_message=100, + keep_head_chars=20, + keep_tail_chars=20, + ) + assert len(out[-1].text) <= 100 + + def test_total_budget_enforced(self): + from agent_framework import Role + from agent_framework import ChatMessage + + orch = _make_orch() + msgs = [ + ChatMessage(role=Role.ASSISTANT, text="A" * 100, author_name=str(i)) + for i in range(20) + ] + out = orch._build_result_generator_conversation( + msgs, + exclude_authors=None, + max_messages=20, + max_total_chars=200, + max_chars_per_message=0, # disabled per-message budget + keep_head_chars=50, + keep_tail_chars=10, + ) + total = sum(len(m.text) for m in out) + assert total <= 200 + + def test_max_messages_caps_count(self): + from agent_framework import Role + from agent_framework import ChatMessage + + orch = _make_orch() + msgs = [ + ChatMessage(role=Role.ASSISTANT, text=f"m{i}", author_name=str(i)) + for i in range(20) + ] + out = orch._build_result_generator_conversation( + msgs, + exclude_authors=None, + max_messages=3, + max_total_chars=10_000, + max_chars_per_message=0, + keep_head_chars=10, + keep_tail_chars=10, + ) + assert len(out) == 3 + + +# ----------------------------------------------------------------------------- +# get_tool_usage_summary +# ----------------------------------------------------------------------------- + + +class TestToolUsageSummary: + def test_empty(self): + orch = _make_orch() + out = orch.get_tool_usage_summary() + assert out["total_tool_calls"] == 0 + + def test_aggregates(self): + orch = _make_orch() + orch.agent_tool_usage = { + "A": [{"tool_name": "search"}, {"tool_name": "search"}], + "B": [{"tool_name": "open"}], + } + out = orch.get_tool_usage_summary() + assert out["total_tool_calls"] == 3 + assert out["calls_by_agent"] == {"A": 2, "B": 1} + assert out["calls_by_tool"] == {"search": 2, "open": 1} + + def test_unknown_tool_name(self): + orch = _make_orch() + orch.agent_tool_usage = {"A": [{}]} + out = orch.get_tool_usage_summary() + assert out["calls_by_tool"] == {"unknown": 1} + + +# ----------------------------------------------------------------------------- +# _generate_final_result +# ----------------------------------------------------------------------------- + + +class TestGenerateFinalResult: + def test_parses_valid_json(self): + from pydantic import BaseModel + from agent_framework import Role + from agent_framework import ChatMessage + + class Model(BaseModel): + x: int + + rg = MagicMock() + run_result = SimpleNamespace(messages=[SimpleNamespace(text='{"x":5}')]) + rg.run = AsyncMock(return_value=run_result) + orch = _make_orch(participants={"Coordinator": object(), "ResultGenerator": rg}, result_format=Model) + out = _run( + orch._generate_final_result( + conversation=[ChatMessage(role=Role.ASSISTANT, text="x", author_name="A")], + result_format=Model, + result_generator_name="ResultGenerator", + ) + ) + assert out.x == 5 + + def test_retry_on_validation_error(self): + from pydantic import BaseModel + from agent_framework import Role + from agent_framework import ChatMessage + + class Model(BaseModel): + x: int + + rg = MagicMock() + # First run returns invalid JSON; second returns valid. + first = SimpleNamespace(messages=[SimpleNamespace(text='{"x":"not_int"}')]) + second = SimpleNamespace(messages=[SimpleNamespace(text='{"x":7}')]) + rg.run = AsyncMock(side_effect=[first, second]) + orch = _make_orch(participants={"Coordinator": object(), "ResultGenerator": rg}, result_format=Model) + out = _run( + orch._generate_final_result( + conversation=[ChatMessage(role=Role.ASSISTANT, text="x", author_name="A")], + result_format=Model, + result_generator_name="ResultGenerator", + ) + ) + assert out.x == 7 + assert rg.run.await_count == 2 + + +# ----------------------------------------------------------------------------- +# _handle_agent_update high-level pipeline +# ----------------------------------------------------------------------------- + + +class TestHandleAgentUpdate: + def test_invokes_subroutines(self): + orch = _make_orch() + ev = SimpleNamespace( + executor_id="groupchat_agent:A", + data=SimpleNamespace(text="chunk", contents=None), + ) + _run(orch._handle_agent_update(ev, None, None)) + assert orch._last_executor_id == "A" + assert orch._current_agent_response == ["chunk"] diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_mem0_async_memory.py b/src/processor/src/tests/unit/libs/agent_framework/test_mem0_async_memory.py new file mode 100644 index 00000000..ca3929ac --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_mem0_async_memory.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import asyncio +from unittest.mock import AsyncMock, patch + +from libs.agent_framework import mem0_async_memory as mam + + +class TestMem0AsyncMemoryManager: + def test_lazy_initialization_caches_instance(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://example.openai.azure.com") + with patch.object(mam, "AsyncMemory") as mem: + mem.from_config = AsyncMock(return_value="memory-instance") + + mgr = mam.Mem0AsyncMemoryManager() + first = asyncio.run(mgr.get_memory()) + second = asyncio.run(mgr.get_memory()) + + assert first == "memory-instance" + assert first is second + mem.from_config.assert_awaited_once() + + def test_uses_default_deployments_when_env_missing(self, monkeypatch): + for k in [ + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", + "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME", + "AZURE_OPENAI_API_VERSION", + ]: + monkeypatch.delenv(k, raising=False) + with patch.object(mam, "AsyncMemory") as mem: + mem.from_config = AsyncMock(return_value="m") + asyncio.run(mam.Mem0AsyncMemoryManager().get_memory()) + cfg = mem.from_config.await_args.args[0] + assert cfg["llm"]["config"]["model"] == "gpt-5.1" + assert cfg["embedder"]["config"]["model"] == "text-embedding-3-large" + assert cfg["llm"]["config"]["azure_kwargs"]["api_version"] == "2024-12-01-preview" diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_middlewares_extras.py b/src/processor/src/tests/unit/libs/agent_framework/test_middlewares_extras.py new file mode 100644 index 00000000..c4c32f5a --- /dev/null +++ b/src/processor/src/tests/unit/libs/agent_framework/test_middlewares_extras.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +from agent_framework import ChatMessage, Role + +from libs.agent_framework.middlewares import ( + DebuggingMiddleware, + LoggingFunctionMiddleware, +) + + +def _run(coro): + return asyncio.run(coro) + + +class TestDebuggingMiddleware: + def test_process_sets_metadata_and_calls_next(self, capsys): + ctx = MagicMock() + ctx.messages = [MagicMock(), MagicMock()] + ctx.is_streaming = True + ctx.metadata = {"existing": "value"} + next_fn = AsyncMock() + mw = DebuggingMiddleware() + _run(mw.process(ctx, next_fn)) + assert ctx.metadata["debug_enabled"] is True + next_fn.assert_awaited_once_with(ctx) + + def test_process_with_empty_metadata(self): + ctx = MagicMock() + ctx.messages = [] + ctx.is_streaming = False + ctx.metadata = {} + next_fn = AsyncMock() + mw = DebuggingMiddleware() + _run(mw.process(ctx, next_fn)) + next_fn.assert_awaited_once() + + +class TestLoggingFunctionMiddleware: + def _make_ctx(self, args=None, result=None): + ctx = MagicMock() + ctx.function = MagicMock() + ctx.function.name = "do_thing" + if args is not None: + ctx.arguments = MagicMock() + ctx.arguments.model_dump.return_value = args + else: + ctx.arguments = None + ctx.result = result + return ctx + + def test_process_with_no_args_no_result(self): + ctx = self._make_ctx() + next_fn = AsyncMock() + _run(LoggingFunctionMiddleware().process(ctx, next_fn)) + next_fn.assert_awaited_once_with(ctx) + + def test_process_with_args_and_string_result(self): + ctx = self._make_ctx(args={"x": 1, "y": "z"}, result="hello") + next_fn = AsyncMock() + _run(LoggingFunctionMiddleware().process(ctx, next_fn)) + next_fn.assert_awaited_once() + + def test_process_with_long_string_result_truncated(self): + ctx = self._make_ctx(args={"x": 1}, result="A" * 2000) + _run(LoggingFunctionMiddleware().process(ctx, AsyncMock())) + + def test_process_with_list_result_with_raw_representation(self): + item = SimpleNamespace(raw_representation={"data": "ok"}, is_error=False) + ctx = self._make_ctx(args={"x": 1}, result=[item]) + _run(LoggingFunctionMiddleware().process(ctx, AsyncMock())) + + def test_process_with_long_raw_representation_truncated(self): + item = SimpleNamespace(raw_representation="B" * 2000, is_error=True) + ctx = self._make_ctx(args={"x": 1}, result=[item]) + _run(LoggingFunctionMiddleware().process(ctx, AsyncMock())) + + +class TestInputObserverMiddleware: + def test_replaces_user_messages_when_replacement_set(self): + from libs.agent_framework.middlewares import InputObserverMiddleware + + msg_user = ChatMessage(role=Role.USER, text="orig user") + msg_assistant = ChatMessage(role=Role.ASSISTANT, text="hi") + ctx = MagicMock() + ctx.messages = [msg_user, msg_assistant] + next_fn = AsyncMock() + mw = InputObserverMiddleware(replacement="REDACTED") + _run(mw.process(ctx, next_fn)) + # First message replaced, second untouched + assert ctx.messages[0].text == "REDACTED" + assert ctx.messages[1].text == "hi" + next_fn.assert_awaited_once() + + def test_no_replacement_keeps_text(self): + from libs.agent_framework.middlewares import InputObserverMiddleware + + msg = ChatMessage(role=Role.USER, text="keep me") + ctx = MagicMock() + ctx.messages = [msg] + mw = InputObserverMiddleware(replacement=None) + _run(mw.process(ctx, AsyncMock())) + assert ctx.messages[0].text == "keep me" diff --git a/src/processor/src/tests/unit/libs/application/test_application_context_extras.py b/src/processor/src/tests/unit/libs/application/test_application_context_extras.py new file mode 100644 index 00000000..b51b8b2a --- /dev/null +++ b/src/processor/src/tests/unit/libs/application/test_application_context_extras.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio + +import pytest + +from libs.application.application_context import ( + AppContext, + ServiceDescriptor, + ServiceLifetime, +) + + +class _S: + pass + + +class _AsyncSvc: + def __init__(self) -> None: + self.closed = False + self.entered = False + + async def __aenter__(self): + self.entered = True + return self + + async def __aexit__(self, *exc): + self.closed = True + + async def close(self) -> None: + self.closed = True + + +class _SyncCleanup: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +def test_set_configuration_and_credential(): + ctx = AppContext() + ctx.set_configuration(object()) # type: ignore[arg-type] + ctx.set_credential(object()) # type: ignore[arg-type] + assert ctx.configuration is not None + assert ctx.credential is not None + + +def test_is_registered_and_get_registered_services(): + ctx = AppContext().add_singleton(_S) + assert ctx.is_registered(_S) is True + assert ctx.is_registered(int) is False + services = ctx.get_registered_services() + assert _S in services + assert services[_S] == ServiceLifetime.SINGLETON + + +def test_async_singleton_caches(): + async def _run(): + ctx = AppContext().add_async_singleton(_AsyncSvc) + a = await ctx.get_service_async(_AsyncSvc) + b = await ctx.get_service_async(_AsyncSvc) + assert a is b + assert a.entered is True + + asyncio.run(_run()) + + +def test_get_service_async_raises_for_unregistered(): + async def _run(): + ctx = AppContext() + with pytest.raises(KeyError): + await ctx.get_service_async(_S) + + asyncio.run(_run()) + + +def test_get_service_async_raises_for_non_async(): + async def _run(): + ctx = AppContext().add_singleton(_S) + with pytest.raises(ValueError): + await ctx.get_service_async(_S) + + asyncio.run(_run()) + + +def test_async_scoped_requires_scope(): + async def _run(): + ctx = AppContext().add_async_scoped(_AsyncSvc) + with pytest.raises(ValueError): + await ctx.get_service_async(_AsyncSvc) + + asyncio.run(_run()) + + +def test_async_transient_creates_new_instances(): + async def _run(): + ctx = AppContext() + # register as async singleton type but resolve via direct descriptor injection + # to exercise non-singleton, non-scoped async path. + descriptor = ServiceDescriptor( + service_type=_AsyncSvc, + implementation=_AsyncSvc, + lifetime=ServiceLifetime.TRANSIENT, + is_async=True, + ) + ctx._services[_AsyncSvc] = descriptor + a = await ctx.get_service_async(_AsyncSvc) + b = await ctx.get_service_async(_AsyncSvc) + assert a is not b + + asyncio.run(_run()) + + +def test_create_async_instance_with_callable_factory(): + async def _run(): + ctx = AppContext().add_async_singleton(_AsyncSvc, lambda: _AsyncSvc()) + a = await ctx.get_service_async(_AsyncSvc) + assert isinstance(a, _AsyncSvc) + assert a.entered is True + + asyncio.run(_run()) + + +def test_create_async_instance_with_async_factory(): + async def _run(): + async def factory(): + return _AsyncSvc() + + ctx = AppContext().add_async_singleton(_AsyncSvc, factory) + a = await ctx.get_service_async(_AsyncSvc) + assert isinstance(a, _AsyncSvc) + + asyncio.run(_run()) + + +def test_create_async_instance_with_pre_built_instance(): + async def _run(): + instance = _AsyncSvc() + ctx = AppContext().add_async_singleton(_AsyncSvc, instance) + a = await ctx.get_service_async(_AsyncSvc) + # add_async_singleton path: implementation is callable when passing class but + # passing instance bypasses callable check and is returned as-is. + assert a is instance + + asyncio.run(_run()) + + +def test_create_instance_with_factory_callable(): + ctx = AppContext().add_singleton(_S, lambda: _S()) + a = ctx.get_service(_S) + assert isinstance(a, _S) + + +def test_create_instance_with_pre_built_instance(): + instance = _S() + ctx = AppContext().add_singleton(_S, instance) + assert ctx.get_service(_S) is instance + + +def test_shutdown_async_clears_caches_and_calls_cleanup(): + async def _run(): + ctx = AppContext().add_async_singleton( + _SyncCleanup, _SyncCleanup, cleanup_method="close" + ) + instance = await ctx.get_service_async(_SyncCleanup) + assert instance.closed is False + await ctx.shutdown_async() + assert instance.closed is True + # caches cleared + assert ctx._instances == {} + + asyncio.run(_run()) + + +def test_async_scoped_cleanup_via_aexit(): + async def _run(): + ctx = AppContext().add_async_scoped(_AsyncSvc) + async with ctx.create_scope() as scope: + svc = await scope.get_service_async(_AsyncSvc) + assert svc.closed is False + # __aexit__ should have been called via _cleanup_scope + assert svc.closed is True + + asyncio.run(_run()) diff --git a/src/processor/src/tests/unit/libs/base/test_application_base_init.py b/src/processor/src/tests/unit/libs/base/test_application_base_init.py new file mode 100644 index 00000000..90454bfe --- /dev/null +++ b/src/processor/src/tests/unit/libs/base/test_application_base_init.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Coverage for libs/base/application_base.py.""" + +from __future__ import annotations + +import inspect +from unittest.mock import MagicMock, patch + +import pytest + + +def _build_concrete(): + """Define a minimal concrete subclass to satisfy abstractmethods.""" + from libs.base.application_base import ApplicationBase + + class _App(ApplicationBase): + def initialize(self): + return None + + async def run(self): + return None + + return _App + + +@pytest.fixture +def patches_chain(): + """Patch every external dep of ApplicationBase.__init__.""" + with patch("libs.base.application_base.load_dotenv") as load_dotenv, \ + patch("libs.base.application_base.DefaultAzureCredential") as cred, \ + patch("libs.base.application_base._envConfiguration") as env_cfg, \ + patch("libs.base.application_base.AppConfigurationHelper") as ac_helper, \ + patch("libs.base.application_base.Configuration") as config, \ + patch("libs.base.application_base.AgentFrameworkSettings") as afs, \ + patch("libs.base.application_base.logging.basicConfig") as basic_config: + env_cfg_inst = env_cfg.return_value + env_cfg_inst.app_configuration_url = None + cfg_instance = MagicMock() + cfg_instance.app_logging_enable = False + config.return_value = cfg_instance + yield { + "load_dotenv": load_dotenv, + "cred": cred, + "env_cfg": env_cfg, + "ac_helper": ac_helper, + "config": config, + "afs": afs, + "basic_config": basic_config, + "config_instance": cfg_instance, + } + + +class TestApplicationBaseInit: + def test_init_with_explicit_env_path_skips_app_config(self, patches_chain, tmp_path): + env_file = tmp_path / ".env" + env_file.write_text("X=1") + _App = _build_concrete() + app = _App(env_file_path=str(env_file)) + # load_dotenv called with explicit path + patches_chain["load_dotenv"].assert_called_once() + # AppConfigurationHelper not used (URL is None) + patches_chain["ac_helper"].assert_not_called() + # Settings + credential set + assert app.application_context is not None + patches_chain["afs"].assert_called_once() + + def test_init_loads_app_config_when_url_set(self, patches_chain, tmp_path): + patches_chain["env_cfg"].return_value.app_configuration_url = "https://x.azconfig.io" + env_file = tmp_path / ".env" + env_file.write_text("X=1") + _App = _build_concrete() + app = _App(env_file_path=str(env_file)) + patches_chain["ac_helper"].assert_called_once() + # The helper instance had its method invoked + helper_instance = patches_chain["ac_helper"].return_value + helper_instance.read_and_set_environmental_variables.assert_called_once() + + def test_init_enables_logging(self, patches_chain, tmp_path): + patches_chain["config_instance"].app_logging_enable = True + patches_chain["config_instance"].app_logging_level = "INFO" + env_file = tmp_path / ".env" + env_file.write_text("X=1") + _App = _build_concrete() + _App(env_file_path=str(env_file)) + patches_chain["basic_config"].assert_called_once() + + def test_init_without_env_path_derives_location(self, patches_chain): + _App = _build_concrete() + # Without explicit path, _load_env -> _get_derived_class_location() -> inspect.getfile(self.__class__) + # On _App defined here, inspect.getfile returns this test's path. load_dotenv gets that adjacent .env. + with patch("libs.base.application_base.os.path.join", return_value="/tmp/derived/.env"), \ + patch("libs.base.application_base.os.path.dirname", return_value="/tmp/derived"): + _App() + patches_chain["load_dotenv"].assert_called_once() + + +class TestLoadEnvDirect: + def test_explicit_path_returns_path(self, patches_chain, tmp_path): + _App = _build_concrete() + app = _App.__new__(_App) + result = app._load_env(env_file_path=str(tmp_path / ".env")) + assert result == str(tmp_path / ".env") + + def test_no_path_derives_via_class(self, patches_chain): + _App = _build_concrete() + app = _App.__new__(_App) + result = app._load_env() + # Should return derived .env path + assert result.endswith(".env") + + +class TestDerivedClassLocation: + def test_returns_file_path(self): + _App = _build_concrete() + app = _App.__new__(_App) + result = app._get_derived_class_location() + # inspect.getfile returns this test module path + assert result.endswith(".py") diff --git a/src/processor/src/tests/unit/libs/base/test_orchestrator_base.py b/src/processor/src/tests/unit/libs/base/test_orchestrator_base.py new file mode 100644 index 00000000..25f1d5dc --- /dev/null +++ b/src/processor/src/tests/unit/libs/base/test_orchestrator_base.py @@ -0,0 +1,279 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +import json +import logging +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from libs.base.orchestrator_base import OrchestratorBase + + +def _run(coro): + return asyncio.run(coro) + + +class _ConcreteOrchestrator(OrchestratorBase): + async def execute(self, task_param=None): + return None + + async def prepare_mcp_tools(self): + return None + + async def prepare_agent_infos(self): + return [] + + +def _make_orchestrator(memory_store=None, framework_helper=None): + """Build an OrchestratorBase via __new__ — sidestep ABC + Azure SDK init.""" + obj = _ConcreteOrchestrator.__new__(_ConcreteOrchestrator) + obj.initialized = False + obj.memory_store = memory_store + obj.step_name = "test_step" + obj.app_context = MagicMock() + obj.agent_framework_helper = framework_helper or MagicMock() + obj._client_cache = {} + obj.task_param = SimpleNamespace(process_id="proc-1") + return obj + + +class TestSimpleHelpers: + def test_console_summarization_disabled_by_default(self): + o = _make_orchestrator() + assert o.is_console_summarization_enabled() is False + + def test_read_prompt_file(self, tmp_path): + p = tmp_path / "prompt.txt" + p.write_text("hello world", encoding="utf-8") + o = _make_orchestrator() + assert o.read_prompt_file(str(p)) == "hello world" + + def test_load_platform_registry_valid(self, tmp_path): + p = tmp_path / "reg.json" + p.write_text(json.dumps({"experts": [{"name": "a"}, {"name": "b"}]}), encoding="utf-8") + o = _make_orchestrator() + result = o.load_platform_registry(str(p)) + assert len(result) == 2 + + def test_load_platform_registry_missing_experts(self, tmp_path): + p = tmp_path / "reg.json" + p.write_text(json.dumps({"other": "data"}), encoding="utf-8") + o = _make_orchestrator() + with pytest.raises(ValueError, match="Invalid platform registry"): + o.load_platform_registry(str(p)) + + def test_load_platform_registry_experts_not_list(self, tmp_path): + p = tmp_path / "reg.json" + p.write_text(json.dumps({"experts": "nope"}), encoding="utf-8") + o = _make_orchestrator() + with pytest.raises(ValueError): + o.load_platform_registry(str(p)) + + +class TestFlushAgentMemories: + def test_flush_with_no_agents(self): + o = _make_orchestrator() + o.agents = {} + _run(o.flush_agent_memories()) # no error + + def test_flush_skips_agent_without_provider(self): + o = _make_orchestrator() + agent = MagicMock(spec=[]) # no context_provider attribute + o.agents = {"a": agent} + _run(o.flush_agent_memories()) + + def test_flush_skips_provider_with_no_inner(self): + o = _make_orchestrator() + agent = MagicMock() + agent.context_provider = MagicMock() + agent.context_provider.providers = None + o.agents = {"a": agent} + _run(o.flush_agent_memories()) + + def test_flush_calls_inner_provider_flush(self): + o = _make_orchestrator() + flush_mock = AsyncMock() + provider = MagicMock() + provider.flush = flush_mock + agent = MagicMock() + agent.context_provider = MagicMock() + agent.context_provider.providers = [provider] + o.agents = {"a": agent} + _run(o.flush_agent_memories()) + flush_mock.assert_awaited_once() + + def test_flush_swallows_provider_errors(self): + o = _make_orchestrator() + provider = MagicMock() + provider.flush = AsyncMock(side_effect=RuntimeError("boom")) + agent = MagicMock() + agent.context_provider = MagicMock() + agent.context_provider.providers = [provider] + o.agents = {"a": agent} + _run(o.flush_agent_memories()) # no raise + + +class TestGetClient: + def test_get_client_cache_hit(self): + o = _make_orchestrator() + o._client_cache["proc-1"] = "cached" + result = _run(o.get_client(thread_id="proc-1")) + assert result == "cached" + + def test_get_client_cache_miss_creates_and_caches(self): + helper = MagicMock() + helper.create_client = MagicMock(return_value="new_client") + cfg = MagicMock(endpoint="https://x", chat_deployment_name="gpt-4", api_version="v1") + helper.settings.get_service_config.return_value = cfg + o = _make_orchestrator(framework_helper=helper) + result = _run(o.get_client(thread_id="proc-9")) + assert result == "new_client" + assert o._client_cache["proc-9"] == "new_client" + + +class TestGetSummarizer: + def test_summarizer_uses_cached_client(self): + helper = MagicMock() + o = _make_orchestrator(framework_helper=helper) + o._client_cache["summarizer"] = "cached_chat_client" + with patch("libs.base.orchestrator_base.AgentBuilder") as mock_builder_cls: + built = MagicMock() + built.with_name.return_value = built + built.with_instructions.return_value = built + built.build.return_value = "summarizer_agent" + mock_builder_cls.return_value = built + result = _run(o.get_summarizer()) + assert result == "summarizer_agent" + mock_builder_cls.assert_called_once_with("cached_chat_client") + + def test_summarizer_fetches_async_when_not_cached(self): + helper = MagicMock() + helper.get_client_async = AsyncMock(return_value="fresh_client") + o = _make_orchestrator(framework_helper=helper) + with patch("libs.base.orchestrator_base.AgentBuilder") as mock_builder_cls: + built = MagicMock() + built.with_name.return_value = built + built.with_instructions.return_value = built + built.build.return_value = "summarizer_agent" + mock_builder_cls.return_value = built + _run(o.get_summarizer()) + assert o._client_cache["summarizer"] == "fresh_client" + + +class TestOnAgentResponse: + def _make_response(self, agent_name, message, elapsed=1.5): + return SimpleNamespace( + agent_name=agent_name, + message=message, + elapsed_time=elapsed, + timestamp="2024-01-01", + ) + + def test_result_generator_logs_only(self, caplog): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + with caplog.at_level(logging.INFO): + _run(o.on_agent_response(self._make_response("ResultGenerator", "x"))) + + def test_other_agent_uses_format_path(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + _run(o.on_agent_response(self._make_response("Expert", "hello"))) + telemetry.update_agent_activity.assert_awaited_once() + kwargs = telemetry.update_agent_activity.await_args.kwargs + assert kwargs["action"] == "responded" + assert kwargs["agent_name"] == "Expert" + + def test_coordinator_with_valid_payload(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_phase = AsyncMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + payload = json.dumps({ + "selected_participant": "Architect", + "instruction": "Phase 6 : Re-Check - verify outputs", + "finish": False, + }) + _run(o.on_agent_response(self._make_response("Coordinator", payload))) + telemetry.update_phase.assert_awaited_once() + telemetry.update_agent_activity.assert_awaited_once() + + def test_coordinator_with_invalid_payload_swallowed(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_phase = AsyncMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + # Bad JSON triggers the broad except path + _run(o.on_agent_response(self._make_response("Coordinator", "not json {{"))) + telemetry.update_phase.assert_not_awaited() + + +class TestOnAgentResponseStream: + def test_stream_message_event(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + resp = SimpleNamespace( + response_type="message", agent_name="Expert", tool_name=None, arguments=None + ) + _run(o.on_agent_response_stream(resp)) + kwargs = telemetry.update_agent_activity.await_args.kwargs + assert kwargs["action"] == "thinking" + + def test_stream_tool_call_event_with_args(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + resp = SimpleNamespace( + response_type="tool_call", + agent_name="Expert", + tool_name="search", + arguments={"q": "hello"}, + ) + _run(o.on_agent_response_stream(resp)) + kwargs = telemetry.update_agent_activity.await_args.kwargs + assert kwargs["action"] == "analyzing" + assert "search" in kwargs["tool_name"] + assert kwargs["tool_used"] is True + + def test_stream_tool_call_event_without_args(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + resp = SimpleNamespace( + response_type="tool_call", + agent_name="Expert", + tool_name=None, + arguments=None, + ) + _run(o.on_agent_response_stream(resp)) + telemetry.update_agent_activity.assert_awaited_once() + + def test_stream_tool_call_with_long_args_truncates(self): + o = _make_orchestrator() + telemetry = MagicMock() + telemetry.update_agent_activity = AsyncMock() + o.app_context.get_service_async = AsyncMock(return_value=telemetry) + resp = SimpleNamespace( + response_type="tool_call", + agent_name="Expert", + tool_name="search", + arguments={"q": "x" * 200}, + ) + _run(o.on_agent_response_stream(resp)) + kwargs = telemetry.update_agent_activity.await_args.kwargs + assert "..." in kwargs["tool_name"] diff --git a/src/processor/src/tests/unit/libs/mcp_server/test_mermaid_internals.py b/src/processor/src/tests/unit/libs/mcp_server/test_mermaid_internals.py new file mode 100644 index 00000000..7d72fccc --- /dev/null +++ b/src/processor/src/tests/unit/libs/mcp_server/test_mermaid_internals.py @@ -0,0 +1,361 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Coverage for mermaid validation/fix helpers + MCP tool wrappers.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from libs.mcp_server.mermaid import mcp_mermaid as mod +from libs.mcp_server.mermaid.mcp_mermaid import ( + _balance_check, + _detect_diagram_type, + _first_nonempty_line, + _mermaid_render_check, + _normalize_text, + _strip_fences_if_present, + basic_fix_mermaid, + basic_validate_mermaid, + extract_mermaid_blocks_from_markdown, +) + + +# ----------------------------------------------------------------------------- +# _normalize_text +# ----------------------------------------------------------------------------- + + +class TestNormalizeText: + def test_none_input(self): + out, fixes = _normalize_text(None) + assert out == "" + assert "input_was_none" in fixes + + def test_normalize_crlf(self): + out, fixes = _normalize_text("a\r\nb\rc") + assert out == "a\nb\nc" + assert "normalize_newlines" in fixes + + def test_replace_smart_quotes(self): + out, fixes = _normalize_text("\u201chello\u201d") + assert out == '"hello"' + assert "replace_smart_quotes" in fixes + + def test_strip_outer_newlines(self): + out, fixes = _normalize_text("\n\nfoo\n") + assert "strip_outer_newlines" in fixes + + def test_passthrough(self): + out, fixes = _normalize_text("plain") + assert out == "plain" + assert fixes == [] + + +# ----------------------------------------------------------------------------- +# extract_mermaid_blocks_from_markdown +# ----------------------------------------------------------------------------- + + +class TestExtractMermaidBlocks: + def test_empty_returns_empty(self): + assert extract_mermaid_blocks_from_markdown("") == [] + + def test_extracts_multiple_blocks(self): + md = """```mermaid +graph TD +A-->B +``` +```mermaid +sequenceDiagram +A->>B: x +```""" + blocks = extract_mermaid_blocks_from_markdown(md) + assert len(blocks) == 2 + + +# ----------------------------------------------------------------------------- +# _strip_fences_if_present +# ----------------------------------------------------------------------------- + + +class TestStripFences: + def test_empty(self): + out, fixes = _strip_fences_if_present("") + assert out == "" + assert fixes == [] + + def test_no_fences(self): + out, fixes = _strip_fences_if_present("plain") + assert out == "plain" + assert fixes == [] + + def test_strips_full_fence_block(self): + out, fixes = _strip_fences_if_present("```mermaid\ngraph TD\nA-->B\n```") + assert "graph TD" in out + assert "strip_code_fences" in fixes + + def test_unmatched_fence_returned_unchanged(self): + out, fixes = _strip_fences_if_present("```mermaid\nno close") + assert "```" in out + assert fixes == [] + + +# ----------------------------------------------------------------------------- +# _first_nonempty_line / _detect_diagram_type +# ----------------------------------------------------------------------------- + + +class TestDetectDiagramType: + def test_first_nonempty_line_none(self): + idx, line = _first_nonempty_line(["", " ", "\t"]) + assert idx is None + assert line is None + + def test_first_nonempty_line(self): + idx, line = _first_nonempty_line(["", " hi ", "next"]) + assert idx == 1 + assert line.strip() == "hi" + + def test_detect_known_prefix(self): + assert _detect_diagram_type("graph TD\nA-->B") == "graph" + + def test_detect_after_init_directive(self): + code = "%%{init: {'theme':'dark'}}%%\nflowchart LR\nA-->B" + assert _detect_diagram_type(code) == "flowchart" + + def test_detect_unknown(self): + assert _detect_diagram_type("randomtext") is None + + def test_detect_empty(self): + assert _detect_diagram_type("") is None + + +# ----------------------------------------------------------------------------- +# _balance_check +# ----------------------------------------------------------------------------- + + +class TestBalanceCheck: + def test_balanced(self): + assert _balance_check("(a) [b] {c}") == [] + + def test_missing_closer(self): + out = _balance_check("(unclosed") + assert any("missing closers" in e for e in out) + + def test_unexpected_closer(self): + out = _balance_check(")") + assert any("unexpected" in e for e in out) + + def test_unbalanced_quotes(self): + out = _balance_check('"open quote') + assert "unbalanced_quotes" in out + + def test_quotes_ignore_brackets(self): + # Brackets inside quotes don't count + assert _balance_check('"(([[{{"') == [] + + def test_backtick_quotes(self): + assert _balance_check("`(unbalanced inside ticks`") == [] + + def test_escape_handled(self): + assert _balance_check('\\"a') == [] + + def test_single_quote_state(self): + assert _balance_check("'(unbalanced'") == [] + + +# ----------------------------------------------------------------------------- +# basic_validate_mermaid +# ----------------------------------------------------------------------------- + + +class TestBasicValidate: + def test_empty_diagram(self): + v = basic_validate_mermaid("") + assert v.valid is False + assert "empty_diagram" in v.errors + + def test_missing_header(self): + v = basic_validate_mermaid("just text") + assert v.valid is False + + def test_normalization_warning(self): + v = basic_validate_mermaid("```mermaid\ngraph TD\nA-->B\n```") + assert "normalized_input" in v.warnings + + def test_valid_diagram(self): + v = basic_validate_mermaid("graph TD\nA-->B") + assert v.valid is True + assert v.diagram_type == "graph" + + +# ----------------------------------------------------------------------------- +# basic_fix_mermaid +# ----------------------------------------------------------------------------- + + +class TestBasicFix: + def test_removes_markdown_bullets(self): + fixed, applied, v = basic_fix_mermaid("- A-->B") + assert "remove_markdown_bullets" in applied + + def test_normalizes_subgraph_label(self): + fixed, applied, v = basic_fix_mermaid('graph TD\nsubgraph S1["My Group"]\nend') + assert "normalize_subgraph_labels" in applied + assert 'subgraph "My Group"' in fixed + + def test_normalizes_subgraph_label_single_quotes(self): + fixed, applied, v = basic_fix_mermaid("graph TD\nsubgraph S1['Label']\nend") + assert "normalize_subgraph_labels" in applied + + def test_appends_missing_brackets(self): + fixed, applied, v = basic_fix_mermaid("graph TD\nA[unclosed") + assert "append_missing_bracket_closers" in applied + + def test_prepends_graph_when_missing_header(self): + fixed, applied, v = basic_fix_mermaid("A-->B") + assert "prepend_graph_td" in applied + assert fixed.startswith("graph TD") + + +# ----------------------------------------------------------------------------- +# _mermaid_render_check +# ----------------------------------------------------------------------------- + + +class TestMermaidRenderCheck: + def test_node_not_found_returns_true(self): + with patch("shutil.which", return_value=None): + ok, err = _mermaid_render_check("graph TD\nA-->B") + assert ok is True + assert err == "" + + def test_subprocess_timeout_returns_true(self): + import subprocess + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="x", timeout=1)): + ok, err = _mermaid_render_check("graph TD") + assert ok is True + + def test_subprocess_os_error_returns_true(self): + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", side_effect=OSError("boom")): + ok, err = _mermaid_render_check("graph TD") + assert ok is True + + def test_valid_response_from_node(self): + result = MagicMock(returncode=0, stdout='{"valid": true}', stderr="") + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", return_value=result): + ok, err = _mermaid_render_check("graph TD") + assert ok is True + + def test_invalid_response_with_error(self): + result = MagicMock(returncode=0, stdout='{"valid": false, "error": "bad syntax"}', stderr="") + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", return_value=result): + ok, err = _mermaid_render_check("graph TD") + assert ok is False + assert "bad syntax" in err + + def test_skipped_response(self): + result = MagicMock(returncode=0, stdout='{"valid": true, "skipped": true}', stderr="") + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", return_value=result): + ok, err = _mermaid_render_check("graph TD") + assert ok is True + + def test_non_zero_with_error_in_stderr(self): + result = MagicMock(returncode=1, stdout="", stderr="Error: parse failure\nmore") + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", return_value=result): + ok, err = _mermaid_render_check("graph TD") + assert ok is False + assert "Error" in err + + def test_non_zero_no_error_lines(self): + result = MagicMock(returncode=1, stdout="", stderr="warning: deprecated\n") + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", return_value=result): + ok, err = _mermaid_render_check("graph TD") + # stderr without 'Error'/'error' falls through to ok=True + assert ok is True + + def test_non_json_stdout_falls_through(self): + result = MagicMock(returncode=0, stdout="not json", stderr="") + with patch("shutil.which", return_value="/usr/bin/node"), \ + patch("subprocess.run", return_value=result): + ok, err = _mermaid_render_check("graph TD") + assert ok is True + + +# ----------------------------------------------------------------------------- +# MCP tool wrappers — call them through their underlying functions +# ----------------------------------------------------------------------------- + + +def _call_tool(tool): + """fastmcp.@mcp.tool() wraps callables; the underlying fn is .fn.""" + if callable(tool): + return tool + fn = getattr(tool, "fn", None) + if fn is not None: + return fn + raise AssertionError(f"Cannot invoke tool {tool!r}") + + +class TestMcpToolWrappers: + def test_validate_mermaid_calls_render_when_valid(self): + with patch.object(mod, "_mermaid_render_check", return_value=(True, "")): + out = _call_tool(mod.validate_mermaid)("graph TD\nA-->B") + assert out["valid"] is True + + def test_validate_mermaid_marks_invalid_on_render_failure(self): + with patch.object(mod, "_mermaid_render_check", return_value=(False, "syntax")): + out = _call_tool(mod.validate_mermaid)("graph TD\nA-->B") + assert out["valid"] is False + assert any("mermaid_render_error" in e for e in out["errors"]) + + def test_validate_mermaid_skips_render_when_already_invalid(self): + with patch.object(mod, "_mermaid_render_check") as render: + out = _call_tool(mod.validate_mermaid)("") # empty → invalid + render.assert_not_called() + assert out["valid"] is False + + def test_fix_mermaid_calls_render_when_valid(self): + with patch.object(mod, "_mermaid_render_check", return_value=(True, "")): + out = _call_tool(mod.fix_mermaid)("A-->B") + assert out["validation"]["valid"] is True + + def test_fix_mermaid_render_failure_marks_invalid(self): + with patch.object(mod, "_mermaid_render_check", return_value=(False, "x")): + out = _call_tool(mod.fix_mermaid)("A-->B") + assert out["validation"]["valid"] is False + + def test_validate_in_markdown(self): + md = """```mermaid +graph TD +A-->B +```""" + with patch.object(mod, "_mermaid_render_check", return_value=(True, "")): + out = _call_tool(mod.validate_mermaid_in_markdown)(md) + assert out["blocks_found"] == 1 + assert out["all_valid"] is True + + def test_validate_in_markdown_no_blocks(self): + out = _call_tool(mod.validate_mermaid_in_markdown)("plain text") + assert out["blocks_found"] == 0 + assert out["all_valid"] is True + + def test_fix_in_markdown_replaces_blocks(self): + md = """text\n```mermaid\nA-->B\n```\nmore""" + with patch.object(mod, "_mermaid_render_check", return_value=(True, "")): + out = _call_tool(mod.fix_mermaid_in_markdown)(md) + assert out["blocks_found"] == 1 + assert "graph TD" in out["updated_markdown"] + assert len(out["per_block_fixes"]) == 1 diff --git a/src/processor/src/tests/unit/libs/reporting/test_migration_report_generator.py b/src/processor/src/tests/unit/libs/reporting/test_migration_report_generator.py new file mode 100644 index 00000000..f35989ed --- /dev/null +++ b/src/processor/src/tests/unit/libs/reporting/test_migration_report_generator.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio + +import pytest + +from libs.reporting.migration_report_generator import ( + MigrationReportCollector, + MigrationReportGenerator, +) +from libs.reporting.models.failure_context import ( + FailureSeverity, + FailureType, +) +from libs.reporting.models.migration_report import ReportStatus + + +def _run(coro): + return asyncio.run(coro) + + +class TestMigrationReportCollectorBasics: + def test_init_seeds_environment_and_ids(self): + c = MigrationReportCollector("p1") + assert c.process_id == "p1" + assert isinstance(c.report_id, str) and len(c.report_id) > 0 + assert c.start_time > 0 + assert c._environment_context is not None + assert c._environment_context.python_version + + def test_set_current_step_creates_and_updates_phase(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis", step_phase="phase-a") + assert c._current_step == "analysis" + assert c._step_contexts["analysis"].step_phase == "phase-a" + c.set_current_step("analysis", step_phase="phase-b") + assert c._step_contexts["analysis"].step_phase == "phase-b" + + def test_set_current_step_handles_invalid(self): + c = MigrationReportCollector("p1") + c.set_current_step("", step_phase=None) + assert c._current_step == "unknown" + + def test_set_current_file_records_size(self, tmp_path): + c = MigrationReportCollector("p1") + f = tmp_path / "deploy.yaml" + f.write_text("kind: Deployment\n") + c.set_current_file("deploy.yaml", str(f), yaml_kind="Deployment") + assert c._file_contexts["deploy.yaml"].yaml_kind == "Deployment" + assert c._file_contexts["deploy.yaml"].file_size_bytes is not None + c.set_current_file("deploy.yaml", str(f)) + assert c._file_contexts["deploy.yaml"].yaml_kind == "Deployment" + + def test_set_current_file_missing_path_no_size(self): + c = MigrationReportCollector("p1") + c.set_current_file("ghost.yaml", "/no/such/path.yaml") + assert c._file_contexts["ghost.yaml"].file_size_bytes is None + + def test_set_current_agent_appends_activity(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis") + c.set_current_file("a.yaml", "/no/a.yaml") + c.set_current_agent("Azure_Expert", "expert", activity="reviewing") + assert c._current_agent == "Azure_Expert" + assert len(c._agent_activities) == 1 + rec = c._agent_activities[0] + assert rec["agent_name"] == "Azure_Expert" + assert rec["step"] == "analysis" + assert rec["file"] == "a.yaml" + + def test_mark_step_completed_sets_time_when_known(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis") + c.mark_step_completed("analysis", execution_time=1.5) + assert c._step_contexts["analysis"].execution_time_seconds == 1.5 + + def test_mark_step_completed_unknown_step_noop(self): + c = MigrationReportCollector("p1") + c.mark_step_completed("not-a-step", execution_time=1.0) + assert "not-a-step" not in c._step_contexts + + +class TestRecordFailure: + def test_record_failure_auto_classifies(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis") + c.set_current_agent("AzureExpert", "expert") + ctx = c.record_failure(ConnectionError("network connection lost")) + assert ctx.failure_type == FailureType.NETWORK_ERROR + assert ctx.severity == FailureSeverity.LOW + assert ctx.agent_context is not None + assert ctx.step_context is not None + assert c._failure_contexts == [ctx] + + def test_record_failure_truncates_long_stack(self): + c = MigrationReportCollector("p1") + long_stack = "x" * 25000 + ctx = c.record_failure(RuntimeError("boom"), stack_trace=long_stack) + assert "[stack trace truncated]" in (ctx.stack_trace or "") + + def test_record_failure_custom_overrides(self): + c = MigrationReportCollector("p1") + ctx = c.record_failure( + Exception("orig"), + failure_type=FailureType.YAML_PARSING_ERROR, + severity=FailureSeverity.MEDIUM, + custom_message="custom", + stack_trace="short trace", + exception_type="MyError", + ) + assert ctx.error_message == "custom" + assert ctx.exception_type == "MyError" + assert ctx.stack_trace == "short trace" + assert ctx.failure_type == FailureType.YAML_PARSING_ERROR + assert ctx.severity == FailureSeverity.MEDIUM + + +class TestClassifiers: + @pytest.mark.parametrize( + "exc,expected", + [ + (ConnectionError("x"), FailureType.NETWORK_ERROR), + (Exception("network connection refused"), FailureType.NETWORK_ERROR), + (Exception("operation timeout"), FailureType.TIMEOUT), + (Exception("auth failed"), FailureType.AUTHENTICATION_FAILURE), + (Exception("credential missing"), FailureType.AUTHENTICATION_FAILURE), + (Exception("permission denied"), FailureType.AUTHENTICATION_FAILURE), + (ValueError("bad value"), FailureType.CONFIGURATION_ERROR), + (TypeError("nope"), FailureType.CONFIGURATION_ERROR), + (Exception("config error"), FailureType.CONFIGURATION_ERROR), + (Exception("yaml parse boom"), FailureType.YAML_PARSING_ERROR), + (Exception("orchestrator failed"), FailureType.ORCHESTRATOR_ERROR), + (Exception("manager crashed"), FailureType.ORCHESTRATOR_ERROR), + (Exception("totally random"), FailureType.UNKNOWN_ERROR), + ], + ) + def test_classify_failure_type(self, exc, expected): + c = MigrationReportCollector("p1") + assert c._classify_failure_type(exc) is expected + + @pytest.mark.parametrize( + "ftype,expected", + [ + (FailureType.AUTHENTICATION_FAILURE, FailureSeverity.CRITICAL), + (FailureType.CONFIGURATION_ERROR, FailureSeverity.CRITICAL), + (FailureType.TIMEOUT, FailureSeverity.HIGH), + (FailureType.ORCHESTRATOR_ERROR, FailureSeverity.HIGH), + (FailureType.YAML_PARSING_ERROR, FailureSeverity.MEDIUM), + (FailureType.UNSUPPORTED_API_VERSION, FailureSeverity.MEDIUM), + (FailureType.NETWORK_ERROR, FailureSeverity.LOW), + ], + ) + def test_classify_failure_severity(self, ftype, expected): + c = MigrationReportCollector("p1") + assert c._classify_failure_severity(Exception("x"), ftype) is expected + + +class TestGenerator: + def test_generate_with_no_failures(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis") + c.mark_step_completed("analysis", execution_time=2.0) + c.set_current_file("a.yaml", "/no/a.yaml", yaml_kind="Deployment") + gen = MigrationReportGenerator(c) + report = _run(gen.generate_failure_report(overall_status=ReportStatus.SUCCESS)) + assert report.process_id == "p1" + assert report.failure_analysis is None + assert report.remediation_guide is None + assert report.input_analysis.total_files == 1 + assert any(s.step_name == "analysis" for s in report.step_details) + + def test_generate_with_failures(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis") + c.record_failure(Exception("auth failed")) # CRITICAL + c.record_failure(asyncio.TimeoutError()) # HIGH + c.record_failure(Exception("yaml parse error")) # MEDIUM + c.record_failure(Exception("orchestrator boom")) # HIGH + c.set_current_file("a.yaml", "/no/a.yaml", yaml_kind=None) + gen = MigrationReportGenerator(c) + report = _run(gen.generate_failure_report()) + assert report.failure_analysis is not None + assert report.remediation_guide is not None + assert len(report.failure_analysis.contributing_factors) == 3 + assert len(report.remediation_guide.priority_actions) >= 1 + assert report.input_analysis.file_breakdown.get("Unknown") == 1 + + def test_step_status_partial_when_no_failure_no_completion(self): + c = MigrationReportCollector("p1") + c.set_current_step("design") + gen = MigrationReportGenerator(c) + details = gen._create_step_details() + assert details[0].status == "partial" + + def test_step_status_failed_when_failure_attached(self): + c = MigrationReportCollector("p1") + c.set_current_step("design") + c.record_failure(Exception("bad")) + gen = MigrationReportGenerator(c) + details = gen._create_step_details() + assert details[0].status == "failed" + + def test_supporting_data_includes_recent_failures(self): + c = MigrationReportCollector("p1") + c.set_current_step("analysis") + for i in range(5): + c.record_failure(Exception(f"err{i}")) + gen = MigrationReportGenerator(c) + sd = gen._create_supporting_data() + assert len(sd.log_excerpts) == 3 + assert sd.environment_info.get("python_version") diff --git a/src/processor/src/tests/unit/services/test_queue_service_helpers.py b/src/processor/src/tests/unit/services/test_queue_service_helpers.py new file mode 100644 index 00000000..890d74b0 --- /dev/null +++ b/src/processor/src/tests/unit/services/test_queue_service_helpers.py @@ -0,0 +1,336 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +import base64 +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from azure.core.exceptions import AzureError, ResourceNotFoundError + +from services.queue_service import ( + MigrationQueueMessage, + QueueMigrationService, + QueueServiceConfig, + create_default_migration_request, + is_base64_encoded, +) +from steps.analysis.models.step_param import Analysis_TaskParam + + +def _run(coro): + return asyncio.run(coro) + + +def _service(account: str = "myacct", queue: str = "q") -> QueueMigrationService: + """Bypass __init__ to avoid creating real Azure clients.""" + s = QueueMigrationService.__new__(QueueMigrationService) + s.config = QueueServiceConfig(storage_account_name=account, queue_name=queue) + s.is_running = False + s.app_context = None + s.main_queue = MagicMock() + s.queue_service = MagicMock() + s.active_workers = set() + s._worker_tasks = {} + s._worker_inflight = {} + s._worker_inflight_message = {} + s._worker_inflight_task_param = {} + s._worker_inflight_task = {} + s._control_watcher_task = None + s.instance_id = 99 + s.debug_mode = False + return s + + +class TestModuleHelpers: + def test_is_base64_encoded_true(self): + s = base64.b64encode(b"hello").decode("utf-8") + assert is_base64_encoded(s) is True + + def test_is_base64_encoded_false(self): + assert is_base64_encoded("not_base64!@#") is False + + def test_create_default_migration_request_keys(self): + req = create_default_migration_request(process_id="p1", user_id="u1") + assert req["process_id"] == "p1" + assert req["user_id"] == "u1" + assert req["container_name"] == "processes" + assert req["source_file_folder"] == "p1/source" + assert req["workspace_file_folder"] == "p1/workspace" + assert req["output_file_folder"] == "p1/converted" + + +class TestMigrationQueueMessage: + def _payload(self) -> dict: + return { + "process_id": "p1", + "user_id": "u1", + "migration_request": { + "process_id": "p1", + "user_id": "u1", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + + def test_post_init_validates_required_fields(self): + with pytest.raises(ValueError, match="missing mandatory fields"): + MigrationQueueMessage( + process_id="p1", + migration_request={"process_id": "p1"}, + ) + + def test_from_queue_message_b64(self): + data = self._payload() + encoded = base64.b64encode(json.dumps(data).encode("utf-8")).decode("utf-8") + qm = SimpleNamespace(content=encoded) + out = MigrationQueueMessage.from_queue_message(qm) + assert out.process_id == "p1" + assert out.migration_request["container_name"] == "c" + + def test_from_queue_message_plain_json(self): + data = self._payload() + qm = SimpleNamespace(content=json.dumps(data)) + out = MigrationQueueMessage.from_queue_message(qm) + assert out.process_id == "p1" + + def test_from_queue_message_bytes(self): + data = self._payload() + qm = SimpleNamespace(content=json.dumps(data).encode("utf-8")) + out = MigrationQueueMessage.from_queue_message(qm) + assert out.process_id == "p1" + + def test_from_queue_message_auto_completes(self): + # Only process_id given → migration_request is auto-built + data = {"process_id": "p9", "user_id": "u9"} + qm = SimpleNamespace(content=json.dumps(data)) + out = MigrationQueueMessage.from_queue_message(qm) + assert out.migration_request["process_id"] == "p9" + assert out.retry_count == 0 + assert out.priority == "normal" + + def test_from_queue_message_invalid_json_raises(self): + qm = SimpleNamespace(content="not json{") + with pytest.raises(ValueError, match="Invalid queue message format"): + MigrationQueueMessage.from_queue_message(qm) + + def test_from_queue_message_unexpected_content_type(self): + qm = SimpleNamespace(content=12345) + with pytest.raises(ValueError, match="Invalid queue message format"): + MigrationQueueMessage.from_queue_message(qm) + + def test_from_queue_message_filters_unexpected_fields(self): + data = self._payload() + data["junk"] = "drop-me" + qm = SimpleNamespace(content=json.dumps(data)) + out = MigrationQueueMessage.from_queue_message(qm) + assert not hasattr(out, "junk") + + +class TestStorageAccountName: + def test_empty(self): + s = _service(account="") + assert s._storage_account_name() == "" + + def test_https_url(self): + s = _service(account="https://mystorage.queue.core.windows.net") + assert s._storage_account_name() == "mystorage" + + def test_http_url(self): + s = _service(account="http://mystorage.dfs.core.windows.net") + assert s._storage_account_name() == "mystorage" + + def test_hostname(self): + s = _service(account="mystorage.queue.core.windows.net") + assert s._storage_account_name() == "mystorage" + + def test_plain_name(self): + s = _service(account="myacct") + assert s._storage_account_name() == "myacct" + + +class TestStatusAndQueueInfo: + def test_get_service_status(self): + s = _service() + s.is_running = True + s.active_workers = {1, 3} + s._worker_inflight = {1: "p1"} + out = s.get_service_status() + assert out["is_running"] is True + assert out["active_workers"] == 2 + assert out["active_worker_ids"] == [1, 3] + assert out["inflight"] == {1: "p1"} + assert out["queue_name"] == "q" + + def test_get_queue_info_success(self): + s = _service() + props = MagicMock() + props.approximate_message_count = 5 + props.metadata = {"k": "v"} + s.main_queue.get_queue_properties.return_value = props + out = _run(s.get_queue_info()) + assert out["main_queue"]["approximate_message_count"] == 5 + assert out["main_queue"]["metadata"] == {"k": "v"} + + def test_get_queue_info_error(self): + s = _service() + s.main_queue.get_queue_properties.side_effect = RuntimeError("nope") + out = _run(s.get_queue_info()) + assert "error" in out + + +class TestEnsureQueuesExist: + def test_swallows_already_exists(self): + s = _service() + s.main_queue.create_queue.side_effect = Exception("already exists") + # Should not raise + _run(s._ensure_queues_exist()) + + def test_creates_queue(self): + s = _service() + s.debug_mode = True + s.main_queue.create_queue.return_value = None + _run(s._ensure_queues_exist()) + s.main_queue.create_queue.assert_called_once() + + +class TestDeleteInflightMessage: + def test_no_message_logs_and_returns(self): + s = _service() + _run(s._delete_inflight_queue_message(1)) + s.main_queue.delete_message.assert_not_called() + + def test_deletes_when_message_present(self): + s = _service() + s._worker_inflight_message[1] = ("mid", "popr") + _run(s._delete_inflight_queue_message(1)) + s.main_queue.delete_message.assert_called_once_with("mid", "popr") + + def test_resource_not_found_swallowed(self): + s = _service() + s._worker_inflight_message[1] = ("mid", "popr") + s.main_queue.delete_message.side_effect = ResourceNotFoundError("gone") + _run(s._delete_inflight_queue_message(1)) + + def test_azure_error_swallowed(self): + s = _service() + s._worker_inflight_message[1] = ("mid", "popr") + s.main_queue.delete_message.side_effect = AzureError("boom") + _run(s._delete_inflight_queue_message(1)) + + +class TestBuildTaskParam: + def test_builds_task_param_from_queue_message(self): + s = _service() + data = { + "process_id": "p1", + "user_id": "u1", + "migration_request": { + "process_id": "p1", + "user_id": "u1", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + qm = SimpleNamespace(content=json.dumps(data)) + tp = s._build_task_param(qm) + assert isinstance(tp, Analysis_TaskParam) + assert tp.process_id == "p1" + assert tp.container_name == "c" + + +class TestCleanupTelemetry: + def test_no_app_context(self): + s = _service() + s.app_context = None + # Should silently skip + _run(s._cleanup_process_telemetry("p1")) + + def test_calls_delete_via_app_context(self): + s = _service() + tm = MagicMock() + tm.delete_process = AsyncMock() + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=tm) + s.app_context = ctx + _run(s._cleanup_process_telemetry("p1")) + tm.delete_process.assert_awaited_once_with("p1") + + def test_falls_back_when_get_service_async_fails(self): + s = _service() + ctx = MagicMock() + ctx.get_service_async = AsyncMock(side_effect=RuntimeError("boom")) + s.app_context = ctx + with patch("services.queue_service.TelemetryManager") as MockTM: + instance = MockTM.return_value + instance.delete_process = AsyncMock() + _run(s._cleanup_process_telemetry("p1")) + MockTM.assert_called_once_with(ctx) + instance.delete_process.assert_awaited_once_with("p1") + + def test_swallows_telemetry_delete_failures(self): + s = _service() + tm = MagicMock() + tm.delete_process = AsyncMock(side_effect=RuntimeError("delete failed")) + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=tm) + s.app_context = ctx + # Should not raise + _run(s._cleanup_process_telemetry("p1")) + + +class TestStopProcess: + def test_returns_false_when_not_inflight(self): + s = _service() + result = _run(s.stop_process("nope")) + assert result is False + + def test_kills_and_returns_true(self): + s = _service() + s._worker_inflight[7] = "p1" + s._worker_inflight_message[7] = ("m", "r") + s._worker_inflight_task_param[7] = Analysis_TaskParam( + process_id="p1", + container_name="c", + source_file_folder="p1/source", + workspace_file_folder="p1/workspace", + output_file_folder="p1/converted", + ) + # No app_context → telemetry cleanup is a no-op + s.app_context = None + + cleaned: list[str] = [] + + async def _cleanup_blobs(tp): + cleaned.append(tp.process_id) + + s._cleanup_process_blobs = _cleanup_blobs # type: ignore[assignment] + + result = _run(s.stop_process("p1", timeout_seconds=0.1)) + assert result is True + assert cleaned == ["p1"] + s.main_queue.delete_message.assert_called_once_with("m", "r") + + def test_kills_without_task_param_skips_blob_cleanup(self): + s = _service() + s._worker_inflight[1] = "p1" + s._worker_inflight_message[1] = ("m", "r") + s.app_context = None + called = [] + + async def _cleanup_blobs(tp): + called.append(tp) + + s._cleanup_process_blobs = _cleanup_blobs # type: ignore[assignment] + + result = _run(s.stop_process("p1", timeout_seconds=0.1)) + assert result is True + assert called == [] diff --git a/src/processor/src/tests/unit/services/test_queue_service_internals.py b/src/processor/src/tests/unit/services/test_queue_service_internals.py new file mode 100644 index 00000000..d5ce627e --- /dev/null +++ b/src/processor/src/tests/unit/services/test_queue_service_internals.py @@ -0,0 +1,806 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Coverage for QueueMigrationService internals: worker loop, processing, +control watcher, blob cleanup, and start/stop lifecycle.""" + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.queue_service import QueueMigrationService, QueueServiceConfig +from steps.analysis.models.step_param import Analysis_TaskParam + + +def _run(coro): + return asyncio.run(coro) + + +def _service(account: str = "myacct", queue: str = "q") -> QueueMigrationService: + """Bypass __init__ to avoid real Azure clients.""" + s = QueueMigrationService.__new__(QueueMigrationService) + s.config = QueueServiceConfig( + storage_account_name=account, + queue_name=queue, + poll_interval_seconds=0, # don't slow tests + control_poll_interval_seconds=0, + visibility_timeout_minutes=1, + concurrent_workers=1, + ) + s.is_running = False + s.app_context = None + s.main_queue = MagicMock() + s.queue_service = MagicMock() + s.active_workers = set() + s._worker_tasks = {} + s._worker_inflight = {} + s._worker_inflight_message = {} + s._worker_inflight_task_param = {} + s._worker_inflight_task = {} + s._control_watcher_task = None + s.instance_id = 123 + s.debug_mode = False + return s + + +def _task_param(pid: str = "p1") -> Analysis_TaskParam: + return Analysis_TaskParam( + process_id=pid, + container_name="c", + source_file_folder=f"{pid}/source", + workspace_file_folder=f"{pid}/workspace", + output_file_folder=f"{pid}/converted", + ) + + +# ----------------------------------------------------------------------------- +# stop_service +# ----------------------------------------------------------------------------- + + +class TestStopService: + def test_stop_service_clears_state_and_closes_clients(self): + s = _service() + s.is_running = True + QueueMigrationService._active_instances.add(s.instance_id) + s._worker_inflight = {1: "p"} + s._worker_inflight_message = {1: ("m", "r")} + s._worker_inflight_task_param = {1: _task_param()} + s._worker_inflight_task = {1: MagicMock()} + _run(s.stop_service()) + assert s.is_running is False + assert s._worker_inflight == {} + assert s._worker_inflight_message == {} + assert s._worker_inflight_task_param == {} + assert s._worker_inflight_task == {} + assert s.instance_id not in QueueMigrationService._active_instances + s.main_queue.close.assert_called_once() + s.queue_service.close.assert_called_once() + + def test_stop_service_cancels_workers(self): + s = _service() + s.is_running = True + + async def _long(): + await asyncio.sleep(60) + + async def _go(): + t1 = asyncio.create_task(_long()) + s._worker_tasks = {1: t1} + await s.stop_service() + assert t1.cancelled() or t1.done() + + _run(_go()) + + def test_stop_service_cancels_control_watcher(self): + s = _service() + s.is_running = True + + async def _long(): + await asyncio.sleep(60) + + async def _go(): + wt = asyncio.create_task(_long()) + s._control_watcher_task = wt + await s.stop_service() + assert wt.cancelled() or wt.done() + assert s._control_watcher_task is None + + _run(_go()) + + def test_stop_service_swallows_close_errors(self): + s = _service() + s.is_running = True + s.main_queue.close.side_effect = RuntimeError("boom") + s.queue_service.close.side_effect = RuntimeError("boom") + _run(s.stop_service()) # no raise + + +# ----------------------------------------------------------------------------- +# stop_worker +# ----------------------------------------------------------------------------- + + +class TestStopWorker: + def test_stop_worker_missing_returns_false(self): + s = _service() + ok = _run(s.stop_worker(99)) + assert ok is False + + def test_stop_worker_cancels_completed_task(self): + """stop_worker called against an already-completed task still returns True + and cleans up bookkeeping.""" + s = _service() + + async def _quick(): + return "done" + + async def _go(): + t = asyncio.create_task(_quick()) + await asyncio.sleep(0) # let it finish + s._worker_tasks = {2: t} + s._worker_inflight = {2: "pid"} + ok = await s.stop_worker(2, timeout_seconds=1.0) + assert ok is True + assert 2 not in s._worker_tasks + assert 2 not in s._worker_inflight + + _run(_go()) + + def test_stop_worker_no_inflight_branch(self): + """Cover the 'no inflight' log branch.""" + s = _service() + + async def _quick(): + return None + + async def _go(): + t = asyncio.create_task(_quick()) + await asyncio.sleep(0) + s._worker_tasks = {3: t} + ok = await s.stop_worker(3, timeout_seconds=1.0) + assert ok is True + + _run(_go()) + + +# ----------------------------------------------------------------------------- +# control watcher +# ----------------------------------------------------------------------------- + + +class TestControlWatcher: + def test_idle_when_no_inflight(self): + s = _service() + s.is_running = True + ctx = MagicMock() + ctrl = MagicMock() + ctx.get_service_async = AsyncMock(return_value=ctrl) + s.app_context = ctx + + async def _go(): + task = asyncio.create_task(s._control_watcher_loop()) + await asyncio.sleep(0.01) + s.is_running = False + await asyncio.sleep(0.01) + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + _run(_go()) + + def test_processes_kill_request(self): + s = _service() + s.is_running = True + s._worker_inflight = {1: "p1"} + + record = SimpleNamespace(kill_requested=True, kill_state="pending") + # Track invocations: after first ack/mark_executed, flip is_running so loop exits. + ctrl = MagicMock() + ctrl.get = AsyncMock(return_value=record) + + async def _ack(*_a, **_kw): + s.is_running = False # let the loop exit cleanly after this ack + + ctrl.ack_executing = AsyncMock(side_effect=_ack) + ctrl.mark_executed = AsyncMock() + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=ctrl) + s.app_context = ctx + s.stop_process = AsyncMock(return_value=True) + + _run(s._control_watcher_loop()) + ctrl.ack_executing.assert_awaited() + ctrl.mark_executed.assert_awaited() + + def test_skips_records_already_executed(self): + s = _service() + s.is_running = True + s._worker_inflight = {1: "p1"} + + record = SimpleNamespace(kill_requested=True, kill_state="executed") + ctrl = MagicMock() + get_calls = {"n": 0} + + async def _get(_pid): + get_calls["n"] += 1 + if get_calls["n"] >= 1: + s.is_running = False # exit after first iteration + return record + + ctrl.get = AsyncMock(side_effect=_get) + ctrl.ack_executing = AsyncMock() + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=ctrl) + s.app_context = ctx + + _run(s._control_watcher_loop()) + ctrl.ack_executing.assert_not_awaited() + + def test_falls_back_to_direct_construction(self): + s = _service() + s.is_running = False # don't loop + ctx = MagicMock() + ctx.get_service_async = AsyncMock(side_effect=RuntimeError("no svc")) + s.app_context = ctx + with patch("services.queue_service.ProcessControlManager") as MockMgr: + MockMgr.return_value = MagicMock() + _run(s._control_watcher_loop()) + MockMgr.assert_called_once_with(ctx) + + def test_swallows_loop_iteration_errors(self): + s = _service() + s.is_running = True + s._worker_inflight = {1: "p1"} + ctrl = MagicMock() + get_calls = {"n": 0} + + async def _get(_pid): + get_calls["n"] += 1 + if get_calls["n"] >= 1: + s.is_running = False # exit after first iteration + raise RuntimeError("bad") + + ctrl.get = AsyncMock(side_effect=_get) + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=ctrl) + s.app_context = ctx + + _run(s._control_watcher_loop()) # exception swallowed by loop + + +# ----------------------------------------------------------------------------- +# _worker_loop +# ----------------------------------------------------------------------------- + + +class TestWorkerLoop: + def test_worker_loop_no_main_queue_sleeps(self): + s = _service() + s.main_queue = None + s.is_running = True + # After one iteration of the no-queue branch, exit cleanly. + original_sleep = asyncio.sleep + sleep_calls = {"n": 0} + + async def _patched_sleep(delay, *a, **kw): + sleep_calls["n"] += 1 + if sleep_calls["n"] >= 1: + s.is_running = False + await original_sleep(0) + + with patch("services.queue_service.asyncio.sleep", new=_patched_sleep): + _run(s._worker_loop(1)) + assert sleep_calls["n"] >= 1 + + def test_worker_loop_swallows_receive_errors(self): + s = _service() + s.is_running = True + + original_sleep = asyncio.sleep + sleep_calls = {"n": 0} + + def _receive(*_a, **_kw): + raise RuntimeError("transient") + + async def _patched_sleep(delay, *a, **kw): + sleep_calls["n"] += 1 + if sleep_calls["n"] >= 1: + s.is_running = False + await original_sleep(0) + + s.main_queue.receive_messages.side_effect = _receive + with patch("services.queue_service.asyncio.sleep", new=_patched_sleep): + _run(s._worker_loop(1)) + + def test_worker_loop_iterates_message(self): + s = _service() + s.is_running = True + + msg = SimpleNamespace(id="m1", pop_receipt="r1", content="x") + + # Configure to yield one message then no more — flip is_running + call_state = {"calls": 0} + + def _receive(*_a, **_kw): + call_state["calls"] += 1 + if call_state["calls"] == 1: + return iter([msg]) + s.is_running = False + return iter([]) + + s.main_queue.receive_messages.side_effect = _receive + + async def _process(worker_id, queue_message): # noqa: D401 + return None + + s._process_queue_message = _process # type: ignore[assignment] + + _run(s._worker_loop(7)) + assert call_state["calls"] >= 1 + + def test_worker_loop_handles_job_crash(self): + """Job exception triggers _handle_failed_no_retry path.""" + s = _service() + s.is_running = True + msg = SimpleNamespace(id="m1", pop_receipt="r1", content="x") + call_state = {"calls": 0} + + def _receive(*_a, **_kw): + call_state["calls"] += 1 + if call_state["calls"] == 1: + return iter([msg]) + s.is_running = False + return iter([]) + + s.main_queue.receive_messages.side_effect = _receive + + async def _crash(worker_id, queue_message): # noqa: D401 + raise RuntimeError("boom") + + s._process_queue_message = _crash # type: ignore[assignment] + s._handle_failed_no_retry = AsyncMock() + + _run(s._worker_loop(1)) + s._handle_failed_no_retry.assert_awaited() + + +# ----------------------------------------------------------------------------- +# _process_queue_message +# ----------------------------------------------------------------------------- + + +class TestProcessQueueMessage: + def test_invalid_payload_triggers_failure_no_retry(self): + s = _service() + s._handle_failed_no_retry = AsyncMock() + msg = SimpleNamespace(id="m1", pop_receipt="r1", content="not-json") + _run(s._process_queue_message(1, msg)) + s._handle_failed_no_retry.assert_awaited() + # process_id is "" since parsing failed + kwargs = s._handle_failed_no_retry.await_args.kwargs + assert kwargs.get("process_id") == "" or s._handle_failed_no_retry.await_args.args[1] == "" + + def test_success_path_calls_successful_handler(self): + s = _service() + ctx = MagicMock() + proc = MagicMock() + proc.run = AsyncMock(return_value=SimpleNamespace(is_hard_terminated=False)) + ctx.get_service.return_value = proc + s.app_context = ctx + s._handle_successful_processing = AsyncMock() + s._handle_failed_no_retry = AsyncMock() + + import json + payload = { + "process_id": "p1", + "user_id": "u", + "migration_request": { + "process_id": "p1", + "user_id": "u", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + msg = SimpleNamespace(id="m1", pop_receipt="r1", content=json.dumps(payload)) + _run(s._process_queue_message(2, msg)) + s._handle_successful_processing.assert_awaited_once() + s._handle_failed_no_retry.assert_not_awaited() + + def test_hard_terminated_result_routes_to_no_retry_with_process_scope(self): + s = _service() + ctx = MagicMock() + result = SimpleNamespace( + is_hard_terminated=True, + blocking_issues=["a", "b"], + reason="denied", + ) + proc = MagicMock() + proc.run = AsyncMock(return_value=result) + ctx.get_service.return_value = proc + s.app_context = ctx + s._handle_failed_no_retry = AsyncMock() + s._handle_successful_processing = AsyncMock() + + import json + payload = { + "process_id": "p1", + "user_id": "u", + "migration_request": { + "process_id": "p1", + "user_id": "u", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + msg = SimpleNamespace(id="m1", pop_receipt="r1", content=json.dumps(payload)) + _run(s._process_queue_message(3, msg)) + s._handle_failed_no_retry.assert_awaited_once() + kwargs = s._handle_failed_no_retry.await_args.kwargs + assert kwargs.get("cleanup_scope") == "process" + + def test_workflow_returns_none_treated_as_failure(self): + s = _service() + ctx = MagicMock() + proc = MagicMock() + proc.run = AsyncMock(return_value=None) + ctx.get_service.return_value = proc + s.app_context = ctx + s._handle_failed_no_retry = AsyncMock() + s._handle_successful_processing = AsyncMock() + + import json + payload = { + "process_id": "p1", + "user_id": "u", + "migration_request": { + "process_id": "p1", + "user_id": "u", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + msg = SimpleNamespace(id="m1", pop_receipt="r1", content=json.dumps(payload)) + _run(s._process_queue_message(4, msg)) + s._handle_failed_no_retry.assert_awaited_once() + + def test_workflow_executor_failed_routes_to_no_retry(self): + s = _service() + ctx = MagicMock() + from steps.migration_processor import WorkflowExecutorFailedException + + proc = MagicMock() + proc.run = AsyncMock(side_effect=WorkflowExecutorFailedException("nope")) + ctx.get_service.return_value = proc + s.app_context = ctx + s._handle_failed_no_retry = AsyncMock() + + import json + payload = { + "process_id": "p1", + "user_id": "u", + "migration_request": { + "process_id": "p1", + "user_id": "u", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + msg = SimpleNamespace(id="m1", pop_receipt="r1", content=json.dumps(payload)) + _run(s._process_queue_message(5, msg)) + s._handle_failed_no_retry.assert_awaited_once() + + def test_unhandled_exception_routes_to_no_retry(self): + s = _service() + ctx = MagicMock() + proc = MagicMock() + proc.run = AsyncMock(side_effect=RuntimeError("kaboom")) + ctx.get_service.return_value = proc + s.app_context = ctx + s._handle_failed_no_retry = AsyncMock() + + import json + payload = { + "process_id": "p1", + "user_id": "u", + "migration_request": { + "process_id": "p1", + "user_id": "u", + "container_name": "c", + "source_file_folder": "p1/source", + "workspace_file_folder": "p1/workspace", + "output_file_folder": "p1/converted", + }, + } + msg = SimpleNamespace(id="m1", pop_receipt="r1", content=json.dumps(payload)) + _run(s._process_queue_message(6, msg)) + s._handle_failed_no_retry.assert_awaited_once() + + +# ----------------------------------------------------------------------------- +# _handle_successful_processing +# ----------------------------------------------------------------------------- + + +class TestHandleSuccessfulProcessing: + def test_deletes_message_on_success(self): + s = _service() + s.debug_mode = True + msg = SimpleNamespace(id="m1", pop_receipt="r1") + _run(s._handle_successful_processing(msg, "p1", 1.5)) + s.main_queue.delete_message.assert_called_once_with("m1", "r1") + + def test_swallows_resource_not_found(self): + from azure.core.exceptions import ResourceNotFoundError + s = _service() + s.main_queue.delete_message.side_effect = ResourceNotFoundError("gone") + msg = SimpleNamespace(id="m1", pop_receipt="r1") + _run(s._handle_successful_processing(msg, "p1", 1.5)) + + def test_swallows_azure_error(self): + from azure.core.exceptions import AzureError + s = _service() + s.main_queue.delete_message.side_effect = AzureError("boom") + msg = SimpleNamespace(id="m1", pop_receipt="r1") + _run(s._handle_successful_processing(msg, "p1", 1.5)) + + +# ----------------------------------------------------------------------------- +# _handle_failed_no_retry +# ----------------------------------------------------------------------------- + + +class TestHandleFailedNoRetry: + def test_writes_failure_telemetry_when_app_context_present(self): + s = _service() + ctx = MagicMock() + telemetry = MagicMock() + telemetry.get_current_process = AsyncMock( + return_value=SimpleNamespace(step="design") + ) + telemetry.record_failure_outcome = AsyncMock() + ctx.get_service_async = AsyncMock(return_value=telemetry) + s.app_context = ctx + msg = SimpleNamespace(id="m1", pop_receipt="r1") + _run( + s._handle_failed_no_retry( + msg, "p1", "boom", 0.5, task_param=None, cleanup_scope="output" + ) + ) + telemetry.record_failure_outcome.assert_awaited_once() + + def test_swallows_telemetry_failure(self): + s = _service() + ctx = MagicMock() + telemetry = MagicMock() + telemetry.get_current_process = AsyncMock(side_effect=RuntimeError("x")) + telemetry.record_failure_outcome = AsyncMock(side_effect=RuntimeError("y")) + ctx.get_service_async = AsyncMock(return_value=telemetry) + s.app_context = ctx + msg = SimpleNamespace(id="m1", pop_receipt="r1") + # Should not raise even if telemetry blows up + _run(s._handle_failed_no_retry(msg, "p1", "boom", 0.5)) + + def test_skips_telemetry_for_unknown_process(self): + s = _service() + ctx = MagicMock() + ctx.get_service_async = AsyncMock() + s.app_context = ctx + msg = SimpleNamespace(id="m1", pop_receipt="r1") + _run(s._handle_failed_no_retry(msg, "", "boom", 0.5)) + ctx.get_service_async.assert_not_called() + + def test_cleanup_scope_process(self): + s = _service() + s.app_context = None + s._cleanup_process_blobs = AsyncMock() + s._cleanup_output_blobs = AsyncMock() + msg = SimpleNamespace(id="m1", pop_receipt="r1") + tp = _task_param() + _run( + s._handle_failed_no_retry( + msg, "p1", "boom", 0.5, task_param=tp, cleanup_scope="process" + ) + ) + s._cleanup_process_blobs.assert_awaited_once_with(tp) + s._cleanup_output_blobs.assert_not_called() + + def test_cleanup_swallows_blob_errors(self): + s = _service() + s.app_context = None + s._cleanup_output_blobs = AsyncMock(side_effect=RuntimeError("io")) + msg = SimpleNamespace(id="m1", pop_receipt="r1") + # Should not raise + _run(s._handle_failed_no_retry(msg, "p1", "boom", 0.5, task_param=_task_param())) + + def test_swallows_delete_error(self): + from azure.core.exceptions import AzureError + s = _service() + s.app_context = None + s.main_queue.delete_message.side_effect = AzureError("nope") + msg = SimpleNamespace(id="m1", pop_receipt="r1") + _run(s._handle_failed_no_retry(msg, "p1", "boom", 0.5)) + + +# ----------------------------------------------------------------------------- +# _cleanup_process_blobs_sync / _cleanup_output_blobs_sync +# ----------------------------------------------------------------------------- + + +class TestCleanupBlobsSync: + def test_no_account_skips(self): + s = _service(account="") + s._cleanup_process_blobs_sync(_task_param()) # no exception + + def test_no_blobs_returns_early(self): + s = _service() + with patch("services.queue_service.StorageBlobHelper") as MockHelper, \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()): + helper = MockHelper.return_value + helper.list_blobs.return_value = [] + s._cleanup_process_blobs_sync(_task_param()) + + def test_deletes_blobs_and_dir(self): + s = _service() + with patch("services.queue_service.StorageBlobHelper") as MockHelper, \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()), \ + patch("importlib.import_module") as mock_import: + helper = MockHelper.return_value + helper.list_blobs.return_value = [ + {"name": "p1/file.txt"}, + {"name": "p1/", "is_directory": True}, # directory entry skipped + {"name": "p1/converted"}, # placeholder skipped + {"name": ""}, + ] + helper.delete_multiple_blobs.return_value = {"p1/file.txt": True} + + dl_mod = MagicMock() + DataLakeServiceClient = MagicMock() + dl_mod.DataLakeServiceClient = DataLakeServiceClient + mock_import.return_value = dl_mod + + s._cleanup_process_blobs_sync(_task_param()) + helper.delete_multiple_blobs.assert_called_once() + DataLakeServiceClient.assert_called_once() + + def test_dir_delete_typeerror_recursive(self): + s = _service() + with patch("services.queue_service.StorageBlobHelper") as MockHelper, \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()), \ + patch("importlib.import_module") as mock_import: + helper = MockHelper.return_value + helper.list_blobs.return_value = [{"name": "p1/file.txt"}] + helper.delete_multiple_blobs.return_value = {"p1/file.txt": True} + + dir_client = MagicMock() + dir_client.delete_directory.side_effect = [ + TypeError("got multiple values for keyword argument 'recursive'"), + None, + ] + fs = MagicMock() + fs.get_directory_client.return_value = dir_client + dl_client = MagicMock() + dl_client.get_file_system_client.return_value = fs + DataLakeServiceClient = MagicMock(return_value=dl_client) + dl_mod = MagicMock(DataLakeServiceClient=DataLakeServiceClient) + mock_import.return_value = dl_mod + + s._cleanup_process_blobs_sync(_task_param()) + assert dir_client.delete_directory.call_count == 2 + + def test_top_level_exception_swallowed(self): + s = _service() + with patch("services.queue_service.StorageBlobHelper", side_effect=RuntimeError("bad")), \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()): + s._cleanup_process_blobs_sync(_task_param()) # no raise + + def test_output_cleanup_no_account(self): + s = _service(account="") + s._cleanup_output_blobs_sync(_task_param()) + + def test_output_cleanup_refuses_broad_prefix(self): + s = _service() + tp = _task_param() + # Force output_file_folder to broad path matching "" + tp = Analysis_TaskParam( + process_id="p1", + container_name="c", + source_file_folder="p1/source", + workspace_file_folder="p1/workspace", + output_file_folder="p1", # broad + ) + with patch("services.queue_service.StorageBlobHelper") as MockHelper, \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()): + s._cleanup_output_blobs_sync(tp) + MockHelper.assert_not_called() + + def test_output_cleanup_no_blobs(self): + s = _service() + with patch("services.queue_service.StorageBlobHelper") as MockHelper, \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()): + helper = MockHelper.return_value + helper.list_blobs.return_value = [ + {"name": ""}, + {"name": "p1/converted", "is_directory": True}, + {"name": "p1/converted"}, # equals dir name → skipped + ] + s._cleanup_output_blobs_sync(_task_param()) + helper.delete_multiple_blobs.assert_not_called() + + def test_output_cleanup_deletes(self): + s = _service() + with patch("services.queue_service.StorageBlobHelper") as MockHelper, \ + patch("services.queue_service.get_azure_credential", return_value=MagicMock()), \ + patch("importlib.import_module") as mock_import: + helper = MockHelper.return_value + helper.list_blobs.return_value = [{"name": "p1/converted/a.yaml"}] + helper.delete_multiple_blobs.return_value = {"p1/converted/a.yaml": True} + + DataLakeServiceClient = MagicMock() + dl_mod = MagicMock(DataLakeServiceClient=DataLakeServiceClient) + mock_import.return_value = dl_mod + + s._cleanup_output_blobs_sync(_task_param()) + helper.delete_multiple_blobs.assert_called_once() + + def test_async_wrappers_invoke_sync(self): + s = _service() + s._cleanup_process_blobs_sync = MagicMock() + s._cleanup_output_blobs_sync = MagicMock() + _run(s._cleanup_process_blobs(_task_param())) + _run(s._cleanup_output_blobs(_task_param())) + s._cleanup_process_blobs_sync.assert_called_once() + s._cleanup_output_blobs_sync.assert_called_once() + + +# ----------------------------------------------------------------------------- +# start_service (high-level smoke) +# ----------------------------------------------------------------------------- + + +class TestStartService: + def test_start_service_already_running_returns(self): + s = _service() + s.is_running = True + _run(s.start_service()) # returns early; no exception + + def test_start_service_runs_and_completes(self): + s = _service() + s.is_running = False + s._ensure_queues_exist = AsyncMock() + s._control_watcher_loop = AsyncMock() + + async def _no_op_worker(self_, worker_id): + return None + + # Patch worker loop to immediate return + s._worker_loop = lambda wid: asyncio.sleep(0) # type: ignore[assignment] + _run(s.start_service()) + assert s.is_running is False # finally clause + + +# ----------------------------------------------------------------------------- +# process_message wrapper +# ----------------------------------------------------------------------------- + + +class TestProcessMessageEntrypoint: + def test_calls_worker_loop_with_id_1(self): + s = _service() + s._worker_loop = AsyncMock() + _run(s.process_message()) + s._worker_loop.assert_awaited_once_with(worker_id=1) diff --git a/src/processor/src/tests/unit/steps/analysis/test_analysis_executor.py b/src/processor/src/tests/unit/steps/analysis/test_analysis_executor.py index c0f2691d..1c2d1298 100644 --- a/src/processor/src/tests/unit/steps/analysis/test_analysis_executor.py +++ b/src/processor/src/tests/unit/steps/analysis/test_analysis_executor.py @@ -62,9 +62,11 @@ async def execute(self, task_param=None): ), ) - # Avoid huge ASCII art in test output. + # Avoid huge ASCII art in test output (text2art may not be imported in this module). monkeypatch.setattr( - "steps.analysis.workflow.analysis_executor.text2art", lambda _s: "ART" + "steps.analysis.workflow.analysis_executor.text2art", + lambda _s: "ART", + raising=False, ) monkeypatch.setattr( "steps.analysis.workflow.analysis_executor.AnalysisOrchestrator", @@ -115,7 +117,9 @@ async def execute(self, task_param=None): ) monkeypatch.setattr( - "steps.analysis.workflow.analysis_executor.text2art", lambda _s: "ART" + "steps.analysis.workflow.analysis_executor.text2art", + lambda _s: "ART", + raising=False, ) monkeypatch.setattr( "steps.analysis.workflow.analysis_executor.AnalysisOrchestrator", diff --git a/src/processor/src/tests/unit/steps/documentation/test_documentation_orchestrator_execute.py b/src/processor/src/tests/unit/steps/documentation/test_documentation_orchestrator_execute.py new file mode 100644 index 00000000..d1cfd9f8 --- /dev/null +++ b/src/processor/src/tests/unit/steps/documentation/test_documentation_orchestrator_execute.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from steps.convert.models.step_output import Yaml_ExtendedBooleanResult +from steps.documentation.orchestration.documentation_orchestrator import ( + DocumentationOrchestrator, +) + + +def _run(coro): + return asyncio.run(coro) + + +def _make_orch(): + """Create an instance bypassing __init__ to keep tests isolated.""" + orch = DocumentationOrchestrator.__new__(DocumentationOrchestrator) + orch.initialized = True + orch.step_name = "Documentation" + orch.app_context = MagicMock() + orch.memory_store = None + orch.agents = {} + return orch + + +class TestPrepareAgentInfos: + def test_raises_when_mcp_tools_none(self): + orch = _make_orch() + orch.mcp_tools = None + with pytest.raises(ValueError, match=r"MCP tools must be prepared"): + _run(orch.prepare_agent_infos()) + + def test_builds_agents_with_registry_entries(self, tmp_path): + orch = _make_orch() + orch.mcp_tools = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + orch.task_param = Yaml_ExtendedBooleanResult(process_id="proc-X") + + registry_entries = [ + {"agent_name": "EKS Expert", "prompt_file": "prompt_eks_expert.txt"}, + {"agent_name": "GKE Expert", "prompt_file": "prompt_gke_expert.txt"}, + {"agent_name": "", "prompt_file": "skip.txt"}, # invalid agent_name + {"agent_name": "Bad", "prompt_file": ""}, # invalid prompt_file + {"agent_name": "Missing", "prompt_file": "nonexistent.txt"}, # path missing + {"agent_name": 42, "prompt_file": "x.txt"}, # wrong type + ] + + # Patch helpers on the instance + with patch.object( + DocumentationOrchestrator, + "load_platform_registry", + return_value=registry_entries, + ), patch.object( + DocumentationOrchestrator, + "read_prompt_file", + return_value="PROMPT BODY", + ), patch( + "steps.documentation.orchestration.documentation_orchestrator.Path" + ) as path_cls: + # Make Path(...).exists() True only for known prompt files. + existing = { + "prompt_eks_expert.txt", + "prompt_gke_expert.txt", + } + + class _FakePath: + def __init__(self, *parts): + self._parts = [str(p) for p in parts] + + def __truediv__(self, other): + return _FakePath(*self._parts, other) + + def resolve(self): + return self + + @property + def parents(self): + # Pretend parents[3] returns repo root that supports __truediv__ + return [self, self, self, _FakePath("repo_root")] + + @property + def parent(self): + return self + + def exists(self): + name = self._parts[-1] + return name in existing + + def __str__(self): + return "/".join(self._parts) + + def __fspath__(self): + return str(self) + + path_cls.side_effect = lambda *a, **k: _FakePath(*a) + + agent_infos = _run(orch.prepare_agent_infos()) + + names = [a.agent_name for a in agent_infos] + # Built-ins + assert "Technical Writer" in names + assert "AKS Expert" in names + assert "Azure Architect" in names + assert "Chief Architect" in names + # Registry experts that exist + assert "EKS Expert" in names + assert "GKE Expert" in names + # Coordinator + ResultGenerator are appended last + assert "Coordinator" in names + assert "ResultGenerator" in names + # Skipped invalid entries + assert "Bad" not in names + assert "Missing" not in names + assert 42 not in names + + +class TestForwardingHooks: + def test_on_agent_response_calls_super(self): + orch = _make_orch() + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response(MagicMock())) + assert super_call.await_count == 1 + + def test_on_agent_response_stream_calls_super(self): + orch = _make_orch() + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response_stream", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response_stream(MagicMock())) + assert super_call.await_count == 1 + + def test_on_orchestration_complete_logs(self, caplog): + orch = _make_orch() + result = MagicMock() + result.execution_time_seconds = 12.5 + with caplog.at_level("INFO"): + _run(orch.on_orchestration_complete(result)) + assert any( + "Documentation Orchestration complete" in r.message for r in caplog.records + ) diff --git a/src/processor/src/tests/unit/steps/test_migration_processor_extras.py b/src/processor/src/tests/unit/steps/test_migration_processor_extras.py new file mode 100644 index 00000000..b84023b0 --- /dev/null +++ b/src/processor/src/tests/unit/steps/test_migration_processor_extras.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from steps.migration_processor import ( + MigrationProcessor, + WorkflowExecutorFailedException, + WorkflowOutputMissingException, +) + + +def _run(coro): + return asyncio.run(coro) + + +class TestWorkflowOutputMissingException: + def test_message_includes_source(self): + exc = WorkflowOutputMissingException("analysis") + assert "analysis" in str(exc) + assert exc.source_executor_id == "analysis" + + def test_message_handles_none_source(self): + exc = WorkflowOutputMissingException(None) + assert "" in str(exc) + + +class TestWorkflowExecutorFailedException: + def test_details_to_dict_with_none(self): + assert WorkflowExecutorFailedException._details_to_dict(None) == {"details": None} + + def test_details_to_dict_with_dict(self): + d = {"executor_id": "x"} + assert WorkflowExecutorFailedException._details_to_dict(d) == d + + def test_details_to_dict_with_pydantic_v2_object(self): + obj = MagicMock() + obj.model_dump = MagicMock(return_value={"executor_id": "v2"}) + del obj.dict # ensure we don't fall through to vars() + result = WorkflowExecutorFailedException._details_to_dict(obj) + assert result == {"executor_id": "v2"} + + def test_details_to_dict_with_pydantic_v1_object(self): + class V1Like: + def dict(self): + return {"executor_id": "v1"} + + result = WorkflowExecutorFailedException._details_to_dict(V1Like()) + assert result == {"executor_id": "v1"} + + def test_details_to_dict_with_pydantic_v2_failure_falls_back(self): + class Bad: + def model_dump(self): + raise RuntimeError("nope") + + def dict(self): + return {"from": "dict"} + + result = WorkflowExecutorFailedException._details_to_dict(Bad()) + assert result == {"from": "dict"} + + def test_details_to_dict_falls_back_to_vars(self): + class Plain: + def __init__(self): + self.executor_id = "plain" + self.message = "ok" + + result = WorkflowExecutorFailedException._details_to_dict(Plain()) + assert result["executor_id"] == "plain" + + def test_details_to_dict_falls_back_to_repr_on_error(self): + class NoVars: + __slots__ = () + + result = WorkflowExecutorFailedException._details_to_dict(NoVars()) + assert "details" in result + + def test_format_message_with_traceback(self): + msg = WorkflowExecutorFailedException._format_message({ + "executor_id": "x", + "error_type": "ValueError", + "message": "bad", + "traceback": "Traceback...", + }) + assert "Traceback" in msg + assert "x" in msg + + def test_format_message_without_traceback(self): + msg = WorkflowExecutorFailedException._format_message({ + "executor_id": "x", + "error_type": "ValueError", + "message": "bad", + }) + assert "WorkflowErrorDetails" in msg + + def test_format_message_with_unknown_fields(self): + msg = WorkflowExecutorFailedException._format_message({}) + assert "" in msg + + def test_constructor_stores_details(self): + exc = WorkflowExecutorFailedException({"executor_id": "x", "message": "m"}) + assert exc.details == {"executor_id": "x", "message": "m"} + + +class TestCreateMemoryStore: + def _make_processor(self): + p = MigrationProcessor.__new__(MigrationProcessor) + p.app_context = MagicMock() + return p + + def test_disabled_when_env_off(self, monkeypatch): + monkeypatch.setenv("SHARED_MEMORY_ENABLED", "0") + p = self._make_processor() + result = _run(p._create_memory_store("proc-1")) + assert result is None + + def test_returns_none_when_no_service_config(self, monkeypatch): + monkeypatch.setenv("SHARED_MEMORY_ENABLED", "true") + p = self._make_processor() + helper = MagicMock() + helper.settings.get_service_config.return_value = None + p.app_context.get_service.return_value = helper + result = _run(p._create_memory_store("proc-1")) + assert result is None + + def test_returns_none_when_no_embedding_deployment(self, monkeypatch): + monkeypatch.setenv("SHARED_MEMORY_ENABLED", "true") + p = self._make_processor() + helper = MagicMock() + cfg = MagicMock(embedding_deployment_name="") + helper.settings.get_service_config.return_value = cfg + p.app_context.get_service.return_value = helper + result = _run(p._create_memory_store("proc-1")) + assert result is None + + def test_returns_none_on_exception(self, monkeypatch): + monkeypatch.setenv("SHARED_MEMORY_ENABLED", "true") + p = self._make_processor() + p.app_context.get_service.side_effect = RuntimeError("fail") + result = _run(p._create_memory_store("proc-1")) + assert result is None diff --git a/src/processor/src/tests/unit/steps/test_migration_processor_run.py b/src/processor/src/tests/unit/steps/test_migration_processor_run.py new file mode 100644 index 00000000..acd4ee40 --- /dev/null +++ b/src/processor/src/tests/unit/steps/test_migration_processor_run.py @@ -0,0 +1,344 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for MigrationProcessor.run() event-stream handling.""" + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agent_framework import ( + ExecutorCompletedEvent, + ExecutorFailedEvent, + ExecutorInvokedEvent, + WorkflowFailedEvent, + WorkflowOutputEvent, + WorkflowStartedEvent, +) +from agent_framework._workflows._events import WorkflowErrorDetails + +from steps.analysis.models.step_param import Analysis_TaskParam +from steps.migration_processor import ( + MigrationProcessor, + WorkflowExecutorFailedException, +) + + +def _run(coro): + return asyncio.run(coro) + + +def _make_input(process_id="p-1") -> Analysis_TaskParam: + return Analysis_TaskParam( + process_id=process_id, + container_name="processes", + source_file_folder=f"{process_id}/source", + output_file_folder=f"{process_id}/converted", + workspace_file_folder=f"{process_id}/workspace", + ) + + +def _make_processor(events: list, memory_store=None) -> MigrationProcessor: + """Create a MigrationProcessor whose workflow streams the given events.""" + proc = MigrationProcessor.__new__(MigrationProcessor) + proc.app_context = MagicMock() + + telemetry = MagicMock() + telemetry.init_process = AsyncMock() + telemetry.update_process_status = AsyncMock() + telemetry.transition_to_phase = AsyncMock() + telemetry.record_step_result = AsyncMock() + telemetry.record_final_outcome = AsyncMock() + telemetry.record_failure_outcome = AsyncMock() + + proc.app_context.get_service_async = AsyncMock(return_value=telemetry) + proc.app_context._instances = {} + proc.app_context.add_singleton = MagicMock() + + proc._telemetry = telemetry # expose for assertions + + async def _stream(_input): + for ev in events: + yield ev + + workflow = MagicMock() + workflow.run_stream = _stream + proc.workflow = workflow + + # Patch _create_memory_store as an AsyncMock returning the provided value. + proc._create_memory_store = AsyncMock(return_value=memory_store) + + return proc + + +class TestRunSuccessFlow: + def test_workflow_started_then_normal_output_returns_data(self): + data = SimpleNamespace(is_hard_terminated=False, value="ok") + events = [ + WorkflowStartedEvent(), + ExecutorInvokedEvent(executor_id="analysis", data=_make_input()), + ExecutorCompletedEvent(executor_id="analysis", data={"r": 1}), + ExecutorInvokedEvent(executor_id="design", data=_make_input()), + WorkflowOutputEvent(data=data, source_executor_id="design"), + ] + proc = _make_processor(events) + result = _run(proc.run(_make_input())) + assert result is data + proc._telemetry.init_process.assert_awaited() + proc._telemetry.update_process_status.assert_any_await( + process_id="p-1", status="completed" + ) + + def test_invoked_event_for_non_analysis_triggers_transition_phase(self): + data = SimpleNamespace(is_hard_terminated=False) + events = [ + WorkflowStartedEvent(), + # Documentation invocation should map to "Documentation" display + ExecutorInvokedEvent(executor_id="documentation", data=_make_input()), + WorkflowOutputEvent(data=data, source_executor_id="documentation"), + ] + proc = _make_processor(events) + _run(proc.run(_make_input())) + # transition_to_phase should be awaited with the documentation phase + calls = proc._telemetry.transition_to_phase.await_args_list + assert any( + c.kwargs.get("phase") == "Initializing Documentation" for c in calls + ) + + def test_invoked_event_unknown_executor_uses_capitalize(self): + data = SimpleNamespace(is_hard_terminated=False) + events = [ + WorkflowStartedEvent(), + ExecutorInvokedEvent(executor_id="custom", data=_make_input()), + WorkflowOutputEvent(data=data, source_executor_id="custom"), + ] + proc = _make_processor(events) + _run(proc.run(_make_input())) + calls = proc._telemetry.transition_to_phase.await_args_list + assert any( + c.kwargs.get("phase") == "Initializing Custom" for c in calls + ) + + +class TestRunHardTerminationFlow: + def test_hard_terminated_returns_data_and_records_failure(self): + data = SimpleNamespace( + is_hard_terminated=True, + reason="Blocked", + blocking_issues=["NEED_HUMAN_REVIEW"], + ) + events = [ + WorkflowStartedEvent(), + WorkflowOutputEvent(data=data, source_executor_id="analysis"), + ] + proc = _make_processor(events) + result = _run(proc.run(_make_input())) + assert result is data + proc._telemetry.record_failure_outcome.assert_awaited() + proc._telemetry.update_process_status.assert_any_await( + process_id="p-1", status="failed" + ) + + def test_hard_terminated_security_policy_collects_evidence(self): + data = SimpleNamespace( + is_hard_terminated=True, + reason="Blocked", + blocking_issues=["SECURITY_POLICY_VIOLATION"], + ) + events = [ + WorkflowStartedEvent(), + WorkflowOutputEvent(data=data, source_executor_id="analysis"), + ] + proc = _make_processor(events) + + with patch( + "utils.security_policy_evidence.collect_security_policy_evidence", + return_value={ + "findings": [ + { + "blob": "secret.yaml", + "secret_key_names": ["AWS_KEY"], + "signals": ["AKIA"], + } + ] + }, + ) as collector: + result = _run(proc.run(_make_input())) + + assert result is data + collector.assert_called_once() + # reason was enriched with redacted evidence block + assert "SECURITY POLICY EVIDENCE" in data.reason + + def test_hard_terminated_security_policy_handles_collector_error(self): + data = SimpleNamespace( + is_hard_terminated=True, + reason="Blocked", + blocking_issues=["SECURITY_POLICY_VIOLATION"], + ) + events = [ + WorkflowStartedEvent(), + WorkflowOutputEvent(data=data, source_executor_id="analysis"), + ] + proc = _make_processor(events) + with patch( + "utils.security_policy_evidence.collect_security_policy_evidence", + side_effect=RuntimeError("boom"), + ): + result = _run(proc.run(_make_input())) + assert result is data + # Ensure failure outcome still recorded (didn't crash on inner exception) + proc._telemetry.record_failure_outcome.assert_awaited() + + +class TestRunOutputMissingFlow: + def test_missing_output_raises_workflow_executor_failed_exception(self): + events = [ + WorkflowStartedEvent(), + WorkflowOutputEvent(data=None, source_executor_id="analysis"), + ] + proc = _make_processor(events) + with pytest.raises(WorkflowExecutorFailedException) as excinfo: + _run(proc.run(_make_input())) + assert "completed without producing output" in str(excinfo.value) + proc._telemetry.record_failure_outcome.assert_awaited() + + def test_missing_output_with_none_source_uses_unknown(self): + events = [ + WorkflowStartedEvent(), + WorkflowOutputEvent(data=None, source_executor_id=None), + ] + proc = _make_processor(events) + with pytest.raises(WorkflowExecutorFailedException): + _run(proc.run(_make_input())) + + +class TestRunWorkflowFailedFlow: + def test_workflow_failed_event_raises_with_details(self): + details = WorkflowErrorDetails( + error_type="ValueError", + message="invalid yaml", + traceback="Traceback ...", + executor_id="yaml", + ) + events = [ + WorkflowStartedEvent(), + ExecutorInvokedEvent(executor_id="yaml", data=_make_input()), + WorkflowFailedEvent(details=details), + ] + proc = _make_processor(events) + with pytest.raises(WorkflowExecutorFailedException) as excinfo: + _run(proc.run(_make_input())) + assert "yaml" in str(excinfo.value) + proc._telemetry.update_process_status.assert_any_await( + process_id="p-1", status="failed" + ) + + def test_workflow_failed_classifies_context_size_message(self): + details = WorkflowErrorDetails( + error_type="RuntimeError", + message="context window exceeded for token limit", + traceback="tb", + executor_id="design", + ) + events = [ + WorkflowStartedEvent(), + WorkflowFailedEvent(details=details), + ] + proc = _make_processor(events) + with pytest.raises(WorkflowExecutorFailedException): + _run(proc.run(_make_input())) + + def test_workflow_failed_classifies_context_error_type(self): + details = WorkflowErrorDetails( + error_type="ContextLengthExceededError", + message="too long", + traceback=None, + executor_id="analysis", + ) + events = [ + WorkflowStartedEvent(), + WorkflowFailedEvent(details=details), + ] + proc = _make_processor(events) + with pytest.raises(WorkflowExecutorFailedException): + _run(proc.run(_make_input())) + + def test_executor_failed_event_is_silently_ignored(self): + # ExecutorFailedEvent does not raise on its own; WorkflowFailedEvent does. + details = WorkflowErrorDetails( + error_type="X", message="m", traceback=None, executor_id="analysis" + ) + data = SimpleNamespace(is_hard_terminated=False) + events = [ + WorkflowStartedEvent(), + ExecutorFailedEvent(executor_id="analysis", details=details), + WorkflowOutputEvent(data=data, source_executor_id="analysis"), + ] + proc = _make_processor(events) + result = _run(proc.run(_make_input())) + assert result is data + + +class TestRunMemoryStoreLifecycle: + def test_memory_store_is_registered_and_closed(self): + data = SimpleNamespace(is_hard_terminated=False) + events = [ + WorkflowStartedEvent(), + ExecutorCompletedEvent(executor_id="analysis", data=None), + WorkflowOutputEvent(data=data, source_executor_id="analysis"), + ] + memory_store = MagicMock() + memory_store.get_count = AsyncMock(return_value=3) + memory_store.close = AsyncMock() + proc = _make_processor(events, memory_store=memory_store) + _run(proc.run(_make_input())) + # Singleton replaced + proc.app_context.add_singleton.assert_called_once() + memory_store.close.assert_awaited() + + def test_memory_store_close_error_is_swallowed(self): + data = SimpleNamespace(is_hard_terminated=False) + events = [ + WorkflowStartedEvent(), + WorkflowOutputEvent(data=data, source_executor_id="analysis"), + ] + memory_store = MagicMock() + memory_store.get_count = AsyncMock(side_effect=RuntimeError("x")) + memory_store.close = AsyncMock() + proc = _make_processor(events, memory_store=memory_store) + # Should not raise + result = _run(proc.run(_make_input())) + assert result is data + + def test_executor_completed_with_memory_store_logs_count(self): + data = SimpleNamespace(is_hard_terminated=False) + events = [ + WorkflowStartedEvent(), + ExecutorCompletedEvent( + executor_id="analysis", data={"some": "result"} + ), + WorkflowOutputEvent(data=data, source_executor_id="design"), + ] + memory_store = MagicMock() + memory_store.get_count = AsyncMock(return_value=7) + memory_store.close = AsyncMock() + proc = _make_processor(events, memory_store=memory_store) + _run(proc.run(_make_input())) + # get_count called at least once during ExecutorCompletedEvent and at finally + assert memory_store.get_count.await_count >= 2 + # record_step_result called for the executor completed event with data + proc._telemetry.record_step_result.assert_any_await( + process_id="p-1", + step_name="analysis", + step_result={"some": "result"}, + execution_time_seconds=pytest.approx( + proc._telemetry.record_step_result.await_args_list[0] + .kwargs["execution_time_seconds"], + rel=1, + ), + ) diff --git a/src/processor/src/tests/unit/steps/test_step_orchestrator_agent_infos.py b/src/processor/src/tests/unit/steps/test_step_orchestrator_agent_infos.py new file mode 100644 index 00000000..390efbbf --- /dev/null +++ b/src/processor/src/tests/unit/steps/test_step_orchestrator_agent_infos.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for prepare_agent_infos and forwarding hooks of analysis/design/yaml orchestrators.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from steps.analysis.models.step_param import Analysis_TaskParam +from steps.analysis.orchestration.analysis_orchestrator import AnalysisOrchestrator +from steps.convert.orchestration.yaml_convert_orchestrator import ( + YamlConvertOrchestrator, +) +from steps.design.orchestration.design_orchestrator import DesignOrchestrator + + +def _run(coro): + return asyncio.run(coro) + + +def _make(cls): + o = cls.__new__(cls) + o.initialized = True + o.app_context = MagicMock() + o.memory_store = None + o.agents = {} + return o + + +REGISTRY_ENTRIES = [ + {"agent_name": "EKS Expert", "prompt_file": "prompt_eks_expert.txt"}, + {"agent_name": "GKE Expert", "prompt_file": "prompt_gke_expert.txt"}, + # invalid entries — must be skipped + {"agent_name": "", "prompt_file": "x.txt"}, + {"agent_name": "X", "prompt_file": ""}, + {"agent_name": 1, "prompt_file": "y.txt"}, + {"agent_name": "Z", "prompt_file": 1}, +] + + +class TestAnalysisOrchestrator: + def test_prepare_agent_infos_raises_when_mcp_tools_missing(self): + orch = _make(AnalysisOrchestrator) + orch.mcp_tools = None + with pytest.raises(ValueError, match=r"MCP tools must be prepared"): + _run(orch.prepare_agent_infos()) + + def test_prepare_agent_infos_builds_full_set(self): + orch = _make(AnalysisOrchestrator) + orch.mcp_tools = [MagicMock(), MagicMock(), MagicMock()] + orch.task_param = Analysis_TaskParam( + process_id="p1", + container_name="processes", + source_file_folder="p1/source", + output_file_folder="p1/converted", + workspace_file_folder="p1/workspace", + ) + with patch.object( + AnalysisOrchestrator, + "load_platform_registry", + return_value=REGISTRY_ENTRIES, + ), patch.object( + AnalysisOrchestrator, + "read_prompt_file", + return_value="PROMPT", + ): + infos = _run(orch.prepare_agent_infos()) + names = [i.agent_name for i in infos] + for must in ( + "EKS Expert", + "GKE Expert", + "AKS Expert", + "Chief Architect", + "Coordinator", + "ResultGenerator", + ): + assert must in names + # Coordinator should be near the end with proper participant rendering. + assert names[-2] == "Coordinator" + assert names[-1] == "ResultGenerator" + + def test_on_agent_response_calls_super(self): + orch = _make(AnalysisOrchestrator) + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response(MagicMock())) + assert super_call.await_count == 1 + + def test_on_agent_response_stream_calls_super(self): + orch = _make(AnalysisOrchestrator) + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response_stream", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response_stream(MagicMock())) + assert super_call.await_count == 1 + + def test_on_orchestration_complete_runs(self, capsys): + orch = _make(AnalysisOrchestrator) + result = MagicMock() + result.execution_time_seconds = 4.2 + _run(orch.on_orchestration_complete(result)) + out = capsys.readouterr().out + assert "Analysis Orchestration complete." in out + + +class TestDesignOrchestrator: + def _task_param(self): + # design uses self.task_param.output.process_id + tp = MagicMock() + tp.output = MagicMock() + tp.output.process_id = "p1" + return tp + + def test_prepare_agent_infos_builds_full_set(self): + orch = _make(DesignOrchestrator) + orch.mcp_tools = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + orch.task_param = self._task_param() + with patch.object( + DesignOrchestrator, + "load_platform_registry", + return_value=REGISTRY_ENTRIES, + ), patch.object( + DesignOrchestrator, + "read_prompt_file", + return_value="PROMPT", + ): + infos = _run(orch.prepare_agent_infos()) + names = [i.agent_name for i in infos] + for must in ( + "EKS Expert", + "GKE Expert", + "AKS Expert", + "Chief Architect", + "Coordinator", + "ResultGenerator", + ): + assert must in names + + def test_on_agent_response_calls_super(self): + orch = _make(DesignOrchestrator) + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response(MagicMock())) + assert super_call.await_count == 1 + + def test_on_agent_response_stream_calls_super(self): + orch = _make(DesignOrchestrator) + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response_stream", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response_stream(MagicMock())) + assert super_call.await_count == 1 + + def test_on_orchestration_complete_runs(self, capsys): + orch = _make(DesignOrchestrator) + result = MagicMock() + result.execution_time_seconds = 1.5 + _run(orch.on_orchestration_complete(result)) + # design prints to stdout + out = capsys.readouterr().out + assert "Design" in out or "Elapsed" in out or out == "" + + +class TestYamlConvertOrchestrator: + def test_prepare_agent_infos_builds_full_set(self): + orch = _make(YamlConvertOrchestrator) + orch.mcp_tools = [MagicMock(), MagicMock(), MagicMock()] + # task_param for yaml convert uses self.task_param.process_id directly + tp = MagicMock() + tp.process_id = "p1" + orch.task_param = tp + with patch.object( + YamlConvertOrchestrator, + "load_platform_registry", + return_value=REGISTRY_ENTRIES, + ), patch.object( + YamlConvertOrchestrator, + "read_prompt_file", + return_value="PROMPT", + ): + infos = _run(orch.prepare_agent_infos()) + names = [i.agent_name for i in infos] + for must in ( + "YAML Expert", + "AKS Expert", + "Azure Architect", + "QA Engineer", + "Chief Architect", + "Coordinator", + "ResultGenerator", + ): + assert must in names + + def test_prepare_agent_infos_raises_when_mcp_tools_missing(self): + orch = _make(YamlConvertOrchestrator) + orch.mcp_tools = None + with pytest.raises(ValueError, match=r"MCP tools must be prepared"): + _run(orch.prepare_agent_infos()) + + def test_on_agent_response_calls_super(self): + orch = _make(YamlConvertOrchestrator) + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response(MagicMock())) + assert super_call.await_count == 1 + + def test_on_agent_response_stream_calls_super(self): + orch = _make(YamlConvertOrchestrator) + with patch( + "libs.base.orchestrator_base.OrchestratorBase.on_agent_response_stream", + new_callable=AsyncMock, + ) as super_call: + _run(orch.on_agent_response_stream(MagicMock())) + assert super_call.await_count == 1 + + def test_on_orchestration_complete_logs(self, caplog): + orch = _make(YamlConvertOrchestrator) + result = MagicMock() + result.execution_time_seconds = 2.0 + with caplog.at_level("INFO"): + _run(orch.on_orchestration_complete(result)) + assert any( + "Yaml Convert Orchestration complete" in r.message + for r in caplog.records + ) diff --git a/src/processor/src/tests/unit/test_main.py b/src/processor/src/tests/unit/test_main.py new file mode 100644 index 00000000..2f30b89c --- /dev/null +++ b/src/processor/src/tests/unit/test_main.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Coverage for src/main.py — the direct-execution entry point. + +Tests instantiate Application via __new__ to avoid loading .env / Azure +credentials, and verify register_services() and run() wiring. +""" + +from __future__ import annotations + +import asyncio +import logging +from unittest.mock import AsyncMock, MagicMock, patch + + +def _run(coro): + return asyncio.run(coro) + + +def _make_app(): + """Build Application without invoking ApplicationBase.__init__.""" + import main as main_mod + app = main_mod.Application.__new__(main_mod.Application) + app.application_context = MagicMock() + app.application_context.llm_settings = MagicMock() + return app + + +class TestApplicationInitializeAndRegister: + def test_initialize_logs_and_registers(self, caplog): + import main as main_mod + app = _make_app() + # Make the chain returned by add_singleton fluent + chain = MagicMock() + chain.add_singleton.return_value = chain + chain.add_async_singleton.return_value = chain + chain.add_transient.return_value = chain + # Pretend the framework helper service exists + helper = MagicMock() + app.application_context.add_singleton.return_value = chain + app.application_context.get_service.return_value = helper + + with caplog.at_level(logging.INFO): + app.initialize() + # register_services was called via initialize() + assert app.application_context.add_singleton.call_count >= 1 + helper.initialize.assert_called_once() + + def test_register_services_handles_cosmos_import_error(self): + app = _make_app() + chain = MagicMock() + chain.add_singleton.return_value = chain + chain.add_async_singleton.return_value = chain + chain.add_transient.return_value = chain + helper = MagicMock() + app.application_context.add_singleton.return_value = chain + app.application_context.get_service.return_value = helper + + # Simulate cosmos checkpoint module failing to import + with patch.dict( + "sys.modules", + {"libs.agent_framework.cosmos_checkpoint_storage": None}, + ): + app.register_services() + # Should not raise — the except path is exercised + + +class TestApplicationRun: + def test_run_calls_migration_processor(self): + app = _make_app() + proc = MagicMock() + proc.run = AsyncMock() + app.application_context.get_service.return_value = proc + _run(app.run()) + proc.run.assert_awaited_once() + + +class TestMainCoroutine: + def test_main_constructs_initializes_runs(self): + import main as main_mod + with patch.object(main_mod, "Application") as MockApp: + instance = MockApp.return_value + instance.run = AsyncMock() + instance.initialize = MagicMock() + _run(main_mod.main()) + instance.initialize.assert_called_once() + instance.run.assert_awaited_once() diff --git a/src/processor/src/tests/unit/test_main_service.py b/src/processor/src/tests/unit/test_main_service.py new file mode 100644 index 00000000..5c365564 --- /dev/null +++ b/src/processor/src/tests/unit/test_main_service.py @@ -0,0 +1,270 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from main_service import QueueMigrationServiceApp + + +def _run(coro): + return asyncio.run(coro) + + +def _make_app(queue_service=None, control_api=None, debug=False, ctx=None): + """Build QueueMigrationServiceApp via __new__ — avoids ApplicationBase init.""" + app = QueueMigrationServiceApp.__new__(QueueMigrationServiceApp) + app.queue_service = queue_service + app.control_api = control_api + app.config_override = {} + app.debug_mode = debug + app.application_context = ctx or MagicMock() + app.app_context = app.application_context + return app + + +class TestServiceStatus: + def test_is_service_running_false_without_queue(self): + app = _make_app(queue_service=None) + assert app.is_service_running() is False + + def test_is_service_running_uses_queue_state(self): + q = MagicMock() + q.is_running = True + app = _make_app(queue_service=q) + assert app.is_service_running() is True + + @patch("main_service.asyncio") + def test_get_service_status_not_initialized(self, mock_asyncio): + mock_loop = MagicMock() + mock_loop.time.return_value = 1000.0 + mock_asyncio.get_event_loop.return_value = mock_loop + app = _make_app(queue_service=None) + status = app.get_service_status() + assert status["status"] == "not_initialized" + assert status["running"] is False + assert status["docker_health"] == "unhealthy" + + def test_get_service_status_when_running(self): + q = MagicMock() + q.is_running = True + q.get_service_status.return_value = {"status": "running"} + app = _make_app(queue_service=q) + status = app.get_service_status() + assert status["docker_health"] == "healthy" + assert status["running"] is True + + +class TestBuildServiceConfig: + def test_uses_env_var_defaults(self, monkeypatch): + for k in [ + "VISIBILITY_TIMEOUT_MINUTES", + "POLL_INTERVAL_SECONDS", + "MESSAGE_TIMEOUT_MINUTES", + "CONCURRENT_WORKERS", + ]: + monkeypatch.delenv(k, raising=False) + ctx = MagicMock() + ctx.configuration.storage_queue_account = "acct" + ctx.configuration.storage_account_process_queue = "queue" + app = _make_app(ctx=ctx) + cfg = app._build_service_config() + assert cfg.storage_account_name == "acct" + assert cfg.queue_name == "queue" + assert cfg.visibility_timeout_minutes == 5 + assert cfg.concurrent_workers == 1 + + def test_applies_override(self, monkeypatch): + ctx = MagicMock() + ctx.configuration.storage_queue_account = "acct" + ctx.configuration.storage_account_process_queue = "queue" + app = _make_app(ctx=ctx) + cfg = app._build_service_config({"concurrent_workers": 7, "ignored_field": "x"}) + assert cfg.concurrent_workers == 7 + + def test_uses_env_var_overrides(self, monkeypatch): + monkeypatch.setenv("VISIBILITY_TIMEOUT_MINUTES", "12") + monkeypatch.setenv("POLL_INTERVAL_SECONDS", "3") + monkeypatch.setenv("MESSAGE_TIMEOUT_MINUTES", "30") + monkeypatch.setenv("CONCURRENT_WORKERS", "4") + ctx = MagicMock() + ctx.configuration.storage_queue_account = "acct" + ctx.configuration.storage_account_process_queue = "queue" + app = _make_app(ctx=ctx, debug=True) + cfg = app._build_service_config() + assert cfg.visibility_timeout_minutes == 12 + assert cfg.poll_interval_seconds == 3 + assert cfg.message_timeout_minutes == 30 + assert cfg.concurrent_workers == 4 + + +class TestBuildControlApi: + def test_disabled_by_env(self, monkeypatch): + monkeypatch.setenv("CONTROL_API_ENABLED", "0") + app = _make_app() + result = _run(app._build_control_api()) + assert result is None + + def test_builds_with_env_settings(self, monkeypatch): + monkeypatch.setenv("CONTROL_API_ENABLED", "1") + monkeypatch.setenv("CONTROL_API_TOKEN", "token-x") + monkeypatch.setenv("CONTROL_API_HOST", "127.0.0.1") + monkeypatch.setenv("CONTROL_API_PORT", "9090") + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=MagicMock()) + app = _make_app(ctx=ctx) + with patch("main_service.ControlApiServer") as srv_cls: + srv_cls.return_value = "server" + result = _run(app._build_control_api()) + assert result == "server" + + def test_invalid_port_falls_back(self, monkeypatch): + monkeypatch.setenv("CONTROL_API_ENABLED", "1") + monkeypatch.setenv("CONTROL_API_PORT", "not-a-number") + ctx = MagicMock() + ctx.get_service_async = AsyncMock(return_value=MagicMock()) + app = _make_app(ctx=ctx) + with patch("main_service.ControlApiServer") as srv_cls, patch( + "main_service.ControlApiConfig" + ) as cfg_cls: + srv_cls.return_value = "s" + _run(app._build_control_api()) + kwargs = cfg_cls.call_args.kwargs + assert kwargs["port"] == 8080 + + def test_falls_back_to_new_control_manager_on_di_error(self, monkeypatch): + monkeypatch.setenv("CONTROL_API_ENABLED", "1") + ctx = MagicMock() + ctx.get_service_async = AsyncMock(side_effect=RuntimeError("not registered")) + app = _make_app(ctx=ctx) + with patch("main_service.ProcessControlManager") as pcm_cls, patch( + "main_service.ControlApiServer" + ) as srv_cls: + srv_cls.return_value = "s" + _run(app._build_control_api()) + pcm_cls.assert_called_once_with(app.application_context) + + +class TestStartShutdown: + def test_start_service_raises_without_init(self): + app = _make_app(queue_service=None) + with pytest.raises(RuntimeError, match="not initialized"): + _run(app.start_service()) + + def test_start_service_runs_and_shuts_down(self): + q = MagicMock() + q.start_service = AsyncMock() + q.stop_service = AsyncMock() + q.is_running = True + app = _make_app(queue_service=q) + # Avoid building control API + app._build_control_api = AsyncMock(return_value=None) + _run(app.start_service()) + q.start_service.assert_awaited_once() + + def test_start_service_with_control_api_enabled(self): + q = MagicMock() + q.start_service = AsyncMock() + q.stop_service = AsyncMock() + api = MagicMock() + api.start = AsyncMock() + api.stop = AsyncMock() + app = _make_app(queue_service=q) + app._build_control_api = AsyncMock(return_value=api) + _run(app.start_service()) + api.start.assert_awaited_once() + api.stop.assert_awaited_once() + + def test_start_service_handles_keyboard_interrupt(self): + q = MagicMock() + q.start_service = AsyncMock(side_effect=KeyboardInterrupt()) + q.stop_service = AsyncMock() + app = _make_app(queue_service=q) + app._build_control_api = AsyncMock(return_value=None) + _run(app.start_service()) # swallowed, no raise + + def test_start_service_handles_generic_exception(self): + q = MagicMock() + q.start_service = AsyncMock(side_effect=RuntimeError("boom")) + q.stop_service = AsyncMock() + app = _make_app(queue_service=q) + app._build_control_api = AsyncMock(return_value=None) + _run(app.start_service()) # swallowed, no raise + + def test_start_service_handles_build_control_api_exception(self): + q = MagicMock() + q.start_service = AsyncMock() + q.stop_service = AsyncMock() + app = _make_app(queue_service=q) + app._build_control_api = AsyncMock(side_effect=RuntimeError("nope")) + _run(app.start_service()) # warned + control_api stays None + + def test_shutdown_service_clears_state(self): + q = MagicMock() + q.stop_service = AsyncMock() + api = MagicMock() + api.stop = AsyncMock() + app = _make_app(queue_service=q, control_api=api) + _run(app.shutdown_service()) + assert app.queue_service is None + assert app.control_api is None + + def test_force_stop_service(self): + q = MagicMock() + q.force_stop = AsyncMock() + app = _make_app(queue_service=q) + _run(app.force_stop_service()) + assert app.queue_service is None + + def test_force_stop_no_queue(self): + app = _make_app(queue_service=None) + _run(app.force_stop_service()) # no-op + + +class TestRunEntrypoint: + def test_run_calls_start_service(self): + app = _make_app() + app.start_service = AsyncMock() + _run(app.run()) + app.start_service.assert_awaited_once() + + +class TestRunQueueService: + def test_run_queue_service_runs_app(self): + from main_service import run_queue_service + + with patch("main_service.QueueMigrationServiceApp") as app_cls: + instance = MagicMock() + instance.run = AsyncMock() + instance.queue_service = MagicMock() + instance.queue_service.stop_service = AsyncMock() + app_cls.return_value = instance + _run(run_queue_service(debug_mode=True)) + instance.run.assert_awaited_once() + + def test_run_queue_service_handles_keyboard_interrupt(self): + from main_service import run_queue_service + + with patch("main_service.QueueMigrationServiceApp") as app_cls: + instance = MagicMock() + instance.run = AsyncMock(side_effect=KeyboardInterrupt()) + instance.queue_service = MagicMock() + instance.queue_service.stop_service = AsyncMock() + app_cls.return_value = instance + _run(run_queue_service()) # no raise + + def test_run_queue_service_reraises_other_exceptions(self): + from main_service import run_queue_service + + with patch("main_service.QueueMigrationServiceApp") as app_cls: + instance = MagicMock() + instance.run = AsyncMock(side_effect=ValueError("oops")) + instance.queue_service = MagicMock() + instance.queue_service.stop_service = AsyncMock() + app_cls.return_value = instance + with pytest.raises(ValueError): + _run(run_queue_service()) diff --git a/src/processor/src/tests/unit/utils/test_agent_telemetry.py b/src/processor/src/tests/unit/utils/test_agent_telemetry.py new file mode 100644 index 00000000..e2c97b68 --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_agent_telemetry.py @@ -0,0 +1,607 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import asyncio +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from utils import agent_telemetry as at +from utils.agent_telemetry import ( + AgentActivity, + AgentActivityHistory, + AgentActivityRepository, + ProcessStatus, + TelemetryManager, + _build_step_lap_times, + _byte_len_text, + _get_process_blob_container_name, + _get_storage_connection_string, + _get_utc_timestamp, + _parse_utc_timestamp, + _sha256_text, + get_orchestration_agents, +) + + +def _run(coro): + return asyncio.run(coro) + + +# ---------- pure helpers ---------- + + +class TestPureHelpers: + def test_sha256_text_deterministic(self): + assert _sha256_text("a") == _sha256_text("a") + assert _sha256_text("a") != _sha256_text("b") + + def test_byte_len_text_handles_unicode(self): + assert _byte_len_text("abc") == 3 + # 'é' is 2 bytes in UTF-8 + assert _byte_len_text("é") == 2 + + def test_get_orchestration_agents_returns_coordinator(self): + assert get_orchestration_agents() == {"Coordinator"} + + def test_get_process_blob_container_default(self, monkeypatch): + monkeypatch.delenv("PROCESS_BLOB_CONTAINER_NAME", raising=False) + assert _get_process_blob_container_name() == "processes" + + def test_get_process_blob_container_env_used(self, monkeypatch): + monkeypatch.setenv("PROCESS_BLOB_CONTAINER_NAME", "mybox") + assert _get_process_blob_container_name() == "mybox" + + def test_get_process_blob_container_blank_falls_back(self, monkeypatch): + monkeypatch.setenv("PROCESS_BLOB_CONTAINER_NAME", " ") + assert _get_process_blob_container_name() == "processes" + + def test_get_storage_connection_string_present(self, monkeypatch): + monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", " conn ") + assert _get_storage_connection_string() == "conn" + + def test_get_storage_connection_string_none(self, monkeypatch): + for key in ["AZURE_STORAGE_CONNECTION_STRING", "STORAGE_CONNECTION_STRING", "AzureWebJobsStorage"]: + monkeypatch.delenv(key, raising=False) + assert _get_storage_connection_string() is None + + def test_get_utc_timestamp_format(self): + ts = _get_utc_timestamp() + assert ts.endswith(" UTC") + assert _parse_utc_timestamp(ts) is not None + + def test_parse_utc_timestamp_invalid(self): + assert _parse_utc_timestamp("") is None + assert _parse_utc_timestamp("not a date") is None + assert _parse_utc_timestamp(None) is None # type: ignore[arg-type] + assert _parse_utc_timestamp(123) is None # type: ignore[arg-type] + + +class TestBuildStepLapTimes: + def test_empty_returns_no_items(self): + items, total = _build_step_lap_times(None) + assert items == [] + assert total == 0.0 + + def test_completed_step_uses_elapsed_seconds(self): + timings = { + "analysis": { + "started_at": "2024-01-01 00:00:00 UTC", + "ended_at": "2024-01-01 00:00:10 UTC", + "elapsed_seconds": 10.0, + } + } + items, total = _build_step_lap_times(timings) + assert len(items) == 1 + assert items[0]["status"] == "completed" + assert items[0]["elapsed_seconds"] == 10.0 + assert total == 10.0 + + def test_running_step_status_and_elapsed(self): + timings = { + "design": { + "started_at": _get_utc_timestamp(), + } + } + items, _ = _build_step_lap_times(timings) + assert items[0]["status"] == "running" + # elapsed should be roughly 0 (just started) + assert items[0]["elapsed_seconds"] is not None + + def test_unknown_status_when_no_timestamps(self): + timings = {"design": {"some": "data"}} + items, _ = _build_step_lap_times(timings) + assert items[0]["status"] == "unknown" + + def test_invalid_timing_skipped(self): + timings = {"": {"started_at": ""}, "x": "not-a-dict", "ok": {}} + items, _ = _build_step_lap_times(timings) + assert {it["step"] for it in items} == {"ok"} + + def test_preferred_order(self): + timings = { + "documentation": {"elapsed_seconds": 1}, + "analysis": {"elapsed_seconds": 2}, + "yaml": {"elapsed_seconds": 3}, + "design": {"elapsed_seconds": 4}, + "extra": {"elapsed_seconds": 5}, + } + items, total = _build_step_lap_times(timings) + order = [it["step"] for it in items] + assert order[:4] == ["analysis", "design", "yaml", "documentation"] + assert order[-1] == "extra" + assert total == 15.0 + + def test_derives_elapsed_from_timestamps_when_no_seconds(self): + timings = { + "analysis": { + "started_at": "2024-01-01 00:00:00 UTC", + "ended_at": "2024-01-01 00:00:30 UTC", + } + } + items, _ = _build_step_lap_times(timings) + assert items[0]["elapsed_seconds"] == 30.0 + + +# ---------- pydantic dataclass defaults ---------- + + +class TestPydanticModels: + def test_agent_activity_defaults(self): + a = AgentActivity(name="X") + assert a.current_action == "idle" + assert a.is_active is False + assert a.message_word_count == 0 + assert a.activity_history == [] + + def test_process_status_defaults(self): + ps = ProcessStatus(id="p") + assert ps.status == "running" + assert ps.agents == {} + assert ps.step_timings == {} + + def test_agent_activity_history_default_timestamp(self): + h = AgentActivityHistory(action="speaking") + assert h.action == "speaking" + assert h.message_preview == "" + + +# ---------- AgentActivityRepository init guard ---------- + + +class TestAgentActivityRepository: + def test_raises_without_configuration(self): + ctx = SimpleNamespace(configuration=None) + with pytest.raises(ValueError): + AgentActivityRepository(ctx) + + +# ---------- TelemetryManager constructor ---------- + + +class TestTelemetryManagerConstruction: + def test_dev_mode_when_no_app_context(self): + tm = TelemetryManager() + assert tm.repository is None + assert tm.app_context is None + + def test_dev_mode_for_localhost_url(self): + cfg = SimpleNamespace(cosmos_db_account_url="http://localhost:8081") + ctx = SimpleNamespace(configuration=cfg) + tm = TelemetryManager(ctx) + assert tm.repository is None + + def test_dev_mode_for_template_placeholder(self): + cfg = SimpleNamespace(cosmos_db_account_url="http://") + ctx = SimpleNamespace(configuration=cfg) + tm = TelemetryManager(ctx) + assert tm.repository is None + + def test_production_creates_repository(self): + cfg = SimpleNamespace( + cosmos_db_account_url="https://prod.documents.azure.com:443/", + cosmos_db_database_name="db", + cosmos_db_container_name="c", + ) + ctx = SimpleNamespace(configuration=cfg) + with patch.object(at, "AgentActivityRepository") as repo_cls: + repo_cls.return_value = "repo-instance" + tm = TelemetryManager(ctx) + assert tm.repository == "repo-instance" + + +# ---------- TelemetryManager methods (dev mode no-ops + with mocked repo) ---------- + + +def _tm_with_repo(): + tm = TelemetryManager() # dev mode, repository = None + tm.repository = MagicMock() + tm.repository.get_async = AsyncMock() + tm.repository.add_async = AsyncMock() + tm.repository.update_async = AsyncMock() + tm.repository.delete_async = AsyncMock() + return tm + + +class TestTelemetryManagerDevModeNoops: + def test_delete_process_noop(self): + tm = TelemetryManager() + _run(tm.delete_process("p")) + + def test_init_process_noop_in_dev(self): + tm = TelemetryManager() + _run(tm.init_process("p", "phase", "analysis")) + + def test_get_current_process_returns_none(self): + tm = TelemetryManager() + assert _run(tm.get_current_process("p")) is None + + def test_get_process_outcome_empty_string(self): + tm = TelemetryManager() + assert _run(tm.get_process_outcome("p")) == "" + + +class TestTelemetryManagerWithRepo: + def test_delete_process_calls_repository_when_record_exists(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = ProcessStatus(id="p") + _run(tm.delete_process("p")) + tm.repository.delete_async.assert_awaited_once_with("p") + + def test_delete_process_skips_when_missing(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = None + _run(tm.delete_process("p")) + tm.repository.delete_async.assert_not_called() + + def test_delete_process_swallows_errors(self): + tm = _tm_with_repo() + tm.repository.get_async.side_effect = RuntimeError("boom") + _run(tm.delete_process("p")) # must not raise + + def test_init_process_seeds_step_timing(self): + tm = _tm_with_repo() + _run(tm.init_process("p1", "phase", "analysis")) + added = tm.repository.add_async.await_args.args[0] + assert added.id == "p1" + assert "analysis" in added.step_timings + assert "started_at" in added.step_timings["analysis"] + + def test_init_process_recovers_when_add_conflict(self): + tm = _tm_with_repo() + tm.repository.add_async.side_effect = [Exception("conflict"), None] + _run(tm.init_process("p1", "phase", "analysis")) + tm.repository.delete_async.assert_awaited_once_with("p1") + assert tm.repository.add_async.await_count == 2 + + def test_update_agent_activity_creates_agent_and_sets_state(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p") + tm.repository.get_async.return_value = ps + _run( + tm.update_agent_activity( + "p", + "Azure_Expert", + "thinking", + message_preview="Analyzing", + full_message="Analyzing details", + ) + ) + assert "Azure_Expert" in ps.agents + a = ps.agents["Azure_Expert"] + assert a.is_active is True + assert a.is_currently_thinking is True + assert a.participation_status == "thinking" + assert a.last_full_message == "Analyzing details" + assert a.message_word_count == 2 + + def test_update_agent_activity_sets_speaking(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", agents={"X": AgentActivity(name="X", current_action="ready")}) + tm.repository.get_async.return_value = ps + _run(tm.update_agent_activity("p", "X", "speaking", message_preview="say")) + assert ps.agents["X"].participation_status == "speaking" + + def test_update_agent_activity_completes(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", agents={"X": AgentActivity(name="X")}) + tm.repository.get_async.return_value = ps + _run(tm.update_agent_activity("p", "X", "completed")) + assert ps.agents["X"].participation_status == "completed" + + def test_update_agent_activity_truncates_long_preview(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p") + tm.repository.get_async.return_value = ps + long_text = "x" * 500 + _run(tm.update_agent_activity("p", "X", "speaking", message_preview=long_text)) + assert ps.agents["X"].last_message_preview.endswith("...") + + def test_update_agent_activity_no_process_returns_silently(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = None + _run(tm.update_agent_activity("p", "X", "thinking")) + tm.repository.update_async.assert_not_called() + + def test_update_agent_activity_step_reset_increments_counter(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", step="design", agents={"X": AgentActivity(name="X")}) + tm.repository.get_async.return_value = ps + _run( + tm.update_agent_activity( + "p", "X", "thinking", reset_for_new_step=True + ) + ) + assert ps.agents["X"].step_reset_count == 1 + + def test_update_agent_activity_other_agents_become_inactive(self): + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + agents={ + "A": AgentActivity(name="A", is_active=True), + "B": AgentActivity(name="B", is_active=True), + }, + ) + tm.repository.get_async.return_value = ps + _run(tm.update_agent_activity("p", "A", "thinking")) + assert ps.agents["B"].is_active is False + + def test_update_process_status_running(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", status="running") + tm.repository.get_async.return_value = ps + _run(tm.update_process_status("p", "running")) + assert ps.status == "running" + + def test_update_process_status_terminal_marks_agents_idle(self): + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + agents={"A": AgentActivity(name="A", is_active=True, is_currently_speaking=True)}, + ) + tm.repository.get_async.return_value = ps + _run(tm.update_process_status("p", "completed")) + assert ps.status == "completed" + assert ps.phase == "end" + a = ps.agents["A"] + assert a.is_active is False and a.is_currently_speaking is False + assert a.participation_status == "standby" + + def test_set_agent_idle(self): + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + agents={"A": AgentActivity(name="A", is_active=True, current_action="speaking")}, + ) + tm.repository.get_async.return_value = ps + _run(tm.set_agent_idle("p", "A")) + assert ps.agents["A"].current_action == "idle" + assert ps.agents["A"].is_active is False + + def test_set_agent_idle_unknown_agent_noop(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", agents={}) + tm.repository.get_async.return_value = ps + _run(tm.set_agent_idle("p", "missing")) + tm.repository.update_async.assert_not_called() + + def test_update_phase_changes_phase(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", phase="analysis") + tm.repository.get_async.return_value = ps + _run(tm.update_phase("p", "design")) + assert ps.phase == "design" + + def test_update_phase_no_process_noop(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = None + _run(tm.update_phase("p", "design")) + tm.repository.update_async.assert_not_called() + + def test_transition_to_phase_seeds_timing(self): + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + phase="analysis", + agents={"A": AgentActivity(name="A")}, + ) + tm.repository.get_async.return_value = ps + _run(tm.transition_to_phase("p", "design phase", "design")) + assert ps.phase == "design phase" + assert ps.step == "design" + assert "design" in ps.step_timings + assert ps.agents["A"].participation_status == "ready" + + def test_complete_all_participant_agents(self): + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + agents={ + "Coordinator": AgentActivity(name="Coordinator", is_active=True), + "X": AgentActivity(name="X", is_active=True), + }, + ) + tm.repository.get_async.return_value = ps + _run(tm.complete_all_participant_agents("p")) + assert ps.agents["X"].current_action == "completed" + # Coordinator (orchestration) untouched + assert ps.agents["Coordinator"].is_active is True + + def test_record_failure(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", step="design") + tm.repository.get_async.return_value = ps + _run( + tm.record_failure( + "p", + "boom", + failure_details="bad", + failure_step="", + failure_agent="A", + stack_trace="trace", + ) + ) + assert ps.status == "failed" + assert ps.failure_reason == "boom" + assert ps.failure_step == "design" # used current step + assert ps.failure_agent == "A" + + def test_get_process_outcome_completed(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = ProcessStatus(id="p", status="completed") + assert _run(tm.get_process_outcome("p")) == "Process completed successfully" + + def test_get_process_outcome_failed(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = ProcessStatus( + id="p", status="failed", failure_reason="boom" + ) + assert "boom" in _run(tm.get_process_outcome("p")) + + def test_get_process_outcome_running(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = ProcessStatus(id="p", status="running") + assert _run(tm.get_process_outcome("p")) == "Process is still running" + + def test_get_process_outcome_other_status(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = ProcessStatus(id="p", status="qa_review") + assert _run(tm.get_process_outcome("p")) == "Status: qa_review" + + def test_get_process_outcome_no_process(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = None + assert _run(tm.get_process_outcome("p")) == "No active process" + + def test_track_tool_usage_updates_agent(self): + tm = _tm_with_repo() + ps = ProcessStatus(id="p", agents={}) + tm.repository.get_async.return_value = ps + _run( + tm.track_tool_usage( + "p", "A", "blob_ops", "list", tool_details="x" * 80, tool_result_preview="y" * 200 + ) + ) + assert ps.agents["A"].current_action == "using_tool" + assert ps.agents["A"].activity_history + assert ps.agents["A"].reasoning_steps + + def test_track_tool_usage_no_process(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = None + _run(tm.track_tool_usage("p", "A", "t", "a")) + tm.repository.update_async.assert_not_called() + + +# ---------- _get_ready_status_message ---------- + + +class TestReadyStatusMessage: + def _tm(self): + return TelemetryManager() + + @pytest.mark.parametrize( + "phase,expected_substr", + [ + ("analysis phase", "platform analysis"), + ("design phase", "Azure architecture"), + ("yaml conversion", "YAML conversion"), + ("documentation phase", "documentation"), + ("final", "expert discussion for migration step"), + ], + ) + def test_coordinator_messages(self, phase, expected_substr): + msg = self._tm()._get_ready_status_message( + "Coordinator", "step", phase, "ready" + ) + assert expected_substr in msg + + def test_analysis_system_agent(self): + msg = self._tm()._get_ready_status_message( + "system_observer", "step", "analysis", "ready" + ) + assert "source platform" in msg + + def test_analysis_other_agent(self): + msg = self._tm()._get_ready_status_message( + "Some_Expert", "Inspect", "analysis", "ready" + ) + assert "inspect" in msg + + def test_design_azure_agent(self): + msg = self._tm()._get_ready_status_message( + "Azure_Expert", "design", "design", "ready" + ) + assert "Azure recommendations" in msg + + def test_yaml_with_yaml_agent(self): + msg = self._tm()._get_ready_status_message( + "yaml_expert", "convert", "yaml", "ready" + ) + assert "YAML configurations" in msg + + def test_documentation_writer(self): + msg = self._tm()._get_ready_status_message( + "technical_writer_one", "write", "documentation", "ready" + ) + assert "comprehensive documentation" in msg + + def test_unknown_phase_default(self): + msg = self._tm()._get_ready_status_message( + "Foo", "", "weird-phase", "ready" + ) + assert "Ready for" in msg + + +# ---------- render_agent_status ---------- + + +class TestRenderAgentStatus: + def test_returns_not_found_when_no_process(self): + tm = _tm_with_repo() + tm.repository.get_async.return_value = None + result = _run(tm.render_agent_status("p")) + assert result["status"] == "not_found" + assert result["agents"] == [] + + def test_renders_speaking_agent(self): + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + phase="analysis", + agents={ + "Azure_Expert": AgentActivity( + name="Azure_Expert", + participation_status="speaking", + is_currently_speaking=True, + current_speaking_content="Talking", + message_word_count=2, + ) + }, + ) + tm.repository.get_async.return_value = ps + result = _run(tm.render_agent_status("p")) + assert result["agents"] == ['\u2713[] Azure Expert: Speaking - "Talking" (2 words)'] + + def test_renders_ready_agent_uses_context_message(self): + # render_agent_status doesn't return formatted lines as 'agents' but + # we just ensure no error and proper structure (returns dict). + tm = _tm_with_repo() + ps = ProcessStatus( + id="p", + phase="analysis", + agents={ + "Coordinator": AgentActivity( + name="Coordinator", + participation_status="ready", + last_message_preview="processing", + ) + }, + ) + tm.repository.get_async.return_value = ps + result = _run(tm.render_agent_status("p")) + assert isinstance(result, dict) + assert "agents" in result and len(result["agents"]) == 1 diff --git a/src/processor/src/tests/unit/utils/test_agent_telemetry_records.py b/src/processor/src/tests/unit/utils/test_agent_telemetry_records.py new file mode 100644 index 00000000..5a3200c2 --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_agent_telemetry_records.py @@ -0,0 +1,342 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Coverage for record_step_result / record_final_outcome / record_failure_outcome.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import utils.agent_telemetry as at +from utils.agent_telemetry import ProcessStatus, TelemetryManager + + +def _run(coro): + return asyncio.run(coro) + + +def _tm_with_repo(record: ProcessStatus | None = None) -> TelemetryManager: + tm = TelemetryManager() + tm.repository = MagicMock() + tm.repository.get_async = AsyncMock(return_value=record) + tm.repository.update_async = AsyncMock() + tm.repository.add_async = AsyncMock() + tm.repository.delete_async = AsyncMock() + return tm + + +class TestRecordStepResult: + def test_no_repo_returns_silently(self): + tm = TelemetryManager() + _run(tm.record_step_result("p", "analysis", {"x": 1})) + + def test_missing_record_warns_and_returns(self): + tm = _tm_with_repo(None) + _run(tm.record_step_result("p", "analysis", {"x": 1})) + tm.repository.update_async.assert_not_awaited() + + def test_records_and_normalizes_singleton_list(self): + rec = ProcessStatus(id="p") + rec.step_timings = {"analysis": {"started_at": "2025-01-01T00:00:00Z"}} + tm = _tm_with_repo(rec) + _run( + tm.record_step_result( + "p", "analysis", [{"foo": "bar"}], execution_time_seconds=2.5 + ) + ) + assert rec.step_results["analysis"]["result"] == {"foo": "bar"} + assert rec.step_timings["analysis"]["elapsed_seconds"] == 2.5 + tm.repository.update_async.assert_awaited() + + def test_uses_timestamp_elapsed_when_perf_too_small(self): + # candidate < 0.5 and timestamps show >5s -> use ts_elapsed + rec = ProcessStatus(id="p") + rec.step_timings = { + "design": {"started_at": "2025-01-01 00:00:00 UTC"} + } + tm = _tm_with_repo(rec) + with patch.object(at, "_get_utc_timestamp", return_value="2025-01-01 00:00:30 UTC"): + _run( + tm.record_step_result( + "p", "design", {"r": 1}, execution_time_seconds=0.001 + ) + ) + assert rec.step_timings["design"]["elapsed_seconds"] == 30.0 + + def test_only_timestamp_elapsed_when_no_perf(self): + rec = ProcessStatus(id="p") + rec.step_timings = { + "yaml": {"started_at": "2025-01-01 00:00:00 UTC"} + } + tm = _tm_with_repo(rec) + with patch.object(at, "_get_utc_timestamp", return_value="2025-01-01 00:00:10 UTC"): + _run(tm.record_step_result("p", "yaml", {"r": 1})) + assert rec.step_timings["yaml"]["elapsed_seconds"] == 10.0 + + def test_swallows_update_exception(self): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + tm.repository.update_async.side_effect = RuntimeError("boom") + # Should not raise + _run(tm.record_step_result("p", "analysis", {"r": 1})) + + +class TestRecordFinalOutcome: + def test_no_repo_returns(self): + tm = TelemetryManager() + _run(tm.record_final_outcome("p", {}, success=True)) + + def test_missing_record_warns_and_returns(self): + tm = _tm_with_repo(None) + _run(tm.record_final_outcome("p", {"x": 1})) + tm.repository.update_async.assert_not_awaited() + + def test_records_legacy_generated_files_collection(self): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + outcome = { + "GeneratedFilesCollection": { + "analysis": [{"file_name": "a.md", "file_type": "md", "content_summary": "s"}], + "yaml": [ + { + "source_file": "src.yaml", + "converted_file": "out.yaml", + "file_type": "deployment", + "conversion_status": "Success", + "accuracy_rating": "high", + } + ], + "total_files_generated": 2, + }, + "ProcessMetrics": { + "platform_detected": "EKS", + "conversion_accuracy": "high", + "documentation_completeness": "high", + "enterprise_readiness": "ready", + }, + } + _run(tm.record_final_outcome("p", outcome, success=True)) + assert rec.status == "completed" + assert len(rec.generated_files) == 2 + assert rec.conversion_metrics["platform_detected"] == "EKS" + assert rec.conversion_metrics["total_files_generated"] == 2 + # finalized_generated includes one artifact for migration_report + assert ( + rec.final_outcome["finalized_generated"]["artifacts"][0]["type"] + == "migration_report" + ) + + def test_records_termination_output_and_conversion_report(self): + rec = ProcessStatus(id="p") + # The yaml step result has a conversion_report_file pointer. + rec.step_results = { + "yaml": { + "result": { + "termination_output": {"conversion_report_file": "p/output/conv.md"} + } + } + } + tm = _tm_with_repo(rec) + outcome = { + "termination_output": { + "generated_files": { + "documentation": [ + {"file_name": "d.md", "file_type": "md", "content_summary": ""} + ], + "total_files_generated": 1, + }, + "process_metrics": {"platform_detected": "GKE"}, + } + } + _run(tm.record_final_outcome("p", outcome, success=True)) + artifact_types = [ + a["type"] for a in rec.final_outcome["finalized_generated"]["artifacts"] + ] + assert "conversion_report" in artifact_types + + def test_failure_path_sets_failed_status(self): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + _run(tm.record_final_outcome("p", {}, success=False)) + assert rec.status == "failed" + + def test_extraction_exception_is_swallowed(self): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + # Pass a truly weird outcome_data shape that triggers the inner exception via + # a non-iterable for the collection slot. + outcome = {"GeneratedFilesCollection": "not-a-dict"} + _run(tm.record_final_outcome("p", outcome, success=True)) + # Did not raise; record still updated + assert rec.status == "completed" + + +class TestRecordFailureOutcome: + def test_no_repo_returns(self): + tm = TelemetryManager() + _run( + tm.record_failure_outcome( + "p", error_message="x", failed_step="analysis" + ) + ) + + def test_missing_record_warns_and_returns(self): + tm = _tm_with_repo(None) + _run( + tm.record_failure_outcome( + "p", error_message="x", failed_step="analysis" + ) + ) + tm.repository.update_async.assert_not_awaited() + + def test_records_failure_with_traceback_inline(self): + rec = ProcessStatus(id="p") + rec.step_timings = {"analysis": {"started_at": "2025-01-01T00:00:00Z"}} + tm = _tm_with_repo(rec) + _run( + tm.record_failure_outcome( + "p", + error_message="oops", + failed_step="analysis", + failure_details={"traceback": "short tb"}, + execution_time_seconds=3.0, + ) + ) + assert rec.status == "failed" + assert rec.failure_reason == "oops" + assert rec.failure_step == "analysis" + assert rec.step_timings["analysis"]["elapsed_seconds"] == 3.0 + + def test_records_failure_offloads_large_traceback(self, monkeypatch): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + big_tb = "x" * 1000 + monkeypatch.setenv("TELEMETRY_TRACEBACK_INLINE_MAX_BYTES", "100") + + async def _fake_upload(**_kwargs): + return {"blob": "debug/traceback.txt"} + + with patch.object(at, "_upload_text_to_process_blob", new=_fake_upload): + _run( + tm.record_failure_outcome( + "p", + error_message="big", + failed_step="design", + failure_details={"traceback": big_tb}, + ) + ) + details = rec.final_outcome["failure_details"] + assert "traceback" not in details + assert details["traceback_artifact"] == {"blob": "debug/traceback.txt"} + + def test_swallows_offload_exception(self, monkeypatch): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + big_tb = "x" * 1000 + monkeypatch.setenv("TELEMETRY_TRACEBACK_INLINE_MAX_BYTES", "100") + + async def _fail(**_kwargs): + raise RuntimeError("blob fail") + + with patch.object(at, "_upload_text_to_process_blob", new=_fail): + # Should not raise + _run( + tm.record_failure_outcome( + "p", + error_message="big", + failed_step="design", + failure_details={"traceback": big_tb}, + ) + ) + + def test_swallows_update_exception(self): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + tm.repository.update_async.side_effect = RuntimeError("boom") + # Should not raise + _run( + tm.record_failure_outcome( + "p", error_message="x", failed_step="analysis" + ) + ) + + +class TestGetFinalResultsSummary: + def test_no_repo_returns_empty(self): + tm = TelemetryManager() + assert _run(tm.get_final_results_summary("p")) == {} + + def test_missing_returns_error(self): + tm = _tm_with_repo(None) + assert _run(tm.get_final_results_summary("p")) == {"error": "No active process"} + + def test_returns_summary(self): + rec = ProcessStatus(id="p") + rec.status = "completed" + rec.step_results = {"analysis": {"result": {}}} + rec.generated_files = [{"file_name": "x"}] + rec.conversion_metrics = {"k": "v"} + tm = _tm_with_repo(rec) + out = _run(tm.get_final_results_summary("p")) + assert out["status"] == "completed" + assert out["generated_files_count"] == 1 + assert "completed_steps" in out + + +class TestRecordUiData: + def test_no_repo_returns(self): + tm = TelemetryManager() + _run(tm.record_ui_data("p", {"x": 1})) + + def test_missing_record_warns_and_returns(self): + tm = _tm_with_repo(None) + _run(tm.record_ui_data("p", {"x": 1})) + tm.repository.update_async.assert_not_awaited() + + def test_records_ui_data(self): + rec = ProcessStatus(id="p") + tm = _tm_with_repo(rec) + ui_data = { + "file_manifest": { + "converted_files": [{"a": 1}], + "failed_files": [], + "report_files": [{"b": 2}], + }, + "dashboard_metrics": {"completion_percentage": 99.0}, + } + _run(tm.record_ui_data("p", ui_data)) + assert rec.ui_telemetry_data["file_manifest"]["converted_files"] == [{"a": 1}] + tm.repository.update_async.assert_awaited() + + def test_swallows_exception(self): + tm = _tm_with_repo(ProcessStatus(id="p")) + tm.repository.update_async.side_effect = RuntimeError("x") + # Should not raise + _run(tm.record_ui_data("p", {"file_manifest": {}, "dashboard_metrics": {}})) + + +class TestGetUiTelemetryData: + def test_no_repo_returns_empty(self): + tm = TelemetryManager() + assert _run(tm.get_ui_telemetry_data("p")) == {} + + def test_missing_returns_empty(self): + tm = _tm_with_repo(None) + assert _run(tm.get_ui_telemetry_data("p")) == {} + + def test_returns_data_when_present(self): + rec = ProcessStatus(id="p") + rec.ui_telemetry_data = {"a": 1} # type: ignore[attr-defined] + tm = _tm_with_repo(rec) + assert _run(tm.get_ui_telemetry_data("p")) == {"a": 1} + + def test_returns_fallback_when_completed_and_empty(self): + rec = ProcessStatus(id="p") + rec.status = "completed" + rec.generated_files = [{"x": 1}, {"y": 2}] + tm = _tm_with_repo(rec) + out = _run(tm.get_ui_telemetry_data("p")) + assert out["dashboard_metrics"]["files_processed"] == 2 diff --git a/src/processor/src/tests/unit/utils/test_console_util.py b/src/processor/src/tests/unit/utils/test_console_util.py new file mode 100644 index 00000000..6dd0f0ff --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_console_util.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from utils.console_util import ConsoleColors, format_agent_message, get_role_style + + +class TestGetRoleStyle: + @pytest.mark.parametrize( + "agent_name", + [ + "Chief Architect", + "GKE Expert", + "EKS Expert", + "Azure Expert", + "YAML Expert", + "OpenShift Expert", + "AKS Expert", + "Rancher Expert", + "Tanzu Expert", + "OnPremK8s Expert", + "Technical Writer", + "QA Engineer", + ], + ) + def test_known_agents_return_styled_label_and_color(self, agent_name): + label, color = get_role_style(agent_name) + assert ConsoleColors.RESET in label + assert color.startswith("\033[") + + def test_unknown_agent_returns_coordinator_default(self): + label, color = get_role_style("Some Unknown Role") + assert "COORDINATOR" in label + assert color == ConsoleColors.WHITE + + def test_none_name_returns_coordinator_default(self): + label, color = get_role_style(None) + assert "COORDINATOR" in label + assert color == ConsoleColors.WHITE + + +class TestFormatAgentMessage: + def test_includes_role_and_content_and_resets(self): + out = format_agent_message("Azure Expert", "hello", timestamp="") + assert "AZURE EXPERT" in out + assert "hello" in out + assert ConsoleColors.RESET in out + + def test_appends_timestamp_when_provided(self): + out = format_agent_message("Azure Expert", "hi", timestamp="12:00:00") + assert "(12:00:00)" in out + + def test_truncates_long_content_with_ellipsis(self): + content = "x" * 500 + out = format_agent_message("Azure Expert", content, "", max_content_length=10) + assert "xxxxxxxxx…" in out + + def test_max_content_length_one_returns_single_ellipsis(self): + out = format_agent_message("Azure Expert", "hello world", "", max_content_length=1) + assert "…" in out + + def test_none_content_renders_as_empty(self): + out = format_agent_message("Azure Expert", None, "") + assert "AZURE EXPERT" in out + + def test_non_string_content_is_stringified(self): + out = format_agent_message("Azure Expert", 12345, "") + assert "12345" in out + + def test_max_content_length_disabled_when_zero(self): + content = "x" * 50 + out = format_agent_message("Azure Expert", content, "", max_content_length=0) + assert content in out diff --git a/src/processor/src/tests/unit/utils/test_credential_util.py b/src/processor/src/tests/unit/utils/test_credential_util.py new file mode 100644 index 00000000..c12b97db --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_credential_util.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from utils import credential_util + + +@pytest.fixture(autouse=True) +def _clear_azure_env(monkeypatch): + """Ensure each test starts with a clean env.""" + for key in [ + "WEBSITE_SITE_NAME", + "AZURE_CLIENT_ID", + "MSI_ENDPOINT", + "IDENTITY_ENDPOINT", + "KUBERNETES_SERVICE_HOST", + "CONTAINER_REGISTRY_LOGIN", + ]: + monkeypatch.delenv(key, raising=False) + + +class TestGetAzureCredentialSync: + def test_azure_environment_with_user_assigned_returns_managed_identity( + self, monkeypatch + ): + monkeypatch.setenv("AZURE_CLIENT_ID", "client-123") + with patch.object(credential_util, "ManagedIdentityCredential") as mic: + mic.return_value = MagicMock(name="managed") + cred = credential_util.get_azure_credential() + mic.assert_called_once_with(client_id="client-123") + assert cred is mic.return_value + + def test_azure_environment_without_client_id_uses_system_assigned( + self, monkeypatch + ): + monkeypatch.setenv("WEBSITE_SITE_NAME", "site") + with patch.object(credential_util, "ManagedIdentityCredential") as mic: + mic.return_value = MagicMock(name="managed") + credential_util.get_azure_credential() + mic.assert_called_once_with() + + def test_local_returns_first_successful_cli_credential(self): + with patch.object(credential_util, "AzureCliCredential") as cli, patch.object( + credential_util, "AzureDeveloperCliCredential" + ) as azd: + cli.return_value = MagicMock(name="cli") + azd.return_value = MagicMock(name="azd") + cred = credential_util.get_azure_credential() + assert cred is cli.return_value + + def test_local_falls_back_to_default_when_all_cli_fail(self): + with patch.object( + credential_util, "AzureCliCredential", side_effect=RuntimeError("nope") + ), patch.object( + credential_util, + "AzureDeveloperCliCredential", + side_effect=RuntimeError("nope"), + ), patch.object(credential_util, "DefaultAzureCredential") as default: + default.return_value = MagicMock(name="default") + cred = credential_util.get_azure_credential() + assert cred is default.return_value + + +class TestGetAsyncAzureCredential: + def test_async_azure_environment_user_assigned(self, monkeypatch): + monkeypatch.setenv("AZURE_CLIENT_ID", "client-xyz") + with patch.object(credential_util, "AsyncManagedIdentityCredential") as mic: + mic.return_value = MagicMock(name="async-managed") + cred = credential_util.get_async_azure_credential() + mic.assert_called_once_with(client_id="client-xyz") + assert cred is mic.return_value + + def test_async_azure_environment_system_assigned(self, monkeypatch): + monkeypatch.setenv("MSI_ENDPOINT", "http://msi/") + with patch.object(credential_util, "AsyncManagedIdentityCredential") as mic: + mic.return_value = MagicMock(name="async-managed") + credential_util.get_async_azure_credential() + mic.assert_called_once_with() + + def test_async_local_uses_first_successful(self): + with patch.object(credential_util, "AsyncAzureCliCredential") as cli, patch.object( + credential_util, "AsyncAzureDeveloperCliCredential" + ) as azd: + cli.return_value = MagicMock(name="async-cli") + azd.return_value = MagicMock(name="async-azd") + cred = credential_util.get_async_azure_credential() + assert cred is cli.return_value + + def test_async_local_falls_back_to_default(self): + with patch.object( + credential_util, + "AsyncAzureCliCredential", + side_effect=RuntimeError("nope"), + ), patch.object( + credential_util, + "AsyncAzureDeveloperCliCredential", + side_effect=RuntimeError("nope"), + ), patch.object(credential_util, "AsyncDefaultAzureCredential") as default: + default.return_value = MagicMock(name="async-default") + cred = credential_util.get_async_azure_credential() + assert cred is default.return_value + + +class TestBearerTokenProviders: + def test_get_bearer_token_provider_uses_credential(self): + with patch.object( + credential_util, "get_azure_credential", return_value=MagicMock() + ) as cred_fn, patch.object( + credential_util, "identity_get_bearer_token_provider" + ) as token_fn: + token_fn.return_value = "token-callable" + res = credential_util.get_bearer_token_provider() + cred_fn.assert_called_once() + token_fn.assert_called_once() + assert res == "token-callable" + + def test_async_get_bearer_token_provider_uses_credential(self): + import asyncio + + with patch.object( + credential_util, "get_async_azure_credential", new=AsyncMock(return_value=MagicMock()) + ) as cred_fn, patch.object( + credential_util, "identity_get_async_bearer_token_provider" + ) as token_fn: + token_fn.return_value = "async-token-callable" + res = asyncio.run(credential_util.get_async_bearer_token_provider()) + cred_fn.assert_called_once() + token_fn.assert_called_once() + assert res == "async-token-callable" + + +class TestValidateAzureAuthentication: + def test_local_environment_recommendations(self): + cred = MagicMock() + cred.__class__.__name__ = "AzureCliCredential" + with patch.object(credential_util, "get_azure_credential", return_value=cred): + info = credential_util.validate_azure_authentication() + assert info["environment"] == "local_development" + assert info["status"] == "configured" + + def test_azure_hosted_with_user_assigned(self, monkeypatch): + monkeypatch.setenv("AZURE_CLIENT_ID", "uami-id") + with patch.object( + credential_util, "get_azure_credential", return_value=MagicMock() + ): + info = credential_util.validate_azure_authentication() + assert info["environment"] == "azure_hosted" + assert info["credential_type"] == "managed_identity" + + def test_azure_hosted_system_assigned(self, monkeypatch): + monkeypatch.setenv("WEBSITE_SITE_NAME", "site") + with patch.object( + credential_util, "get_azure_credential", return_value=MagicMock() + ): + info = credential_util.validate_azure_authentication() + assert info["environment"] == "azure_hosted" + + def test_credential_failure_reports_error(self): + with patch.object( + credential_util, + "get_azure_credential", + side_effect=RuntimeError("creds bad"), + ): + info = credential_util.validate_azure_authentication() + assert info["status"] == "error" + assert "creds bad" in info["error"] diff --git a/src/processor/src/tests/unit/utils/test_logging_utils.py b/src/processor/src/tests/unit/utils/test_logging_utils.py new file mode 100644 index 00000000..2ba4fc5e --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_logging_utils.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +from unittest.mock import MagicMock + +import pytest + +from utils.logging_utils import ( + LogMessages, + _format_specific_error_details, + configure_application_logging, + create_migration_logger, + get_error_details, + log_error_with_context, + safe_log, +) + + +class TestConfigureApplicationLogging: + def test_production_mode_sets_warning_levels(self): + configure_application_logging(debug_mode=False) + assert logging.getLogger("httpx").level == logging.WARNING + assert logging.getLogger("azure.cosmos").level == logging.WARNING + + def test_debug_mode_keeps_http_warning_but_info_for_others(self): + configure_application_logging(debug_mode=True) + assert logging.getLogger("httpx").level == logging.WARNING + assert logging.getLogger("asyncio").level == logging.INFO + + +class TestCreateMigrationLogger: + def test_creates_logger_with_handler(self): + logger = create_migration_logger("test.migration.unique1") + assert logger.handlers + assert logger.level == logging.INFO + + def test_does_not_duplicate_handlers_on_repeat_calls(self): + name = "test.migration.unique2" + logger1 = create_migration_logger(name) + handler_count = len(logger1.handlers) + logger2 = create_migration_logger(name) + assert len(logger2.handlers) == handler_count + + def test_respects_level_argument(self): + logger = create_migration_logger("test.migration.debug", level=logging.DEBUG) + assert logger.level == logging.DEBUG + + +class TestSafeLog: + def test_substitutes_variables_in_template(self): + logger = MagicMock(spec=logging.Logger) + safe_log(logger, "info", "value={value}", value=42) + logger.info.assert_called_once_with("value=42") + + def test_complex_objects_converted_to_strings(self): + logger = MagicMock(spec=logging.Logger) + safe_log(logger, "warning", "data={d}", d={"a": 1}) + called = logger.warning.call_args[0][0] + assert "{'a': 1}" in called + + def test_exception_value_safely_stringified(self): + logger = MagicMock(spec=logging.Logger) + safe_log(logger, "error", "err={e}", e=ValueError("boom")) + called = logger.error.call_args[0][0] + assert "boom" in called + + def test_format_failure_raises_runtime_error(self): + logger = MagicMock(spec=logging.Logger) + with pytest.raises(RuntimeError): + safe_log(logger, "info", "missing {missing_key}", other=1) + assert logger.error.called + + +class TestGetErrorDetails: + def test_basic_exception_details(self): + try: + raise ValueError("boom") + except ValueError as e: + details = get_error_details(e) + assert details["exception_type"] == "ValueError" + assert details["exception_message"] == "boom" + + def test_chained_exception_details(self): + try: + try: + raise ValueError("orig") + except ValueError as inner: + raise RuntimeError("wrap") from inner + except RuntimeError as e: + details = get_error_details(e) + assert details["exception_cause"] is not None + assert "orig" in details["exception_cause"] + + def test_http_response_error_includes_http_fields(self): + from azure.core.exceptions import HttpResponseError + + err = HttpResponseError(message="bad") + err.status_code = 503 + err.reason = "Service Unavailable" + details = get_error_details(err) + assert details["http_status_code"] == 503 + assert details["http_reason"] == "Service Unavailable" + + +class TestFormatSpecificErrorDetails: + def test_http_details_formatted(self): + out = _format_specific_error_details( + {"http_status_code": 500, "http_reason": "Server Error"} + ) + assert "HTTP Status Code: 500" in out + assert "HTTP Reason: Server Error" in out + + def test_service_error_code_formatted(self): + out = _format_specific_error_details({"service_error_code": "SVC42"}) + assert "Service Error Code: SVC42" in out + + def test_azure_chat_completion_error_with_model_and_endpoint(self): + out = _format_specific_error_details( + { + "azure_chat_completion_error": True, + "model_deployment": "gpt-4o", + "endpoint": "https://example.openai.azure.com", + } + ) + assert "Azure ChatCompletion Error Detected" in out + assert "gpt-4o" in out + assert "openai.azure.com" in out + + def test_empty_dict_returns_empty_string(self): + assert _format_specific_error_details({}) == "" + + +class TestLogErrorWithContext: + def test_logs_error_and_returns_details(self): + logger = MagicMock(spec=logging.Logger) + try: + raise ValueError("ctx-err") + except ValueError as e: + details = log_error_with_context(logger, e, context="MyOp", k="v") + + assert details["exception_type"] == "ValueError" + assert details["additional_context"] == {"k": "v"} + assert logger.error.called + + +class TestLogMessages: + def test_format_templates_have_placeholders(self): + formatted = LogMessages.ERROR_STEP_FAILED.format(step="analysis", error="x") + assert "analysis" in formatted and "x" in formatted + + formatted = LogMessages.SUCCESS_COMPLETED.format(operation="op", details="d") + assert "op" in formatted and "d" in formatted + + formatted = LogMessages.INFO_PROCESSING.format(item="thing") + assert "thing" in formatted diff --git a/src/processor/src/tests/unit/utils/test_prompt_util.py b/src/processor/src/tests/unit/utils/test_prompt_util.py new file mode 100644 index 00000000..63f69b5b --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_prompt_util.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from utils.prompt_util import TemplateUtility + + +class TestTemplateUtility: + def test_render_substitutes_variables(self): + out = TemplateUtility.render("Hello {{ name }}!", name="Ada") + assert out == "Hello Ada!" + + def test_render_with_no_placeholders_returns_original(self): + out = TemplateUtility.render("plain text") + assert out == "plain text" + + def test_render_supports_multiple_variables(self): + out = TemplateUtility.render("{{ a }} + {{ b }} = {{ c }}", a=1, b=2, c=3) + assert out == "1 + 2 = 3" + + def test_render_from_file_reads_and_renders(self, tmp_path): + f = tmp_path / "template.txt" + f.write_text("Hi {{ user }}", encoding="utf-8") + out = TemplateUtility.render_from_file(str(f), user="bob") + assert out == "Hi bob" + + def test_render_supports_loops(self): + tpl = "{% for x in items %}{{ x }},{% endfor %}" + out = TemplateUtility.render(tpl, items=[1, 2, 3]) + assert out == "1,2,3," diff --git a/src/processor/src/tests/unit/utils/test_security_policy_evidence.py b/src/processor/src/tests/unit/utils/test_security_policy_evidence.py new file mode 100644 index 00000000..1205905f --- /dev/null +++ b/src/processor/src/tests/unit/utils/test_security_policy_evidence.py @@ -0,0 +1,224 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from utils import security_policy_evidence as spe + + +def _blob(name, size=10): + return SimpleNamespace(name=name, size=size) + + +def _container_with_blobs(name_to_text: dict, blobs): + """Build a mock container client whose list_blobs returns `blobs` and whose + get_blob_client(name) returns a blob client serving name_to_text[name].""" + cc = MagicMock() + cc.list_blobs.return_value = blobs + + def _get_blob_client(name): + bc = MagicMock() + text = name_to_text.get(name, "") + bc.download_blob.return_value.readall.return_value = text.encode("utf-8") + return bc + + cc.get_blob_client.side_effect = _get_blob_client + return cc + + +@pytest.fixture +def patch_client(): + """Patch `_get_blob_service_client` so no real Azure call is made.""" + + def _apply(container_client): + client = MagicMock() + client.get_container_client.return_value = container_client + return patch.object(spe, "_get_blob_service_client", return_value=client) + + return _apply + + +class TestGetBlobServiceClient: + def test_account_name_uses_credential(self, monkeypatch): + monkeypatch.setenv("STORAGE_ACCOUNT_NAME", "myacct") + with patch.object(spe, "BlobServiceClient") as bsc, patch.object( + spe, "get_azure_credential", return_value="cred" + ): + bsc.return_value = "client" + result = spe._get_blob_service_client() + bsc.assert_called_once_with( + account_url="https://myacct.blob.core.windows.net", + credential="cred", + ) + assert result == "client" + + def test_alt_account_env_used(self, monkeypatch): + monkeypatch.delenv("STORAGE_ACCOUNT_NAME", raising=False) + monkeypatch.setenv("AZURE_STORAGE_ACCOUNT_NAME", "alt") + with patch.object(spe, "BlobServiceClient") as bsc, patch.object( + spe, "get_azure_credential", return_value="c" + ): + spe._get_blob_service_client() + assert "alt.blob.core.windows.net" in bsc.call_args.kwargs["account_url"] + + def test_connection_string_fallback(self, monkeypatch): + monkeypatch.delenv("STORAGE_ACCOUNT_NAME", raising=False) + monkeypatch.delenv("AZURE_STORAGE_ACCOUNT_NAME", raising=False) + monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", "DefaultEndpointsProtocol=https;...") + with patch.object(spe, "BlobServiceClient") as bsc: + bsc.from_connection_string.return_value = "from-cs" + result = spe._get_blob_service_client() + bsc.from_connection_string.assert_called_once() + assert result == "from-cs" + + def test_missing_config_raises(self, monkeypatch): + for key in [ + "STORAGE_ACCOUNT_NAME", + "AZURE_STORAGE_ACCOUNT_NAME", + "AZURE_STORAGE_CONNECTION_STRING", + "STORAGE_CONNECTION_STRING", + "AzureWebJobsStorage", + ]: + monkeypatch.delenv(key, raising=False) + with pytest.raises(RuntimeError): + spe._get_blob_service_client() + + +class TestCollectSecurityPolicyEvidence: + def test_empty_folder_returns_zero_findings(self, patch_client): + cc = _container_with_blobs({}, []) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="proj/" + ) + assert result["scanned_files"] == 0 + assert result["findings"] == [] + assert result["errors"] == [] + assert result["source_folder"] == "proj" + + def test_list_blobs_failure_surfaces_error(self, patch_client): + cc = MagicMock() + cc.list_blobs.side_effect = RuntimeError("listing blew up") + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="x" + ) + assert result["findings"] == [] + assert any("list_blobs_failed" in e for e in result["errors"]) + + def test_skips_non_relevant_extensions(self, patch_client): + cc = _container_with_blobs( + {"foo.png": "irrelevant"}, [_blob("foo.png")] + ) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + assert result["scanned_files"] == 0 + + def test_skips_keep_files(self, patch_client): + cc = _container_with_blobs( + {"folder/.keep": ""}, [_blob("folder/.keep")] + ) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + assert result["scanned_files"] == 0 + + def test_skips_oversized_files(self, patch_client): + cc = _container_with_blobs( + {"big.yaml": "x"}, [_blob("big.yaml", size=10_000_000)] + ) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", + source_folder="", + max_bytes_per_file=1024, + ) + assert result["scanned_files"] == 0 + assert result["skipped_files"] == 1 + + def test_max_files_cap_respected(self, patch_client): + names = [f"f{i}.yaml" for i in range(5)] + cc = _container_with_blobs({n: "no signals here" for n in names}, [_blob(n) for n in names]) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="", max_files=2 + ) + assert result["scanned_files"] == 2 + + def test_detects_aws_pattern(self, patch_client): + text = "key: AKIAIOSFODNN7EXAMPLE\n" + cc = _container_with_blobs({"a.yaml": text}, [_blob("a.yaml")]) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + assert result["scanned_files"] == 1 + assert result["findings"][0]["signals"] == ["aws_access_key_id_pattern"] + + def test_detects_gcp_and_generic(self, patch_client): + text = "private_key_id: abc\npassword: hunter2\n" + cc = _container_with_blobs({"x.json": text}, [_blob("x.json")]) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + signals = result["findings"][0]["signals"] + assert "gcp_service_account_key_fields" in signals + assert "generic_secret_keywords" in signals + + def test_extracts_secret_key_names_from_k8s_secret(self, patch_client): + text = ( + "apiVersion: v1\n" + "kind: Secret\n" + "metadata:\n" + " name: app\n" + "data:\n" + " username: dXNlcg==\n" + " password: cGFzcw==\n" + "type: Opaque\n" + ) + cc = _container_with_blobs({"s.yaml": text}, [_blob("s.yaml")]) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + finding = result["findings"][0] + assert "k8s_kind_secret" in finding["signals"] + assert "username" in finding["secret_key_names"] + assert "password" in finding["secret_key_names"] + + def test_per_file_download_error_recorded_and_continues(self, patch_client): + good_text = "password: foo\n" + cc = MagicMock() + cc.list_blobs.return_value = [_blob("bad.yaml"), _blob("good.yaml")] + + def _get_blob_client(name): + bc = MagicMock() + if name == "bad.yaml": + bc.download_blob.side_effect = RuntimeError("download failed") + else: + bc.download_blob.return_value.readall.return_value = good_text.encode() + return bc + + cc.get_blob_client.side_effect = _get_blob_client + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + assert any("bad.yaml" in e for e in result["errors"]) + assert any(f["blob"] == "good.yaml" for f in result["findings"]) + + def test_no_signals_no_finding(self, patch_client): + cc = _container_with_blobs({"plain.txt": "nothing interesting"}, [_blob("plain.txt")]) + with patch_client(cc): + result = spe.collect_security_policy_evidence( + container_name="c", source_folder="" + ) + assert result["scanned_files"] == 1 + assert result["findings"] == []