diff --git a/scripts/validate_plugins/run.py b/scripts/validate_plugins/run.py index f95b9741..6426c209 100644 --- a/scripts/validate_plugins/run.py +++ b/scripts/validate_plugins/run.py @@ -630,6 +630,21 @@ async def _return_self(): return _return_self().__await__() + async def __aenter__(self) -> "NullStub": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + del exc_type, exc, tb + return False + + def get(self, key=None, default=None): + del key + return default + + def pop(self, key=None, default=None): + del key + return default + def __iter__(self): return iter(()) @@ -637,9 +652,47 @@ def __bool__(self) -> bool: return False +class DummyConfig(dict): + def __init__(self, initial=None) -> None: + super().__init__(initial or {}) + + def __missing__(self, key): + del key + return NullStub() + + def __getattr__(self, name: str): + return self[name] + + class DummyContext: def __init__(self) -> None: self._star_manager = None + self._astrbot_root = Path(os.environ.get("ASTRBOT_ROOT", Path.cwd())).resolve() + self._data_root = self._astrbot_root / "data" + self._plugin_data_dir = self._data_root / "plugin_data" + self._config = DummyConfig( + { + "wake_prefix": [], + "dashboard": DummyConfig(), + "admins_id": [], + "admin_ids": [], + "platform_settings": DummyConfig( + { + "aiocqhttp": {}, + "qqofficial": {}, + "telegram": {}, + "gewechat": {}, + "wechatpadpro": {}, + } + ), + "data_dir": str(self._data_root), + } + ) + self.config = self._config + + def _ensure_plugin_data_dir(self) -> Path: + self._plugin_data_dir.mkdir(parents=True, exist_ok=True) + return self._plugin_data_dir def get_all_stars(self): try: @@ -669,6 +722,16 @@ def register_llm_tool(self, name: str, func_args, desc: str, func_obj) -> None: def unregister_llm_tool(self, name: str) -> None: del name + def get_config(self, umo: str | None = None): + del umo + return self._config + + def get_context_config(self): + return self.get_config() + + def get_data_dir(self) -> str: + return str(self._ensure_plugin_data_dir()) + def __getattr__(self, name: str) -> NullStub: del name return NullStub() diff --git a/tests/test_validate_plugins.py b/tests/test_validate_plugins.py index 5d4df159..cc17887e 100644 --- a/tests/test_validate_plugins.py +++ b/tests/test_validate_plugins.py @@ -385,6 +385,72 @@ def test_load_plugins_index_rejects_non_dict_values(self): os.remove(index_path) +class DummyContextStubTests(unittest.IsolatedAsyncioTestCase): + def test_dummy_context_defers_plugin_data_dir_creation_until_requested(self): + module = load_validator_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + astrbot_root = Path(tmp_dir) / "astrbot-root" + plugin_data_dir = astrbot_root / "data" / "plugin_data" + + with mock.patch.dict(os.environ, {"ASTRBOT_ROOT": str(astrbot_root)}, clear=True): + context = module.DummyContext() + dir_exists_before = plugin_data_dir.exists() + created_dir = Path(context.get_data_dir()) + dir_exists_after = plugin_data_dir.is_dir() + + self.assertFalse(dir_exists_before) + self.assertEqual(created_dir.resolve(), plugin_data_dir.resolve()) + self.assertTrue(dir_exists_after) + + def test_dummy_context_returns_worker_data_dir_for_plugin_storage(self): + module = load_validator_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + astrbot_root = Path(tmp_dir) / "astrbot-root" + with mock.patch.dict(os.environ, {"ASTRBOT_ROOT": str(astrbot_root)}, clear=True): + data_dir = Path(module.DummyContext().get_data_dir()) + data_dir_exists = data_dir.is_dir() + + self.assertEqual(data_dir.resolve(), (astrbot_root / "data" / "plugin_data").resolve()) + self.assertTrue(data_dir_exists) + + async def test_null_stub_supports_async_database_context_pattern(self): + module = load_validator_module() + + db = module.DummyContext().get_db() + + async with db.get_db() as session: + self.assertIsInstance(session, module.NullStub) + async with session.begin() as transaction: + self.assertIs(transaction, session) + result = await session.execute("SELECT 1") + + self.assertIs(result, session) + + async def test_null_stub_returns_defaults_for_restart_style_config_access(self): + module = load_validator_module() + + with mock.patch.dict(os.environ, {}, clear=True): + dashboard_config = module.DummyContext().get_config().get("dashboard", {}) + + self.assertEqual(dashboard_config.get("host", "127.0.0.1"), "127.0.0.1") + self.assertEqual( + int(os.environ.get("DASHBOARD_PORT", dashboard_config.get("port", 6185))), + 6185, + ) + + def test_dummy_context_exposes_dict_like_config_defaults(self): + module = load_validator_module() + + with mock.patch.dict(os.environ, {}, clear=True): + context = module.DummyContext() + + self.assertEqual(context.get_config()["wake_prefix"], []) + self.assertEqual(context.get_config()["dashboard"].get("port", 6185), 6185) + self.assertEqual(context._config.get("expire_seconds", 300), 300) + + class ValidationProgressTests(unittest.TestCase): def test_build_parser_defaults_max_workers_to_sixteen(self): module = load_validator_module()