diff --git a/.github/workflows/validate-plugin-smoke.yml b/.github/workflows/validate-plugin-smoke.yml new file mode 100644 index 00000000..b398271b --- /dev/null +++ b/.github/workflows/validate-plugin-smoke.yml @@ -0,0 +1,118 @@ +name: Validate Plugin Smoke + +on: + pull_request: + paths: + - "plugins.json" + - "scripts/validate_plugins/**" + - ".github/workflows/validate-plugin-smoke.yml" + schedule: + - cron: "0 2 * * *" + workflow_dispatch: + inputs: + plugin_names: + description: "Comma-separated plugin keys from plugins.json" + required: false + default: "" + plugin_limit: + description: "Validate the first N plugins when plugin_names is empty. Leave blank or use -1 for all plugins" + required: false + default: "" + astrbot_ref: + description: "AstrBot git ref to validate against" + required: false + default: "master" + max_workers: + description: "Maximum concurrent plugin validations" + required: false + default: "8" + +jobs: + validate-plugin-smoke: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set manual validation inputs + if: github.event_name == 'workflow_dispatch' + run: | + echo "ASTRBOT_REF=${{ inputs.astrbot_ref }}" >> "$GITHUB_ENV" + echo "PLUGIN_NAME_LIST=${{ inputs.plugin_names }}" >> "$GITHUB_ENV" + echo "PLUGIN_LIMIT=${{ inputs.plugin_limit }}" >> "$GITHUB_ENV" + echo "MAX_WORKERS=${{ inputs.max_workers }}" >> "$GITHUB_ENV" + echo "SHOULD_VALIDATE=true" >> "$GITHUB_ENV" + + - name: Set scheduled validation inputs + if: github.event_name == 'schedule' + run: | + echo "ASTRBOT_REF=master" >> "$GITHUB_ENV" + echo "PLUGIN_NAME_LIST=" >> "$GITHUB_ENV" + echo "PLUGIN_LIMIT=" >> "$GITHUB_ENV" + echo "MAX_WORKERS=8" >> "$GITHUB_ENV" + echo "SHOULD_VALIDATE=true" >> "$GITHUB_ENV" + echo "VALIDATION_NOTE=Running scheduled full plugin validation." >> "$GITHUB_ENV" + + - name: Detect changed plugins from pull request + if: github.event_name == 'pull_request' + run: python scripts/validate_plugins/detect_changed_plugins.py + + - name: Show PR diff selection + if: github.event_name == 'pull_request' + run: | + if [ "$SHOULD_VALIDATE" != "true" ]; then + printf '%s\n' "${VALIDATION_NOTE:-Smoke validation skipped.}" + else + printf 'Selected plugins: %s\n' "$PLUGIN_NAME_LIST" + fi + + - name: Set up Python + if: env.SHOULD_VALIDATE == 'true' + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install validator dependencies + if: env.SHOULD_VALIDATE == 'true' + run: python -m pip install --upgrade pip pyyaml + + - name: Clone AstrBot + if: env.SHOULD_VALIDATE == 'true' + run: git clone --depth 1 --branch "$ASTRBOT_REF" "https://github.com/AstrBotDevs/AstrBot" ".cache/AstrBot" + + - name: Install AstrBot dependencies + if: env.SHOULD_VALIDATE == 'true' + run: python -m pip install -r ".cache/AstrBot/requirements.txt" + + - name: Run plugin smoke validator + if: env.SHOULD_VALIDATE == 'true' + run: | + args=( + --astrbot-path ".cache/AstrBot" + --report-path "validation-report.json" + ) + + if [ -n "${PLUGIN_NAME_LIST:-}" ]; then + args+=(--plugin-name-list "$PLUGIN_NAME_LIST") + fi + + if [ -n "${PLUGIN_LIMIT:-}" ]; then + args+=(--limit "$PLUGIN_LIMIT") + fi + + if [ -n "${MAX_WORKERS:-}" ]; then + args+=(--max-workers "$MAX_WORKERS") + fi + + python scripts/validate_plugins/run.py "${args[@]}" + + - name: Upload validation report + if: always() + uses: actions/upload-artifact@v7 + with: + name: validation-report + path: validation-report.json + if-no-files-found: warn diff --git a/scripts/validate_plugins/__init__.py b/scripts/validate_plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/validate_plugins/detect_changed_plugins.py b/scripts/validate_plugins/detect_changed_plugins.py new file mode 100644 index 00000000..b78b5d79 --- /dev/null +++ b/scripts/validate_plugins/detect_changed_plugins.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from pathlib import Path + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from scripts.validate_plugins.plugins_map import load_plugins_map_text + + +DEFAULT_ASTRBOT_REF = "master" +ASTRBOT_REMOTE_URL = "https://github.com/AstrBotDevs/AstrBot" + + +def load_plugins_map(text: str, *, source_name: str) -> dict[str, dict]: + return load_plugins_map_text(text, source_name=source_name) + + +def detect_changed_plugin_names(*, base: dict[str, dict], head: dict[str, dict]) -> list[str]: + return [name for name, payload in head.items() if base.get(name) != payload] + + +def fetch_base_ref(base_ref: str) -> None: + subprocess.run(["git", "fetch", "origin", base_ref, "--depth", "1"], check=True) + + +def read_base_plugins_json(base_ref: str) -> str: + return subprocess.check_output( + ["git", "show", f"origin/{base_ref}:plugins.json"], + text=True, + stderr=subprocess.DEVNULL, + ) + + +def resolve_astrbot_ref() -> str: + try: + default_head = subprocess.check_output( + ["git", "ls-remote", "--symref", ASTRBOT_REMOTE_URL, "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError: + return DEFAULT_ASTRBOT_REF + + for line in default_head.splitlines(): + if line.startswith("ref: refs/heads/") and line.endswith("\tHEAD"): + return line.split("refs/heads/", 1)[1].split("\t", 1)[0] + return DEFAULT_ASTRBOT_REF + + +def detect_pull_request_selection(*, repo_root: Path, base_ref: str) -> dict[str, object]: + try: + fetch_base_ref(base_ref) + base = load_plugins_map(read_base_plugins_json(base_ref), source_name=f"base ref {base_ref}") + except (subprocess.CalledProcessError, ValueError): + base = {} + + head_text = (repo_root / "plugins.json").read_text(encoding="utf-8") + try: + head = load_plugins_map(head_text, source_name="PR head") + except ValueError as exc: + raise ValueError(f"plugins.json is invalid on the PR head: {exc}") from exc + + changed = detect_changed_plugin_names(base=base, head=head) + validation_note = "" + if not changed: + validation_note = "No plugin entries changed in plugins.json; skipping smoke validation." + + return { + "changed": changed, + "should_validate": bool(changed), + "validation_note": validation_note, + } + + +def write_github_env( + *, + env_path: Path, + astrbot_ref: str, + changed: list[str], + should_validate: bool, + validation_note: str, +) -> None: + with env_path.open("a", encoding="utf-8") as handle: + handle.write(f"ASTRBOT_REF={astrbot_ref}\n") + handle.write(f"PLUGIN_NAME_LIST={','.join(changed)}\n") + handle.write("PLUGIN_LIMIT=\n") + handle.write(f"SHOULD_VALIDATE={'true' if should_validate else 'false'}\n") + handle.write(f"VALIDATION_NOTE={validation_note}\n") + + +def main() -> int: + base_ref = os.environ["GITHUB_BASE_REF"] + github_env = Path(os.environ["GITHUB_ENV"]) + repo_root = Path.cwd() + + try: + result = detect_pull_request_selection(repo_root=repo_root, base_ref=base_ref) + except ValueError as exc: + print(str(exc), file=sys.stderr) + return 1 + + write_github_env( + env_path=github_env, + astrbot_ref=resolve_astrbot_ref(), + changed=result["changed"], + should_validate=result["should_validate"], + validation_note=result["validation_note"], + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/validate_plugins/plugins_map.py b/scripts/validate_plugins/plugins_map.py new file mode 100644 index 00000000..07b689cd --- /dev/null +++ b/scripts/validate_plugins/plugins_map.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json +from pathlib import Path + + +def validate_plugins_map(data: object, *, source_name: str) -> dict[str, dict]: + if not isinstance(data, dict): + raise ValueError("plugins.json must contain a JSON object") + + for name, payload in data.items(): + if not isinstance(name, str): + raise ValueError( + f"plugins.json on the {source_name} has a non-string key: {name!r}" + ) + if not isinstance(payload, dict): + raise ValueError( + f"plugins.json entry {name!r} on the {source_name} must be a JSON object" + ) + + return data + + +def load_plugins_map_text(text: str, *, source_name: str) -> dict[str, dict]: + try: + data = json.loads(text) + except json.JSONDecodeError as exc: + raise ValueError(f"plugins.json is invalid on the {source_name}: {exc}") from exc + + return validate_plugins_map(data, source_name=source_name) + + +def load_plugins_map_file(path: Path, *, source_name: str) -> dict[str, dict]: + return load_plugins_map_text(path.read_text(encoding="utf-8"), source_name=source_name) diff --git a/scripts/validate_plugins/run.py b/scripts/validate_plugins/run.py new file mode 100644 index 00000000..941f5c2b --- /dev/null +++ b/scripts/validate_plugins/run.py @@ -0,0 +1,835 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import asyncio +import concurrent.futures +import hashlib +import json +import os +import re +import shutil +import subprocess +import sys +import tempfile +import traceback +from pathlib import Path +from urllib.parse import urlparse + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from scripts.validate_plugins.plugins_map import load_plugins_map_file + +try: + import yaml +except ImportError: # pragma: no cover - optional in local unit tests + yaml = None + + +REQUIRED_METADATA_FIELDS = ("name", "desc", "version", "author") +DEFAULT_CLONE_TIMEOUT = 120 +DEFAULT_MAX_WORKERS = 8 +CONFLICT_MARKERS = ("<<<<<<<", "=======", ">>>>>>>") + + +class MetadataLoadError(ValueError): + pass + + +def positive_int(raw_value: str) -> int: + value = int(raw_value) + if value <= 0: + raise argparse.ArgumentTypeError("must be a positive integer") + return value + + +def build_result( + *, + plugin: str, + repo: str, + normalized_repo_url: str | None, + ok: bool, + stage: str, + message: str, + severity: str | None = None, + plugin_dir_name: str | None = None, + details: dict | str | None = None, +) -> dict: + resolved_severity = severity or ("pass" if ok else "fail") + resolved_ok = True if resolved_severity in {"pass", "warn"} else ok + result = { + "plugin": plugin, + "repo": repo, + "normalized_repo_url": normalized_repo_url, + "ok": resolved_ok, + "stage": stage, + "message": message, + "severity": resolved_severity, + } + if plugin_dir_name: + result["plugin_dir_name"] = plugin_dir_name + if details is not None: + result["details"] = details + return result + + +def normalize_repo_url(repo_url: str) -> str: + parsed = urlparse(repo_url.strip()) + if parsed.scheme not in {"http", "https"}: + raise ValueError("repo URL must use http or https") + if parsed.netloc.lower() != "github.com": + raise ValueError("repo URL must point to github.com") + + parts = [part for part in parsed.path.split("/") if part] + if len(parts) != 2: + raise ValueError("repo URL must include owner and repository") + + owner, repo = parts[0], parts[1] + if repo.endswith(".git"): + repo = repo[:-4] + if not owner or not repo: + raise ValueError("repo URL owner or repository is empty") + + return f"https://github.com/{owner}/{repo}" + + +def select_plugins( + *, + plugins: dict, + requested_names: list[str] | None, + limit: int | None, +) -> list[tuple[str, dict]]: + if requested_names: + selected = [] + for name in requested_names: + if name not in plugins: + raise KeyError(f"plugin not found: {name}") + selected.append((name, plugins[name])) + return selected + + items = list(plugins.items()) + if limit is None or limit < 0: + return items + return items[:limit] + + +def _parse_simple_yaml(path: Path) -> dict: + """Very small YAML subset parser used as a fallback when PyYAML is unavailable. + + Supported format: + - Flat mapping of `key: value` pairs + - No indentation (no nested objects or multiline continuations) + - No lists (`- item` syntax) + - `#` starts a comment when preceded by whitespace (or at line start) + """ + + def parse_value(raw_value: str) -> str: + value = raw_value.strip() + if not value: + return "" + + if value[0] in {'"', "'"}: + quote = value[0] + end_index = value.rfind(quote) + if end_index > 0: + return value[1:end_index] + + value = re.split(r"\s+#", value, maxsplit=1)[0].rstrip() + return value.strip("\"'") + + result: dict[str, str] = {} + for lineno, raw_line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): + stripped = raw_line.strip() + if not stripped: + continue + if stripped.startswith("#"): + continue + if raw_line[0].isspace(): + raise ValueError( + f"Unsupported YAML indentation in {path} at line {lineno}: {raw_line!r}" + ) + + line = stripped + if line.startswith("-"): + raise ValueError( + f"Unsupported YAML list syntax in {path} at line {lineno}: {raw_line!r}" + ) + if ":" not in line: + raise ValueError( + f"Unsupported YAML content (expected 'key: value') in {path} at line {lineno}: {raw_line!r}" + ) + + key, value = line.split(":", 1) + key = key.strip() + if not key: + raise ValueError(f"Empty key is not allowed in {path} at line {lineno}: {raw_line!r}") + if key in result: + raise ValueError(f"Duplicate key '{key}' in {path} at line {lineno}") + + result[key] = parse_value(value) + return result + + +def load_metadata(path: Path) -> dict: + text = path.read_text(encoding="utf-8") + if any(marker in text for marker in CONFLICT_MARKERS): + raise MetadataLoadError( + "could not find expected ':' (merge conflict markers found in metadata.yaml)" + ) + + if yaml is not None: + try: + loaded = yaml.safe_load(text) + except yaml.YAMLError as exc: + raise MetadataLoadError(str(exc)) from exc + if loaded is None: + return {} + if not isinstance(loaded, dict): + raise MetadataLoadError("metadata.yaml must contain a mapping at the top level") + return loaded + + try: + return _parse_simple_yaml(path) + except ValueError as exc: + raise MetadataLoadError(str(exc)) from exc + + +def precheck_plugin_directory(plugin_dir: Path) -> dict: + metadata_path = plugin_dir / "metadata.yaml" + if not metadata_path.exists(): + return { + "ok": False, + "stage": "metadata", + "message": "missing metadata.yaml", + } + + try: + metadata = load_metadata(metadata_path) + except MetadataLoadError as exc: + return { + "ok": False, + "stage": "metadata", + "message": "invalid metadata.yaml", + "details": str(exc), + } + + missing = [ + field + for field in REQUIRED_METADATA_FIELDS + if not isinstance(metadata.get(field), str) or not metadata[field].strip() + ] + if missing: + return { + "ok": False, + "severity": "warn", + "stage": "metadata", + "message": f"missing required metadata fields: {', '.join(missing)}", + } + + try: + plugin_name = validate_plugin_dir_name(metadata["name"]) + except ValueError as exc: + return { + "ok": False, + "stage": "metadata", + "message": "invalid plugin directory name", + "details": str(exc), + } + + entry_candidates = [plugin_dir / "main.py", plugin_dir / f"{plugin_name}.py"] + if not any(path.exists() for path in entry_candidates): + return { + "ok": False, + "stage": "entrypoint", + "message": f"missing main.py or {plugin_name}.py", + } + + return { + "ok": True, + "stage": "precheck", + "message": "ok", + "metadata": metadata, + "plugin_dir_name": plugin_name, + } + + +def build_worker_command( + *, + script_path: Path, + astrbot_path: Path, + plugin_source_dir: Path, + plugin_dir_name: str, + normalized_repo_url: str, +) -> list[str]: + return [ + sys.executable, + str(script_path), + "--worker", + "--astrbot-path", + str(astrbot_path), + "--plugin-source-dir", + str(plugin_source_dir), + "--plugin-dir-name", + plugin_dir_name, + "--normalized-repo-url", + normalized_repo_url, + ] + + +def build_worker_sys_path(*, astrbot_root: Path, astrbot_path: Path) -> list[str]: + return [str(astrbot_root.resolve()), str(astrbot_path.resolve())] + + +def build_report(results: list[dict]) -> dict: + passed = sum(1 for result in results if result.get("severity") == "pass") + warned = sum(1 for result in results if result.get("severity") == "warn") + failed = sum(1 for result in results if result.get("severity") == "fail") + return { + "summary": { + "total": len(results), + "passed": passed, + "warned": warned, + "failed": failed, + }, + "results": results, + } + + +def load_plugins_index(path: Path) -> dict[str, dict]: + return load_plugins_map_file(path, source_name="plugins.json") + + +def combine_requested_names( + plugin_names: list[str] | None, + plugin_name_list: str | None, +) -> list[str]: + names = [name.strip() for name in (plugin_names or [])] + if plugin_name_list: + names.extend(part.strip() for part in plugin_name_list.split(",")) + return [name for name in names if name] + + +def sanitize_name(name: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "-", name).strip("-") + return sanitized or "plugin" + + +def validate_plugin_dir_name(name: str) -> str: + candidate = name.strip() + if not candidate or candidate in {".", ".."}: + raise ValueError("unsafe plugin_dir_name") + if "/" in candidate or "\\" in candidate: + raise ValueError("unsafe plugin_dir_name") + if ".." in candidate: + raise ValueError("unsafe plugin_dir_name") + return candidate + + +def build_plugin_clone_dir(work_dir: Path, plugin: str) -> Path: + digest = hashlib.sha256(plugin.encode("utf-8")).hexdigest()[:8] + return work_dir / f"{sanitize_name(plugin)}-{digest}" + + +def _normalize_process_output(output: str | bytes | None) -> str | None: + if output is None: + return None + if isinstance(output, bytes): + output = output.decode("utf-8", errors="replace") + normalized = output.strip() + return normalized or None + + +def build_process_output_details( + *, + stdout: str | bytes | None, + stderr: str | bytes | None, +) -> dict | None: + details = {} + stdout_text = _normalize_process_output(stdout) + stderr_text = _normalize_process_output(stderr) + if stdout_text: + details["stdout"] = stdout_text + if stderr_text: + details["stderr"] = stderr_text + return details or None + + +def clone_plugin_repo( + repo_url: str, + destination: Path, + *, + timeout: int = DEFAULT_CLONE_TIMEOUT, +) -> None: + subprocess.run( + ["git", "clone", "--depth", "1", repo_url, str(destination)], + check=True, + capture_output=True, + text=True, + timeout=timeout, + ) + + +def parse_worker_output( + *, + plugin: str, + repo: str, + normalized_repo_url: str, + completed: subprocess.CompletedProcess[str], + plugin_dir_name: str, +) -> dict: + stdout = completed.stdout.strip() + if stdout: + for line in reversed(stdout.splitlines()): + try: + payload = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + payload["plugin"] = plugin + payload["repo"] = repo + payload["normalized_repo_url"] = normalized_repo_url + payload.setdefault("plugin_dir_name", plugin_dir_name) + return payload + + stderr = completed.stderr.strip() + message = stderr or stdout or "worker returned no structured output" + return build_result( + plugin=plugin, + repo=repo, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="worker", + message=message, + plugin_dir_name=plugin_dir_name, + ) + + +def validate_plugin( + *, + plugin: str, + plugin_data: dict, + astrbot_path: Path, + script_path: Path, + work_dir: Path, + clone_timeout: int, + load_timeout: int, +) -> dict: + repo_url = plugin_data.get("repo") + if not isinstance(repo_url, str) or not repo_url.strip(): + return build_result( + plugin=plugin, + repo="", + normalized_repo_url=None, + ok=False, + stage="repo_url", + message="missing repo field", + ) + + try: + normalized_repo_url = normalize_repo_url(repo_url) + except ValueError as exc: + return build_result( + plugin=plugin, + repo=repo_url, + normalized_repo_url=None, + ok=False, + stage="repo_url", + message=str(exc), + ) + + plugin_clone_dir = build_plugin_clone_dir(work_dir, plugin) + try: + clone_plugin_repo( + normalized_repo_url, + plugin_clone_dir, + timeout=clone_timeout, + ) + except subprocess.CalledProcessError as exc: + message = exc.stderr.strip() or exc.stdout.strip() or str(exc) + return build_result( + plugin=plugin, + repo=repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="clone", + message=message, + ) + except subprocess.TimeoutExpired as exc: + return build_result( + plugin=plugin, + repo=repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="clone_timeout", + message=f"git clone timed out after {clone_timeout} seconds", + details=build_process_output_details(stdout=exc.stdout, stderr=exc.stderr), + ) + + precheck = precheck_plugin_directory(plugin_clone_dir) + if not precheck["ok"]: + return build_result( + plugin=plugin, + repo=repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage=precheck["stage"], + message=precheck["message"], + severity=precheck.get("severity"), + details=precheck.get("details"), + ) + + plugin_dir_name = precheck["plugin_dir_name"] + command = build_worker_command( + script_path=script_path, + astrbot_path=astrbot_path, + plugin_source_dir=plugin_clone_dir, + plugin_dir_name=plugin_dir_name, + normalized_repo_url=normalized_repo_url, + ) + + try: + completed = subprocess.run( + command, + check=False, + capture_output=True, + text=True, + timeout=load_timeout, + ) + except subprocess.TimeoutExpired as exc: + return build_result( + plugin=plugin, + repo=repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="timeout", + message=f"worker timed out after {load_timeout} seconds", + plugin_dir_name=plugin_dir_name, + details=build_process_output_details(stdout=exc.stdout, stderr=exc.stderr), + ) + + return parse_worker_output( + plugin=plugin, + repo=repo_url, + normalized_repo_url=normalized_repo_url, + completed=completed, + plugin_dir_name=plugin_dir_name, + ) + + +def validate_selected_plugins( + *, + selected: list[tuple[str, dict]], + astrbot_path: Path, + script_path: Path, + work_dir: Path, + clone_timeout: int, + load_timeout: int, + max_workers: int, +) -> list[dict]: + total = len(selected) + results: list[dict | None] = [None] * total + + def task(index: int, plugin: str, plugin_data: dict) -> tuple[int, dict]: + return ( + index, + validate_plugin( + plugin=plugin, + plugin_data=plugin_data, + astrbot_path=astrbot_path, + script_path=script_path, + work_dir=work_dir, + clone_timeout=clone_timeout, + load_timeout=load_timeout, + ), + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_context: dict[concurrent.futures.Future, tuple[int, str]] = {} + + for index, (plugin, plugin_data) in enumerate(selected, start=1): + print(f"[{index}/{total}] Queued {plugin}", flush=True) + future = executor.submit(task, index, plugin, plugin_data) + future_to_context[future] = (index, plugin) + + for future in concurrent.futures.as_completed(future_to_context): + index, plugin = future_to_context[future] + try: + original_index, result = future.result() + except Exception as exc: + original_index = index + result = build_result( + plugin=plugin, + repo="", + normalized_repo_url=None, + ok=False, + stage="threadpool", + message=str(exc), + details=traceback.format_exc(), + ) + + results[original_index - 1] = result + severity = result.get("severity", "pass" if result.get("ok") else "fail") + status = {"pass": "PASS", "warn": "WARN", "fail": "FAIL"}.get(severity, "FAIL") + stage = result.get("stage", "unknown") + message = result.get("message", "") + print(f"[{original_index}/{total}] {status} {plugin} [{stage}] {message}", flush=True) + + finalized = [result for result in results if result is not None] + if len(finalized) != total: + raise RuntimeError("parallel validation finished with missing results") + + return finalized + + +class NullStub: + def __getattr__(self, name: str) -> "NullStub": + del name + return self + + def __call__(self, *args, **kwargs) -> "NullStub": + del args, kwargs + return self + + def __await__(self): + async def _return_self(): + return self + + return _return_self().__await__() + + def __iter__(self): + return iter(()) + + def __bool__(self) -> bool: + return False + + +class DummyContext: + def __init__(self) -> None: + self._star_manager = None + + def get_all_stars(self): + try: + from astrbot.core.star.star import star_registry + + return list(star_registry) + except Exception: + return [] + + def get_registered_star(self, star_name: str): + for star in self.get_all_stars(): + if getattr(star, "name", None) == star_name: + return star + return None + + def activate_llm_tool(self, name: str) -> bool: + del name + return True + + def deactivate_llm_tool(self, name: str) -> bool: + del name + return True + + def register_llm_tool(self, name: str, func_args, desc: str, func_obj) -> None: + del name, func_args, desc, func_obj + + def unregister_llm_tool(self, name: str) -> None: + del name + + def __getattr__(self, name: str) -> NullStub: + del name + return NullStub() + + +async def run_worker_load_check(plugin_dir_name: str, normalized_repo_url: str) -> dict: + try: + from astrbot.core.star.star_manager import PluginManager + except Exception as exc: + return build_result( + plugin=plugin_dir_name, + repo=normalized_repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="astrbot_import", + message=str(exc), + plugin_dir_name=plugin_dir_name, + details=traceback.format_exc(), + ) + + context = DummyContext() + manager = PluginManager(context, {}) + + try: + success, error = await manager.load(specified_dir_name=plugin_dir_name) + except Exception as exc: + return build_result( + plugin=plugin_dir_name, + repo=normalized_repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="load", + message=str(exc), + plugin_dir_name=plugin_dir_name, + details=traceback.format_exc(), + ) + + if success: + return build_result( + plugin=plugin_dir_name, + repo=normalized_repo_url, + normalized_repo_url=normalized_repo_url, + ok=True, + stage="load", + message="plugin loaded successfully", + plugin_dir_name=plugin_dir_name, + ) + + return build_result( + plugin=plugin_dir_name, + repo=normalized_repo_url, + normalized_repo_url=normalized_repo_url, + ok=False, + stage="load", + message=str(error) if error else "plugin load failed", + plugin_dir_name=plugin_dir_name, + details=manager.failed_plugin_dict.get(plugin_dir_name), + ) + + +def run_worker(args: argparse.Namespace) -> int: + temp_root = Path(tempfile.mkdtemp(prefix="astrbot-plugin-worker-")) + try: + astrbot_root = temp_root / "astrbot-root" + plugin_store = astrbot_root / "data" / "plugins" + plugin_config = astrbot_root / "data" / "config" + plugin_store.mkdir(parents=True, exist_ok=True) + plugin_config.mkdir(parents=True, exist_ok=True) + + source_dir = Path(args.plugin_source_dir).resolve() + target_dir = plugin_store / args.plugin_dir_name + shutil.copytree(source_dir, target_dir, dirs_exist_ok=True) + + os.environ["ASTRBOT_ROOT"] = str(astrbot_root) + os.environ.setdefault("TESTING", "true") + sys.path[:0] = build_worker_sys_path( + astrbot_root=astrbot_root, + astrbot_path=Path(args.astrbot_path), + ) + + result = asyncio.run( + run_worker_load_check(args.plugin_dir_name, args.normalized_repo_url) + ) + except Exception as exc: + result = build_result( + plugin=args.plugin_dir_name, + repo=args.normalized_repo_url, + normalized_repo_url=args.normalized_repo_url, + ok=False, + stage="worker", + message=str(exc), + plugin_dir_name=args.plugin_dir_name, + details=traceback.format_exc(), + ) + finally: + shutil.rmtree(temp_root, ignore_errors=True) + + print(json.dumps(result, ensure_ascii=False)) + return 0 if result["ok"] else 1 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Validate AstrBot plugins") + parser.add_argument("--plugins-json", default="plugins.json") + parser.add_argument("--plugin-name", action="append", dest="plugin_names") + parser.add_argument("--plugin-name-list") + parser.add_argument( + "--limit", + type=int, + help="Validate the first N plugins when plugin names are empty. Omit or use -1 for all plugins.", + ) + parser.add_argument("--astrbot-path") + parser.add_argument("--report-path", default="validation-report.json") + parser.add_argument("--work-dir") + parser.add_argument("--clone-timeout", type=positive_int, default=DEFAULT_CLONE_TIMEOUT) + parser.add_argument("--load-timeout", type=positive_int, default=300) + parser.add_argument("--max-workers", type=positive_int, default=DEFAULT_MAX_WORKERS) + parser.add_argument("--worker", action="store_true") + parser.add_argument("--plugin-source-dir") + parser.add_argument("--plugin-dir-name") + parser.add_argument("--normalized-repo-url") + return parser + + +def main() -> int: + parser = build_parser() + args = parser.parse_args() + + if args.worker: + missing = [ + flag + for flag, value in ( + ("--astrbot-path", args.astrbot_path), + ("--plugin-source-dir", args.plugin_source_dir), + ("--plugin-dir-name", args.plugin_dir_name), + ("--normalized-repo-url", args.normalized_repo_url), + ) + if not value + ] + if missing: + parser.error(f"worker mode requires: {', '.join(missing)}") + return run_worker(args) + + if not args.astrbot_path: + parser.error("--astrbot-path is required") + + requested_names = combine_requested_names(args.plugin_names, args.plugin_name_list) + plugins = load_plugins_index(Path(args.plugins_json)) + selected = select_plugins( + plugins=plugins, + requested_names=requested_names or None, + limit=args.limit, + ) + + temp_dir = None + work_dir = Path(args.work_dir) if args.work_dir else None + if work_dir is None: + temp_dir = tempfile.TemporaryDirectory(prefix="astrbot-plugin-validate-") + work_dir = Path(temp_dir.name) + work_dir.mkdir(parents=True, exist_ok=True) + + try: + results = validate_selected_plugins( + selected=selected, + astrbot_path=Path(args.astrbot_path).resolve(), + script_path=Path(__file__).resolve(), + work_dir=work_dir, + clone_timeout=args.clone_timeout, + load_timeout=args.load_timeout, + max_workers=args.max_workers, + ) + finally: + if temp_dir is not None: + temp_dir.cleanup() + + report = build_report(results) + report_path = Path(args.report_path) + report_path.write_text( + json.dumps(report, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + + print( + json.dumps( + { + "report_path": str(report_path), + "summary": report["summary"], + }, + ensure_ascii=False, + ) + ) + return 0 if report["summary"]["failed"] == 0 else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_detect_changed_plugins.py b/tests/test_detect_changed_plugins.py new file mode 100644 index 00000000..64678af7 --- /dev/null +++ b/tests/test_detect_changed_plugins.py @@ -0,0 +1,147 @@ +import importlib.util +import tempfile +import unittest +from pathlib import Path +from unittest import mock + + +ROOT = Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "scripts" / "validate_plugins" / "detect_changed_plugins.py" + + +def load_detection_module(): + if not MODULE_PATH.exists(): + raise AssertionError(f"detection script missing: {MODULE_PATH}") + + spec = importlib.util.spec_from_file_location("detect_changed_plugins", MODULE_PATH) + if spec is None or spec.loader is None: + raise AssertionError("unable to load detection module spec") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class LoadPluginsMapTests(unittest.TestCase): + def test_load_plugins_map_requires_json_object(self): + module = load_detection_module() + + with self.assertRaisesRegex(ValueError, "plugins.json must contain a JSON object"): + module.load_plugins_map('[{"name": "bad"}]', source_name="head") + + def test_load_plugins_map_returns_dict_for_valid_json(self): + module = load_detection_module() + + plugins = module.load_plugins_map('{"plugin-a": {"repo": "https://github.com/example/a"}}', source_name="head") + + self.assertEqual(plugins, {"plugin-a": {"repo": "https://github.com/example/a"}}) + + def test_load_plugins_map_rejects_non_dict_entries(self): + module = load_detection_module() + + with self.assertRaisesRegex(ValueError, "plugins.json entry 'plugin-a' on the PR head must be a JSON object"): + module.load_plugins_map('{"plugin-a": "bad"}', source_name="PR head") + + +class ChangedPluginDetectionTests(unittest.TestCase): + def test_detect_changed_plugin_names_returns_only_modified_entries(self): + module = load_detection_module() + + changed = module.detect_changed_plugin_names( + base={"plugin-a": {"repo": "a"}, "plugin-b": {"repo": "b"}}, + head={"plugin-a": {"repo": "a"}, "plugin-b": {"repo": "changed"}, "plugin-c": {"repo": "c"}}, + ) + + self.assertEqual(changed, ["plugin-b", "plugin-c"]) + + +class AstrbotRefTests(unittest.TestCase): + def test_resolve_astrbot_ref_uses_remote_default_branch(self): + module = load_detection_module() + + with mock.patch.object(module.subprocess, "check_output", return_value="ref: refs/heads/main\tHEAD\nabc\tHEAD\n"): + ref = module.resolve_astrbot_ref() + + self.assertEqual(ref, "main") + + def test_resolve_astrbot_ref_falls_back_to_master(self): + module = load_detection_module() + + with mock.patch.object(module.subprocess, "check_output", side_effect=module.subprocess.CalledProcessError(1, ["git"])): + ref = module.resolve_astrbot_ref() + + self.assertEqual(ref, "master") + + +class PullRequestDetectionTests(unittest.TestCase): + def test_detect_pull_request_selection_handles_fetch_base_ref_failure(self): + module = load_detection_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + repo_root = Path(tmp_dir) + plugins_json = repo_root / "plugins.json" + plugins_json.write_text('{"plugin-a": {"repo": "https://github.com/example/a"}}', encoding="utf-8") + + with mock.patch.object( + module, + "fetch_base_ref", + side_effect=module.subprocess.CalledProcessError(1, ["git", "fetch"]), + ): + result = module.detect_pull_request_selection(repo_root=repo_root, base_ref="main") + + self.assertEqual(result["changed"], ["plugin-a"]) + self.assertEqual(result["validation_note"], "") + + def test_detect_pull_request_selection_handles_missing_base_file(self): + module = load_detection_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + repo_root = Path(tmp_dir) + plugins_json = repo_root / "plugins.json" + plugins_json.write_text('{"plugin-a": {"repo": "https://github.com/example/a"}}', encoding="utf-8") + + with mock.patch.object(module, "fetch_base_ref") as fetch_mock: + with mock.patch.object(module, "read_base_plugins_json", side_effect=module.subprocess.CalledProcessError(1, ["git"])): + result = module.detect_pull_request_selection(repo_root=repo_root, base_ref="main") + + fetch_mock.assert_called_once_with("main") + self.assertEqual(result["changed"], ["plugin-a"]) + self.assertEqual(result["validation_note"], "") + + def test_detect_pull_request_selection_raises_on_invalid_head_json(self): + module = load_detection_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + repo_root = Path(tmp_dir) + (repo_root / "plugins.json").write_text('{bad json', encoding="utf-8") + + with mock.patch.object(module, "fetch_base_ref"): + with mock.patch.object(module, "read_base_plugins_json", return_value='{}'): + with self.assertRaisesRegex(ValueError, "plugins.json is invalid on the PR head"): + module.detect_pull_request_selection(repo_root=repo_root, base_ref="main") + + def test_write_github_env_outputs_expected_values(self): + module = load_detection_module() + + with tempfile.NamedTemporaryFile("w+", delete=False) as handle: + env_path = Path(handle.name) + + try: + module.write_github_env( + env_path=env_path, + astrbot_ref="master", + changed=["plugin-a", "plugin-b"], + should_validate=True, + validation_note="", + ) + content = env_path.read_text(encoding="utf-8") + finally: + env_path.unlink(missing_ok=True) + + self.assertIn("ASTRBOT_REF=master\n", content) + self.assertIn("PLUGIN_NAME_LIST=plugin-a,plugin-b\n", content) + self.assertIn("SHOULD_VALIDATE=true\n", content) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_validate_plugins.py b/tests/test_validate_plugins.py new file mode 100644 index 00000000..e0f0699e --- /dev/null +++ b/tests/test_validate_plugins.py @@ -0,0 +1,814 @@ +import importlib.util +import json +import os +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path +from unittest import mock + + +ROOT = Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "scripts" / "validate_plugins" / "run.py" + + +def load_validator_module(): + if not MODULE_PATH.exists(): + raise AssertionError(f"validator script missing: {MODULE_PATH}") + + spec = importlib.util.spec_from_file_location("validate_plugins_run", MODULE_PATH) + if spec is None or spec.loader is None: + raise AssertionError("unable to load validator module spec") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class NormalizeRepoUrlTests(unittest.TestCase): + def test_strips_git_suffix_trailing_slash_and_query(self): + module = load_validator_module() + + self.assertEqual( + module.normalize_repo_url( + "https://github.com/example/demo-plugin.git/?tab=readme-ov-file" + ), + "https://github.com/example/demo-plugin", + ) + + def test_rejects_non_github_urls(self): + module = load_validator_module() + + with self.assertRaises(ValueError): + module.normalize_repo_url("https://gitlab.com/example/demo-plugin") + + def test_rejects_non_http_schemes(self): + module = load_validator_module() + + for url in ( + "git://github.com/example/demo-plugin", + "ssh://github.com/example/demo-plugin", + ): + with self.subTest(url=url): + with self.assertRaisesRegex(ValueError, "repo URL must use http or https"): + module.normalize_repo_url(url) + + def test_rejects_missing_owner_or_repository(self): + module = load_validator_module() + + for url in ( + "https://github.com/", + "https://github.com/example", + "https://github.com/example/", + "https://github.com//demo-plugin", + "https://github.com/example//", + ): + with self.subTest(url=url): + with self.assertRaisesRegex(ValueError, "repo URL must include owner and repository"): + module.normalize_repo_url(url) + + def test_strips_leading_and_trailing_whitespace(self): + module = load_validator_module() + + self.assertEqual( + module.normalize_repo_url(" https://github.com/example/demo-plugin "), + "https://github.com/example/demo-plugin", + ) + + +class SelectPluginsTests(unittest.TestCase): + def test_returns_all_plugins_when_limit_is_none(self): + module = load_validator_module() + plugins = { + "plugin-a": {"repo": "https://github.com/example/plugin-a"}, + "plugin-b": {"repo": "https://github.com/example/plugin-b"}, + } + + selected = module.select_plugins( + plugins=plugins, + requested_names=None, + limit=None, + ) + + self.assertEqual([item[0] for item in selected], ["plugin-a", "plugin-b"]) + + def test_returns_all_plugins_when_limit_is_negative_one(self): + module = load_validator_module() + plugins = { + "plugin-a": {"repo": "https://github.com/example/plugin-a"}, + "plugin-b": {"repo": "https://github.com/example/plugin-b"}, + } + + selected = module.select_plugins( + plugins=plugins, + requested_names=None, + limit=-1, + ) + + self.assertEqual([item[0] for item in selected], ["plugin-a", "plugin-b"]) + + def test_prefers_explicit_names_in_requested_order(self): + module = load_validator_module() + plugins = { + "plugin-a": {"repo": "https://github.com/example/plugin-a"}, + "plugin-b": {"repo": "https://github.com/example/plugin-b"}, + "plugin-c": {"repo": "https://github.com/example/plugin-c"}, + } + + selected = module.select_plugins( + plugins=plugins, + requested_names=["plugin-c", "plugin-a"], + limit=None, + ) + + self.assertEqual([item[0] for item in selected], ["plugin-c", "plugin-a"]) + + def test_respects_positive_limit_when_names_not_requested(self): + module = load_validator_module() + plugins = { + "plugin-a": {"repo": "https://github.com/example/plugin-a"}, + "plugin-b": {"repo": "https://github.com/example/plugin-b"}, + "plugin-c": {"repo": "https://github.com/example/plugin-c"}, + } + + selected = module.select_plugins( + plugins=plugins, + requested_names=None, + limit=1, + ) + + self.assertEqual([item[0] for item in selected], ["plugin-a"]) + + def test_raises_key_error_for_unknown_requested_plugin(self): + module = load_validator_module() + plugins = { + "known-plugin": {"repo": "https://github.com/example/known-plugin"}, + } + + with self.assertRaisesRegex(KeyError, "plugin not found: missing-plugin"): + module.select_plugins( + plugins=plugins, + requested_names=["known-plugin", "missing-plugin"], + limit=None, + ) + + +class HelperFunctionTests(unittest.TestCase): + def test_combine_requested_names_merges_trims_and_drops_empty_values(self): + module = load_validator_module() + + combined = module.combine_requested_names( + plugin_names=["foo", " bar ", "", " "], + plugin_name_list="baz, qux , ,foo ", + ) + + self.assertEqual(combined, ["foo", "bar", "baz", "qux", "foo"]) + + def test_combine_requested_names_handles_none_inputs(self): + module = load_validator_module() + + self.assertEqual(module.combine_requested_names(None, None), []) + + def test_sanitize_name_replaces_invalid_chars_and_falls_back_when_needed(self): + module = load_validator_module() + + self.assertEqual(module.sanitize_name(" -invalid name!*?- "), "invalid-name") + self.assertEqual(module.sanitize_name("valid-name_123"), "valid-name_123") + self.assertEqual(module.sanitize_name(" "), "plugin") + self.assertEqual(module.sanitize_name("!!!"), "plugin") + + def test_build_plugin_clone_dir_is_unique_for_colliding_sanitized_names(self): + module = load_validator_module() + + first = module.build_plugin_clone_dir(Path("/tmp/work"), "foo bar") + second = module.build_plugin_clone_dir(Path("/tmp/work"), "foo/bar") + + self.assertNotEqual(first, second) + self.assertEqual(first.parent, Path("/tmp/work")) + self.assertEqual(second.parent, Path("/tmp/work")) + + def test_build_process_output_details_keeps_partial_timeout_logs(self): + module = load_validator_module() + + details = module.build_process_output_details( + stdout="line one\nline two\n", + stderr=b"warning\n", + ) + + self.assertEqual(details, {"stdout": "line one\nline two", "stderr": "warning"}) + + def test_parse_simple_yaml_handles_comments_quotes_and_whitespace(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write( + "# leading comment\n\n" + "key1: value1 # trailing comment\n" + 'key2: " spaced value "\n' + "key3: 'another value'\n" + "key4: value-with-#-hash\n" + ) + metadata_path = Path(handle.name) + + try: + parsed = module._parse_simple_yaml(metadata_path) + finally: + os.remove(metadata_path) + + self.assertEqual(parsed["key1"], "value1") + self.assertEqual(parsed["key2"], " spaced value ") + self.assertEqual(parsed["key3"], "another value") + self.assertEqual(parsed["key4"], "value-with-#-hash") + + def test_parse_simple_yaml_rejects_indented_lines(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write("name: demo\n nested: nope\n") + metadata_path = Path(handle.name) + + try: + with self.assertRaisesRegex(ValueError, "Unsupported YAML indentation"): + module._parse_simple_yaml(metadata_path) + finally: + os.remove(metadata_path) + + def test_parse_simple_yaml_rejects_list_syntax(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write("- item\n") + metadata_path = Path(handle.name) + + try: + with self.assertRaisesRegex(ValueError, "Unsupported YAML list syntax"): + module._parse_simple_yaml(metadata_path) + finally: + os.remove(metadata_path) + + def test_parse_simple_yaml_rejects_duplicate_keys(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write("name: first\nname: second\n") + metadata_path = Path(handle.name) + + try: + with self.assertRaisesRegex(ValueError, "Duplicate key 'name'"): + module._parse_simple_yaml(metadata_path) + finally: + os.remove(metadata_path) + + def test_load_metadata_uses_yaml_safe_load_when_available(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write("name: should-be-overridden\n") + metadata_path = Path(handle.name) + + fake_yaml = mock.Mock() + fake_yaml.safe_load.return_value = {"name": "from-yaml", "version": "1.0.0"} + + try: + with mock.patch.object(module, "yaml", fake_yaml): + metadata = module.load_metadata(metadata_path) + finally: + os.remove(metadata_path) + + self.assertEqual(metadata, {"name": "from-yaml", "version": "1.0.0"}) + fake_yaml.safe_load.assert_called_once() + + def test_load_metadata_rejects_non_mapping_yaml_root(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write("- item\n") + metadata_path = Path(handle.name) + + fake_yaml = mock.Mock() + fake_yaml.safe_load.return_value = ["item"] + fake_yaml.YAMLError = ValueError + + try: + with mock.patch.object(module, "yaml", fake_yaml): + with self.assertRaisesRegex( + module.MetadataLoadError, + "metadata.yaml must contain a mapping at the top level", + ): + module.load_metadata(metadata_path) + finally: + os.remove(metadata_path) + + def test_load_metadata_uses_simple_parser_when_yaml_unavailable(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write('name: demo-plugin\nversion: "0.2.3"\n') + metadata_path = Path(handle.name) + + yaml_backup = getattr(module, "yaml", None) + try: + module.yaml = None + metadata = module.load_metadata(metadata_path) + finally: + module.yaml = yaml_backup + os.remove(metadata_path) + + self.assertEqual(metadata.get("name"), "demo-plugin") + self.assertEqual(metadata.get("version"), "0.2.3") + + def test_load_metadata_wraps_fallback_parse_errors(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".yml", delete=False) as handle: + handle.write("name: demo\n nested: nope\n") + metadata_path = Path(handle.name) + + yaml_backup = getattr(module, "yaml", None) + try: + module.yaml = None + with self.assertRaisesRegex(module.MetadataLoadError, "Unsupported YAML indentation"): + module.load_metadata(metadata_path) + finally: + module.yaml = yaml_backup + os.remove(metadata_path) + + def test_load_plugins_index_accepts_valid_object(self): + module = load_validator_module() + + index_obj = { + "good-plugin": {"name": "Good Plugin", "repo": "https://github.com/example/good"}, + "another-plugin": {"name": "Another Plugin"}, + } + + with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as handle: + json.dump(index_obj, handle) + index_path = Path(handle.name) + + try: + plugins = module.load_plugins_index(index_path) + finally: + os.remove(index_path) + + self.assertEqual(plugins, index_obj) + + def test_load_plugins_index_rejects_json_array(self): + module = load_validator_module() + + with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as handle: + json.dump([{"name": "array-entry"}], handle) + index_path = Path(handle.name) + + try: + with self.assertRaisesRegex(ValueError, "plugins.json must contain a JSON object"): + module.load_plugins_index(index_path) + finally: + os.remove(index_path) + + def test_load_plugins_index_rejects_non_dict_values(self): + module = load_validator_module() + + index_obj = { + "valid-plugin": {"name": "Valid Plugin", "repo": "https://github.com/example/valid"}, + "not-a-dict": "just-a-string", + } + + with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) as handle: + json.dump(index_obj, handle) + index_path = Path(handle.name) + + try: + with self.assertRaisesRegex(ValueError, "plugins.json entry 'not-a-dict'.*must be a JSON object"): + module.load_plugins_index(index_path) + finally: + os.remove(index_path) + + +class ValidationProgressTests(unittest.TestCase): + def test_build_parser_defaults_max_workers_to_eight(self): + module = load_validator_module() + + args = module.build_parser().parse_args(["--astrbot-path", "/tmp/AstrBot"]) + + self.assertEqual(args.max_workers, 8) + + def test_build_parser_rejects_non_positive_worker_and_timeout_values(self): + module = load_validator_module() + + with self.assertRaises(SystemExit): + module.build_parser().parse_args(["--astrbot-path", "/tmp/AstrBot", "--max-workers", "0"]) + + with self.assertRaises(SystemExit): + module.build_parser().parse_args(["--astrbot-path", "/tmp/AstrBot", "--clone-timeout", "0"]) + + with self.assertRaises(SystemExit): + module.build_parser().parse_args(["--astrbot-path", "/tmp/AstrBot", "--load-timeout", "0"]) + + def test_validate_selected_plugins_emits_progress_and_result_lines(self): + module = load_validator_module() + selected = [ + ("plugin-a", {"repo": "https://github.com/example/plugin-a"}), + ("plugin-b", {"repo": "https://github.com/example/plugin-b"}), + ] + fake_results = [ + {"plugin": "plugin-a", "ok": True, "severity": "pass", "stage": "load", "message": "ok"}, + {"plugin": "plugin-b", "ok": False, "severity": "warn", "stage": "metadata", "message": "missing required metadata fields: desc"}, + ] + + with mock.patch.object(module, "validate_plugin", side_effect=fake_results) as validate_mock: + with mock.patch("builtins.print") as print_mock: + results = module.validate_selected_plugins( + selected=selected, + astrbot_path=Path("/tmp/AstrBot"), + script_path=Path("/tmp/run.py"), + work_dir=Path("/tmp/work"), + clone_timeout=60, + load_timeout=300, + max_workers=8, + ) + + self.assertEqual(results, fake_results) + self.assertEqual(validate_mock.call_count, 2) + print_mock.assert_any_call("[1/2] Queued plugin-a", flush=True) + print_mock.assert_any_call("[1/2] PASS plugin-a [load] ok", flush=True) + print_mock.assert_any_call("[2/2] WARN plugin-b [metadata] missing required metadata fields: desc", flush=True) + + def test_validate_selected_plugins_preserves_result_order_with_out_of_order_completion(self): + module = load_validator_module() + selected = [ + ("plugin-a", {"repo": "https://github.com/example/plugin-a"}), + ("plugin-b", {"repo": "https://github.com/example/plugin-b"}), + ("plugin-c", {"repo": "https://github.com/example/plugin-c"}), + ] + futures = [mock.Mock(name="future-a"), mock.Mock(name="future-b"), mock.Mock(name="future-c")] + future_to_result = { + futures[0]: (1, {"plugin": "plugin-a", "ok": True, "stage": "load", "message": "a"}), + futures[1]: (2, {"plugin": "plugin-b", "ok": False, "stage": "metadata", "message": "b"}), + futures[2]: (3, {"plugin": "plugin-c", "ok": True, "stage": "load", "message": "c"}), + } + + executor = mock.MagicMock() + executor.__enter__.return_value = executor + executor.__exit__.return_value = False + executor.submit.side_effect = futures + + def future_result(future): + return future_to_result[future] + + for future in futures: + future.result.side_effect = lambda _timeout=None, future=future: future_result(future) + + with mock.patch.object(module.concurrent.futures, "ThreadPoolExecutor", return_value=executor) as pool_mock: + with mock.patch.object(module.concurrent.futures, "as_completed", return_value=[futures[2], futures[0], futures[1]]): + with mock.patch("builtins.print") as print_mock: + results = module.validate_selected_plugins( + selected=selected, + astrbot_path=Path("/tmp/AstrBot"), + script_path=Path("/tmp/run.py"), + work_dir=Path("/tmp/work"), + clone_timeout=60, + load_timeout=300, + max_workers=8, + ) + + pool_mock.assert_called_once_with(max_workers=8) + self.assertEqual([item["plugin"] for item in results], ["plugin-a", "plugin-b", "plugin-c"]) + print_mock.assert_any_call("[1/3] Queued plugin-a", flush=True) + print_mock.assert_any_call("[3/3] PASS plugin-c [load] c", flush=True) + + +class ValidatePluginTests(unittest.TestCase): + def setUp(self): + self.module = load_validator_module() + self.plugin_key = "demo-plugin" + self.plugin_data = {"repo": "https://github.com/example/demo-plugin"} + self.astrbot_path = Path("/tmp/AstrBot") + self.script_path = Path("/tmp/run.py") + self.work_dir = Path("/tmp/work") + + def call_validate_plugin(self, plugin_data=None): + return self.module.validate_plugin( + plugin=self.plugin_key, + plugin_data=self.plugin_data if plugin_data is None else plugin_data, + astrbot_path=self.astrbot_path, + script_path=self.script_path, + work_dir=self.work_dir, + clone_timeout=30, + load_timeout=60, + ) + + def test_missing_repo_field_sets_repo_url_stage(self): + result = self.call_validate_plugin(plugin_data={}) + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "repo_url") + self.assertEqual(result["message"], "missing repo field") + + def test_invalid_repo_url_sets_repo_url_stage(self): + with mock.patch.object(self.module, "normalize_repo_url", side_effect=ValueError("invalid repo URL")): + result = self.call_validate_plugin() + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "repo_url") + self.assertEqual(result["message"], "invalid repo URL") + + def test_clone_called_process_error_uses_stderr_or_stdout(self): + error = subprocess.CalledProcessError( + returncode=1, + cmd=["git", "clone"], + output="clone stdout", + stderr="clone stderr", + ) + + with mock.patch.object(self.module, "clone_plugin_repo", side_effect=error): + result = self.call_validate_plugin() + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "clone") + self.assertIn("clone stderr", result["message"]) + + def test_clone_timeout_uses_process_output_details(self): + timeout = subprocess.TimeoutExpired(cmd=["git", "clone"], timeout=30, output="slow", stderr="warning") + + with mock.patch.object(self.module, "clone_plugin_repo", side_effect=timeout): + with mock.patch.object( + self.module, + "build_process_output_details", + return_value={"stdout": "slow", "stderr": "warning"}, + ) as details_mock: + result = self.call_validate_plugin() + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "clone_timeout") + self.assertEqual(result["details"], {"stdout": "slow", "stderr": "warning"}) + details_mock.assert_called_once_with(stdout="slow", stderr="warning") + + def test_precheck_failure_is_mapped_into_result(self): + with mock.patch.object(self.module, "clone_plugin_repo"): + with mock.patch.object( + self.module, + "precheck_plugin_directory", + return_value={"ok": False, "stage": "metadata", "message": "invalid metadata", "details": "line 3"}, + ): + result = self.call_validate_plugin() + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "metadata") + self.assertEqual(result["message"], "invalid metadata") + self.assertEqual(result["details"], "line 3") + + def test_precheck_warning_is_non_fatal_in_final_result(self): + with mock.patch.object(self.module, "clone_plugin_repo"): + with mock.patch.object( + self.module, + "precheck_plugin_directory", + return_value={ + "ok": False, + "severity": "warn", + "stage": "metadata", + "message": "missing required metadata fields: desc", + }, + ): + result = self.call_validate_plugin() + + self.assertTrue(result["ok"]) + self.assertEqual(result["severity"], "warn") + self.assertEqual(result["stage"], "metadata") + + def test_load_timeout_uses_process_output_details(self): + timeout = subprocess.TimeoutExpired( + cmd=[sys.executable, str(self.script_path)], + timeout=60, + output="timeout-stdout", + stderr="timeout-stderr", + ) + + with mock.patch.object( + self.module, + "precheck_plugin_directory", + return_value={"ok": True, "plugin_dir_name": "demo-plugin", "message": "ok", "stage": "precheck"}, + ): + with mock.patch.object(self.module, "clone_plugin_repo"): + with mock.patch.object(subprocess, "run", side_effect=timeout): + with mock.patch.object( + self.module, + "build_process_output_details", + return_value={"stdout": "timeout-stdout", "stderr": "timeout-stderr"}, + ) as details_mock: + result = self.call_validate_plugin() + + self.assertEqual(result["stage"], "timeout") + self.assertEqual(result["plugin_dir_name"], "demo-plugin") + self.assertEqual(result["details"], {"stdout": "timeout-stdout", "stderr": "timeout-stderr"}) + details_mock.assert_called_once_with(stdout="timeout-stdout", stderr="timeout-stderr") + + def test_successful_clone_and_precheck_invokes_worker_and_parses_output(self): + completed = subprocess.CompletedProcess( + args=["python3", "run.py"], + returncode=0, + stdout='{"ok": true}', + stderr="", + ) + parsed_output = {"ok": True, "stage": "load", "message": "plugin loaded successfully"} + + with mock.patch.object( + self.module, + "precheck_plugin_directory", + return_value={"ok": True, "plugin_dir_name": "demo_plugin", "message": "ok", "stage": "precheck"}, + ) as precheck_mock: + with mock.patch.object(self.module, "clone_plugin_repo"): + with mock.patch.object(subprocess, "run", return_value=completed) as run_mock: + with mock.patch.object(self.module, "parse_worker_output", return_value=parsed_output) as parse_mock: + result = self.call_validate_plugin() + + self.assertEqual(result, parsed_output) + precheck_mock.assert_called_once() + run_mock.assert_called_once() + parse_mock.assert_called_once_with( + plugin=self.plugin_key, + repo=self.plugin_data["repo"], + normalized_repo_url=self.plugin_data["repo"], + completed=completed, + plugin_dir_name="demo_plugin", + ) + + +class MetadataValidationTests(unittest.TestCase): + def test_reports_missing_required_metadata_fields(self): + module = load_validator_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + plugin_dir = Path(tmp_dir) + (plugin_dir / "metadata.yaml").write_text( + "name: demo_plugin\nauthor: AstrBot Team\n", + encoding="utf-8", + ) + (plugin_dir / "main.py").write_text("print('hello')\n", encoding="utf-8") + + result = module.precheck_plugin_directory(plugin_dir) + + self.assertFalse(result["ok"]) + self.assertEqual(result["severity"], "warn") + self.assertEqual(result["stage"], "metadata") + self.assertIn("desc", result["message"]) + self.assertIn("version", result["message"]) + + def test_reports_invalid_metadata_yaml_without_raising(self): + module = load_validator_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + plugin_dir = Path(tmp_dir) + (plugin_dir / "metadata.yaml").write_text( + "name: demo_plugin\n<<<<<<< HEAD\ndesc: broken\n=======\ndesc: fixed\n>>>>>>> branch\n", + encoding="utf-8", + ) + (plugin_dir / "main.py").write_text("print('hello')\n", encoding="utf-8") + + result = module.precheck_plugin_directory(plugin_dir) + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "metadata") + self.assertIn("invalid metadata.yaml", result["message"]) + self.assertIn("could not find expected ':'", result["details"]) + + def test_rejects_unsafe_plugin_dir_name_from_metadata(self): + module = load_validator_module() + + with tempfile.TemporaryDirectory() as tmp_dir: + plugin_dir = Path(tmp_dir) + (plugin_dir / "metadata.yaml").write_text( + "name: ../escape\ndesc: demo\nversion: 1.0.0\nauthor: AstrBot Team\n", + encoding="utf-8", + ) + (plugin_dir / "main.py").write_text("print('hello')\n", encoding="utf-8") + + result = module.precheck_plugin_directory(plugin_dir) + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "metadata") + self.assertEqual(result["message"], "invalid plugin directory name") + self.assertIn("unsafe plugin_dir_name", result["details"]) + + +class WorkerCommandTests(unittest.TestCase): + def test_build_worker_command_contains_required_arguments(self): + module = load_validator_module() + + command = module.build_worker_command( + script_path=Path("/tmp/run.py"), + astrbot_path=Path("/tmp/astrbot"), + plugin_source_dir=Path("/tmp/plugin-src"), + plugin_dir_name="demo_plugin", + normalized_repo_url="https://github.com/example/demo-plugin", + ) + + self.assertEqual(command[0], sys.executable) + self.assertEqual(command[1], "/tmp/run.py") + self.assertIn("--worker", command) + self.assertIn("--astrbot-path", command) + self.assertIn("--plugin-source-dir", command) + self.assertIn("--plugin-dir-name", command) + self.assertIn("--normalized-repo-url", command) + + +class WorkerSysPathTests(unittest.TestCase): + def test_worker_sys_path_includes_astrbot_root_before_codebase(self): + module = load_validator_module() + + sys_path_entries = module.build_worker_sys_path( + astrbot_root=Path("/tmp/astrbot-root"), + astrbot_path=Path("/tmp/AstrBot"), + ) + + self.assertEqual( + [Path(item) for item in sys_path_entries], + [Path("/tmp/astrbot-root").resolve(), Path("/tmp/AstrBot").resolve()], + ) + + +class WorkerLoadCheckTests(unittest.IsolatedAsyncioTestCase): + async def test_stringifies_non_string_plugin_load_error_message(self): + module = load_validator_module() + + class FakeManager: + def __init__(self, context, config): + del context, config + self.failed_plugin_dict = {"demo_plugin": {"error": "detail"}} + + async def load(self, specified_dir_name: str): + del specified_dir_name + return False, {"reason": "boom"} + + with mock.patch.dict(sys.modules, {"astrbot.core.star.star_manager": mock.Mock(PluginManager=FakeManager)}): + result = await module.run_worker_load_check("demo_plugin", "https://github.com/example/demo") + + self.assertFalse(result["ok"]) + self.assertEqual(result["stage"], "load") + self.assertEqual(result["message"], "{'reason': 'boom'}") + + +class ReportBuilderTests(unittest.TestCase): + def test_build_report_counts_passed_warned_and_failed_results(self): + module = load_validator_module() + + report = module.build_report( + [ + {"plugin": "plugin-a", "ok": True, "severity": "pass", "stage": "load", "message": "ok"}, + {"plugin": "plugin-b", "ok": False, "severity": "warn", "stage": "metadata", "message": "missing desc"}, + {"plugin": "plugin-c", "ok": False, "severity": "fail", "stage": "load", "message": "boom"}, + ] + ) + + self.assertEqual(report["summary"]["total"], 3) + self.assertEqual(report["summary"]["passed"], 1) + self.assertEqual(report["summary"]["failed"], 1) + self.assertEqual(report["summary"]["warned"], 1) + self.assertEqual(report["results"][1]["plugin"], "plugin-b") + + +class WorkerOutputParsingTests(unittest.TestCase): + def test_parse_worker_output_keeps_market_plugin_key(self): + module = load_validator_module() + completed = subprocess.CompletedProcess( + args=["python3", "run.py"], + returncode=1, + stdout='{"plugin": "demo_plugin", "ok": false, "stage": "load", "message": "boom"}', + stderr="", + ) + + result = module.parse_worker_output( + plugin="market-plugin-key", + repo="https://github.com/example/demo-plugin?tab=readme-ov-file", + normalized_repo_url="https://github.com/example/demo-plugin", + completed=completed, + plugin_dir_name="demo_plugin", + ) + + self.assertEqual(result["plugin"], "market-plugin-key") + self.assertEqual(result["plugin_dir_name"], "demo_plugin") + + def test_parse_worker_output_uses_last_json_line_after_logs(self): + module = load_validator_module() + completed = subprocess.CompletedProcess( + args=["python3", "run.py"], + returncode=1, + stdout='log line\n{"plugin": "demo_plugin", "ok": false, "stage": "load", "message": "boom"}', + stderr="", + ) + + result = module.parse_worker_output( + plugin="market-plugin-key", + repo="https://github.com/example/demo-plugin", + normalized_repo_url="https://github.com/example/demo-plugin", + completed=completed, + plugin_dir_name="demo_plugin", + ) + + self.assertEqual(result["plugin"], "market-plugin-key") + self.assertEqual(result["stage"], "load") + self.assertEqual(result["message"], "boom") + + +if __name__ == "__main__": + unittest.main()