diff --git a/.gitignore b/.gitignore
index 5eb9616c8c..afd1659b8f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -63,4 +63,5 @@ GenieData/
.kilocode/
.worktrees/
+.astrbot_sdk_testing/
dashboard/bun.lock
diff --git a/astrbot-sdk/LICENSE b/astrbot-sdk/LICENSE
new file mode 100644
index 0000000000..51d7fd4c87
--- /dev/null
+++ b/astrbot-sdk/LICENSE
@@ -0,0 +1,11 @@
+AstrBot SDK repository notice
+=============================
+
+This repository does not currently publish a standalone open-source license text.
+
+This file exists so the source repository and its `vendor/` subtree snapshot carry
+the same notice instead of silently omitting licensing information.
+
+Unless the maintainers publish different licensing terms, do not assume this
+repository grants redistribution or modification rights beyond applicable law and
+explicit permission from the maintainers.
diff --git a/astrbot-sdk/README.md b/astrbot-sdk/README.md
new file mode 100644
index 0000000000..9cd71c50f0
--- /dev/null
+++ b/astrbot-sdk/README.md
@@ -0,0 +1,14 @@
+# AstrBot SDK Vendor Snapshot
+
+This directory is the minimized subtree payload consumed by the AstrBot main
+repository.
+
+- `src/astrbot_sdk/` keeps the runtime SDK package plus the minimal testing
+ helpers that AstrBot and SDK-generated templates still treat as part of the
+ vendored contract
+- agent skill templates and embedded markdown reference files are excluded
+- root project-note templates for `astr init` stay vendored because the CLI
+ still generates `AGENTS.md` / `CLAUDE.md` by default
+- `pyproject.toml` keeps the src-layout package discovery but drops dev/test-only metadata
+- `VENDORED.md` describes the vendoring contract
+- tests, docs, CI files, and other source-repo-only content stay outside this directory
diff --git a/astrbot-sdk/VENDORED.md b/astrbot-sdk/VENDORED.md
new file mode 100644
index 0000000000..0937882566
--- /dev/null
+++ b/astrbot-sdk/VENDORED.md
@@ -0,0 +1,22 @@
+# Vendored Snapshot Notes
+
+This directory is a minimized snapshot for the AstrBot main repository to import
+via `git subtree`.
+
+- The source of truth is this `astrbot-sdk` repository.
+- `vendor/src/astrbot_sdk/` is synchronized from `src/astrbot_sdk/`.
+- Vendored snapshots keep the runtime SDK plus the minimal testing helpers
+ (`testing.py`, `_testing_support.py`, `_internal/testing_support.py`) because
+ AstrBot and SDK-generated test templates still depend on them.
+- Vendored snapshots retain the default `AGENTS.md` / `CLAUDE.md` project-note
+ templates and the minimal `astrbot-plugin-dev` skill scaffold used by
+ `astr init --agents`, but still exclude larger markdown reference assets that
+ are not needed by the subtree consumer.
+- `vendor/pyproject.toml` keeps src-layout package discovery, but strips
+ test/dev-only sections so the subtree stays runtime-focused.
+- Do not edit vendored files directly inside the AstrBot main repository.
+- Tests and broader documentation remain only in the SDK source repository.
+ The vendored snapshot only keeps the runtime-facing templates required by
+ `astr init`.
+- If the vendored copy needs changes, update the SDK source repository first and
+ regenerate the `vendor/` snapshot.
diff --git a/astrbot-sdk/pyproject.toml b/astrbot-sdk/pyproject.toml
new file mode 100644
index 0000000000..db6eff3658
--- /dev/null
+++ b/astrbot-sdk/pyproject.toml
@@ -0,0 +1,50 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "astrbot-sdk"
+version = "0.1.0"
+description = "AstrBot SDK with s5r runtime, worker protocol, and plugin tooling"
+readme = "README.md"
+requires-python = ">=3.12"
+classifiers = [
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
+]
+dependencies = [
+ "aiohttp>=3.13.2",
+ "anthropic>=0.72.1",
+ "certifi>=2025.10.5",
+ "click>=8.3.0",
+ "docstring-parser>=0.17.0",
+ "google-genai>=1.50.0",
+ "loguru>=0.7.3",
+ "msgpack>=1.1.1",
+ "openai>=2.7.2",
+ "pydantic>=2.12.3",
+ "pyyaml>=6.0.3",
+ "uv>=0.9.17",
+]
+
+[project.scripts]
+astr = "astrbot_sdk.cli:cli"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/astrbot_sdk"]
+exclude = ["/src/astrbot_sdk/AGENTS.md"]
+
+[tool.hatch.build.targets.sdist]
+include = [
+ "/src",
+ "/README.md",
+ "/LICENSE",
+]
+
+# ============================================================
+# Optional Dependencies
+# ============================================================
diff --git a/astrbot-sdk/src/astrbot_sdk/__init__.py b/astrbot-sdk/src/astrbot_sdk/__init__.py
new file mode 100644
index 0000000000..fb211b4489
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/__init__.py
@@ -0,0 +1,213 @@
+"""AstrBot SDK 的顶层公共 API。
+
+这里仅重新导出 astrbot-sdk 推荐直接导入的稳定入口。
+
+新插件应直接使用此模块的导出:
+ from astrbot_sdk import Star, Context, MessageEvent
+ from astrbot_sdk.decorators import on_command, on_message
+
+迁移期适配入口位于独立模块;此处只暴露 astrbot-sdk 原生主入口。
+"""
+
+from .clients.managers import (
+ ConversationCreateParams,
+ ConversationManagerClient,
+ ConversationRecord,
+ ConversationUpdateParams,
+ KnowledgeBaseCreateParams,
+ KnowledgeBaseDocumentRecord,
+ KnowledgeBaseDocumentUploadParams,
+ KnowledgeBaseManagerClient,
+ KnowledgeBaseRecord,
+ KnowledgeBaseRetrieveResult,
+ KnowledgeBaseRetrieveResultItem,
+ KnowledgeBaseUpdateParams,
+ MessageHistoryManagerClient,
+ MessageHistoryPage,
+ MessageHistoryRecord,
+ MessageHistorySender,
+ PersonaCreateParams,
+ PersonaManagerClient,
+ PersonaRecord,
+ PersonaUpdateParams,
+)
+from .clients.metadata import PluginMetadata, StarMetadata
+from .clients.permission import (
+ PermissionCheckResult,
+ PermissionClient,
+ PermissionManagerClient,
+)
+from .clients.platform import PlatformError, PlatformStats, PlatformStatus
+from .clients.provider import (
+ ManagedProviderRecord,
+ ProviderChangeEvent,
+ ProviderManagerClient,
+)
+from .clients.session import SessionPluginManager, SessionServiceManager
+from .commands import CommandGroup, command_group, print_cmd_tree
+from .context import Context
+from .conversation import (
+ ConversationClosed,
+ ConversationReplaced,
+ ConversationSession,
+ ConversationState,
+)
+from .decorators import (
+ admin_only,
+ background_task,
+ conversation_command,
+ cooldown,
+ group_only,
+ http_api,
+ message_types,
+ on_command,
+ on_event,
+ on_message,
+ on_provider_change,
+ on_schedule,
+ platforms,
+ priority,
+ private_only,
+ provide_capability,
+ rate_limit,
+ register_skill,
+ require_admin,
+ require_permission,
+ validate_config,
+)
+from .errors import AstrBotError
+from .events import MessageEvent
+from .filters import (
+ CustomFilter,
+ MessageTypeFilter,
+ PlatformFilter,
+ all_of,
+ any_of,
+ custom_filter,
+)
+from .message.components import (
+ At,
+ AtAll,
+ BaseMessageComponent,
+ File,
+ Forward,
+ Image,
+ MediaHelper,
+ Plain,
+ Poke,
+ Record,
+ Reply,
+ UnknownComponent,
+ Video,
+)
+from .message.result import (
+ EventResultType,
+ MessageBuilder,
+ MessageChain,
+ MessageEventResult,
+)
+from .message.session import MessageSession
+from .plugin_kv import PluginKVStoreMixin
+from .schedule import ScheduleContext
+from .session_waiter import SessionController, session_waiter
+from .star import Star
+from .star_tools import StarTools
+from .types import GreedyStr
+
+__all__ = [
+ "AstrBotError",
+ "At",
+ "AtAll",
+ "BaseMessageComponent",
+ "CommandGroup",
+ "ConversationClosed",
+ "ConversationCreateParams",
+ "ConversationManagerClient",
+ "ConversationReplaced",
+ "ConversationRecord",
+ "ConversationSession",
+ "ConversationState",
+ "ConversationUpdateParams",
+ "Context",
+ "CustomFilter",
+ "EventResultType",
+ "File",
+ "Forward",
+ "GreedyStr",
+ "Image",
+ "KnowledgeBaseCreateParams",
+ "KnowledgeBaseDocumentRecord",
+ "KnowledgeBaseDocumentUploadParams",
+ "KnowledgeBaseManagerClient",
+ "KnowledgeBaseRecord",
+ "KnowledgeBaseRetrieveResult",
+ "KnowledgeBaseRetrieveResultItem",
+ "KnowledgeBaseUpdateParams",
+ "ManagedProviderRecord",
+ "MediaHelper",
+ "MessageHistoryManagerClient",
+ "MessageHistoryPage",
+ "MessageHistoryRecord",
+ "MessageHistorySender",
+ "MessageEvent",
+ "MessageEventResult",
+ "MessageChain",
+ "MessageBuilder",
+ "MessageSession",
+ "MessageTypeFilter",
+ "Plain",
+ "PluginKVStoreMixin",
+ "PluginMetadata",
+ "PermissionCheckResult",
+ "PermissionClient",
+ "PermissionManagerClient",
+ "PlatformFilter",
+ "PlatformError",
+ "PlatformStats",
+ "PlatformStatus",
+ "Poke",
+ "PersonaCreateParams",
+ "PersonaManagerClient",
+ "PersonaRecord",
+ "PersonaUpdateParams",
+ "ProviderChangeEvent",
+ "ProviderManagerClient",
+ "Record",
+ "Reply",
+ "ScheduleContext",
+ "SessionPluginManager",
+ "SessionServiceManager",
+ "SessionController",
+ "Star",
+ "StarMetadata",
+ "StarTools",
+ "UnknownComponent",
+ "Video",
+ "admin_only",
+ "all_of",
+ "any_of",
+ "background_task",
+ "cooldown",
+ "conversation_command",
+ "command_group",
+ "custom_filter",
+ "group_only",
+ "http_api",
+ "message_types",
+ "on_command",
+ "on_event",
+ "on_message",
+ "on_provider_change",
+ "on_schedule",
+ "platforms",
+ "print_cmd_tree",
+ "priority",
+ "provide_capability",
+ "private_only",
+ "rate_limit",
+ "require_admin",
+ "require_permission",
+ "register_skill",
+ "session_waiter",
+ "validate_config",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/__main__.py b/astrbot-sdk/src/astrbot_sdk/__main__.py
new file mode 100644
index 0000000000..624fd22f4c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/__main__.py
@@ -0,0 +1,11 @@
+"""`python -m astrbot_sdk` 的 CLI 入口。"""
+
+from .cli import cli
+
+
+def main() -> None:
+ cli()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/astrbot-sdk/src/astrbot_sdk/_command_model.py b/astrbot-sdk/src/astrbot_sdk/_command_model.py
new file mode 100644
index 0000000000..fd8f1ad851
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_command_model.py
@@ -0,0 +1,17 @@
+from ._internal.command_model import (
+ COMMAND_MODEL_DOCS_URL,
+ CommandModelParseResult,
+ ResolvedCommandModelParam,
+ format_command_model_help,
+ parse_command_model_remainder,
+ resolve_command_model_param,
+)
+
+__all__ = [
+ "COMMAND_MODEL_DOCS_URL",
+ "CommandModelParseResult",
+ "ResolvedCommandModelParam",
+ "format_command_model_help",
+ "parse_command_model_remainder",
+ "resolve_command_model_param",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py b/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py
new file mode 100644
index 0000000000..6ccc0d22e9
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py
@@ -0,0 +1,7 @@
+"""Internal implementation modules for astrbot_sdk.
+
+This package groups private helpers that are not part of the public SDK API.
+Imports outside the SDK should avoid depending on these modules directly.
+"""
+
+__all__: list[str] = []
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py b/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py
new file mode 100644
index 0000000000..6237826b8f
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py
@@ -0,0 +1,236 @@
+from __future__ import annotations
+
+import inspect
+from dataclasses import dataclass
+from typing import Any
+
+from pydantic import BaseModel
+
+from ..errors import AstrBotError
+from ..runtime._command_matching import split_command_remainder
+from .injected_params import is_framework_injected_parameter
+from .typing_utils import unwrap_optional
+
+# TODO:文档内容喵
+COMMAND_MODEL_DOCS_URL = "https://docs.astrbot.org/sdk/parameter-injection"
+
+
+@dataclass(slots=True)
+class ResolvedCommandModelParam:
+ name: str
+ model_cls: type[BaseModel]
+
+
+@dataclass(slots=True)
+class CommandModelParseResult:
+ model: BaseModel | None = None
+ help_text: str | None = None
+
+
+def resolve_command_model_param(handler: Any) -> ResolvedCommandModelParam | None:
+ try:
+ signature = inspect.signature(handler)
+ except (TypeError, ValueError):
+ return None
+ try:
+ type_hints = inspect.get_annotations(handler, eval_str=True)
+ except Exception:
+ type_hints = {}
+
+ candidates: list[ResolvedCommandModelParam] = []
+ other_names: list[str] = []
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+ annotation = type_hints.get(parameter.name)
+ if _is_injected_parameter(parameter.name, annotation):
+ continue
+ normalized, _is_optional = unwrap_optional(annotation)
+ if isinstance(normalized, type) and issubclass(normalized, BaseModel):
+ candidates.append(
+ ResolvedCommandModelParam(
+ name=parameter.name,
+ model_cls=normalized,
+ )
+ )
+ continue
+ other_names.append(parameter.name)
+
+ if not candidates:
+ return None
+ if len(candidates) > 1 or other_names:
+ names = [item.name for item in candidates]
+ raise ValueError(
+ "Command BaseModel injection requires exactly one non-injected BaseModel "
+ f"parameter, got models={names!r} others={other_names!r}"
+ )
+ _validate_supported_model(candidates[0].model_cls)
+ return candidates[0]
+
+
+def parse_command_model_remainder(
+ *,
+ remainder: str,
+ model_param: ResolvedCommandModelParam,
+ command_name: str,
+) -> CommandModelParseResult:
+ tokens = split_command_remainder(remainder)
+ if any(token in {"-h", "--help"} for token in tokens):
+ return CommandModelParseResult(
+ help_text=format_command_model_help(command_name, model_param.model_cls)
+ )
+
+ fields = model_param.model_cls.model_fields
+ explicit_values: dict[str, Any] = {}
+ positional_values: dict[str, Any] = {}
+ positional_field_names = [
+ name
+ for name, field in fields.items()
+ if _supported_scalar_type(field.annotation)[0] is not bool
+ ]
+ positional_index = 0
+ index = 0
+ while index < len(tokens):
+ token = tokens[index]
+ if not token.startswith("--"):
+ assigned = False
+ while positional_index < len(positional_field_names):
+ field_name = positional_field_names[positional_index]
+ positional_index += 1
+ if field_name in explicit_values or field_name in positional_values:
+ continue
+ positional_values[field_name] = token
+ assigned = True
+ break
+ if not assigned:
+ raise _command_parse_error("Too many positional arguments")
+ index += 1
+ continue
+
+ raw_name = token[2:]
+ if not raw_name:
+ raise _command_parse_error("Invalid option '--'")
+ explicit_value: str | None = None
+ if "=" in raw_name:
+ raw_name, explicit_value = raw_name.split("=", 1)
+ negated = raw_name.startswith("no-")
+ # 与 argparse/click 惯例一致:--foo-bar 自动映射为字段名 foo_bar
+ cli_name = raw_name[3:] if negated else raw_name
+ field_name = cli_name.replace("-", "_")
+ field = fields.get(field_name)
+ if field is None:
+ raise _command_parse_error(f"Unknown option: --{raw_name}")
+ option_name = _format_option_name(field_name)
+ negated_option_name = f"--no-{option_name[2:]}"
+ if field_name in explicit_values:
+ raise _command_parse_error(f"Duplicate option: {option_name}")
+ field_type, _is_optional = _supported_scalar_type(field.annotation)
+ if field_type is bool:
+ if explicit_value is not None:
+ raise _command_parse_error(
+ f"Boolean option '{option_name}' only supports {option_name} or {negated_option_name}"
+ )
+ explicit_values[field_name] = not negated
+ index += 1
+ continue
+ if negated:
+ raise _command_parse_error(
+ f"Non-boolean option '{option_name}' does not support {negated_option_name}"
+ )
+ if explicit_value is None:
+ index += 1
+ if index >= len(tokens):
+ raise _command_parse_error(f"Missing value for option: {option_name}")
+ explicit_value = tokens[index]
+ explicit_values[field_name] = explicit_value
+ index += 1
+
+ values = {**positional_values, **explicit_values}
+
+ try:
+ model = model_param.model_cls.model_validate(values)
+ except Exception as exc:
+ raise AstrBotError.invalid_input(
+ "命令参数解析失败",
+ hint=str(exc),
+ docs_url=COMMAND_MODEL_DOCS_URL,
+ details={
+ "command": command_name,
+ "parameter": model_param.name,
+ "values": values,
+ },
+ ) from exc
+ return CommandModelParseResult(model=model)
+
+
+def format_command_model_help(command_name: str, model_cls: type[BaseModel]) -> str:
+ _validate_supported_model(model_cls)
+ lines = [f"用法: /{command_name} [options]"]
+ if model_cls.model_fields:
+ lines.append("参数:")
+ for name, field in model_cls.model_fields.items():
+ field_type, is_optional = _supported_scalar_type(field.annotation)
+ type_name = getattr(field_type, "__name__", str(field_type))
+ required = field.is_required()
+ default_text = ""
+ if not required:
+ default_text = f",默认 {field.default!r}"
+ elif is_optional:
+ default_text = ",默认 None"
+ description = str(field.description or "").strip()
+ detail = f"{name}: {type_name}"
+ if description:
+ detail += f" - {description}"
+ detail += ",必填" if required else ",可选"
+ detail += default_text
+ if field_type is bool:
+ option_name = _format_option_name(name)
+ detail += f",使用 {option_name} / --no-{option_name[2:]}"
+ lines.append(detail)
+ return "\n".join(lines)
+
+
+def _validate_supported_model(model_cls: type[BaseModel]) -> None:
+ for name, field in model_cls.model_fields.items():
+ try:
+ _supported_scalar_type(field.annotation)
+ except TypeError as exc:
+ raise ValueError(
+ f"Unsupported command model field '{name}': {exc}"
+ ) from exc
+
+
+def _supported_scalar_type(annotation: Any) -> tuple[type[Any], bool]:
+ normalized, is_optional = unwrap_optional(annotation)
+ if normalized in {str, int, float, bool}:
+ return normalized, is_optional
+ raise TypeError("only str/int/float/bool and Optional variants are supported")
+
+
+def _format_option_name(field_name: str) -> str:
+ # Surface the canonical CLI spelling so parse errors match the user's option syntax.
+ return f"--{field_name.replace('_', '-')}"
+
+
+def _command_parse_error(message: str) -> AstrBotError:
+ return AstrBotError.invalid_input(
+ message,
+ docs_url=COMMAND_MODEL_DOCS_URL,
+ )
+
+
+def _is_injected_parameter(name: str, annotation: Any) -> bool:
+ return is_framework_injected_parameter(name, annotation)
+
+
+__all__ = [
+ "COMMAND_MODEL_DOCS_URL",
+ "CommandModelParseResult",
+ "ResolvedCommandModelParam",
+ "format_command_model_help",
+ "parse_command_model_remainder",
+ "resolve_command_model_param",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py b/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py
new file mode 100644
index 0000000000..c1e47356f1
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py
@@ -0,0 +1,531 @@
+from __future__ import annotations
+
+import asyncio
+import inspect
+from contextlib import suppress
+from dataclasses import dataclass, field
+from typing import Any
+
+from pydantic import ValidationError
+
+from ..context import Context as RuntimeContext
+from ..decorators import (
+ BackgroundTaskMeta,
+ HttpApiMeta,
+ ValidateConfigMeta,
+ get_background_task_meta,
+ get_http_api_meta,
+ get_provider_change_meta,
+ get_skill_meta,
+ get_validate_config_meta,
+)
+from ..star import Star
+from .sdk_logger import logger
+from .star_runtime import bind_star_runtime
+
+_RUNTIME_STATE_ATTR = "__astrbot_decorator_runtime_state__"
+_VALIDATED_CONFIGS_ATTR = "__astrbot_validated_configs__"
+
+
+@dataclass(slots=True)
+class DecoratorRuntimeState:
+ http_apis: list[tuple[str, list[str]]] = field(default_factory=list)
+ provider_hooks: list[asyncio.Task[None]] = field(default_factory=list)
+ background_tasks: list[asyncio.Task[Any]] = field(default_factory=list)
+ registered_skills: list[str] = field(default_factory=list)
+
+
+def _runtime_state(instance: Any) -> DecoratorRuntimeState:
+ state = getattr(instance, _RUNTIME_STATE_ATTR, None)
+ if isinstance(state, DecoratorRuntimeState):
+ return state
+ state = DecoratorRuntimeState()
+ setattr(instance, _RUNTIME_STATE_ATTR, state)
+ return state
+
+
+def _iter_bound_methods(instance: Any):
+ seen_names: set[str] = set()
+ for name in dir(instance.__class__):
+ if name.startswith("__") or name in seen_names:
+ continue
+ seen_names.add(name)
+ try:
+ raw_attr = inspect.getattr_static(instance, name)
+ except AttributeError:
+ continue
+ if isinstance(raw_attr, property):
+ continue
+ bound = getattr(instance, name, None)
+ if not callable(bound):
+ continue
+ raw = getattr(bound, "__func__", bound)
+ yield name, bound, raw
+
+
+def _validated_config_store(instance: Any) -> dict[str, Any]:
+ values = getattr(instance, _VALIDATED_CONFIGS_ATTR, None)
+ if isinstance(values, dict):
+ return values
+ values = {}
+ setattr(instance, _VALIDATED_CONFIGS_ATTR, values)
+ return values
+
+
+def _positional_arg_count(func: Any) -> int:
+ try:
+ signature = inspect.signature(func)
+ except (TypeError, ValueError):
+ return 0
+ return sum(
+ 1
+ for parameter in signature.parameters.values()
+ if parameter.kind
+ in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ )
+
+
+def _call_with_optional_context(bound: Any, context: RuntimeContext) -> Any:
+ return bound(context) if _positional_arg_count(bound) >= 1 else bound()
+
+
+async def _await_if_needed(value: Any) -> Any:
+ if inspect.isawaitable(value):
+ return await value
+ return value
+
+
+def _decorator_target_name(instance: Any, method_name: str | None = None) -> str:
+ class_name = instance.__class__.__name__
+ if method_name is None:
+ return class_name
+ return f"{class_name}.{method_name}"
+
+
+def _decorator_error(
+ *,
+ instance: Any,
+ decorator_name: str,
+ exc: Exception,
+ method_name: str | None = None,
+ details: str | None = None,
+) -> RuntimeError:
+ message = f"{_decorator_target_name(instance, method_name)} {decorator_name} failed"
+ if details:
+ message += f" ({details})"
+ message += f": {exc}"
+ return RuntimeError(message)
+
+
+def _http_api_details(meta: HttpApiMeta) -> str:
+ details = [f"route={meta.route!r}", f"methods={list(meta.methods)!r}"]
+ if meta.capability_name:
+ details.append(f"capability_name={meta.capability_name!r}")
+ return ", ".join(details)
+
+
+def _provider_change_details(meta: Any) -> str:
+ return f"provider_types={list(meta.provider_types)!r}"
+
+
+def _background_task_details(meta: BackgroundTaskMeta, method_name: str) -> str:
+ description = meta.description or f"background_task:{method_name}"
+ return (
+ f"description={description!r}, auto_start={meta.auto_start!r}, "
+ f"on_error={meta.on_error!r}"
+ )
+
+
+def _skill_details(name: str, path: str) -> str:
+ return f"name={name!r}, path={path!r}"
+
+
+def _normalize_provider_type(value: Any) -> str:
+ enum_value = getattr(value, "value", None)
+ if isinstance(enum_value, str):
+ return enum_value.strip().lower()
+ return str(value).strip().lower()
+
+
+def _is_valid_schema_expected_type(value: Any) -> bool:
+ if isinstance(value, type):
+ return True
+ return (
+ isinstance(value, tuple)
+ and len(value) > 0
+ and all(isinstance(item, type) for item in value)
+ )
+
+
+async def _run_model_validation(
+ *,
+ instance: Any,
+ method_name: str,
+ meta: ValidateConfigMeta,
+ config: dict[str, Any],
+) -> None:
+ if meta.model is not None:
+ try:
+ validated = meta.model.model_validate(config)
+ except ValidationError as exc:
+ raise ValueError(str(exc)) from exc
+ _validated_config_store(instance)[method_name] = validated
+ return
+
+ assert meta.schema is not None
+ validated = _validate_schema_config(meta.schema, config)
+ _validated_config_store(instance)[method_name] = validated
+
+
+def _validate_schema_config(
+ schema: dict[str, Any],
+ config: dict[str, Any],
+) -> dict[str, Any]:
+ validated: dict[str, Any] = {}
+ errors: list[str] = []
+
+ for field_name, field_schema in schema.items():
+ if not isinstance(field_schema, dict):
+ errors.append(f"{field_name}: schema entry must be an object")
+ continue
+ present = field_name in config
+ value = config.get(field_name, field_schema.get("default"))
+ required = bool(field_schema.get("required", False))
+ if value is None:
+ if required and "default" not in field_schema:
+ errors.append(f"{field_name}: is required")
+ validated[field_name] = value
+ continue
+ expected_type = field_schema.get("type")
+ if expected_type is not None and not _is_valid_schema_expected_type(
+ expected_type
+ ):
+ errors.append(
+ f"{field_name}: invalid schema 'type' entry {expected_type!r}; "
+ "expected a type or tuple of types"
+ )
+ continue
+ if expected_type is not None and not isinstance(value, expected_type):
+ errors.append(
+ f"{field_name}: expected {getattr(expected_type, '__name__', expected_type)}, "
+ f"got {type(value).__name__}"
+ )
+ continue
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
+ minimum = field_schema.get("min")
+ maximum = field_schema.get("max")
+ range_value = field_schema.get("range")
+ if minimum is not None and value < minimum:
+ errors.append(f"{field_name}: must be >= {minimum}")
+ if maximum is not None and value > maximum:
+ errors.append(f"{field_name}: must be <= {maximum}")
+ if (
+ isinstance(range_value, tuple)
+ and len(range_value) == 2
+ and not (range_value[0] <= value <= range_value[1])
+ ):
+ errors.append(
+ f"{field_name}: must be within [{range_value[0]}, {range_value[1]}]"
+ )
+ if required and not present and "default" not in field_schema:
+ errors.append(f"{field_name}: is required")
+ validated[field_name] = value
+
+ if errors:
+ raise ValueError("validate_config schema failed: " + "; ".join(errors))
+ return validated
+
+
+async def _run_validate_config(instance: Any, context: RuntimeContext) -> None:
+ config_payload = await context.metadata.get_plugin_config()
+ config = dict(config_payload or {})
+ for method_name, _bound, raw in _iter_bound_methods(instance):
+ meta = get_validate_config_meta(raw)
+ if meta is None:
+ continue
+ try:
+ await _run_model_validation(
+ instance=instance,
+ method_name=method_name,
+ meta=meta,
+ config=config,
+ )
+ except Exception as exc:
+ raise _decorator_error(
+ instance=instance,
+ method_name=method_name,
+ decorator_name="@validate_config",
+ exc=exc,
+ ) from exc
+
+
+async def _register_http_apis(instance: Any, context: RuntimeContext) -> None:
+ state = _runtime_state(instance)
+ for method_name, bound, raw in _iter_bound_methods(instance):
+ meta = get_http_api_meta(raw)
+ if meta is None:
+ continue
+ try:
+ await _register_http_api(bound=bound, meta=meta, context=context)
+ except Exception as exc:
+ raise _decorator_error(
+ instance=instance,
+ method_name=method_name,
+ decorator_name="@http_api",
+ details=_http_api_details(meta),
+ exc=exc,
+ ) from exc
+ state.http_apis.append((meta.route, list(meta.methods)))
+
+
+async def _register_http_api(
+ *,
+ bound: Any,
+ meta: HttpApiMeta,
+ context: RuntimeContext,
+) -> None:
+ if meta.capability_name:
+ await context.http.register_api(
+ route=meta.route,
+ handler_capability=meta.capability_name,
+ methods=list(meta.methods),
+ description=meta.description,
+ )
+ return
+ await context.http.register_api(
+ route=meta.route,
+ handler=bound,
+ methods=list(meta.methods),
+ description=meta.description,
+ )
+
+
+async def _register_provider_change_hooks(
+ instance: Any,
+ context: RuntimeContext,
+) -> None:
+ state = _runtime_state(instance)
+ for method_name, bound, raw in _iter_bound_methods(instance):
+ meta = get_provider_change_meta(raw)
+ if meta is None:
+ continue
+ target_name = _decorator_target_name(instance, method_name)
+
+ async def callback(
+ provider_id: str,
+ provider_type: Any,
+ umo: str | None,
+ *,
+ _bound=bound,
+ _meta=meta,
+ ) -> None:
+ if _meta.provider_types:
+ current_type = _normalize_provider_type(provider_type)
+ if current_type not in _meta.provider_types:
+ return
+ owner = instance if isinstance(instance, Star) else None
+ try:
+ with bind_star_runtime(owner, context):
+ result = _bound(provider_id, provider_type, umo)
+ await _await_if_needed(result)
+ except Exception as exc:
+ raise RuntimeError(
+ f"{target_name} @on_provider_change callback failed "
+ f"(provider_id={provider_id!r}, provider_type={provider_type!r}, "
+ f"umo={umo!r}): {exc}"
+ ) from exc
+
+ try:
+ task = await context.provider_manager.register_provider_change_hook(
+ callback
+ )
+ except Exception as exc:
+ raise _decorator_error(
+ instance=instance,
+ method_name=method_name,
+ decorator_name="@on_provider_change",
+ details=_provider_change_details(meta),
+ exc=exc,
+ ) from exc
+ # TODO: provider.manager.watch_changes is currently restricted to
+ # reserved/system plugins. If this decorator should be public-facing,
+ # the capability boundary needs to be widened or a dedicated event feed
+ # should be introduced.
+ state.provider_hooks.append(task)
+
+
+async def _start_background_tasks(instance: Any, context: RuntimeContext) -> None:
+ state = _runtime_state(instance)
+ for method_name, bound, raw in _iter_bound_methods(instance):
+ meta = get_background_task_meta(raw)
+ if meta is None or not meta.auto_start:
+ continue
+ try:
+ task = await context.register_task(
+ _background_runner(
+ instance=instance,
+ bound=bound,
+ context=context,
+ meta=meta,
+ method_name=method_name,
+ ),
+ meta.description
+ or f"background_task:{instance.__class__.__name__}.{method_name}",
+ )
+ except Exception as exc:
+ raise _decorator_error(
+ instance=instance,
+ method_name=method_name,
+ decorator_name="@background_task",
+ details=_background_task_details(meta, method_name),
+ exc=exc,
+ ) from exc
+ state.background_tasks.append(task)
+
+
+async def _background_runner(
+ *,
+ instance: Any,
+ bound: Any,
+ context: RuntimeContext,
+ meta: BackgroundTaskMeta,
+ method_name: str,
+) -> None:
+ while True:
+ try:
+ owner = instance if isinstance(instance, Star) else None
+ with bind_star_runtime(owner, context):
+ result = _call_with_optional_context(bound, context)
+ await _await_if_needed(result)
+ return
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ if meta.on_error != "restart":
+ raise _decorator_error(
+ instance=instance,
+ method_name=method_name,
+ decorator_name="@background_task",
+ details=_background_task_details(meta, method_name),
+ exc=exc,
+ ) from exc
+ context.logger.exception(
+ "SDK decorator background_task restarting after failure: plugin_id={} task={} details={}",
+ context.plugin_id,
+ f"{instance.__class__.__name__}.{method_name}",
+ _background_task_details(meta, method_name),
+ )
+
+
+def _iter_class_and_method_meta_entries(
+ instance: Any,
+ getter,
+) -> list[tuple[str, Any]]:
+ values = [
+ (_decorator_target_name(instance), meta) for meta in getter(instance.__class__)
+ ]
+ for method_name, _bound, raw in _iter_bound_methods(instance):
+ values.extend(
+ (_decorator_target_name(instance, method_name), meta)
+ for meta in getter(raw)
+ )
+ return values
+
+
+async def _register_skills(instance: Any, context: RuntimeContext) -> None:
+ state = _runtime_state(instance)
+ for target_name, meta in _iter_class_and_method_meta_entries(
+ instance, get_skill_meta
+ ):
+ try:
+ await context.register_skill(
+ name=meta.name,
+ path=meta.path,
+ description=meta.description,
+ )
+ except Exception as exc:
+ raise RuntimeError(
+ f"{target_name} @register_skill failed "
+ f"({_skill_details(meta.name, meta.path)}): {exc}"
+ ) from exc
+ state.registered_skills.append(meta.name)
+
+
+async def _teardown_decorator_resources(instance: Any, context: RuntimeContext) -> None:
+ state = _runtime_state(instance)
+
+ for task in reversed(state.provider_hooks):
+ with suppress(asyncio.CancelledError):
+ await context.provider_manager.unregister_provider_change_hook(task)
+ state.provider_hooks.clear()
+
+ for task in reversed(state.background_tasks):
+ if not task.done():
+ task.cancel()
+ for task in reversed(state.background_tasks):
+ with suppress(asyncio.CancelledError, Exception):
+ await task
+ state.background_tasks.clear()
+
+ for route, methods in reversed(state.http_apis):
+ try:
+ await context.http.unregister_api(route, methods)
+ except Exception:
+ logger.exception(
+ "decorator http_api cleanup failed: plugin_id={} route={}",
+ context.plugin_id,
+ route,
+ )
+ state.http_apis.clear()
+
+ for name in reversed(state.registered_skills):
+ with suppress(Exception):
+ await context.unregister_skill(name)
+ state.registered_skills.clear()
+
+
+async def _invoke_hook(
+ *,
+ instance: Any,
+ hook: Any | None,
+ context: RuntimeContext,
+) -> None:
+ if hook is None:
+ return
+ owner = instance if isinstance(instance, Star) else None
+ with bind_star_runtime(owner, context):
+ result = _call_with_optional_context(hook, context)
+ await _await_if_needed(result)
+
+
+async def run_lifecycle_with_decorators(
+ *,
+ instance: Any,
+ hook: Any | None,
+ method_name: str,
+ context: RuntimeContext,
+) -> None:
+ # Wrap decorator-managed startup failures with decorator-specific context so
+ # plugin authors do not only see a generic worker initialize timeout.
+ # Keep the lifecycle wrapper centralized so decorator-managed resources still
+ # work when plugins override on_start/on_stop without calling super().
+ if method_name == "on_start":
+ await _run_validate_config(instance, context)
+ await _invoke_hook(instance=instance, hook=hook, context=context)
+ await _register_http_apis(instance, context)
+ await _register_provider_change_hooks(instance, context)
+ await _register_skills(instance, context)
+ await _start_background_tasks(instance, context)
+ return
+
+ try:
+ await _invoke_hook(instance=instance, hook=hook, context=context)
+ finally:
+ if method_name == "on_stop":
+ await _teardown_decorator_resources(instance, context)
+
+
+__all__ = ["run_lifecycle_with_decorators"]
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py b/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py
new file mode 100644
index 0000000000..ced6229f93
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+import functools
+import inspect
+from typing import Any
+
+try:
+ from typing import get_type_hints
+except ImportError: # pragma: no cover
+ get_type_hints = None
+
+from .typing_utils import unwrap_optional
+
+_INJECTED_PARAMETER_NAMES = {
+ "event",
+ "ctx",
+ "context",
+ "sched",
+ "schedule",
+ "conversation",
+ "conv",
+}
+
+
+def is_framework_injected_parameter(name: str, annotation: Any) -> bool:
+ if name in _INJECTED_PARAMETER_NAMES:
+ return True
+ normalized, _is_optional = unwrap_optional(annotation)
+ if normalized is None:
+ return False
+ try:
+ injected_types = _framework_injected_types()
+ except Exception:
+ return False
+ if normalized in injected_types:
+ return True
+ if isinstance(normalized, type):
+ return issubclass(normalized, injected_types)
+ return False
+
+
+def legacy_arg_parameter_names(handler: Any) -> list[str]:
+ try:
+ signature = inspect.signature(handler)
+ except (TypeError, ValueError):
+ return []
+ try:
+ if get_type_hints is None:
+ type_hints = {}
+ else:
+ type_hints = get_type_hints(handler)
+ except Exception:
+ type_hints = {}
+
+ names: list[str] = []
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+ if is_framework_injected_parameter(
+ parameter.name, type_hints.get(parameter.name)
+ ):
+ continue
+ names.append(parameter.name)
+ return names
+
+
+@functools.lru_cache(maxsize=1)
+def _framework_injected_types() -> tuple[type[Any], ...]:
+ from ..clients.llm import LLMResponse
+ from ..context import Context
+ from ..conversation import ConversationSession
+ from ..events import MessageEvent
+ from ..llm.entities import ProviderRequest
+ from ..message.result import MessageEventResult
+ from ..schedule import ScheduleContext
+
+ return (
+ Context,
+ MessageEvent,
+ ScheduleContext,
+ ConversationSession,
+ ProviderRequest,
+ LLMResponse,
+ MessageEventResult,
+ )
+
+
+__all__ = ["is_framework_injected_parameter", "legacy_arg_parameter_names"]
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py b/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py
new file mode 100644
index 0000000000..2fe2ec1d5e
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py
@@ -0,0 +1,86 @@
+"""插件调用者身份上下文管理。
+
+本模块使用 contextvars 实现跨异步任务传播插件身份,
+用于在 capability 调用时自动识别调用者插件。
+
+典型场景:
+ - http.register_api: 记录哪个插件注册了 API
+ - metadata.get_plugin_config: 只允许查询当前插件自己的配置
+ - 能力路由层权限校验
+
+使用方式:
+ with caller_plugin_scope("my_plugin"):
+ # 在此作用域内,current_caller_plugin_id() 返回 "my_plugin"
+ await ctx.http.register_api(...)
+
+注意:
+ contextvars 会自动传播到子任务(asyncio.create_task),
+ 无需手动传递。
+"""
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+from contextlib import contextmanager
+from contextvars import ContextVar, Token
+
+# 存储当前调用者插件 ID 的上下文变量
+_CALLER_PLUGIN_ID: ContextVar[str | None] = ContextVar(
+ "astrbot_sdk_caller_plugin_id",
+ default=None,
+)
+
+
+def current_caller_plugin_id() -> str | None:
+ """获取当前上下文中的调用者插件 ID。
+
+ Returns:
+ 当前插件 ID,如果不在插件调用上下文中则返回 None
+ """
+ return _CALLER_PLUGIN_ID.get()
+
+
+def bind_caller_plugin_id(plugin_id: str | None) -> Token[str | None]:
+ """绑定调用者插件 ID 到当前上下文。
+
+ Args:
+ plugin_id: 插件 ID,空字符串会被视为 None
+
+ Returns:
+ 用于后续 reset 的 Token
+
+ Note:
+ 通常使用 caller_plugin_scope 上下文管理器而非直接调用此函数
+ """
+ normalized = plugin_id.strip() if isinstance(plugin_id, str) else ""
+ return _CALLER_PLUGIN_ID.set(normalized or None)
+
+
+def reset_caller_plugin_id(token: Token[str | None]) -> None:
+ """重置调用者插件 ID 到之前的状态。
+
+ Args:
+ token: bind_caller_plugin_id 返回的 Token
+ """
+ _CALLER_PLUGIN_ID.reset(token)
+
+
+@contextmanager
+def caller_plugin_scope(plugin_id: str | None) -> Iterator[None]:
+ """创建一个绑定插件身份的上下文作用域。
+
+ Args:
+ plugin_id: 要绑定的插件 ID
+
+ Yields:
+ None
+
+ 示例:
+ with caller_plugin_scope("my_plugin"):
+ await some_capability_call()
+ """
+ token = bind_caller_plugin_id(plugin_id)
+ try:
+ yield
+ finally:
+ reset_caller_plugin_id(token)
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py b/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py
new file mode 100644
index 0000000000..d13720b500
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py
@@ -0,0 +1,213 @@
+from __future__ import annotations
+
+import json
+import math
+import re
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+
+def is_ttl_memory_entry(value: Any) -> bool:
+ """Return whether a stored memory payload uses the TTL wrapper shape."""
+
+ return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
+
+
+def memory_value_for_search(stored: Any) -> dict[str, Any] | None:
+ """Unwrap the search payload from a stored memory record when possible."""
+
+ if not isinstance(stored, dict):
+ return None
+ if is_ttl_memory_entry(stored):
+ value = stored.get("value")
+ return value if isinstance(value, dict) else None
+ return stored
+
+
+def extract_memory_text(stored: Any) -> str:
+ """Pick the canonical text that keyword/vector search should index."""
+
+ value = memory_value_for_search(stored)
+ if not isinstance(value, dict):
+ return ""
+ for field_name in ("embedding_text", "content", "summary", "title", "text"):
+ item = value.get(field_name)
+ if isinstance(item, str) and item.strip():
+ return item.strip()
+ return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
+
+
+def memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
+ """Translate a TTL in seconds into an absolute UTC expiration timestamp."""
+
+ try:
+ ttl = int(ttl_seconds)
+ except (TypeError, ValueError):
+ return None
+ if ttl < 1:
+ return None
+ return datetime.now(timezone.utc) + timedelta(seconds=ttl)
+
+
+def memory_expiration_from_stored_payload(stored: Any) -> datetime | None:
+ """Recover an absolute expiration timestamp from a stored TTL payload."""
+
+ if not is_ttl_memory_entry(stored) or not isinstance(stored, dict):
+ return None
+ raw_expires_at = stored.get("expires_at")
+ if isinstance(raw_expires_at, (int, float)):
+ return datetime.fromtimestamp(float(raw_expires_at), tz=timezone.utc)
+ if not isinstance(raw_expires_at, str):
+ return None
+
+ normalized = raw_expires_at.strip()
+ if not normalized:
+ return None
+ if normalized.endswith("Z"):
+ normalized = f"{normalized[:-1]}+00:00"
+ try:
+ expires_at = datetime.fromisoformat(normalized)
+ except ValueError:
+ return None
+ if expires_at.tzinfo is None:
+ expires_at = expires_at.replace(tzinfo=timezone.utc)
+ return expires_at.astimezone(timezone.utc)
+
+
+def normalize_memory_namespace(value: Any) -> str:
+ """Normalize a namespace path into a stable slash-delimited string."""
+
+ if value is None:
+ return ""
+ if isinstance(value, (list, tuple)):
+ return join_memory_namespace(*value)
+ text = str(value).strip().replace("\\", "/")
+ if not text:
+ return ""
+ parts = [segment.strip() for segment in text.split("/") if segment.strip()]
+ return "/".join(parts)
+
+
+def join_memory_namespace(*parts: Any) -> str:
+ """Join namespace segments while preserving the root namespace as empty."""
+
+ normalized_parts: list[str] = []
+ for part in parts:
+ normalized = normalize_memory_namespace(part)
+ if not normalized:
+ continue
+ normalized_parts.extend(
+ segment for segment in normalized.split("/") if segment.strip()
+ )
+ return "/".join(normalized_parts)
+
+
+def memory_namespace_matches(
+ candidate: str,
+ namespace: str | None,
+ *,
+ include_descendants: bool,
+) -> bool:
+ """Check whether a stored namespace belongs to the requested scope."""
+
+ if namespace is None:
+ return True
+ normalized_candidate = normalize_memory_namespace(candidate)
+ normalized_namespace = normalize_memory_namespace(namespace)
+ if not normalized_namespace:
+ return include_descendants or normalized_candidate == ""
+ if normalized_candidate == normalized_namespace:
+ return True
+ return include_descendants and normalized_candidate.startswith(
+ f"{normalized_namespace}/"
+ )
+
+
+def display_memory_namespace(value: Any) -> str | None:
+ """Return a user-facing namespace value."""
+
+ normalized = normalize_memory_namespace(value)
+ return normalized or None
+
+
+def _memory_query_terms(value: str) -> list[str]:
+ normalized = re.sub(r"\s+", " ", str(value).strip().casefold())
+ if not normalized:
+ return []
+ terms = [item for item in re.findall(r"\w+", normalized, flags=re.UNICODE) if item]
+ if terms:
+ return terms
+ compact = normalized.replace(" ", "")
+ return [compact] if compact else []
+
+
+def memory_keyword_score(query: str, key: str, text: str) -> float:
+ """Score a keyword hit the same way across runtime and core bridge."""
+
+ normalized_query = str(query).casefold()
+ if not normalized_query:
+ return 1.0
+ normalized_key = str(key).casefold()
+ normalized_text = str(text).casefold()
+ best = 0.0
+ if normalized_query in normalized_key:
+ best = 1.0
+ if normalized_query in normalized_text:
+ best = max(best, 0.92)
+
+ terms = _memory_query_terms(normalized_query)
+ if not terms:
+ return best
+
+ key_hits = sum(1 for term in terms if term in normalized_key)
+ text_hits = sum(1 for term in terms if term in normalized_text)
+ if key_hits:
+ best = max(best, 0.5 + 0.5 * (key_hits / len(terms)))
+ if text_hits:
+ best = max(best, 0.35 + 0.55 * (text_hits / len(terms)))
+ return min(best, 1.0)
+
+
+def cosine_similarity(left: list[float], right: list[float]) -> float:
+ """Compute cosine similarity defensively for embedding vectors."""
+
+ if not left or not right or len(left) != len(right):
+ return 0.0
+ left_norm = math.sqrt(sum(value * value for value in left))
+ right_norm = math.sqrt(sum(value * value for value in right))
+ if left_norm <= 0 or right_norm <= 0:
+ return 0.0
+ return sum(a * b for a, b in zip(left, right, strict=False)) / (
+ left_norm * right_norm
+ )
+
+
+def normalize_embedding(vector: list[float]) -> list[float]:
+ """Normalize an embedding for cosine/inner-product search."""
+
+ if not vector:
+ return []
+ norm = math.sqrt(sum(value * value for value in vector))
+ if norm <= 0:
+ return [0.0 for _ in vector]
+ return [float(value) / norm for value in vector]
+
+
+def memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
+ """Normalize cached sidecar data into a stable memory index record."""
+
+ if isinstance(entry, dict):
+ return {
+ "text": str(entry.get("text", text)),
+ "embedding": (
+ [float(item) for item in entry.get("embedding", [])]
+ if isinstance(entry.get("embedding"), list)
+ else None
+ ),
+ "provider_id": (
+ str(entry.get("provider_id")).strip()
+ if entry.get("provider_id") is not None
+ else None
+ ),
+ }
+ return {"text": text, "embedding": None, "provider_id": None}
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py
new file mode 100644
index 0000000000..471875e2fb
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py
@@ -0,0 +1,79 @@
+from __future__ import annotations
+
+import re
+from pathlib import Path
+
+PLUGIN_ID_PATTERN = re.compile(r"^[A-Za-z0-9_](?:[A-Za-z0-9._-]{0,126}[A-Za-z0-9_])?$")
+_WINDOWS_RESERVED_PLUGIN_IDS = {
+ "CON",
+ "PRN",
+ "AUX",
+ "NUL",
+ "COM1",
+ "COM2",
+ "COM3",
+ "COM4",
+ "COM5",
+ "COM6",
+ "COM7",
+ "COM8",
+ "COM9",
+ "LPT1",
+ "LPT2",
+ "LPT3",
+ "LPT4",
+ "LPT5",
+ "LPT6",
+ "LPT7",
+ "LPT8",
+ "LPT9",
+}
+
+
+def validate_plugin_id(plugin_id: str) -> str:
+ normalized = str(plugin_id).strip()
+ if not normalized:
+ raise ValueError("plugin_id must not be empty")
+ if not PLUGIN_ID_PATTERN.fullmatch(normalized):
+ raise ValueError(
+ "plugin_id must use only letters, digits, dots, underscores, or hyphens"
+ )
+ upper_normalized = normalized.upper()
+ base_name = upper_normalized.split(".", 1)[0]
+ if (
+ upper_normalized in _WINDOWS_RESERVED_PLUGIN_IDS
+ or base_name in _WINDOWS_RESERVED_PLUGIN_IDS
+ ):
+ raise ValueError("plugin_id must not use a reserved Windows device name")
+ return normalized
+
+
+def plugin_capability_prefix(plugin_id: str) -> str:
+ return f"{validate_plugin_id(plugin_id)}."
+
+
+def capability_belongs_to_plugin(capability_name: str, plugin_id: str) -> bool:
+ return str(capability_name).strip().startswith(plugin_capability_prefix(plugin_id))
+
+
+def plugin_http_route_root(plugin_id: str) -> str:
+ return f"/{validate_plugin_id(plugin_id)}"
+
+
+def http_route_belongs_to_plugin(route: str, plugin_id: str) -> bool:
+ normalized_route = str(route).strip()
+ route_root = plugin_http_route_root(plugin_id)
+ return normalized_route == route_root or normalized_route.startswith(
+ f"{route_root}/"
+ )
+
+
+def resolve_plugin_data_dir(root: Path, plugin_id: str) -> Path:
+ normalized = validate_plugin_id(plugin_id)
+ resolved_root = root.resolve()
+ candidate = (resolved_root / normalized).resolve()
+ try:
+ candidate.relative_to(resolved_root)
+ except ValueError as exc:
+ raise ValueError("plugin_id escapes the plugin data root") from exc
+ return candidate
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py
new file mode 100644
index 0000000000..b89fb8dc18
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py
@@ -0,0 +1,313 @@
+from __future__ import annotations
+
+import asyncio
+import inspect
+import os
+import time
+from collections.abc import AsyncIterator
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Any
+
+try:
+ from astrbot.core.config.default import VERSION as _ASTRBOT_VERSION
+except Exception: # noqa: BLE001
+ _ASTRBOT_VERSION = ""
+
+__all__ = ["PluginLogEntry", "PluginLogger"]
+
+
+@dataclass(slots=True)
+class PluginLogEntry:
+ level: str
+ time: float
+ message: str
+ plugin_id: str
+ context: dict[str, Any] = field(default_factory=dict)
+
+
+class _PluginLogBroker:
+ def __init__(self, plugin_id: str) -> None:
+ self.plugin_id = plugin_id
+ self._subscribers: set[asyncio.Queue[PluginLogEntry]] = set()
+
+ def publish(self, entry: PluginLogEntry) -> None:
+ for queue in list(self._subscribers):
+ try:
+ queue.put_nowait(entry)
+ except asyncio.QueueFull:
+ continue
+
+ async def watch(self) -> AsyncIterator[PluginLogEntry]:
+ queue: asyncio.Queue[PluginLogEntry] = asyncio.Queue()
+ self._subscribers.add(queue)
+ try:
+ while True:
+ yield await queue.get()
+ finally:
+ self._subscribers.discard(queue)
+
+
+_BROKERS: dict[str, _PluginLogBroker] = {}
+
+_SHORT_LEVEL_NAMES = {
+ "DEBUG": "DBUG",
+ "INFO": "INFO",
+ "WARNING": "WARN",
+ "ERROR": "ERRO",
+ "CRITICAL": "CRIT",
+}
+
+_ANSI_RESET = "\u001b[0m"
+_ANSI_GREEN = "\u001b[32m"
+_ANSI_LEVEL_COLORS = {
+ "DEBUG": "\u001b[1;34m",
+ "INFO": "\u001b[1;36m",
+ "WARNING": "\u001b[1;33m",
+ "ERROR": "\u001b[31m",
+ "CRITICAL": "\u001b[1;31m",
+}
+
+
+def _get_short_level_name(level_name: str) -> str:
+ return _SHORT_LEVEL_NAMES.get(level_name.upper(), level_name[:4].upper())
+
+
+def _build_source_file(pathname: str | None) -> str:
+ if not pathname:
+ return "unknown"
+ dirname = os.path.dirname(pathname)
+ return (
+ os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "")
+ )
+
+
+def _plugin_tag_from_path(pathname: str | None) -> str:
+ if not pathname:
+ return "[Plug]"
+ norm_path = os.path.normpath(pathname)
+ if any(
+ marker in norm_path
+ for marker in (
+ os.path.normpath("data/plugins"),
+ os.path.normpath("data/sdk_plugins"),
+ os.path.normpath("astrbot/builtin_stars"),
+ )
+ ):
+ return "[Plug]"
+ return "[Core]"
+
+
+def _level_color(level: str) -> str:
+ return _ANSI_LEVEL_COLORS.get(level.upper(), _ANSI_RESET)
+
+
+def _get_broker(plugin_id: str) -> _PluginLogBroker:
+ broker = _BROKERS.get(plugin_id)
+ if broker is None:
+ broker = _PluginLogBroker(plugin_id)
+ _BROKERS[plugin_id] = broker
+ return broker
+
+
+class PluginLogger:
+ def __init__(
+ self,
+ *,
+ plugin_id: str,
+ logger: Any,
+ bound_context: dict[str, Any] | None = None,
+ ) -> None:
+ self._plugin_id = plugin_id
+ self._logger = logger
+ self._broker = _get_broker(plugin_id)
+ self._bound_context = dict(bound_context or {})
+
+ @property
+ def plugin_id(self) -> str:
+ return self._plugin_id
+
+ def bind(self, **kwargs: Any) -> PluginLogger:
+ bind = getattr(self._logger, "bind", None)
+ next_logger = self._logger
+ if callable(bind):
+ try:
+ next_logger = bind(**kwargs)
+ except Exception:
+ next_logger = self._logger
+ return PluginLogger(
+ plugin_id=self._plugin_id,
+ logger=next_logger,
+ bound_context={**self._bound_context, **kwargs},
+ )
+
+ def opt(self, *args: Any, **kwargs: Any) -> PluginLogger:
+ opt = getattr(self._logger, "opt", None)
+ next_logger = self._logger
+ if callable(opt):
+ try:
+ next_logger = opt(*args, **kwargs)
+ except Exception:
+ next_logger = self._logger
+ return PluginLogger(
+ plugin_id=self._plugin_id,
+ logger=next_logger,
+ bound_context=self._bound_context,
+ )
+
+ async def watch(self) -> AsyncIterator[PluginLogEntry]:
+ async for entry in self._broker.watch():
+ yield entry
+
+ def log(self, level: str, message: Any, *args: Any, **kwargs: Any) -> None:
+ normalized_level = str(level).upper()
+ self._emit_console(normalized_level, message, *args, **kwargs)
+ self._publish(normalized_level, message, *args, **kwargs)
+
+ def debug(self, message: Any, *args: Any, **kwargs: Any) -> None:
+ self._emit_console("DEBUG", message, *args, **kwargs)
+ self._publish("DEBUG", message, *args, **kwargs)
+
+ def info(self, message: Any, *args: Any, **kwargs: Any) -> None:
+ self._emit_console("INFO", message, *args, **kwargs)
+ self._publish("INFO", message, *args, **kwargs)
+
+ def warning(self, message: Any, *args: Any, **kwargs: Any) -> None:
+ self._emit_console("WARNING", message, *args, **kwargs)
+ self._publish("WARNING", message, *args, **kwargs)
+
+ def error(self, message: Any, *args: Any, **kwargs: Any) -> None:
+ self._emit_console("ERROR", message, *args, **kwargs)
+ self._publish("ERROR", message, *args, **kwargs)
+
+ def exception(self, message: Any, *args: Any, **kwargs: Any) -> None:
+ self._emit_console("ERROR", message, *args, exception=True, **kwargs)
+ self._publish("ERROR", message, *args, **kwargs)
+
+ def _emit_console(
+ self,
+ level: str,
+ message: Any,
+ *args: Any,
+ exception: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ if self._emit_console_with_opt(
+ level,
+ message,
+ *args,
+ exception=exception,
+ **kwargs,
+ ):
+ return
+ self._emit_console_fallback(
+ level,
+ message,
+ *args,
+ exception=exception,
+ **kwargs,
+ )
+
+ def _emit_console_with_opt(
+ self,
+ level: str,
+ message: Any,
+ *args: Any,
+ exception: bool = False,
+ **kwargs: Any,
+ ) -> bool:
+ opt = getattr(self._logger, "opt", None)
+ if not callable(opt):
+ return False
+ formatted_message = self._format_message(message, *args, **kwargs)
+ pathname, source_line = self._caller_info()
+ plugin_tag = _plugin_tag_from_path(pathname)
+ source_file = _build_source_file(pathname)
+ version_tag = (
+ f" [v{_ASTRBOT_VERSION}]"
+ if _ASTRBOT_VERSION and level in {"WARNING", "ERROR", "CRITICAL"}
+ else ""
+ )
+ timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
+ level_text = _get_short_level_name(level)
+ level_color = _level_color(level)
+ line = (
+ f"{_ANSI_GREEN}[{timestamp}]{_ANSI_RESET} {plugin_tag} "
+ f"{level_color}[{level_text}]{_ANSI_RESET}{version_tag} "
+ f"[{source_file}:{source_line}]: {level_color}{formatted_message}{_ANSI_RESET}"
+ )
+ try:
+ emitter = opt(raw=True, exception=True) if exception else opt(raw=True)
+ log = getattr(emitter, "log", None)
+ if not callable(log):
+ return False
+ log(level, line + "\n")
+ return True
+ except Exception:
+ return False
+
+ def _emit_console_fallback(
+ self,
+ level: str,
+ message: Any,
+ *args: Any,
+ exception: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ method_names = []
+ if exception:
+ method_names.append("exception")
+ method_names.append(str(level).lower())
+ if exception:
+ method_names.append("error")
+ for method_name in method_names:
+ method = getattr(self._logger, method_name, None)
+ if not callable(method):
+ continue
+ try:
+ method(message, *args, **kwargs)
+ except Exception:
+ continue
+ return
+ log = getattr(self._logger, "log", None)
+ if callable(log):
+ try:
+ log(level, self._format_message(message, *args, **kwargs))
+ except Exception:
+ return
+
+ def _caller_info(self) -> tuple[str | None, int]:
+ frame = inspect.currentframe()
+ if frame is None:
+ return None, 0
+ frame = frame.f_back
+ while frame is not None and frame.f_globals.get("__name__") == __name__:
+ frame = frame.f_back
+ if frame is None:
+ return None, 0
+ return str(frame.f_code.co_filename), int(frame.f_lineno)
+
+ def _publish(self, level: str, message: Any, *args: Any, **kwargs: Any) -> None:
+ entry = PluginLogEntry(
+ level=level,
+ time=time.time(),
+ message=self._format_message(message, *args, **kwargs),
+ plugin_id=self._plugin_id,
+ context=dict(self._bound_context),
+ )
+ self._broker.publish(entry)
+
+ @staticmethod
+ def _format_message(message: Any, *args: Any, **kwargs: Any) -> str:
+ if not isinstance(message, str):
+ return str(message)
+ text = message
+ if not args and not kwargs:
+ return text
+ try:
+ return text.format(*args, **kwargs)
+ except Exception:
+ return text
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._logger, name)
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/sdk_logger.py b/astrbot-sdk/src/astrbot_sdk/_internal/sdk_logger.py
new file mode 100644
index 0000000000..687926ffea
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/sdk_logger.py
@@ -0,0 +1,50 @@
+from __future__ import annotations
+
+import os
+
+from loguru import logger as _raw_loguru_logger
+
+try:
+ from astrbot.core.config.default import VERSION as _ASTRBOT_VERSION
+except Exception: # noqa: BLE001
+ _ASTRBOT_VERSION = ""
+
+_SHORT_LEVEL_NAMES = {
+ "DEBUG": "DBUG",
+ "INFO": "INFO",
+ "WARNING": "WARN",
+ "ERROR": "ERRO",
+ "CRITICAL": "CRIT",
+}
+
+
+def _get_short_level_name(level_name: str) -> str:
+ return _SHORT_LEVEL_NAMES.get(level_name.upper(), level_name[:4].upper())
+
+
+def _build_source_file(pathname: str | None) -> str:
+ if not pathname:
+ return "unknown"
+ dirname = os.path.dirname(pathname)
+ return (
+ os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "")
+ )
+
+
+def _patch_record(record: dict) -> None:
+ extra = record["extra"]
+ extra.setdefault("plugin_tag", "[Core]")
+ extra.setdefault("short_levelname", _get_short_level_name(record["level"].name))
+ level_no = record["level"].no
+ version_tag = (
+ f" [v{_ASTRBOT_VERSION}]" if _ASTRBOT_VERSION and level_no >= 30 else ""
+ )
+ extra.setdefault("astrbot_version_tag", version_tag)
+ extra.setdefault("source_file", _build_source_file(record["file"].path))
+ extra.setdefault("source_line", record["line"])
+ extra.setdefault("is_trace", False)
+
+
+logger = _raw_loguru_logger.patch(_patch_record)
+
+__all__ = ["logger"]
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py b/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py
new file mode 100644
index 0000000000..37211735e6
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py
@@ -0,0 +1,46 @@
+from __future__ import annotations
+
+from collections.abc import Iterator
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from ..context import Context
+ from ..star import Star
+
+
+_CURRENT_STAR_CONTEXT: ContextVar[Context | None] = ContextVar(
+ "astrbot_sdk_current_star_context",
+ default=None,
+)
+_CURRENT_STAR_INSTANCE: ContextVar[Star | None] = ContextVar(
+ "astrbot_sdk_current_star_instance",
+ default=None,
+)
+
+
+def current_star_context() -> Context | None:
+ return _CURRENT_STAR_CONTEXT.get()
+
+
+def current_runtime_context() -> Context | None:
+ return _CURRENT_STAR_CONTEXT.get()
+
+
+def current_star_instance() -> Star | None:
+ return _CURRENT_STAR_INSTANCE.get()
+
+
+@contextmanager
+def bind_star_runtime(star: Star | None, ctx: Context | None) -> Iterator[None]:
+ context_token = _CURRENT_STAR_CONTEXT.set(ctx)
+ star_token = _CURRENT_STAR_INSTANCE.set(star)
+ instance_token = star._bind_runtime_context(ctx) if star is not None else None
+ try:
+ yield
+ finally:
+ if star is not None and instance_token is not None:
+ star._reset_runtime_context(instance_token)
+ _CURRENT_STAR_INSTANCE.reset(star_token)
+ _CURRENT_STAR_CONTEXT.reset(context_token)
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py b/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py
new file mode 100644
index 0000000000..2594d453e9
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py
@@ -0,0 +1,591 @@
+"""Shared support primitives for local SDK testing."""
+
+from __future__ import annotations
+
+import asyncio
+import typing
+from collections.abc import Mapping
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from typing import Any, TextIO
+
+from ..context import CancelToken
+from ..context import Context as RuntimeContext
+from ..events import MessageEvent
+from ..protocol.messages import EventMessage, PeerInfo
+from ..runtime._streaming import StreamExecution
+from ..runtime.capability_router import CapabilityRouter
+
+
+def _clone_payload_mapping(value: Any) -> dict[str, Any] | None:
+ if not isinstance(value, dict):
+ return None
+ return {str(key): item for key, item in value.items()}
+
+
+@dataclass(slots=True)
+class RecordedSend:
+ kind: str
+ message_id: str
+ session_id: str
+ text: str | None = None
+ image_url: str | None = None
+ chain: list[dict[str, Any]] | None = None
+ target: dict[str, Any] | None = None
+ raw: dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def session(self) -> str:
+ return self.session_id
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> RecordedSend:
+ if "text" in payload:
+ kind = "text"
+ elif "image_url" in payload:
+ kind = "image"
+ elif "chain" in payload:
+ kind = "chain"
+ else:
+ kind = "unknown"
+ return cls(
+ kind=kind,
+ message_id=str(payload.get("message_id", "")),
+ session_id=str(payload.get("session", "")),
+ text=payload.get("text") if isinstance(payload.get("text"), str) else None,
+ image_url=(
+ payload.get("image_url")
+ if isinstance(payload.get("image_url"), str)
+ else None
+ ),
+ chain=(
+ [dict(item) for item in payload.get("chain", [])]
+ if isinstance(payload.get("chain"), list)
+ else None
+ ),
+ target=_clone_payload_mapping(payload.get("target")),
+ raw=dict(payload),
+ )
+
+
+class StdoutPlatformSink:
+ def __init__(self, stream: TextIO | None = None) -> None:
+ self._stream = stream
+ self.records: list[RecordedSend] = []
+
+ def record(self, item: RecordedSend) -> None:
+ self.records.append(item)
+ if self._stream is None:
+ return
+ self._stream.write(self._format(item) + "\n")
+ self._stream.flush()
+
+ def clear(self) -> None:
+ self.records.clear()
+
+ def _format(self, item: RecordedSend) -> str:
+ if item.kind == "text":
+ return f"[text][{item.session_id}] {item.text or ''}"
+ if item.kind == "image":
+ return f"[image][{item.session_id}] {item.image_url or ''}"
+ if item.kind == "chain":
+ count = len(item.chain or [])
+ return f"[chain][{item.session_id}] {count} components"
+ return f"[send][{item.session_id}] {item.raw}"
+
+
+class InMemoryDB:
+ def __init__(self, store: dict[str, Any]) -> None:
+ self._store = store
+
+ def get(self, key: str, default: Any = None) -> Any:
+ return self._store.get(key, default)
+
+ def set(self, key: str, value: Any) -> None:
+ self._store[key] = value
+
+ def delete(self, key: str) -> None:
+ self._store.pop(key, None)
+
+ def list(self, prefix: str | None = None) -> list[str]:
+ keys = sorted(self._store.keys())
+ if prefix is None:
+ return keys
+ return [key for key in keys if key.startswith(prefix)]
+
+ def get_many(self, keys: list[str]) -> list[dict[str, Any]]:
+ return [{"key": key, "value": self._store.get(key)} for key in keys]
+
+ def set_many(self, items: list[dict[str, Any]]) -> None:
+ for item in items:
+ self.set(str(item.get("key", "")), item.get("value"))
+
+
+class InMemoryMemory:
+ def __init__(
+ self,
+ store: dict[str, dict[str, Any]],
+ *,
+ expires_at: dict[str, datetime | None] | None = None,
+ ) -> None:
+ self._store = store
+ self._expires_at = expires_at if expires_at is not None else {}
+
+ @staticmethod
+ def _is_ttl_entry(value: Any) -> bool:
+ """判断测试 memory 值是否使用 TTL 包装结构。
+
+ Args:
+ value: 待检查的存储值。
+
+ Returns:
+ bool: 如果包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
+ """
+ return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
+
+ @classmethod
+ def _search_text(cls, value: Any) -> str:
+ """提取测试用 memory.search 的匹配文本。
+
+ Args:
+ value: 当前存储的 memory 值。
+
+ Returns:
+ str: 用于本地测试搜索的文本内容。
+ """
+ if cls._is_ttl_entry(value):
+ value = value.get("value")
+ if not isinstance(value, dict):
+ return ""
+ for field_name in ("embedding_text", "content", "summary", "title", "text"):
+ item = value.get(field_name)
+ if isinstance(item, str) and item.strip():
+ return item.strip()
+ return str(value)
+
+ def _is_expired(self, key: str) -> bool:
+ """判断测试 memory 键是否已经过期。
+
+ Args:
+ key: memory 条目的键。
+
+ Returns:
+ bool: 如果当前时间已超过过期时间则返回 ``True``。
+ """
+ expires_at = self._expires_at.get(key)
+ return expires_at is not None and expires_at <= datetime.now(timezone.utc)
+
+ def _purge_if_expired(self, key: str) -> bool:
+ """在测试 helper 中清理已过期的 memory 条目。
+
+ Args:
+ key: memory 条目的键。
+
+ Returns:
+ bool: 如果条目已过期并被清理则返回 ``True``。
+ """
+ if not self._is_expired(key):
+ return False
+ self._store.pop(key, None)
+ self._expires_at.pop(key, None)
+ return True
+
+ def get(self, key: str, default: Any = None) -> Any:
+ if self._purge_if_expired(key):
+ return default
+ return self._store.get(key, default)
+
+ def save(self, key: str, value: dict[str, Any]) -> None:
+ self._store[key] = dict(value)
+
+ def delete(self, key: str) -> None:
+ self._store.pop(key, None)
+ self._expires_at.pop(key, None)
+
+ def search(self, query: str) -> list[dict[str, Any]]:
+ results: list[dict[str, Any]] = []
+ for key, value in list(self._store.items()):
+ if self._purge_if_expired(key):
+ continue
+ if query in key or query in self._search_text(value):
+ results.append({"key": key, "value": value})
+ return results
+
+
+class MockLLMClient:
+ def __init__(self, client: Any, router: MockCapabilityRouter) -> None:
+ self._client = client
+ self._router = router
+
+ def mock_response(self, text: str) -> None:
+ self._router.enqueue_llm_response(text)
+
+ def mock_stream_response(self, text: str) -> None:
+ self._router.enqueue_llm_stream_response(text)
+
+ def clear_mock_responses(self) -> None:
+ self._router.clear_llm_responses()
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._client, name)
+
+
+class MockPlatformClient:
+ def __init__(self, client: Any, sink: StdoutPlatformSink) -> None:
+ self._client = client
+ self._sink = sink
+
+ @property
+ def records(self) -> list[RecordedSend]:
+ return list(self._sink.records)
+
+ def assert_sent(
+ self,
+ expected_text: str | None = None,
+ *,
+ kind: str = "text",
+ count: int | None = None,
+ ) -> None:
+ matched = [item for item in self._sink.records if item.kind == kind]
+ if expected_text is not None:
+ matched = [item for item in matched if item.text == expected_text]
+ if count is not None:
+ if len(matched) != count:
+ raise AssertionError(
+ f"expected {count} sent records, got {len(matched)}: {matched}"
+ )
+ return
+ if not matched:
+ raise AssertionError(
+ f"expected sent record kind={kind!r} text={expected_text!r}, got {self._sink.records}"
+ )
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._client, name)
+
+
+class MockCapabilityRouter(CapabilityRouter):
+ def __init__(self, *, platform_sink: StdoutPlatformSink | None = None) -> None:
+ self.platform_sink = platform_sink or StdoutPlatformSink()
+ self._llm_responses: list[str] = []
+ self._llm_stream_responses: list[str] = []
+ super().__init__()
+ self.db = InMemoryDB(self.db_store)
+ self.memory = InMemoryMemory(
+ self.memory_store,
+ expires_at=self._memory_expires_at,
+ )
+
+ def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]:
+ return super().list_dynamic_command_routes(plugin_id)
+
+ def remove_dynamic_command_routes_for_plugin(self, plugin_id: str) -> None:
+ super().remove_dynamic_command_routes_for_plugin(plugin_id)
+
+ def emit_provider_change(
+ self,
+ provider_id: str,
+ provider_type: str,
+ umo: str | None = None,
+ ) -> None:
+ super().emit_provider_change(provider_id, provider_type, umo)
+
+ def record_platform_error(
+ self,
+ platform_id: str,
+ message: str,
+ *,
+ traceback: str | None = None,
+ ) -> None:
+ super().record_platform_error(platform_id, message, traceback=traceback)
+
+ def set_platform_stats(self, platform_id: str, stats: dict[str, Any]) -> None:
+ super().set_platform_stats(platform_id, stats)
+
+ def enqueue_llm_response(self, text: str) -> None:
+ self._llm_responses.append(text)
+
+ def enqueue_llm_stream_response(self, text: str) -> None:
+ self._llm_stream_responses.append(text)
+
+ def clear_llm_responses(self) -> None:
+ self._llm_responses.clear()
+ self._llm_stream_responses.clear()
+
+ async def execute(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ stream: bool,
+ cancel_token,
+ request_id: str,
+ ) -> dict[str, Any] | StreamExecution:
+ if capability == "llm.chat":
+ return {"text": self._take_llm_response(str(payload.get("prompt", "")))}
+ if capability == "llm.chat_raw":
+ text = self._take_llm_response(str(payload.get("prompt", "")))
+ return {
+ "text": text,
+ "usage": {
+ "input_tokens": len(str(payload.get("prompt", ""))),
+ "output_tokens": len(text),
+ },
+ "finish_reason": "stop",
+ "tool_calls": [],
+ "role": "assistant",
+ "reasoning_content": None,
+ "reasoning_signature": None,
+ }
+ if capability == "llm.stream_chat":
+ text = self._take_llm_stream_response(str(payload.get("prompt", "")))
+
+ async def iterator() -> typing.AsyncIterator[dict[str, Any]]:
+ for char in text:
+ cancel_token.raise_if_cancelled()
+ await asyncio.sleep(0)
+ yield {"text": char}
+
+ return StreamExecution(
+ iterator=iterator(),
+ finalize=lambda chunks: {
+ "text": "".join(item.get("text", "") for item in chunks)
+ },
+ )
+ before = len(self.sent_messages)
+ result = await super().execute(
+ capability,
+ payload,
+ stream=stream,
+ cancel_token=cancel_token,
+ request_id=request_id,
+ )
+ self._flush_platform_records(before)
+ return result
+
+ def _flush_platform_records(self, start_index: int) -> None:
+ for payload in self.sent_messages[start_index:]:
+ self.platform_sink.record(RecordedSend.from_payload(payload))
+
+ def _take_llm_response(self, prompt: str) -> str:
+ if self._llm_responses:
+ return self._llm_responses.pop(0)
+ return f"Echo: {prompt}"
+
+ def _take_llm_stream_response(self, prompt: str) -> str:
+ if self._llm_stream_responses:
+ return self._llm_stream_responses.pop(0)
+ if self._llm_responses:
+ return self._llm_responses.pop(0)
+ return f"Echo: {prompt}"
+
+
+class MockPeer:
+ def __init__(self, router: MockCapabilityRouter) -> None:
+ self._router = router
+ self._counter = 0
+ self.remote_peer = PeerInfo(
+ name="astrbot-local-core",
+ role="core",
+ version="local",
+ )
+ self.remote_capabilities = list(router.all_descriptors())
+ self.remote_capability_map = {
+ item.name: item for item in self.remote_capabilities
+ }
+ self.remote_handlers: list[Any] = []
+ self.remote_provided_capabilities: list[Any] = []
+ self.remote_metadata = {"mode": "local"}
+
+ async def invoke(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ stream: bool = False,
+ request_id: str | None = None,
+ ) -> dict[str, Any]:
+ if stream:
+ raise ValueError("stream=True 请使用 invoke_stream()")
+ return typing.cast(
+ dict[str, Any],
+ await self._router.execute(
+ capability,
+ payload,
+ stream=False,
+ cancel_token=CancelToken(),
+ request_id=request_id or self._next_id(),
+ ),
+ )
+
+ async def invoke_stream(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ request_id: str | None = None,
+ include_completed: bool = False,
+ ):
+ request_id = request_id or self._next_id()
+ execution = typing.cast(
+ StreamExecution,
+ await self._router.execute(
+ capability,
+ payload,
+ stream=True,
+ cancel_token=CancelToken(),
+ request_id=request_id,
+ ),
+ )
+
+ async def iterator():
+ yield EventMessage.model_validate({"id": request_id, "phase": "started"})
+ chunks: list[dict[str, Any]] = []
+ async for chunk in execution.iterator:
+ if execution.collect_chunks:
+ chunks.append(chunk)
+ yield EventMessage.model_validate(
+ {"id": request_id, "phase": "delta", "data": chunk}
+ )
+ output = execution.finalize(chunks)
+ if include_completed:
+ yield EventMessage.model_validate(
+ {"id": request_id, "phase": "completed", "output": output}
+ )
+
+ return iterator()
+
+ def _next_id(self) -> str:
+ self._counter += 1
+ return f"local_{self._counter:04d}"
+
+
+def _normalize_plugin_metadata(
+ plugin_id: str,
+ plugin_metadata: Mapping[str, Any] | None,
+) -> dict[str, Any]:
+ if plugin_metadata is None:
+ plugin_metadata = {}
+ declared_name = plugin_metadata.get("name")
+ if declared_name is not None and str(declared_name) != plugin_id:
+ raise ValueError(
+ "MockContext.plugin_metadata['name'] 必须与 plugin_id 一致,"
+ f"当前收到 {declared_name!r} != {plugin_id!r}"
+ )
+ description = plugin_metadata.get("description")
+ if description is None:
+ description = plugin_metadata.get("desc", "")
+ return {
+ "name": plugin_id,
+ "display_name": str(plugin_metadata.get("display_name") or plugin_id),
+ "description": str(description or ""),
+ "author": str(plugin_metadata.get("author") or ""),
+ "version": str(plugin_metadata.get("version") or "0.0.0"),
+ "enabled": bool(plugin_metadata.get("enabled", True)),
+ "reserved": bool(plugin_metadata.get("reserved", False)),
+ "support_platforms": [
+ str(item)
+ for item in plugin_metadata.get("support_platforms", [])
+ if isinstance(item, str)
+ ]
+ if isinstance(plugin_metadata.get("support_platforms"), list)
+ else [],
+ "astrbot_version": (
+ str(plugin_metadata.get("astrbot_version"))
+ if plugin_metadata.get("astrbot_version") is not None
+ else None
+ ),
+ }
+
+
+class MockContext(RuntimeContext):
+ def __init__(
+ self,
+ *,
+ plugin_id: str = "test-plugin",
+ logger: Any | None = None,
+ cancel_token: CancelToken | None = None,
+ platform_sink: StdoutPlatformSink | None = None,
+ plugin_metadata: Mapping[str, Any] | None = None,
+ ) -> None:
+ self.platform_sink = platform_sink or StdoutPlatformSink()
+ self.router = MockCapabilityRouter(platform_sink=self.platform_sink)
+ self.mock_peer = MockPeer(self.router)
+ super().__init__(
+ peer=self.mock_peer,
+ plugin_id=plugin_id,
+ cancel_token=cancel_token,
+ logger=logger,
+ )
+ self.router.upsert_plugin(
+ metadata=_normalize_plugin_metadata(plugin_id, plugin_metadata),
+ config={},
+ )
+ self.llm = MockLLMClient(self.llm, self.router)
+ self.platform = MockPlatformClient(self.platform, self.platform_sink)
+
+ @property
+ def sent_messages(self) -> list[RecordedSend]:
+ return list(self.platform_sink.records)
+
+ @property
+ def event_actions(self) -> list[dict[str, Any]]:
+ return list(self.router.event_actions)
+
+
+class MockMessageEvent(MessageEvent):
+ def __init__(
+ self,
+ *,
+ text: str = "",
+ user_id: str | None = "test-user",
+ group_id: str | None = None,
+ platform: str | None = "test",
+ session_id: str | None = "test-session",
+ raw: dict[str, Any] | None = None,
+ context: MockContext | None = None,
+ ) -> None:
+ self.replies: list[str] = []
+ super().__init__(
+ text=text,
+ user_id=user_id,
+ group_id=group_id,
+ platform=platform,
+ session_id=session_id,
+ raw=raw,
+ context=context,
+ )
+ if context is not None:
+ self.bind_runtime_reply(context)
+ elif self._reply_handler is None:
+ self.bind_reply_handler(self._capture_reply)
+
+ @property
+ def is_private(self) -> bool:
+ return self.group_id is None
+
+ def bind_runtime_reply(self, context: MockContext) -> None:
+ self._context = context
+
+ async def reply(text: str) -> None:
+ self.replies.append(text)
+ await context.platform.send(self.session_ref or self.session_id, text)
+
+ self.bind_reply_handler(reply)
+
+ async def _capture_reply(self, text: str) -> None:
+ self.replies.append(text)
+
+
+__all__ = [
+ "InMemoryDB",
+ "InMemoryMemory",
+ "MockCapabilityRouter",
+ "MockContext",
+ "MockLLMClient",
+ "MockMessageEvent",
+ "MockPeer",
+ "MockPlatformClient",
+ "RecordedSend",
+ "StdoutPlatformSink",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py b/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py
new file mode 100644
index 0000000000..7cac7421ba
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import typing
+from types import UnionType
+from typing import Any
+
+
+def unwrap_optional(annotation: Any) -> tuple[Any, bool]:
+ origin = typing.get_origin(annotation)
+ if origin in {typing.Union, UnionType}:
+ args = [item for item in typing.get_args(annotation) if item is not type(None)]
+ if len(args) == 1:
+ return args[0], True
+ return annotation, False
+
+
+__all__ = ["unwrap_optional"]
diff --git a/astrbot-sdk/src/astrbot_sdk/_memory_backend.py b/astrbot-sdk/src/astrbot_sdk/_memory_backend.py
new file mode 100644
index 0000000000..50f94cbced
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_memory_backend.py
@@ -0,0 +1,1515 @@
+from __future__ import annotations
+
+import asyncio
+import json
+import re
+import sqlite3
+import threading
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any, cast
+
+from ._internal.memory_utils import (
+ cosine_similarity,
+ display_memory_namespace,
+ extract_memory_text,
+ join_memory_namespace,
+ memory_keyword_score,
+ memory_namespace_matches,
+ memory_value_for_search,
+ normalize_embedding,
+ normalize_memory_namespace,
+)
+
+
+def _utcnow() -> datetime:
+ # Centralize time access so expiry tests can advance time without mutating SQLite internals.
+ return datetime.now(timezone.utc)
+
+
+def _sql_placeholders(count: int) -> str:
+ if count <= 0:
+ raise ValueError("count must be positive")
+ return ", ".join("?" for _ in range(count))
+
+
+def _normalize_scope_namespace(namespace: str | None) -> str | None:
+ if namespace is None:
+ return None
+ return normalize_memory_namespace(namespace)
+
+
+def _escape_like_value(value: str) -> str:
+ return str(value).replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
+
+
+EmbedMany = Callable[[list[str]], Awaitable[list[list[float]]] | list[list[float]]]
+EmbedOne = Callable[[str], Awaitable[list[float]] | list[float]]
+
+
+@dataclass(slots=True)
+class MemorySearchResult:
+ key: str
+ namespace: str
+ value: dict[str, Any] | None
+ score: float
+ match_type: str
+
+ def to_payload(self) -> dict[str, Any]:
+ payload: dict[str, Any] = {
+ "key": self.key,
+ "value": self.value,
+ "score": self.score,
+ "match_type": self.match_type,
+ }
+ namespace = display_memory_namespace(self.namespace)
+ if namespace is not None:
+ payload["namespace"] = namespace
+ return payload
+
+
+@dataclass(slots=True)
+class _StoredRecord:
+ namespace: str
+ key: str
+ stored: dict[str, Any]
+ search_text: str
+ updated_at: str
+
+
+@dataclass(slots=True)
+class _VectorCandidate:
+ namespace: str
+ key: str
+ stored: dict[str, Any]
+ search_text: str
+ score: float
+
+
+class PluginMemoryBackend:
+ """Persistent plugin-scoped memory backend with namespace-aware search."""
+
+ def __init__(self, data_dir: Path) -> None:
+ self._base_dir = Path(data_dir) / "memory"
+ self._db_path = self._base_dir / "memory.sqlite3"
+ self._vector_dir = self._base_dir / "vectors"
+ self._lock = threading.RLock()
+ self._initialized = False
+ self._fts_enabled = False
+ self._vector_indexes: dict[str, Any | None] = {}
+ self._vector_fallbacks: dict[str, list[tuple[int, list[float]]]] = {}
+
+ async def save(
+ self,
+ key: str,
+ value: dict[str, Any],
+ *,
+ namespace: str | None = None,
+ ) -> None:
+ await asyncio.to_thread(
+ self._save_sync,
+ str(key),
+ dict(value),
+ normalize_memory_namespace(namespace),
+ None,
+ )
+
+ async def save_with_ttl(
+ self,
+ key: str,
+ value: dict[str, Any],
+ ttl_seconds: int,
+ *,
+ namespace: str | None = None,
+ ) -> None:
+ expires_at = _utcnow().timestamp() + max(int(ttl_seconds), 0)
+ await asyncio.to_thread(
+ self._save_sync,
+ str(key),
+ dict(value),
+ normalize_memory_namespace(namespace),
+ {
+ "ttl_seconds": int(ttl_seconds),
+ "expires_at": datetime.fromtimestamp(
+ expires_at,
+ tz=timezone.utc,
+ ).isoformat(),
+ },
+ )
+
+ async def get(
+ self,
+ key: str,
+ *,
+ namespace: str | None = None,
+ ) -> dict[str, Any] | None:
+ return await asyncio.to_thread(
+ self._get_sync,
+ str(key),
+ normalize_memory_namespace(namespace),
+ )
+
+ async def list_keys(
+ self,
+ *,
+ namespace: str | None = None,
+ ) -> list[str]:
+ return await asyncio.to_thread(
+ self._list_keys_sync,
+ normalize_memory_namespace(namespace),
+ )
+
+ async def exists(
+ self,
+ key: str,
+ *,
+ namespace: str | None = None,
+ ) -> bool:
+ return await asyncio.to_thread(
+ self._exists_sync,
+ str(key),
+ normalize_memory_namespace(namespace),
+ )
+
+ async def get_many(
+ self,
+ keys: list[str],
+ *,
+ namespace: str | None = None,
+ ) -> list[dict[str, Any]]:
+ normalized_namespace = normalize_memory_namespace(namespace)
+ return await asyncio.to_thread(
+ self._get_many_sync,
+ [str(item) for item in keys],
+ normalized_namespace,
+ )
+
+ async def delete(
+ self,
+ key: str,
+ *,
+ namespace: str | None = None,
+ ) -> bool:
+ return await asyncio.to_thread(
+ self._delete_sync,
+ str(key),
+ normalize_memory_namespace(namespace),
+ )
+
+ async def clear_namespace(
+ self,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = False,
+ ) -> int:
+ normalized_namespace = _normalize_scope_namespace(namespace)
+ return await asyncio.to_thread(
+ self._clear_namespace_sync,
+ normalized_namespace,
+ bool(include_descendants),
+ )
+
+ async def delete_many(
+ self,
+ keys: list[str],
+ *,
+ namespace: str | None = None,
+ ) -> int:
+ normalized_namespace = normalize_memory_namespace(namespace)
+ return await asyncio.to_thread(
+ self._delete_many_sync,
+ [str(item) for item in keys],
+ normalized_namespace,
+ )
+
+ async def count(
+ self,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = False,
+ ) -> int:
+ normalized_namespace = _normalize_scope_namespace(namespace)
+ return await asyncio.to_thread(
+ self._count_sync,
+ normalized_namespace,
+ bool(include_descendants),
+ )
+
+ async def stats(
+ self,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = True,
+ ) -> dict[str, Any]:
+ normalized_namespace = _normalize_scope_namespace(namespace)
+ return await asyncio.to_thread(
+ self._stats_sync,
+ normalized_namespace,
+ bool(include_descendants),
+ )
+
+ async def search(
+ self,
+ query: str,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = True,
+ mode: str,
+ limit: int | None,
+ min_score: float | None,
+ provider_id: str | None = None,
+ embed_one: EmbedOne | None = None,
+ embed_many: EmbedMany | None = None,
+ ) -> list[dict[str, Any]]:
+ normalized_namespace = _normalize_scope_namespace(namespace)
+ normalized_mode = str(mode).strip().lower() or "keyword"
+ query_text = str(query)
+
+ await asyncio.to_thread(self._purge_expired_sync)
+
+ keyword_candidates = await asyncio.to_thread(
+ self._keyword_candidates_sync,
+ query_text,
+ normalized_namespace,
+ bool(include_descendants),
+ limit,
+ )
+
+ vector_candidates: list[_VectorCandidate] = []
+ if normalized_mode in {"vector", "hybrid"} and provider_id:
+ await self._ensure_embeddings(
+ provider_id=provider_id,
+ namespace=normalized_namespace,
+ include_descendants=bool(include_descendants),
+ embed_one=embed_one,
+ embed_many=embed_many,
+ )
+ if embed_one is not None:
+ raw_query_embedding = await _maybe_await(embed_one(query_text))
+ query_embedding = normalize_embedding(
+ [float(item) for item in raw_query_embedding]
+ )
+ vector_candidates = await asyncio.to_thread(
+ self._vector_candidates_sync,
+ provider_id,
+ query_embedding,
+ normalized_namespace,
+ bool(include_descendants),
+ limit,
+ )
+
+ merged: dict[tuple[str, str], dict[str, Any]] = {}
+ for record in keyword_candidates:
+ identity = (record.namespace, record.key)
+ merged[identity] = {
+ "namespace": record.namespace,
+ "key": record.key,
+ "stored": record.stored,
+ "keyword_score": memory_keyword_score(
+ query_text,
+ record.key,
+ record.search_text,
+ ),
+ "vector_score": 0.0,
+ }
+ for record in vector_candidates:
+ identity = (record.namespace, record.key)
+ current = merged.setdefault(
+ identity,
+ {
+ "namespace": record.namespace,
+ "key": record.key,
+ "stored": record.stored,
+ "keyword_score": memory_keyword_score(
+ query_text,
+ record.key,
+ record.search_text,
+ ),
+ "vector_score": 0.0,
+ },
+ )
+ current["vector_score"] = max(
+ float(current["vector_score"]),
+ float(record.score),
+ )
+
+ results: list[MemorySearchResult] = []
+ for item in merged.values():
+ keyword_score = max(0.0, float(item["keyword_score"]))
+ vector_score = max(0.0, float(item["vector_score"]))
+ score = self._combined_score(
+ mode=normalized_mode,
+ keyword_score=keyword_score,
+ vector_score=vector_score,
+ )
+ if score <= 0:
+ continue
+ if min_score is not None and score < float(min_score):
+ continue
+
+ if normalized_mode == "keyword" or (
+ keyword_score > 0 and vector_score <= 0
+ ):
+ match_type = "keyword"
+ elif normalized_mode == "vector" or keyword_score <= 0:
+ match_type = "vector"
+ else:
+ match_type = "hybrid"
+
+ results.append(
+ MemorySearchResult(
+ key=str(item["key"]),
+ namespace=str(item["namespace"]),
+ value=memory_value_for_search(item["stored"]),
+ score=score,
+ match_type=match_type,
+ )
+ )
+
+ results.sort(key=lambda item: (-item.score, item.namespace, item.key))
+ if limit is not None and limit >= 0:
+ results = results[:limit]
+ return [item.to_payload() for item in results]
+
+ async def _ensure_embeddings(
+ self,
+ *,
+ provider_id: str,
+ namespace: str | None,
+ include_descendants: bool,
+ embed_one: EmbedOne | None,
+ embed_many: EmbedMany | None,
+ ) -> None:
+ missing = await asyncio.to_thread(
+ self._missing_embeddings_sync,
+ provider_id,
+ namespace,
+ include_descendants,
+ )
+ if missing:
+ texts = [record.search_text for record in missing]
+ embeddings: list[list[float]]
+ if embed_many is not None:
+ raw_embeddings = await _maybe_await(embed_many(texts))
+ embeddings = [
+ normalize_embedding([float(value) for value in item])
+ for item in raw_embeddings
+ ]
+ elif embed_one is not None:
+ embeddings = []
+ for text in texts:
+ raw_vector = await _maybe_await(embed_one(text))
+ embeddings.append(
+ normalize_embedding([float(value) for value in raw_vector])
+ )
+ else:
+ embeddings = []
+ await asyncio.to_thread(
+ self._upsert_embeddings_sync,
+ provider_id,
+ missing,
+ embeddings,
+ )
+ await asyncio.to_thread(self._ensure_vector_index_sync, provider_id)
+
+ def _save_sync(
+ self,
+ key: str,
+ value: dict[str, Any],
+ namespace: str,
+ ttl_metadata: dict[str, Any] | None,
+ ) -> None:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ stored = dict(value)
+ expires_at: str | None = None
+ if ttl_metadata is not None:
+ expires_at = str(ttl_metadata.get("expires_at", "")).strip() or None
+ stored = {
+ "value": dict(value),
+ "ttl_seconds": int(ttl_metadata.get("ttl_seconds", 0)),
+ }
+ if expires_at is not None:
+ stored["expires_at"] = expires_at
+ search_text = extract_memory_text(stored)
+ stored_json = json.dumps(
+ stored,
+ ensure_ascii=False,
+ sort_keys=True,
+ default=str,
+ )
+ updated_at = _utcnow().isoformat()
+ conn.execute(
+ """
+ INSERT INTO memory_records(namespace, key, stored_json, search_text, expires_at, updated_at)
+ VALUES(?, ?, ?, ?, ?, ?)
+ ON CONFLICT(namespace, key) DO UPDATE SET
+ stored_json = excluded.stored_json,
+ search_text = excluded.search_text,
+ expires_at = excluded.expires_at,
+ updated_at = excluded.updated_at
+ """,
+ (namespace, key, stored_json, search_text, expires_at, updated_at),
+ )
+ self._sync_fts_row_locked(
+ conn,
+ namespace=namespace,
+ key=key,
+ search_text=search_text,
+ )
+ provider_rows = conn.execute(
+ """
+ SELECT DISTINCT provider_id
+ FROM memory_embeddings
+ WHERE namespace = ? AND key = ?
+ """,
+ (namespace, key),
+ ).fetchall()
+ conn.execute(
+ "DELETE FROM memory_embeddings WHERE namespace = ? AND key = ?",
+ (namespace, key),
+ )
+ for row in provider_rows:
+ provider_id = str(row[0]).strip()
+ if provider_id:
+ self._mark_vector_dirty_locked(conn, provider_id)
+ conn.commit()
+ finally:
+ conn.close()
+
+ def _get_sync(self, key: str, namespace: str) -> dict[str, Any] | None:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ row = conn.execute(
+ """
+ SELECT stored_json
+ FROM memory_records
+ WHERE namespace = ? AND key = ?
+ """,
+ (namespace, key),
+ ).fetchone()
+ if row is None:
+ return None
+ stored = self._load_stored_json(row[0])
+ return memory_value_for_search(stored)
+ finally:
+ conn.close()
+
+ def _list_keys_sync(self, namespace: str) -> list[str]:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ rows = conn.execute(
+ """
+ SELECT key
+ FROM memory_records
+ WHERE namespace = ?
+ ORDER BY key COLLATE NOCASE ASC, key ASC
+ """,
+ (namespace,),
+ ).fetchall()
+ return [str(row[0]) for row in rows]
+ finally:
+ conn.close()
+
+ def _exists_sync(self, key: str, namespace: str) -> bool:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ row = conn.execute(
+ """
+ SELECT 1
+ FROM memory_records
+ WHERE namespace = ? AND key = ?
+ LIMIT 1
+ """,
+ (namespace, key),
+ ).fetchone()
+ return row is not None
+ finally:
+ conn.close()
+
+ def _get_many_sync(self, keys: list[str], namespace: str) -> list[dict[str, Any]]:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ if not keys:
+ return []
+ lookup_keys = list(dict.fromkeys(keys))
+ placeholders = _sql_placeholders(len(lookup_keys))
+ rows = conn.execute(
+ f"""
+ SELECT key, stored_json
+ FROM memory_records
+ WHERE namespace = ? AND key IN ({placeholders})
+ """,
+ (namespace, *lookup_keys),
+ ).fetchall()
+ stored_by_key = {
+ str(row[0]): self._load_stored_json(row[1]) for row in rows
+ }
+ return [
+ {
+ "key": key,
+ "value": memory_value_for_search(stored_by_key.get(key)),
+ }
+ for key in keys
+ ]
+ finally:
+ conn.close()
+
+ def _delete_sync(self, key: str, namespace: str) -> bool:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ deleted = self._delete_record_locked(conn, namespace=namespace, key=key)
+ conn.commit()
+ return deleted
+ finally:
+ conn.close()
+
+ def _clear_namespace_sync(
+ self,
+ namespace: str | None,
+ include_descendants: bool,
+ ) -> int:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ deleted = self._delete_scope_locked(
+ conn,
+ namespace=namespace,
+ include_descendants=include_descendants,
+ )
+ conn.commit()
+ return deleted
+ finally:
+ conn.close()
+
+ def _delete_many_sync(self, keys: list[str], namespace: str) -> int:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ unique_keys = list(dict.fromkeys(keys))
+ if not unique_keys:
+ conn.commit()
+ return 0
+ placeholders = _sql_placeholders(len(unique_keys))
+ provider_rows = conn.execute(
+ f"""
+ SELECT DISTINCT provider_id
+ FROM memory_embeddings
+ WHERE namespace = ? AND key IN ({placeholders})
+ """,
+ (namespace, *unique_keys),
+ ).fetchall()
+ conn.execute(
+ f"DELETE FROM memory_embeddings WHERE namespace = ? AND key IN ({placeholders})",
+ (namespace, *unique_keys),
+ )
+ deleted = conn.execute(
+ f"DELETE FROM memory_records WHERE namespace = ? AND key IN ({placeholders})",
+ (namespace, *unique_keys),
+ ).rowcount
+ if self._fts_enabled:
+ conn.execute(
+ f"DELETE FROM memory_records_fts WHERE namespace = ? AND key IN ({placeholders})",
+ (namespace, *unique_keys),
+ )
+ for row in provider_rows:
+ provider_id = str(row[0]).strip()
+ if provider_id:
+ self._mark_vector_dirty_locked(conn, provider_id)
+ conn.commit()
+ return deleted
+ finally:
+ conn.close()
+
+ def _count_sync(
+ self,
+ namespace: str | None,
+ include_descendants: bool,
+ ) -> int:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ where_sql, params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ )
+ return int(
+ conn.execute(
+ f"SELECT COUNT(*) FROM memory_records WHERE {where_sql}",
+ params,
+ ).fetchone()[0]
+ )
+ finally:
+ conn.close()
+
+ def _stats_sync(
+ self,
+ namespace: str | None,
+ include_descendants: bool,
+ ) -> dict[str, Any]:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ where_sql, params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ )
+ total_items = int(
+ conn.execute(
+ f"SELECT COUNT(*) FROM memory_records WHERE {where_sql}",
+ params,
+ ).fetchone()[0]
+ )
+ ttl_entries = int(
+ conn.execute(
+ f"""
+ SELECT COUNT(*)
+ FROM memory_records
+ WHERE {where_sql} AND expires_at IS NOT NULL
+ """,
+ params,
+ ).fetchone()[0]
+ )
+ total_bytes = int(
+ conn.execute(
+ f"""
+ SELECT COALESCE(SUM(LENGTH(key) + LENGTH(stored_json)), 0)
+ FROM memory_records
+ WHERE {where_sql}
+ """,
+ params,
+ ).fetchone()[0]
+ )
+ namespace_count = int(
+ conn.execute(
+ f"""
+ SELECT COUNT(DISTINCT namespace)
+ FROM memory_records
+ WHERE {where_sql}
+ """,
+ params,
+ ).fetchone()[0]
+ )
+ embedding_where_sql, embedding_params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ alias="e",
+ )
+ embedded_items = int(
+ conn.execute(
+ f"""
+ SELECT COUNT(*)
+ FROM (
+ SELECT DISTINCT e.namespace, e.key
+ FROM memory_embeddings e
+ WHERE {embedding_where_sql}
+ )
+ """,
+ embedding_params,
+ ).fetchone()[0]
+ )
+ indexed_items = total_items
+ dirty_items = max(indexed_items - embedded_items, 0)
+ provider_rows = conn.execute(
+ """
+ SELECT provider_id, dirty
+ FROM memory_vector_state
+ ORDER BY provider_id
+ """
+ ).fetchall()
+ return {
+ "total_items": total_items,
+ "total_bytes": total_bytes,
+ "ttl_entries": ttl_entries,
+ "namespace": (
+ None
+ if namespace is None
+ else normalize_memory_namespace(namespace)
+ ),
+ "namespace_count": namespace_count,
+ "indexed_items": indexed_items,
+ "embedded_items": embedded_items,
+ "dirty_items": dirty_items,
+ "fts_enabled": self._fts_enabled,
+ "vector_backend": self._vector_backend_label(),
+ "vector_indexes": [
+ {
+ "provider_id": str(provider_id),
+ "dirty": bool(dirty),
+ }
+ for provider_id, dirty in provider_rows
+ ],
+ }
+ finally:
+ conn.close()
+
+ def _keyword_candidates_sync(
+ self,
+ query: str,
+ namespace: str | None,
+ include_descendants: bool,
+ limit: int | None,
+ ) -> list[_StoredRecord]:
+ with self._lock:
+ conn = self._connect()
+ try:
+ fetch_limit = max((int(limit) if limit is not None else 10) * 8, 50)
+ where_sql, params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ )
+ seen: set[tuple[str, str]] = set()
+ records: list[_StoredRecord] = []
+ fts_query = self._fts_query(query)
+ if self._fts_enabled and fts_query is not None:
+ fts_where_sql, fts_params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ alias="r",
+ )
+ rows = conn.execute(
+ f"""
+ SELECT r.namespace, r.key, r.stored_json, r.search_text, r.updated_at
+ FROM memory_records_fts f
+ JOIN memory_records r
+ ON r.namespace = f.namespace AND r.key = f.key
+ WHERE {fts_where_sql} AND memory_records_fts MATCH ?
+ ORDER BY bm25(memory_records_fts), r.updated_at DESC
+ LIMIT ?
+ """,
+ (*fts_params, fts_query, fetch_limit),
+ ).fetchall()
+ for row in rows:
+ record = self._stored_record_from_row(row)
+ identity = (record.namespace, record.key)
+ if identity not in seen:
+ seen.add(identity)
+ records.append(record)
+
+ like_query = f"%{str(query).strip()}%"
+ if not records or len(records) < fetch_limit:
+ rows = conn.execute(
+ f"""
+ SELECT namespace, key, stored_json, search_text, updated_at
+ FROM memory_records
+ WHERE {where_sql}
+ AND (? = '%%' OR key LIKE ? COLLATE NOCASE OR search_text LIKE ? COLLATE NOCASE)
+ ORDER BY updated_at DESC
+ LIMIT ?
+ """,
+ (*params, like_query, like_query, like_query, fetch_limit),
+ ).fetchall()
+ for row in rows:
+ record = self._stored_record_from_row(row)
+ identity = (record.namespace, record.key)
+ if identity not in seen:
+ seen.add(identity)
+ records.append(record)
+ return records
+ finally:
+ conn.close()
+
+ def _missing_embeddings_sync(
+ self,
+ provider_id: str,
+ namespace: str | None,
+ include_descendants: bool,
+ ) -> list[_StoredRecord]:
+ with self._lock:
+ conn = self._connect()
+ try:
+ where_sql, params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ alias="r",
+ )
+ rows = conn.execute(
+ f"""
+ SELECT r.namespace, r.key, r.stored_json, r.search_text, r.updated_at
+ FROM memory_records r
+ LEFT JOIN memory_embeddings e
+ ON e.namespace = r.namespace
+ AND e.key = r.key
+ AND e.provider_id = ?
+ WHERE {where_sql} AND e.id IS NULL
+ ORDER BY r.updated_at DESC
+ """,
+ (provider_id, *params),
+ ).fetchall()
+ return [self._stored_record_from_row(row) for row in rows]
+ finally:
+ conn.close()
+
+ def _upsert_embeddings_sync(
+ self,
+ provider_id: str,
+ records: list[_StoredRecord],
+ embeddings: list[list[float]],
+ ) -> None:
+ if not records:
+ return
+ with self._lock:
+ conn = self._connect()
+ try:
+ for index, record in enumerate(records):
+ vector = embeddings[index] if index < len(embeddings) else []
+ conn.execute(
+ """
+ INSERT INTO memory_embeddings(namespace, key, provider_id, embedding_json, updated_at)
+ VALUES(?, ?, ?, ?, ?)
+ ON CONFLICT(namespace, key, provider_id) DO UPDATE SET
+ embedding_json = excluded.embedding_json,
+ updated_at = excluded.updated_at
+ """,
+ (
+ record.namespace,
+ record.key,
+ provider_id,
+ json.dumps(
+ vector, ensure_ascii=False, separators=(",", ":")
+ ),
+ _utcnow().isoformat(),
+ ),
+ )
+ self._mark_vector_dirty_locked(conn, provider_id)
+ conn.commit()
+ finally:
+ conn.close()
+
+ def _vector_candidates_sync(
+ self,
+ provider_id: str,
+ query_embedding: list[float],
+ namespace: str | None,
+ include_descendants: bool,
+ limit: int | None,
+ ) -> list[_VectorCandidate]:
+ if not query_embedding:
+ return []
+ with self._lock:
+ conn = self._connect()
+ try:
+ index = self._vector_indexes.get(provider_id)
+ fetch_limit = max((int(limit) if limit is not None else 10) * 10, 50)
+ if index is not None and self._faiss_available():
+ return self._faiss_vector_candidates_locked(
+ conn=conn,
+ provider_id=provider_id,
+ query_embedding=query_embedding,
+ namespace=namespace,
+ include_descendants=include_descendants,
+ fetch_limit=fetch_limit,
+ )
+ return self._fallback_vector_candidates_locked(
+ conn=conn,
+ provider_id=provider_id,
+ query_embedding=query_embedding,
+ namespace=namespace,
+ include_descendants=include_descendants,
+ fetch_limit=fetch_limit,
+ )
+ finally:
+ conn.close()
+
+ def _ensure_vector_index_sync(self, provider_id: str) -> None:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._init_storage_locked(conn)
+ row = conn.execute(
+ """
+ SELECT dirty
+ FROM memory_vector_state
+ WHERE provider_id = ?
+ """,
+ (provider_id,),
+ ).fetchone()
+ dirty = True if row is None else bool(row[0])
+ if not dirty and provider_id in self._vector_indexes:
+ return
+
+ index_path = (
+ self._vector_dir / f"{self._safe_filename(provider_id)}.faiss"
+ )
+ if not dirty and index_path.exists() and self._faiss_available():
+ try:
+ faiss = self._import_faiss()
+ self._vector_indexes[provider_id] = faiss.read_index(
+ str(index_path)
+ )
+ self._vector_fallbacks.pop(provider_id, None)
+ return
+ except Exception:
+ pass
+
+ rows = conn.execute(
+ """
+ SELECT id, embedding_json
+ FROM memory_embeddings
+ WHERE provider_id = ?
+ ORDER BY id
+ """,
+ (provider_id,),
+ ).fetchall()
+ ids: list[int] = []
+ vectors: list[list[float]] = []
+ for raw_id, raw_vector in rows:
+ vector = self._load_embedding_json(raw_vector)
+ if not vector:
+ continue
+ ids.append(int(raw_id))
+ vectors.append(vector)
+
+ if self._faiss_available() and vectors:
+ faiss = self._import_faiss()
+ np = self._import_numpy()
+ dimension = len(vectors[0])
+ base_index = faiss.IndexFlatIP(dimension)
+ index = faiss.IndexIDMap2(base_index)
+ index.add_with_ids(
+ np.array(vectors, dtype="float32"),
+ np.array(ids, dtype="int64"),
+ )
+ self._vector_indexes[provider_id] = index
+ self._vector_fallbacks.pop(provider_id, None)
+ self._vector_dir.mkdir(parents=True, exist_ok=True)
+ faiss.write_index(index, str(index_path))
+ else:
+ self._vector_indexes[provider_id] = None
+ self._vector_fallbacks[provider_id] = list(
+ zip(ids, vectors, strict=False)
+ )
+ conn.execute(
+ """
+ INSERT INTO memory_vector_state(provider_id, dirty, updated_at)
+ VALUES(?, 0, ?)
+ ON CONFLICT(provider_id) DO UPDATE SET
+ dirty = 0,
+ updated_at = excluded.updated_at
+ """,
+ (provider_id, _utcnow().isoformat()),
+ )
+ conn.commit()
+ finally:
+ conn.close()
+
+ def _faiss_vector_candidates_locked(
+ self,
+ *,
+ conn: sqlite3.Connection,
+ provider_id: str,
+ query_embedding: list[float],
+ namespace: str | None,
+ include_descendants: bool,
+ fetch_limit: int,
+ ) -> list[_VectorCandidate]:
+ index = self._vector_indexes.get(provider_id)
+ if index is None:
+ return []
+ np = self._import_numpy()
+ total_count = int(getattr(index, "ntotal", 0) or 0)
+ if total_count <= 0:
+ return []
+
+ collected: list[_VectorCandidate] = []
+ seen: set[tuple[str, str]] = set()
+ current_limit = min(fetch_limit, total_count)
+ while current_limit > 0:
+ scores, ids = index.search(
+ np.array([query_embedding], dtype="float32"),
+ current_limit,
+ )
+ raw_ids = [int(item) for item in ids[0] if int(item) >= 0]
+ score_map = {
+ int(item_id): max(0.0, float(score))
+ for item_id, score in zip(raw_ids, scores[0], strict=False)
+ }
+ if not score_map:
+ break
+ placeholders = ",".join("?" for _ in score_map)
+ rows = conn.execute(
+ f"""
+ SELECT e.id, r.namespace, r.key, r.stored_json, r.search_text
+ FROM memory_embeddings e
+ JOIN memory_records r
+ ON r.namespace = e.namespace AND r.key = e.key
+ WHERE e.provider_id = ?
+ AND e.id IN ({placeholders})
+ """,
+ (provider_id, *score_map.keys()),
+ ).fetchall()
+ row_map = {int(row[0]): row for row in rows}
+ for item_id in raw_ids:
+ row = row_map.get(item_id)
+ if row is None:
+ continue
+ record_namespace = normalize_memory_namespace(row[1])
+ if not memory_namespace_matches(
+ record_namespace,
+ namespace,
+ include_descendants=include_descendants,
+ ):
+ continue
+ identity = (record_namespace, str(row[2]))
+ if identity in seen:
+ continue
+ seen.add(identity)
+ collected.append(
+ _VectorCandidate(
+ namespace=record_namespace,
+ key=str(row[2]),
+ stored=self._load_stored_json(row[3]),
+ search_text=str(row[4]),
+ score=max(0.0, score_map.get(item_id, 0.0)),
+ )
+ )
+ if len(collected) >= fetch_limit or current_limit >= total_count:
+ break
+ next_limit = min(total_count, current_limit * 2)
+ if next_limit == current_limit:
+ break
+ current_limit = next_limit
+ return collected
+
+ def _fallback_vector_candidates_locked(
+ self,
+ *,
+ conn: sqlite3.Connection,
+ provider_id: str,
+ query_embedding: list[float],
+ namespace: str | None,
+ include_descendants: bool,
+ fetch_limit: int,
+ ) -> list[_VectorCandidate]:
+ rows = conn.execute(
+ """
+ SELECT e.namespace, e.key, e.embedding_json, r.stored_json, r.search_text
+ FROM memory_embeddings e
+ JOIN memory_records r
+ ON r.namespace = e.namespace AND r.key = e.key
+ WHERE e.provider_id = ?
+ """,
+ (provider_id,),
+ ).fetchall()
+ candidates: list[_VectorCandidate] = []
+ for raw_namespace, raw_key, raw_embedding, raw_stored, raw_search_text in rows:
+ record_namespace = normalize_memory_namespace(raw_namespace)
+ if not memory_namespace_matches(
+ record_namespace,
+ namespace,
+ include_descendants=include_descendants,
+ ):
+ continue
+ embedding = self._load_embedding_json(raw_embedding)
+ score = max(0.0, cosine_similarity(query_embedding, embedding))
+ if score <= 0:
+ continue
+ candidates.append(
+ _VectorCandidate(
+ namespace=record_namespace,
+ key=str(raw_key),
+ stored=self._load_stored_json(raw_stored),
+ search_text=str(raw_search_text),
+ score=score,
+ )
+ )
+ candidates.sort(key=lambda item: (-item.score, item.namespace, item.key))
+ return candidates[:fetch_limit]
+
+ def _purge_expired_sync(self) -> None:
+ with self._lock:
+ conn = self._connect()
+ try:
+ self._purge_expired_locked(conn)
+ conn.commit()
+ finally:
+ conn.close()
+
+ def _purge_expired_locked(self, conn: sqlite3.Connection) -> None:
+ self._init_storage_locked(conn)
+ now_iso = _utcnow().isoformat()
+ rows = conn.execute(
+ """
+ SELECT namespace, key
+ FROM memory_records
+ WHERE expires_at IS NOT NULL AND expires_at <= ?
+ """,
+ (now_iso,),
+ ).fetchall()
+ for namespace, key in rows:
+ self._delete_record_locked(
+ conn,
+ namespace=normalize_memory_namespace(namespace),
+ key=str(key),
+ )
+
+ def _delete_record_locked(
+ self,
+ conn: sqlite3.Connection,
+ *,
+ namespace: str,
+ key: str,
+ ) -> bool:
+ provider_rows = conn.execute(
+ """
+ SELECT DISTINCT provider_id
+ FROM memory_embeddings
+ WHERE namespace = ? AND key = ?
+ """,
+ (namespace, key),
+ ).fetchall()
+ conn.execute(
+ "DELETE FROM memory_embeddings WHERE namespace = ? AND key = ?",
+ (namespace, key),
+ )
+ deleted = (
+ conn.execute(
+ "DELETE FROM memory_records WHERE namespace = ? AND key = ?",
+ (namespace, key),
+ ).rowcount
+ > 0
+ )
+ if self._fts_enabled:
+ conn.execute(
+ "DELETE FROM memory_records_fts WHERE namespace = ? AND key = ?",
+ (namespace, key),
+ )
+ for row in provider_rows:
+ provider_id = str(row[0]).strip()
+ if provider_id:
+ self._mark_vector_dirty_locked(conn, provider_id)
+ return deleted
+
+ def _delete_scope_locked(
+ self,
+ conn: sqlite3.Connection,
+ *,
+ namespace: str | None,
+ include_descendants: bool,
+ ) -> int:
+ where_sql, params = self._namespace_where(
+ namespace,
+ include_descendants=include_descendants,
+ )
+ affected_rows = conn.execute(
+ f"""
+ SELECT namespace, key
+ FROM memory_records
+ WHERE {where_sql}
+ """,
+ params,
+ ).fetchall()
+ if not affected_rows:
+ return 0
+
+ pair_placeholders = ", ".join("(?, ?)" for _ in affected_rows)
+ pair_params = tuple(
+ value
+ for raw_namespace, raw_key in affected_rows
+ for value in (normalize_memory_namespace(raw_namespace), str(raw_key))
+ )
+
+ provider_rows = conn.execute(
+ f"""
+ SELECT DISTINCT provider_id
+ FROM memory_embeddings
+ WHERE (namespace, key) IN ({pair_placeholders})
+ """,
+ pair_params,
+ ).fetchall()
+ conn.execute(
+ f"""
+ DELETE FROM memory_embeddings
+ WHERE (namespace, key) IN ({pair_placeholders})
+ """,
+ pair_params,
+ )
+ if self._fts_enabled:
+ conn.execute(
+ f"""
+ DELETE FROM memory_records_fts
+ WHERE (namespace, key) IN ({pair_placeholders})
+ """,
+ pair_params,
+ )
+ deleted = conn.execute(
+ f"""
+ DELETE FROM memory_records
+ WHERE (namespace, key) IN ({pair_placeholders})
+ """,
+ pair_params,
+ ).rowcount
+ for row in provider_rows:
+ provider_id = str(row[0]).strip()
+ if provider_id:
+ self._mark_vector_dirty_locked(conn, provider_id)
+ return deleted
+
+ def _connect(self) -> sqlite3.Connection:
+ self._base_dir.mkdir(parents=True, exist_ok=True)
+ conn = sqlite3.connect(self._db_path)
+ conn.row_factory = sqlite3.Row
+ self._init_storage_locked(conn)
+ return conn
+
+ def _init_storage_locked(self, conn: sqlite3.Connection) -> None:
+ if self._initialized:
+ return
+ conn.execute("PRAGMA journal_mode=WAL")
+ conn.execute("PRAGMA synchronous=NORMAL")
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS memory_records (
+ namespace TEXT NOT NULL,
+ key TEXT NOT NULL,
+ stored_json TEXT NOT NULL,
+ search_text TEXT NOT NULL,
+ expires_at TEXT,
+ updated_at TEXT NOT NULL,
+ PRIMARY KEY(namespace, key)
+ )
+ """
+ )
+ conn.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_memory_records_namespace
+ ON memory_records(namespace)
+ """
+ )
+ conn.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_memory_records_expires_at
+ ON memory_records(expires_at)
+ """
+ )
+ try:
+ conn.execute(
+ """
+ CREATE VIRTUAL TABLE IF NOT EXISTS memory_records_fts
+ USING fts5(namespace UNINDEXED, key, search_text, tokenize='unicode61')
+ """
+ )
+ self._fts_enabled = True
+ except sqlite3.OperationalError:
+ self._fts_enabled = False
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS memory_embeddings (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ namespace TEXT NOT NULL,
+ key TEXT NOT NULL,
+ provider_id TEXT NOT NULL,
+ embedding_json TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ UNIQUE(namespace, key, provider_id)
+ )
+ """
+ )
+ conn.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_memory_embeddings_provider
+ ON memory_embeddings(provider_id, namespace)
+ """
+ )
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS memory_vector_state (
+ provider_id TEXT PRIMARY KEY,
+ dirty INTEGER NOT NULL DEFAULT 1,
+ updated_at TEXT NOT NULL
+ )
+ """
+ )
+ conn.commit()
+ self._initialized = True
+
+ def _sync_fts_row_locked(
+ self,
+ conn: sqlite3.Connection,
+ *,
+ namespace: str,
+ key: str,
+ search_text: str,
+ ) -> None:
+ if not self._fts_enabled:
+ return
+ conn.execute(
+ "DELETE FROM memory_records_fts WHERE namespace = ? AND key = ?",
+ (namespace, key),
+ )
+ conn.execute(
+ """
+ INSERT INTO memory_records_fts(namespace, key, search_text)
+ VALUES(?, ?, ?)
+ """,
+ (namespace, key, search_text),
+ )
+
+ def _mark_vector_dirty_locked(
+ self,
+ conn: sqlite3.Connection,
+ provider_id: str,
+ ) -> None:
+ conn.execute(
+ """
+ INSERT INTO memory_vector_state(provider_id, dirty, updated_at)
+ VALUES(?, 1, ?)
+ ON CONFLICT(provider_id) DO UPDATE SET
+ dirty = 1,
+ updated_at = excluded.updated_at
+ """,
+ (provider_id, _utcnow().isoformat()),
+ )
+ self._vector_indexes.pop(provider_id, None)
+ self._vector_fallbacks.pop(provider_id, None)
+
+ @staticmethod
+ def _combined_score(
+ *,
+ mode: str,
+ keyword_score: float,
+ vector_score: float,
+ ) -> float:
+ if mode == "keyword":
+ return keyword_score
+ if mode == "vector":
+ return vector_score
+ if keyword_score > 0 and vector_score > 0:
+ return min(1.0, 0.65 * vector_score + 0.35 * keyword_score + 0.05)
+ if vector_score > 0:
+ return min(1.0, vector_score)
+ return min(1.0, keyword_score)
+
+ @staticmethod
+ def _load_stored_json(raw_value: Any) -> dict[str, Any]:
+ if isinstance(raw_value, dict):
+ return dict(raw_value)
+ if isinstance(raw_value, str):
+ decoded = json.loads(raw_value)
+ return dict(decoded) if isinstance(decoded, dict) else {}
+ return {}
+
+ @staticmethod
+ def _load_embedding_json(raw_value: Any) -> list[float]:
+ if isinstance(raw_value, list):
+ return [float(item) for item in raw_value]
+ if isinstance(raw_value, str):
+ decoded = json.loads(raw_value)
+ if isinstance(decoded, list):
+ return [float(item) for item in decoded]
+ return []
+
+ @staticmethod
+ def _stored_record_from_row(row: Any) -> _StoredRecord:
+ return _StoredRecord(
+ namespace=normalize_memory_namespace(row[0]),
+ key=str(row[1]),
+ stored=PluginMemoryBackend._load_stored_json(row[2]),
+ search_text=str(row[3]),
+ updated_at=str(row[4]),
+ )
+
+ @staticmethod
+ def _namespace_where(
+ namespace: str | None,
+ *,
+ include_descendants: bool,
+ alias: str | None = None,
+ ) -> tuple[str, tuple[Any, ...]]:
+ column = f"{alias}.namespace" if alias else "namespace"
+ if namespace is None:
+ return "1 = 1", ()
+ normalized_namespace = normalize_memory_namespace(namespace)
+ if not normalized_namespace:
+ if include_descendants:
+ return "1 = 1", ()
+ return f"{column} = ''", ()
+ if include_descendants:
+ escaped_namespace = _escape_like_value(normalized_namespace)
+ return (
+ f"({column} = ? OR {column} LIKE ? ESCAPE '\\')",
+ (normalized_namespace, f"{escaped_namespace}/%"),
+ )
+ return f"{column} = ?", (normalized_namespace,)
+
+ @staticmethod
+ def _fts_query(query: str) -> str | None:
+ stripped = str(query).strip()
+ if not stripped:
+ return None
+ terms = [
+ item for item in re.findall(r"\w+", stripped, flags=re.UNICODE) if item
+ ]
+ if not terms:
+ return None
+ escaped_terms = [term.replace('"', '""') for term in terms[:8]]
+ return " OR ".join(f'"{term}"' for term in escaped_terms)
+
+ @staticmethod
+ def _safe_filename(value: str) -> str:
+ return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(value)).strip("._") or "default"
+
+ @staticmethod
+ def _import_faiss() -> Any:
+ # FAISS often ships without stable type stubs, so keep the lazy import
+ # boundary explicitly dynamic to avoid false-positive Pylance errors.
+ import faiss
+
+ return cast(Any, faiss)
+
+ @staticmethod
+ def _import_numpy():
+ import numpy
+
+ return numpy
+
+ @classmethod
+ def _faiss_available(cls) -> bool:
+ try:
+ faiss = cls._import_faiss()
+ cls._import_numpy()
+ except Exception:
+ return False
+ required_attrs = (
+ "IndexFlatIP",
+ "IndexIDMap2",
+ "read_index",
+ "write_index",
+ )
+ return all(hasattr(faiss, attr) for attr in required_attrs)
+
+ def _vector_backend_label(self) -> str:
+ return "faiss" if self._faiss_available() else "exact"
+
+
+async def _maybe_await(value: Any) -> Any:
+ if asyncio.iscoroutine(value) or isinstance(value, asyncio.Future):
+ return await value
+ return value
+
+
+def extend_memory_namespace(
+ base_namespace: str | None,
+ extra_namespace: str | None,
+) -> str:
+ """Join a base namespace with a relative namespace override."""
+
+ return join_memory_namespace(base_namespace, extra_namespace)
diff --git a/astrbot-sdk/src/astrbot_sdk/_message_types.py b/astrbot-sdk/src/astrbot_sdk/_message_types.py
new file mode 100644
index 0000000000..1d2df56040
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_message_types.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+from typing import Any
+
+_GROUP_MESSAGE_TYPES = {"group", "groupmessage", "group_message"}
+_PRIVATE_MESSAGE_TYPES = {
+ "private",
+ "privatemessage",
+ "private_message",
+ "friend",
+ "friendmessage",
+ "friend_message",
+}
+_OTHER_MESSAGE_TYPES = {"other", "othermessage", "other_message"}
+
+
+def normalize_message_type(
+ value: Any,
+ *,
+ group_id: str | None = None,
+ user_id: str | None = None,
+ empty_default: str = "",
+) -> str:
+ """Collapse SDK-visible message types to canonical values."""
+
+ normalized = str(getattr(value, "value", value) or "").strip().lower()
+ if normalized in _GROUP_MESSAGE_TYPES:
+ return "group"
+ if normalized in _PRIVATE_MESSAGE_TYPES:
+ return "private"
+ if normalized in _OTHER_MESSAGE_TYPES:
+ return "other"
+ if group_id:
+ return "group"
+ if user_id:
+ return "private"
+ if not normalized:
+ return empty_default
+ return "other"
diff --git a/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py b/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py
new file mode 100644
index 0000000000..5d2a3d9b17
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py
@@ -0,0 +1,3 @@
+from ._internal.plugin_logger import PluginLogEntry, PluginLogger
+
+__all__ = ["PluginLogEntry", "PluginLogger"]
diff --git a/astrbot-sdk/src/astrbot_sdk/_star_runtime.py b/astrbot-sdk/src/astrbot_sdk/_star_runtime.py
new file mode 100644
index 0000000000..d6d9fe215d
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_star_runtime.py
@@ -0,0 +1,13 @@
+from ._internal.star_runtime import (
+ bind_star_runtime,
+ current_runtime_context,
+ current_star_context,
+ current_star_instance,
+)
+
+__all__ = [
+ "bind_star_runtime",
+ "current_runtime_context",
+ "current_star_context",
+ "current_star_instance",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/_testing_support.py b/astrbot-sdk/src/astrbot_sdk/_testing_support.py
new file mode 100644
index 0000000000..1e945e8e06
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/_testing_support.py
@@ -0,0 +1,25 @@
+from ._internal.testing_support import (
+ InMemoryDB,
+ InMemoryMemory,
+ MockCapabilityRouter,
+ MockContext,
+ MockLLMClient,
+ MockMessageEvent,
+ MockPeer,
+ MockPlatformClient,
+ RecordedSend,
+ StdoutPlatformSink,
+)
+
+__all__ = [
+ "InMemoryDB",
+ "InMemoryMemory",
+ "MockCapabilityRouter",
+ "MockContext",
+ "MockLLMClient",
+ "MockMessageEvent",
+ "MockPeer",
+ "MockPlatformClient",
+ "RecordedSend",
+ "StdoutPlatformSink",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/cli.py b/astrbot-sdk/src/astrbot_sdk/cli.py
new file mode 100644
index 0000000000..eb7112c8c8
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/cli.py
@@ -0,0 +1,1579 @@
+"""AstrBot SDK 的命令行入口。
+
+本模块提供 astrbot-sdk 命令行工具的所有子命令,包括:
+- init: 创建新插件骨架,生成 plugin.yaml、main.py、README.md 等模板文件
+- validate: 校验插件清单、导入路径和 handler 发现是否正常
+- build: 将插件打包为 .zip 发布包
+- dev: 本地开发模式,支持 --local/--watch/--interactive 等调试选项
+- run: 启动插件主管进程(supervisor),通过 stdio 与 AstrBot 核心通信
+- worker: 内部命令,由 supervisor 调用以启动单个插件工作进程
+
+错误处理:
+所有 CLI 异常都会被分类并返回标准化的退出码和错误提示,
+便于 CI/CD 集成和用户快速定位问题。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import importlib.resources as resources
+import os
+import re
+import sys
+import typing
+import zipfile
+from collections.abc import Coroutine
+from dataclasses import dataclass, field
+from importlib.resources.abc import Traversable
+from pathlib import Path
+from textwrap import dedent
+from typing import Any
+
+import click
+
+from ._internal.sdk_logger import logger
+from .errors import AstrBotError
+from .runtime.bootstrap import run_plugin_worker, run_supervisor, run_websocket_server
+from .runtime.loader import load_plugin, load_plugin_spec, validate_plugin_spec
+
+EXIT_OK = 0
+EXIT_UNEXPECTED = 1
+EXIT_USAGE = 2
+EXIT_PLUGIN_LOAD = 3
+EXIT_RUNTIME = 4
+EXIT_PLUGIN_EXECUTION = 5
+BUILD_EXCLUDED_DIRS = {
+ ".agents",
+ ".claude",
+ ".git",
+ ".idea",
+ ".mypy_cache",
+ ".opencode",
+ ".pytest_cache",
+ ".ruff_cache",
+ ".venv",
+ "__pycache__",
+ "dist",
+}
+BUILD_EXCLUDED_FILES = {
+ "AGENTS.md",
+ "CLAUDE.md",
+ ".astrbot-worker-state.json",
+}
+WATCH_POLL_INTERVAL_SECONDS = 0.5
+SUPPORTED_INIT_AGENTS = ("claude", "codex", "opencode")
+_TEMPLATE_VARIABLE_PATTERN = re.compile(r"{{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*}}")
+INIT_AGENT_SKILL_ROOTS = {
+ "claude": Path(".claude") / "skills",
+ "codex": Path(".agents") / "skills",
+ "opencode": Path(".opencode") / "skills",
+}
+INIT_AGENT_DISPLAY_NAMES = {
+ "claude": "Claude Code",
+ "codex": "Codex",
+ "opencode": "OpenCode",
+}
+INIT_SKILL_TEMPLATE_NAME = "astrbot-plugin-dev"
+INIT_PROJECT_NOTE_TEMPLATE_DIR = ("templates", "project_notes")
+INIT_PROJECT_NOTE_TEMPLATE_NAMES = ("AGENTS.md", "CLAUDE.md")
+
+
+class _CliPluginValidationError(RuntimeError):
+ """CLI 侧的插件结构或打包校验失败。"""
+
+
+class _CliPluginLoadError(RuntimeError):
+ """CLI 侧的本地开发插件加载失败。"""
+
+
+class _CliPluginExecutionError(RuntimeError):
+ """CLI 侧的本地开发插件执行失败。"""
+
+
+@dataclass(slots=True)
+class _PluginTreeWatcher:
+ plugin_dir: Path
+ snapshot: dict[str, tuple[int, int]] = field(init=False, default_factory=dict)
+
+ def __post_init__(self) -> None:
+ self.snapshot = _snapshot_watch_files(self.plugin_dir)
+
+ def poll_changes(self) -> list[str]:
+ current = _snapshot_watch_files(self.plugin_dir)
+ changed = sorted(
+ path
+ for path in set(self.snapshot) | set(current)
+ if self.snapshot.get(path) != current.get(path)
+ )
+ self.snapshot = current
+ return changed
+
+
+@dataclass(slots=True)
+class _LocalDevState:
+ session_id: str
+ user_id: str
+ platform: str
+ group_id: str | None
+ event_type: str
+
+ def dispatch_kwargs(self) -> dict[str, Any]:
+ return {
+ "session_id": str(self.session_id),
+ "user_id": str(self.user_id),
+ "platform": str(self.platform),
+ "group_id": self.group_id,
+ "event_type": str(self.event_type),
+ }
+
+
+def setup_logger(verbose: bool = False) -> None:
+ """初始化 CLI 使用的日志配置。"""
+ logger.remove()
+ logger.add(
+ sys.stderr,
+ format="{time:HH:mm:ss} | {level: <8} | {message}",
+ level="DEBUG" if verbose else "INFO",
+ colorize=True,
+ )
+
+
+def _resolve_protocol_stdout(
+ protocol_stdout: str | None,
+) -> tuple[typing.TextIO, typing.TextIO | None]:
+ configured = str(protocol_stdout).strip() if protocol_stdout is not None else ""
+ if not configured:
+ stdout = sys.stdout
+ if callable(getattr(stdout, "isatty", None)) and stdout.isatty():
+ opened_stdout = open(os.devnull, "w", encoding="utf-8")
+ return opened_stdout, opened_stdout
+ return stdout, None
+ if configured.lower() == "console":
+ return sys.stdout, None
+ output_path = os.devnull if configured.lower() == "silent" else configured
+ opened_stdout = open(output_path, "w", encoding="utf-8")
+ return opened_stdout, opened_stdout
+
+
+def _handle_cli_entrypoint_failure(
+ exc: Exception,
+ *,
+ context: dict[str, Any] | None = None,
+) -> typing.NoReturn:
+ exit_code, error_code, hint = _classify_cli_exception(exc)
+ docs_url = exc.docs_url if isinstance(exc, AstrBotError) else ""
+ details = exc.details if isinstance(exc, AstrBotError) else None
+ _render_cli_error(
+ error_code=error_code,
+ message=str(exc),
+ hint=hint,
+ docs_url=docs_url,
+ details=details,
+ context=context,
+ )
+ if exit_code == EXIT_UNEXPECTED:
+ logger.exception("CLI 异常退出")
+ raise SystemExit(exit_code) from exc
+
+
+def _run_entrypoint(
+ runner: typing.Callable[[], object],
+ *,
+ log_message: str,
+ log_level: str = "info",
+ context: dict[str, Any] | None = None,
+) -> None:
+ getattr(logger, log_level)(log_message)
+ try:
+ runner()
+ except (click.Abort, KeyboardInterrupt):
+ click.echo("\n已中断操作", err=True)
+ raise SystemExit(130)
+ except Exception as exc:
+ _handle_cli_entrypoint_failure(exc, context=context)
+
+
+def _run_async_entrypoint(
+ entrypoint: Coroutine[Any, Any, object],
+ *,
+ log_message: str,
+ log_level: str = "info",
+ context: dict[str, Any] | None = None,
+) -> None:
+ _run_entrypoint(
+ lambda: asyncio.run(entrypoint),
+ log_message=log_message,
+ log_level=log_level,
+ context=context,
+ )
+
+
+def _run_sync_entrypoint(
+ entrypoint: typing.Callable[[], object],
+ *,
+ log_message: str,
+ log_level: str = "info",
+ context: dict[str, Any] | None = None,
+) -> None:
+ _run_entrypoint(
+ entrypoint,
+ log_message=log_message,
+ log_level=log_level,
+ context=context,
+ )
+
+
+def _classify_cli_exception(exc: Exception) -> tuple[int, str, str]:
+ if isinstance(exc, AstrBotError):
+ return (
+ EXIT_RUNTIME,
+ exc.code,
+ exc.hint or "请检查本地 mock core 与插件调用参数",
+ )
+ if isinstance(
+ exc,
+ (
+ _CliPluginValidationError,
+ _CliPluginLoadError,
+ FileNotFoundError,
+ ImportError,
+ ModuleNotFoundError,
+ ),
+ ):
+ return (
+ EXIT_PLUGIN_LOAD,
+ "plugin_load_error",
+ "请检查插件目录、plugin.yaml、requirements.txt(如有)和导入路径",
+ )
+ if isinstance(exc, LookupError):
+ return (
+ EXIT_RUNTIME,
+ "dispatch_error",
+ "请检查 handler 或 capability 是否已正确注册",
+ )
+ if isinstance(exc, _CliPluginExecutionError):
+ return (
+ EXIT_PLUGIN_EXECUTION,
+ "plugin_execution_error",
+ "请检查插件生命周期、handler 或 capability 的实现",
+ )
+ return (
+ EXIT_UNEXPECTED,
+ "unexpected_error",
+ "请查看详细日志,必要时使用 --verbose 重试",
+ )
+
+
+def _render_cli_error(
+ *,
+ error_code: str,
+ message: str,
+ hint: str = "",
+ docs_url: str = "",
+ details: dict[str, Any] | None = None,
+ context: dict[str, Any] | None = None,
+) -> None:
+ click.echo(f"Error[{error_code}]: {message}", err=True)
+ if hint:
+ click.echo(f"Suggestion: {hint}", err=True)
+ if docs_url:
+ click.echo(f"Docs: {docs_url}", err=True)
+ if details:
+ click.echo(f"Details: {details}", err=True)
+ if not context:
+ return
+ for key, value in context.items():
+ click.echo(f"{key}: {value}", err=True)
+
+
+def _render_nonfatal_dev_error(
+ exc: Exception,
+ *,
+ context: dict[str, Any] | None = None,
+) -> None:
+ exit_code, error_code, hint = _classify_cli_exception(exc)
+ _render_cli_error(
+ error_code=error_code,
+ message=str(exc),
+ hint=hint,
+ context=context,
+ )
+ if exit_code == EXIT_UNEXPECTED:
+ logger.exception("watch 模式收到未分类异常")
+
+
+def _should_include_plugin_file(
+ path: Path,
+ *,
+ plugin_root: Path,
+ output_root: Path | None = None,
+) -> bool:
+ # Keep watch/build file selection on the same exclusion contract so hot
+ # reload and packaged artifacts do not silently drift apart.
+ if output_root is not None and _path_is_within(path, output_root):
+ return False
+ relative = path.relative_to(plugin_root)
+ if any(part in BUILD_EXCLUDED_DIRS for part in relative.parts[:-1]):
+ return False
+ if relative.name in BUILD_EXCLUDED_FILES:
+ return False
+ return path.suffix not in {".pyc", ".pyo"}
+
+
+def _iter_watch_files(plugin_dir: Path) -> typing.Iterator[Path]:
+ root = plugin_dir.resolve()
+ stack = [root]
+ while stack:
+ current_dir = stack.pop()
+ try:
+ with os.scandir(current_dir) as entries:
+ for entry in entries:
+ entry_path = Path(entry.path)
+ if entry.is_dir(follow_symlinks=False):
+ if entry.name in BUILD_EXCLUDED_DIRS:
+ continue
+ stack.append(entry_path)
+ continue
+ if not _should_include_plugin_file(
+ entry_path,
+ plugin_root=root,
+ ):
+ continue
+ yield entry_path
+ except FileNotFoundError:
+ continue
+
+
+def _snapshot_watch_files(plugin_dir: Path) -> dict[str, tuple[int, int]]:
+ root = plugin_dir.resolve()
+ snapshot: dict[str, tuple[int, int]] = {}
+ for path in _iter_watch_files(root):
+ try:
+ stat = path.stat()
+ except FileNotFoundError:
+ continue
+ snapshot[path.relative_to(root).as_posix()] = (
+ stat.st_mtime_ns,
+ stat.st_size,
+ )
+ return snapshot
+
+
+def _format_watch_changes(changes: list[str], *, limit: int = 5) -> str:
+ if not changes:
+ return "未知文件"
+ preview = changes[:limit]
+ text = ", ".join(preview)
+ if len(changes) > limit:
+ text += f" 等 {len(changes)} 个文件"
+ return text
+
+
+class _ReloadableLocalDevRunner:
+ def __init__(
+ self,
+ *,
+ plugin_dir: Path,
+ state: _LocalDevState,
+ plugin_load_error: type[Exception],
+ plugin_execution_error: type[Exception],
+ plugin_harness,
+ stdout_platform_sink,
+ ) -> None:
+ self.plugin_dir = plugin_dir
+ self.state = state
+ self._plugin_load_error = plugin_load_error
+ self._plugin_execution_error = plugin_execution_error
+ self._plugin_harness = plugin_harness
+ self._stdout_platform_sink = stdout_platform_sink
+ self._harness = None
+ self._lock = asyncio.Lock()
+
+ def _dispatch_kwargs(self) -> dict[str, Any]:
+ return self.state.dispatch_kwargs()
+
+ async def close(self) -> None:
+ async with self._lock:
+ await self._stop_harness()
+
+ async def reload(self) -> bool:
+ async with self._lock:
+ await self._stop_harness()
+ harness = self._plugin_harness.from_plugin_dir(
+ self.plugin_dir,
+ **self._dispatch_kwargs(),
+ platform_sink=self._stdout_platform_sink(stream=sys.stdout),
+ )
+ try:
+ await harness.start()
+ except self._plugin_load_error as exc:
+ _render_nonfatal_dev_error(
+ _CliPluginLoadError(str(exc)),
+ context={"plugin_dir": self.plugin_dir},
+ )
+ return False
+ except self._plugin_execution_error as exc:
+ _render_nonfatal_dev_error(
+ _CliPluginExecutionError(str(exc)),
+ context={"plugin_dir": self.plugin_dir},
+ )
+ return False
+ self._harness = harness
+ return True
+
+ async def dispatch_text(self, text: str) -> bool:
+ async with self._lock:
+ if self._harness is None:
+ click.echo("当前插件未成功加载,等待下一次文件变更后重试。")
+ return False
+ try:
+ await self._harness.dispatch_text(
+ text,
+ **self._dispatch_kwargs(),
+ )
+ except (self._plugin_load_error, self._plugin_execution_error) as exc:
+ _render_nonfatal_dev_error(
+ _CliPluginExecutionError(str(exc)),
+ context={"plugin_dir": self.plugin_dir},
+ )
+ return False
+ except Exception as exc:
+ _render_nonfatal_dev_error(
+ exc,
+ context={"plugin_dir": self.plugin_dir},
+ )
+ return False
+ return True
+
+ async def _stop_harness(self) -> None:
+ if self._harness is None:
+ return
+ try:
+ await self._harness.stop()
+ finally:
+ self._harness = None
+
+
+async def _run_local_dev_watch(
+ *,
+ runner: _ReloadableLocalDevRunner,
+ event_text: str | None,
+ interactive: bool,
+ watch_poll_interval: float,
+ max_watch_reloads: int | None = None,
+) -> None:
+ watcher = _PluginTreeWatcher(runner.plugin_dir)
+ reload_count = 0
+
+ async def reload_and_maybe_rerun(*, announce: str | None) -> None:
+ if announce:
+ click.echo(announce)
+ if not await runner.reload():
+ return
+ if event_text is not None:
+ await runner.dispatch_text(event_text)
+
+ async def watch_loop(stop_event: asyncio.Event) -> None:
+ nonlocal reload_count
+ while not stop_event.is_set():
+ await asyncio.sleep(watch_poll_interval)
+ changes = watcher.poll_changes()
+ if not changes:
+ continue
+ await reload_and_maybe_rerun(
+ announce=(
+ f"检测到文件变更,重新加载插件:{_format_watch_changes(changes)}"
+ )
+ )
+ reload_count += 1
+ if max_watch_reloads is not None and reload_count >= max_watch_reloads:
+ stop_event.set()
+ return
+
+ stop_event = asyncio.Event()
+ watch_task: asyncio.Task[None] | None = None
+ try:
+ await reload_and_maybe_rerun(
+ announce=(
+ "watch 模式已启动,监听插件目录变更。"
+ if event_text is not None
+ else "watch 模式已启动,监听插件目录变更并按需热重载。"
+ )
+ )
+ if max_watch_reloads == 0:
+ return
+ watch_task = asyncio.create_task(watch_loop(stop_event))
+ if interactive:
+ click.echo(
+ "本地交互模式已启动。可用命令:/session /user /platform /group /private /event /exit"
+ )
+ while not stop_event.is_set():
+ line = await asyncio.to_thread(sys.stdin.readline)
+ if not line:
+ break
+ text = line.strip()
+ if not text:
+ continue
+ if _handle_dev_meta_command(text, runner.state):
+ if text in {"/exit", "/quit"}:
+ break
+ continue
+ await runner.dispatch_text(text)
+ stop_event.set()
+ return
+ await stop_event.wait()
+ finally:
+ stop_event.set()
+ if watch_task is not None:
+ watch_task.cancel()
+ try:
+ await watch_task
+ except asyncio.CancelledError:
+ pass
+ await runner.close()
+
+
+async def _run_local_dev(
+ *,
+ plugin_dir: Path,
+ event_text: str | None,
+ interactive: bool,
+ watch: bool,
+ session_id: str,
+ user_id: str,
+ platform: str,
+ group_id: str | None,
+ event_type: str,
+ watch_poll_interval: float = WATCH_POLL_INTERVAL_SECONDS,
+ max_watch_reloads: int | None = None,
+) -> None:
+ from .testing import (
+ PluginHarness,
+ StdoutPlatformSink,
+ _PluginExecutionError,
+ _PluginLoadError,
+ )
+
+ state = _LocalDevState(
+ session_id=str(session_id),
+ user_id=str(user_id),
+ platform=str(platform),
+ group_id=group_id,
+ event_type=str(event_type),
+ )
+ if watch:
+ runner = _ReloadableLocalDevRunner(
+ plugin_dir=plugin_dir,
+ state=state,
+ plugin_load_error=_PluginLoadError,
+ plugin_execution_error=_PluginExecutionError,
+ plugin_harness=PluginHarness,
+ stdout_platform_sink=StdoutPlatformSink,
+ )
+ await _run_local_dev_watch(
+ runner=runner,
+ event_text=event_text,
+ interactive=interactive,
+ watch_poll_interval=watch_poll_interval,
+ max_watch_reloads=max_watch_reloads,
+ )
+ return
+
+ sink = StdoutPlatformSink(stream=sys.stdout)
+ harness = PluginHarness.from_plugin_dir(
+ plugin_dir,
+ **state.dispatch_kwargs(),
+ platform_sink=sink,
+ )
+ try:
+ async with harness:
+ if interactive:
+ click.echo(
+ "本地交互模式已启动。可用命令:/session /user /platform /group /private /event /exit"
+ )
+ while True:
+ line = await asyncio.to_thread(sys.stdin.readline)
+ if not line:
+ break
+ text = line.strip()
+ if not text:
+ continue
+ if _handle_dev_meta_command(text, state):
+ if text in {"/exit", "/quit"}:
+ break
+ continue
+ await harness.dispatch_text(
+ text,
+ **state.dispatch_kwargs(),
+ )
+ return
+ assert event_text is not None
+ await harness.dispatch_text(event_text, **state.dispatch_kwargs())
+ except _PluginLoadError as exc:
+ raise _CliPluginLoadError(str(exc)) from exc
+ except _PluginExecutionError as exc:
+ raise _CliPluginExecutionError(str(exc)) from exc
+
+
+def _handle_dev_meta_command(command: str, state: _LocalDevState) -> bool:
+ if command in {"/exit", "/quit"}:
+ return True
+ if command.startswith("/session "):
+ state.session_id = command.split(" ", 1)[1].strip()
+ click.echo(f"切换 session_id -> {state.session_id}")
+ return True
+ if command.startswith("/user "):
+ state.user_id = command.split(" ", 1)[1].strip()
+ click.echo(f"切换 user_id -> {state.user_id}")
+ return True
+ if command.startswith("/platform "):
+ state.platform = command.split(" ", 1)[1].strip()
+ click.echo(f"切换 platform -> {state.platform}")
+ return True
+ if command.startswith("/group "):
+ state.group_id = command.split(" ", 1)[1].strip()
+ click.echo(f"切换 group_id -> {state.group_id}")
+ return True
+ if command == "/private":
+ state.group_id = None
+ click.echo("已切换为私聊上下文")
+ return True
+ if command.startswith("/event "):
+ state.event_type = command.split(" ", 1)[1].strip()
+ click.echo(f"切换 event_type -> {state.event_type}")
+ return True
+ return False
+
+
+def _slugify_plugin_name(value: str) -> str:
+ slug = re.sub(r"[^a-zA-Z0-9]+", "_", value).strip("_").lower()
+ return slug or "my_plugin"
+
+
+def _normalize_plugin_name(value: str) -> str:
+ normalized = _slugify_plugin_name(value)
+ if normalized.startswith("astrbot_plugin_"):
+ return normalized
+ normalized = normalized.removeprefix("astrbot_plugin")
+ normalized = normalized.strip("_")
+ suffix = normalized or "my_plugin"
+ return f"astrbot_plugin_{suffix}"
+
+
+def _class_name_for_plugin(value: str) -> str:
+ parts = [part for part in re.split(r"[^a-zA-Z0-9]+", value) if part]
+ if not parts:
+ return "MyPlugin"
+ return "".join(part[:1].upper() + part[1:] for part in parts)
+
+
+def _sanitize_build_part(value: str) -> str:
+ sanitized = re.sub(r"[^a-zA-Z0-9._-]+", "_", value).strip("._-")
+ return sanitized or "artifact"
+
+
+def _parse_init_agents(
+ _ctx: click.Context,
+ _param: click.Parameter,
+ value: str | None,
+) -> tuple[str, ...]:
+ if value is None:
+ return ()
+
+ normalized_agents: list[str] = []
+ seen: set[str] = set()
+ invalid_agents: list[str] = []
+ for raw_agent in value.split(","):
+ candidate = raw_agent.strip().lower()
+ if not candidate:
+ invalid_agents.append("")
+ continue
+ if candidate not in SUPPORTED_INIT_AGENTS:
+ invalid_agents.append(raw_agent.strip())
+ continue
+ if candidate in seen:
+ continue
+ seen.add(candidate)
+ normalized_agents.append(candidate)
+
+ if invalid_agents:
+ supported = ", ".join(SUPPORTED_INIT_AGENTS)
+ invalid = ", ".join(invalid_agents)
+ raise click.BadParameter(f"仅支持以下 agent: {supported};非法值: {invalid}")
+ return tuple(normalized_agents)
+
+
+def _render_init_plugin_yaml(
+ *,
+ plugin_name: str,
+ display_name: str,
+ desc: str,
+ author: str,
+ repo: str,
+ version: str,
+) -> str:
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
+ class_name = _class_name_for_plugin(plugin_name)
+ return dedent(
+ f"""\
+ name: {plugin_name}
+ display_name: {display_name}
+ desc: {desc}
+ author: {author}
+ repo: {repo}
+ version: {version}
+ runtime:
+ python: "{python_version}"
+ components:
+ - class: main:{class_name}
+ """
+ )
+
+
+def _render_init_main_py(*, plugin_name: str) -> str:
+ class_name = _class_name_for_plugin(plugin_name)
+ return dedent(
+ f"""\
+ from astrbot_sdk import Context, MessageEvent, Star, on_command
+
+
+ class {class_name}(Star):
+ @on_command("hello")
+ async def hello(self, event: MessageEvent, ctx: Context) -> None:
+ await event.reply("Hello, World!")
+ """
+ )
+
+
+def _render_init_readme(*, plugin_name: str) -> str:
+ return dedent(
+ f"""\
+ # {plugin_name}
+
+ 一个最小可运行的 AstrBot SDK 插件。
+
+ ## 目录结构
+
+ ```
+ .
+ ├── plugin.yaml
+ ├── requirements.txt
+ ├── main.py
+ └── tests
+ └── test_plugin.py
+ ```
+
+ ## 本地开发
+
+ ```bash
+ astrbot-sdk validate
+ astrbot-sdk dev --local --event-text hello
+ astrbot-sdk dev --local --watch --event-text hello
+ ```
+
+ ## 运行测试
+
+ ```bash
+ python -m pytest tests/test_plugin.py -v
+ ```
+ """
+ )
+
+
+def _render_init_gitignore() -> str:
+ return dedent(
+ """\
+ # Python
+ __pycache__/
+ *.py[cod]
+ *.pyo
+ *.egg-info/
+ dist/
+ build/
+ *.egg
+
+ # 虚拟环境
+ .venv/
+ venv/
+ env/
+
+ # IDE
+ .idea/
+ .vscode/
+ *.swp
+ *.swo
+ *~
+
+ # OS
+ .DS_Store
+ Thumbs.db
+ desktop.ini
+
+ # 测试 / 检查缓存
+ .pytest_cache/
+ .ruff_cache/
+ .mypy_cache/
+ .coverage
+ htmlcov/
+
+ # 开发/构建工具
+ /.claude/
+ /.agents/
+ /.opencode/
+
+ # 图床配置(含 API 密钥等敏感信息)
+ /image_host/config.json
+
+ # 插件测试产物
+ /.astrbot_sdk_testing/
+ """
+ )
+
+
+def _render_init_test_py(*, plugin_name: str) -> str:
+ class_name = _class_name_for_plugin(plugin_name)
+ return dedent(
+ f"""\
+ from pathlib import Path
+
+ import pytest
+
+ from astrbot_sdk.testing import MockContext, MockMessageEvent, PluginHarness
+ from main import {class_name}
+
+
+ @pytest.mark.asyncio
+ async def test_hello_handler():
+ plugin = {class_name}()
+ ctx = MockContext(
+ plugin_id="{plugin_name}",
+ plugin_metadata={{"display_name": "{class_name}"}},
+ )
+ event = MockMessageEvent(text="/hello", context=ctx)
+
+ await plugin.hello(event, ctx)
+
+ assert event.replies == ["Hello, World!"]
+ ctx.platform.assert_sent("Hello, World!")
+
+
+ @pytest.mark.asyncio
+ async def test_hello_dispatch():
+ plugin_dir = Path(__file__).resolve().parents[1]
+
+ async with PluginHarness.from_plugin_dir(plugin_dir) as harness:
+ records = await harness.dispatch_text("hello")
+
+ assert any(record.text == "Hello, World!" for record in records)
+ """
+ )
+
+
+def _plugin_root_hint_for_agent(agent: str) -> str:
+ skill_dir = INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME
+ return "/".join(".." for _ in skill_dir.parts) or "."
+
+
+def _build_agent_template_context(
+ *,
+ plugin_name: str,
+ display_name: str,
+ agent: str,
+) -> dict[str, str]:
+ return {
+ "plugin_name": plugin_name,
+ "display_name": display_name,
+ "class_name": _class_name_for_plugin(plugin_name),
+ "skill_name": f"{plugin_name}_project",
+ "plugin_root": _plugin_root_hint_for_agent(agent),
+ "agent_name": agent,
+ "agent_display_name": INIT_AGENT_DISPLAY_NAMES[agent],
+ "skill_dir_name": INIT_SKILL_TEMPLATE_NAME,
+ }
+
+
+def _render_template_text(template_text: str, context: dict[str, str]) -> str:
+ def replace(match: re.Match[str]) -> str:
+ key = match.group(1)
+ if key not in context:
+ raise _CliPluginValidationError(f"agent 模板变量未定义:{key}")
+ return context[key]
+
+ return _TEMPLATE_VARIABLE_PATTERN.sub(replace, template_text)
+
+
+def _copy_rendered_template_tree(
+ source_dir: Traversable,
+ target_dir: Path,
+ *,
+ context: dict[str, str],
+) -> None:
+ target_dir.mkdir(parents=True, exist_ok=True)
+ for entry in sorted(source_dir.iterdir(), key=lambda item: item.name):
+ destination = target_dir / entry.name
+ if entry.is_dir():
+ _copy_rendered_template_tree(entry, destination, context=context)
+ continue
+ destination.write_text(
+ _render_template_text(entry.read_text(encoding="utf-8"), context),
+ encoding="utf-8",
+ )
+
+
+def _render_init_agent_templates(
+ *,
+ target_dir: Path,
+ plugin_name: str,
+ display_name: str,
+ agents: tuple[str, ...],
+) -> None:
+ if not agents:
+ return
+
+ template_root = resources.files("astrbot_sdk").joinpath(
+ "templates",
+ "skills",
+ INIT_SKILL_TEMPLATE_NAME,
+ )
+ if not template_root.is_dir():
+ raise _CliPluginValidationError(
+ f"未找到项目级 skill 模板:{INIT_SKILL_TEMPLATE_NAME}"
+ )
+
+ for agent in agents:
+ context = _build_agent_template_context(
+ plugin_name=plugin_name,
+ display_name=display_name,
+ agent=agent,
+ )
+ _copy_rendered_template_tree(
+ template_root,
+ target_dir / INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME,
+ context=context,
+ )
+
+
+def _render_init_project_notes(*, target_dir: Path) -> None:
+ template_root = resources.files("astrbot_sdk").joinpath(
+ *INIT_PROJECT_NOTE_TEMPLATE_DIR
+ )
+ if not template_root.is_dir():
+ raise _CliPluginValidationError("未找到项目级说明模板:AGENTS.md / CLAUDE.md")
+
+ for template_name in INIT_PROJECT_NOTE_TEMPLATE_NAMES:
+ template_path = template_root.joinpath(template_name)
+ if not template_path.is_file():
+ raise _CliPluginValidationError(
+ f"未找到项目级说明模板文件:{template_name}"
+ )
+ # Keep these notes as packaged resources so `astr init` behaves the same
+ # from a repo checkout, an sdist, and an installed wheel.
+ (target_dir / template_name).write_text(
+ template_path.read_text(encoding="utf-8"),
+ encoding="utf-8",
+ )
+
+
+def _ensure_plugin_dir_exists(plugin_dir: Path) -> Path:
+ resolved = plugin_dir.resolve()
+ if not resolved.exists() or not resolved.is_dir():
+ raise _CliPluginValidationError(f"插件目录不存在:{plugin_dir}")
+ return resolved
+
+
+def _resolve_dev_plugin_dir(plugin_dir: Path | None) -> Path:
+ if plugin_dir is not None:
+ return plugin_dir
+ current_dir = Path.cwd()
+ if (current_dir / "plugin.yaml").exists():
+ return Path(".")
+ raise click.BadParameter(
+ "未提供 --plugin-dir,且当前目录未找到 plugin.yaml",
+ param_hint="--plugin-dir",
+ )
+
+
+def _load_validated_plugin(plugin_dir: Path) -> tuple[Any, Any]:
+ resolved_dir = _ensure_plugin_dir_exists(plugin_dir)
+ plugin = load_plugin_spec(resolved_dir)
+ try:
+ validate_plugin_spec(plugin)
+ except ValueError as exc:
+ raise _CliPluginValidationError(str(exc)) from exc
+
+ loaded = load_plugin(plugin)
+ if not loaded.instances:
+ raise _CliPluginValidationError(
+ "未找到可加载的组件,请检查 plugin.yaml 中的 components"
+ )
+ return plugin, loaded
+
+
+def _build_kind(plugin: Any) -> str:
+ return (
+ "legacy-main"
+ if bool(plugin.manifest_data.get("__legacy_main__"))
+ else "plugin-yaml"
+ )
+
+
+def _path_is_within(path: Path, root: Path) -> bool:
+ try:
+ path.resolve().relative_to(root.resolve())
+ except ValueError:
+ return False
+ return True
+
+
+def _iter_build_files(plugin_dir: Path, output_dir: Path) -> list[Path]:
+ files: list[Path] = []
+ for path in sorted(plugin_dir.rglob("*")):
+ if path.is_dir():
+ continue
+ if not _should_include_plugin_file(
+ path,
+ plugin_root=plugin_dir,
+ output_root=output_dir,
+ ):
+ continue
+ files.append(path)
+ return files
+
+
+def _prompt_nonempty_text(prompt: str) -> str:
+ while True:
+ value = click.prompt(prompt, type=str, default="", show_default=False).strip()
+ if value:
+ return value
+ click.echo("该字段不能为空,请重新输入。")
+
+
+def _default_init_repo_name(plugin_name: str) -> str:
+ return _normalize_plugin_name(plugin_name)
+
+
+def _collect_init_metadata(name: str | None) -> tuple[str, str, str, str, str]:
+ plugin_name = name if name is not None else _prompt_nonempty_text("插件名字")
+ author = _prompt_nonempty_text("作者")
+ repo = _default_init_repo_name(plugin_name)
+ desc = click.prompt("描述", type=str, default="", show_default=False).strip()
+ version = click.prompt("版本", type=str, default="1.0.0", show_default=True).strip()
+ return plugin_name, author, repo, desc, version or "1.0.0"
+
+
+def _init_plugin(name: str | None, agents: tuple[str, ...] = ()) -> None:
+ raw_name, author, repo, desc, version = _collect_init_metadata(name)
+ plugin_name = _normalize_plugin_name(raw_name)
+ target_dir = Path(plugin_name)
+ if target_dir.exists():
+ raise _CliPluginValidationError(f"目标目录已存在:{target_dir}")
+
+ display_name = raw_name.strip() or plugin_name
+ target_dir.mkdir(parents=True, exist_ok=False)
+ (target_dir / "tests").mkdir()
+ (target_dir / "plugin.yaml").write_text(
+ _render_init_plugin_yaml(
+ plugin_name=plugin_name,
+ display_name=display_name,
+ desc=desc,
+ author=author,
+ repo=repo,
+ version=version,
+ ),
+ encoding="utf-8",
+ )
+ (target_dir / "requirements.txt").write_text("", encoding="utf-8")
+ (target_dir / "main.py").write_text(
+ _render_init_main_py(plugin_name=plugin_name),
+ encoding="utf-8",
+ )
+ (target_dir / "README.md").write_text(
+ _render_init_readme(plugin_name=plugin_name),
+ encoding="utf-8",
+ )
+ (target_dir / ".gitignore").write_text(
+ _render_init_gitignore(),
+ encoding="utf-8",
+ )
+ (target_dir / "tests" / "test_plugin.py").write_text(
+ _render_init_test_py(plugin_name=plugin_name),
+ encoding="utf-8",
+ )
+ _render_init_project_notes(target_dir=target_dir)
+ _render_init_agent_templates(
+ target_dir=target_dir,
+ plugin_name=plugin_name,
+ display_name=display_name,
+ agents=agents,
+ )
+
+ import subprocess
+
+ try:
+ process = subprocess.run(
+ ["git", "init", str(target_dir)],
+ capture_output=True,
+ text=True,
+ )
+ if process.returncode != 0:
+ stderr = process.stderr.strip()
+ raise RuntimeError(
+ f"Git 初始化失败(退出码 {process.returncode})"
+ + (f": {stderr}" if stderr else "")
+ )
+ click.echo(f"Git 仓库已初始化: {target_dir}")
+ except FileNotFoundError:
+ click.echo("警告: 未找到 git 命令,请先安装 git 后手动执行 git init")
+ except RuntimeError as e:
+ click.echo(f"警告: {e}")
+
+ click.echo(f"已创建插件:{target_dir}")
+ if agents:
+ generated_paths = ", ".join(
+ str(INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME)
+ for agent in agents
+ )
+ click.echo(f"已生成项目级 skill:{generated_paths}")
+ click.echo("后续命令:")
+ click.echo(f" astrbot-sdk validate --plugin-dir {target_dir}")
+ click.echo(
+ f" astrbot-sdk dev --local --plugin-dir {target_dir} --event-text hello"
+ )
+
+
+def _validate_plugin(plugin_dir: Path) -> None:
+ plugin, loaded = _load_validated_plugin(plugin_dir)
+ click.echo(f"校验通过:{plugin.name}")
+ click.echo(f"kind: {_build_kind(plugin)}")
+ click.echo(f"plugin_dir: {plugin.plugin_dir}")
+ click.echo(f"handlers: {len(loaded.handlers)}")
+ click.echo(f"capabilities: {len(loaded.capabilities)}")
+ click.echo(f"instances: {len(loaded.instances)}")
+
+
+def _build_plugin(plugin_dir: Path, output_dir: Path | None) -> None:
+ plugin, _ = _load_validated_plugin(plugin_dir)
+ build_dir = (output_dir or (plugin.plugin_dir / "dist")).resolve()
+ build_dir.mkdir(parents=True, exist_ok=True)
+
+ version = _sanitize_build_part(str(plugin.manifest_data.get("version") or "0.0.0"))
+ archive_name = f"{_sanitize_build_part(plugin.name)}-{version}.zip"
+ archive_path = build_dir / archive_name
+
+ with zipfile.ZipFile(
+ archive_path,
+ mode="w",
+ compression=zipfile.ZIP_DEFLATED,
+ ) as archive:
+ for path in _iter_build_files(plugin.plugin_dir, build_dir):
+ archive.write(path, arcname=path.relative_to(plugin.plugin_dir))
+
+ click.echo(f"构建完成:{archive_path}")
+ click.echo(f"artifact: {archive_path}")
+
+
+def _run_websocket_worker_entrypoint(
+ *,
+ worker_id: str | None,
+ plugin_dirs: tuple[Path, ...],
+ host: str,
+ port: int,
+ path: str,
+ tls_ca_file: Path,
+ tls_cert_file: Path,
+ tls_key_file: Path,
+ wire_codec: str,
+) -> None:
+ resolved_plugin_dirs = list(plugin_dirs) if plugin_dirs else [Path.cwd()]
+ _run_async_entrypoint(
+ run_websocket_server(
+ worker_id=worker_id,
+ plugin_dirs=resolved_plugin_dirs,
+ host=host,
+ port=port,
+ path=path,
+ tls_ca_file=tls_ca_file,
+ tls_cert_file=tls_cert_file,
+ tls_key_file=tls_key_file,
+ wire_codec=wire_codec,
+ ),
+ log_message=f"启动 WebSocket Worker,端口:{port}",
+ context={
+ "worker_id": worker_id,
+ "plugin_dirs": resolved_plugin_dirs,
+ "port": port,
+ "path": path,
+ },
+ )
+
+
+@click.group()
+@click.option("-v", "--verbose", is_flag=True, help="Enable verbose output")
+@click.pass_context
+def cli(ctx, verbose: bool) -> None:
+ """AstrBot SDK CLI。"""
+ ctx.ensure_object(dict)
+ ctx.obj["verbose"] = verbose
+ setup_logger(verbose)
+
+
+@cli.command()
+@click.option(
+ "--plugins-dir",
+ default="plugins",
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+ help="Directory containing plugin folders",
+)
+@click.option(
+ "--workers-manifest",
+ default=None,
+ type=click.Path(file_okay=True, dir_okay=False, path_type=Path),
+ help="Supervisor manifest describing remote websocket workers",
+)
+@click.option(
+ "--protocol-stdout",
+ default=None,
+ type=str,
+ help="Redirect runtime protocol stdout to console, silent, or a file path",
+)
+@click.option(
+ "--wire-codec",
+ type=click.Choice(["msgpack", "json"]),
+ default="msgpack",
+ show_default=True,
+ help="Wire codec for runtime protocol",
+)
+def run(
+ plugins_dir: Path,
+ workers_manifest: Path | None,
+ protocol_stdout: str | None,
+ wire_codec: str,
+) -> None:
+ """Start the plugin supervisor over stdio."""
+ transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout)
+ try:
+ _run_async_entrypoint(
+ run_supervisor(
+ plugins_dir=plugins_dir,
+ stdout=transport_stdout,
+ workers_manifest=workers_manifest,
+ wire_codec=wire_codec,
+ ),
+ log_message=f"启动插件主管进程,插件目录:{plugins_dir}",
+ context={
+ "plugins_dir": plugins_dir,
+ "workers_manifest": workers_manifest,
+ },
+ )
+ finally:
+ if opened_stdout is not None:
+ opened_stdout.close()
+
+
+@cli.command()
+@click.argument("name", type=str, required=False)
+@click.option(
+ "--agents",
+ callback=_parse_init_agents,
+ metavar="claude,codex,opencode",
+ help="Generate per-agent project templates, comma-separated: claude,codex,opencode",
+)
+def init(name: str | None, agents: tuple[str, ...]) -> None:
+ """Create a new plugin skeleton in the target directory."""
+ _run_sync_entrypoint(
+ lambda: _init_plugin(name, agents),
+ log_message=f"创建插件:{name or ''}",
+ context={"target": name or ""},
+ )
+
+
+@cli.command()
+@click.option(
+ "--plugin-dir",
+ default=".",
+ show_default=True,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+ help="Plugin directory to validate",
+)
+def validate(plugin_dir: Path) -> None:
+ """Validate plugin manifest, imports and handler discovery."""
+ _run_sync_entrypoint(
+ lambda: _validate_plugin(plugin_dir),
+ log_message=f"校验插件目录:{plugin_dir}",
+ context={"plugin_dir": plugin_dir},
+ )
+
+
+@cli.command()
+@click.option(
+ "--plugin-dir",
+ default=".",
+ show_default=True,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+ help="Plugin directory to package",
+)
+@click.option(
+ "--output-dir",
+ default=None,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+ help="Directory for the build artifact, defaults to /dist",
+)
+def build(plugin_dir: Path, output_dir: Path | None) -> None:
+ """Validate and package a plugin into a zip artifact."""
+ _run_sync_entrypoint(
+ lambda: _build_plugin(plugin_dir, output_dir),
+ log_message=f"构建插件包:{plugin_dir}",
+ context={"plugin_dir": plugin_dir, "output_dir": output_dir},
+ )
+
+
+@cli.command()
+@click.option(
+ "--plugin-dir",
+ required=False,
+ default=None,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+ help="Plugin directory to run locally, defaults to current directory when plugin.yaml exists",
+)
+@click.option("--local", "local_mode", is_flag=True, help="Run against local mock core")
+@click.option(
+ "--standalone",
+ "standalone_mode",
+ is_flag=True,
+ help="Deprecated alias of --local",
+)
+@click.option("--event-text", type=str, help="Single message text to dispatch")
+@click.option("--interactive", is_flag=True, help="Read follow-up messages from stdin")
+@click.option(
+ "--watch",
+ is_flag=True,
+ help="Reload the local harness when plugin files change",
+)
+@click.option("--session-id", default="local-session", show_default=True)
+@click.option("--user-id", default="local-user", show_default=True)
+@click.option("--platform", "platform_name", default="test", show_default=True)
+@click.option("--group-id", default=None)
+@click.option("--event-type", default="message", show_default=True)
+def dev(
+ plugin_dir: Path | None,
+ local_mode: bool,
+ standalone_mode: bool,
+ event_text: str | None,
+ interactive: bool,
+ watch: bool,
+ session_id: str,
+ user_id: str,
+ platform_name: str,
+ group_id: str | None,
+ event_type: str,
+) -> None:
+ """Run a plugin against the local mock core for development."""
+ if not (local_mode or standalone_mode):
+ raise click.BadParameter("当前 dev 只支持 --local/--standalone 模式")
+ if interactive and event_text:
+ raise click.BadParameter("--interactive 与 --event-text 不能同时使用")
+ if not interactive and not event_text:
+ raise click.BadParameter("请提供 --event-text,或改用 --interactive")
+ resolved_plugin_dir = _resolve_dev_plugin_dir(plugin_dir)
+ _run_async_entrypoint(
+ _run_local_dev(
+ plugin_dir=resolved_plugin_dir,
+ event_text=event_text,
+ interactive=interactive,
+ watch=watch,
+ session_id=session_id,
+ user_id=user_id,
+ platform=platform_name,
+ group_id=group_id,
+ event_type=event_type,
+ ),
+ log_message=f"启动本地开发模式:{resolved_plugin_dir}",
+ context={
+ "plugin_dir": resolved_plugin_dir,
+ "session_id": session_id,
+ "platform": platform_name,
+ "event_type": event_type,
+ },
+ )
+
+
+@cli.command(hidden=True)
+@click.option(
+ "--plugin-dir",
+ required=False,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+)
+@click.option(
+ "--group-metadata",
+ required=False,
+ type=click.Path(file_okay=True, dir_okay=False, path_type=Path),
+)
+@click.option(
+ "--protocol-stdout",
+ default=None,
+ type=str,
+ help="Redirect runtime protocol stdout to console, silent, or a file path",
+)
+@click.option(
+ "--wire-codec",
+ type=click.Choice(["msgpack", "json"]),
+ default="msgpack",
+ show_default=True,
+ help="Wire codec for runtime protocol",
+)
+def worker(
+ plugin_dir: Path | None,
+ group_metadata: Path | None,
+ protocol_stdout: str | None,
+ wire_codec: str,
+) -> None:
+ """Internal command used by the supervisor to start a worker."""
+ if plugin_dir is None and group_metadata is None:
+ raise click.UsageError("Either --plugin-dir or --group-metadata is required")
+ if plugin_dir is not None and group_metadata is not None:
+ raise click.UsageError(
+ "--plugin-dir and --group-metadata are mutually exclusive"
+ )
+
+ target = str(group_metadata or plugin_dir)
+ transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout)
+ if group_metadata is not None:
+ entrypoint = run_plugin_worker(
+ group_metadata=group_metadata,
+ stdout=transport_stdout,
+ wire_codec=wire_codec,
+ )
+ else:
+ entrypoint = run_plugin_worker(
+ plugin_dir=plugin_dir,
+ stdout=transport_stdout,
+ wire_codec=wire_codec,
+ )
+ try:
+ _run_async_entrypoint(
+ entrypoint,
+ log_message=f"启动插件工作进程:{target}",
+ log_level="debug",
+ context={"plugin_dir": plugin_dir},
+ )
+ finally:
+ if opened_stdout is not None:
+ opened_stdout.close()
+
+
+@cli.command("serve-worker")
+@click.option("--worker-id", default=None, type=str, help="Stable websocket worker id")
+@click.option(
+ "--plugin-dir",
+ "plugin_dirs",
+ multiple=True,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+ help="Plugin directory to serve; repeat to host multiple plugins in one worker",
+)
+@click.option("--host", default="127.0.0.1", show_default=True)
+@click.option("--port", default=8765, type=int, show_default=True)
+@click.option("--path", default="/", show_default=True)
+@click.option(
+ "--tls-ca-file",
+ required=True,
+ type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
+)
+@click.option(
+ "--tls-cert-file",
+ required=True,
+ type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
+)
+@click.option(
+ "--tls-key-file",
+ required=True,
+ type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
+)
+@click.option(
+ "--wire-codec",
+ type=click.Choice(["msgpack", "json"]),
+ default="msgpack",
+ show_default=True,
+ help="Wire codec for runtime protocol",
+)
+def serve_worker(
+ worker_id: str | None,
+ plugin_dirs: tuple[Path, ...],
+ host: str,
+ port: int,
+ path: str,
+ tls_ca_file: Path,
+ tls_cert_file: Path,
+ tls_key_file: Path,
+ wire_codec: str,
+) -> None:
+ """Serve one or more plugins as a standalone websocket worker."""
+ _run_websocket_worker_entrypoint(
+ worker_id=worker_id,
+ plugin_dirs=plugin_dirs,
+ host=host,
+ port=port,
+ path=path,
+ tls_ca_file=tls_ca_file,
+ tls_cert_file=tls_cert_file,
+ tls_key_file=tls_key_file,
+ wire_codec=wire_codec,
+ )
+
+
+@cli.command(hidden=True)
+@click.option("--worker-id", default=None, type=str)
+@click.option(
+ "--plugin-dir",
+ "plugin_dirs",
+ multiple=True,
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
+)
+@click.option("--host", default="127.0.0.1", show_default=True)
+@click.option("--port", default=8765, type=int, show_default=True)
+@click.option("--path", default="/", show_default=True)
+@click.option(
+ "--tls-ca-file",
+ required=True,
+ type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
+)
+@click.option(
+ "--tls-cert-file",
+ required=True,
+ type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
+)
+@click.option(
+ "--tls-key-file",
+ required=True,
+ type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
+)
+@click.option(
+ "--wire-codec",
+ type=click.Choice(["msgpack", "json"]),
+ default="msgpack",
+ show_default=True,
+ help="Wire codec for runtime protocol",
+)
+def websocket(
+ worker_id: str | None,
+ plugin_dirs: tuple[Path, ...],
+ host: str,
+ port: int,
+ path: str,
+ tls_ca_file: Path,
+ tls_cert_file: Path,
+ tls_key_file: Path,
+ wire_codec: str,
+) -> None:
+ """Deprecated websocket runtime wrapper for standalone worker scenarios."""
+ logger.warning("'astr websocket' is deprecated; use 'astr serve-worker' instead")
+ _run_websocket_worker_entrypoint(
+ worker_id=worker_id,
+ plugin_dirs=plugin_dirs,
+ host=host,
+ port=port,
+ path=path,
+ tls_ca_file=tls_ca_file,
+ tls_cert_file=tls_cert_file,
+ tls_key_file=tls_key_file,
+ wire_codec=wire_codec,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/__init__.py b/astrbot-sdk/src/astrbot_sdk/clients/__init__.py
new file mode 100644
index 0000000000..da7677a183
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/__init__.py
@@ -0,0 +1,98 @@
+"""原生 astrbot-sdk 能力客户端
+
+这些客户端为 Context 提供了用于调用远程能力的狭窄且具类型化 (typed) 的接口。
+它们负责处理能力名称、载荷格式化(payload shaping)以及结果解码,且不会暴露协议或传输层的具体细节。
+
+为了保持 Context 接口的精简与稳定,迁移适配层 (Migration shims) 以及高层级编排逻辑 (higher-level orchestration) 均不包含在这些原生能力客户端之内。
+
+当前公开客户端:
+ - LLMClient: 文本/结构化/流式 LLM 调用
+ - MemoryClient: 记忆搜索、保存、读取、删除
+ - DBClient: 键值存储 get/set/delete/list
+ - PlatformClient: 平台消息发送与成员查询
+ - ProviderClient: Provider 元信息与专用 provider proxy
+ - PersonaManagerClient: 人格管理
+ - ConversationManagerClient: 对话管理
+ - KnowledgeBaseManagerClient: 知识库管理
+ - HTTPClient: Web API 注册
+ - MetadataClient: 插件元数据查询
+ - SkillClient: 运行时注册插件 skill
+"""
+
+from .db import DBClient
+from .http import HTTPClient
+from .llm import ChatMessage, LLMClient, LLMResponse
+from .managers import (
+ ConversationCreateParams,
+ ConversationManagerClient,
+ ConversationRecord,
+ ConversationUpdateParams,
+ KnowledgeBaseCreateParams,
+ KnowledgeBaseManagerClient,
+ KnowledgeBaseRecord,
+ MessageHistoryManagerClient,
+ MessageHistoryPage,
+ MessageHistoryRecord,
+ MessageHistorySender,
+ PersonaCreateParams,
+ PersonaManagerClient,
+ PersonaRecord,
+ PersonaUpdateParams,
+)
+from .memory import MemoryClient
+from .metadata import MetadataClient, PluginMetadata, StarMetadata
+from .permission import PermissionCheckResult, PermissionClient, PermissionManagerClient
+from .platform import PlatformClient, PlatformError, PlatformStats, PlatformStatus
+from .provider import (
+ ManagedProviderRecord,
+ ProviderChangeEvent,
+ ProviderClient,
+ ProviderManagerClient,
+)
+from .registry import HandlerMetadata, RegistryClient
+from .session import SessionPluginManager, SessionServiceManager
+from .skills import SkillClient, SkillRegistration
+
+__all__ = [
+ "ChatMessage",
+ "ConversationCreateParams",
+ "ConversationManagerClient",
+ "ConversationRecord",
+ "ConversationUpdateParams",
+ "DBClient",
+ "HTTPClient",
+ "KnowledgeBaseCreateParams",
+ "KnowledgeBaseManagerClient",
+ "KnowledgeBaseRecord",
+ "MessageHistoryManagerClient",
+ "MessageHistoryPage",
+ "MessageHistoryRecord",
+ "MessageHistorySender",
+ "LLMClient",
+ "LLMResponse",
+ "MemoryClient",
+ "ManagedProviderRecord",
+ "MetadataClient",
+ "PermissionCheckResult",
+ "PermissionClient",
+ "PermissionManagerClient",
+ "PlatformClient",
+ "PlatformError",
+ "PlatformStats",
+ "PlatformStatus",
+ "PersonaCreateParams",
+ "PersonaManagerClient",
+ "PersonaRecord",
+ "PersonaUpdateParams",
+ "ProviderChangeEvent",
+ "ProviderClient",
+ "ProviderManagerClient",
+ "PluginMetadata",
+ "StarMetadata",
+ "HandlerMetadata",
+ "RegistryClient",
+ "SessionPluginManager",
+ "SessionServiceManager",
+ "SkillClient",
+ "SkillRegistration",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/_errors.py b/astrbot-sdk/src/astrbot_sdk/clients/_errors.py
new file mode 100644
index 0000000000..e926321b25
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/_errors.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+
+from ..errors import AstrBotError
+
+
+def client_call_label(
+ client_name: str,
+ method_name: str,
+ details: str | None = None,
+) -> str:
+ label = f"{client_name}.{method_name}"
+ if details:
+ return f"{label} ({details})"
+ return label
+
+
+def wrap_client_exception(
+ *,
+ client_name: str,
+ method_name: str,
+ exc: Exception,
+ details: str | None = None,
+) -> Exception:
+ message = f"{client_call_label(client_name, method_name, details)} failed: {exc}"
+ if isinstance(exc, AstrBotError):
+ return AstrBotError(
+ code=exc.code,
+ message=message,
+ hint=exc.hint,
+ retryable=exc.retryable,
+ docs_url=exc.docs_url,
+ details=exc.details,
+ )
+ try:
+ rebuilt = exc.__class__(message)
+ except Exception:
+ return RuntimeError(message)
+ if isinstance(rebuilt, Exception):
+ return rebuilt
+ return RuntimeError(message)
+
+
+__all__ = ["client_call_label", "wrap_client_exception"]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py b/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py
new file mode 100644
index 0000000000..4a6e9db7d9
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py
@@ -0,0 +1,188 @@
+"""能力代理模块。
+
+提供 CapabilityProxy 类,作为客户端与 Peer 之间的中间层,负责:
+- 检查远程能力是否可用
+- 验证流式调用支持
+- 统一封装 invoke 和 invoke_stream 调用
+
+设计说明:
+ CapabilityProxy 是新版架构的核心组件。每个专用客户端 (LLMClient, DBClient 等)
+ 都通过 CapabilityProxy 与远程通信,并在发起调用时绑定当前插件身份,
+ 让运行时把调用者信息放进协议层而不是业务 payload。
+
+使用示例:
+ proxy = CapabilityProxy(peer)
+
+ # 普通调用
+ result = await proxy.call("llm.chat", {"prompt": "hello"})
+
+ # 流式调用
+ async for delta in proxy.stream("llm.stream_chat", {"prompt": "hello"}):
+ print(delta["text"])
+"""
+
+from __future__ import annotations
+
+from collections.abc import AsyncIterator, Mapping
+from typing import Any, Protocol
+
+from .._internal.invocation_context import caller_plugin_scope
+from ..errors import AstrBotError
+
+
+class _CapabilityDescriptorLike(Protocol):
+ supports_stream: bool | None
+
+
+class _CapabilityPeerLike(Protocol):
+ remote_capability_map: Mapping[str, _CapabilityDescriptorLike]
+ remote_peer: Any | None
+
+ async def invoke(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ stream: bool = False,
+ request_id: str | None = None,
+ ) -> dict[str, Any]: ...
+
+ async def invoke_stream(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ request_id: str | None = None,
+ ) -> AsyncIterator[Any]: ...
+
+
+class CapabilityProxy:
+ """能力代理类,封装 Peer 的能力调用接口。
+
+ 负责在调用前验证能力可用性和流式支持,提供统一的 call/stream 接口。
+
+ Attributes:
+ _peer: 底层 Peer 实例,负责实际的 RPC 通信
+ """
+
+ def __init__(
+ self,
+ peer: _CapabilityPeerLike,
+ caller_plugin_id: str | None = None,
+ request_scope_id: str | None = None,
+ ) -> None:
+ """初始化能力代理。
+
+ Args:
+ peer: Peer 实例,提供 remote_capability_map 和 invoke/invoke_stream 方法
+ """
+ self._peer = peer
+ self._caller_plugin_id = caller_plugin_id
+ self._request_scope_id = request_scope_id
+
+ def _get_descriptor(self, name: str) -> _CapabilityDescriptorLike | None:
+ """获取能力描述符。
+
+ Args:
+ name: 能力名称,如 "llm.chat"
+
+ Returns:
+ 能力描述符,若不存在则返回 None
+ """
+ capability_map = getattr(self._peer, "remote_capability_map", {})
+ if not isinstance(capability_map, Mapping):
+ return None
+ return capability_map.get(name)
+
+ def _remote_initialized(self) -> bool:
+ peer_attrs = getattr(self._peer, "__dict__", None)
+ if not isinstance(peer_attrs, dict):
+ return False
+
+ # Avoid getattr() here: MagicMock synthesizes truthy child attributes and
+ # makes an uninitialized peer look ready.
+ remote_peer = peer_attrs.get("remote_peer")
+ capability_map = peer_attrs.get("remote_capability_map")
+ return bool(remote_peer) or (
+ isinstance(capability_map, Mapping) and bool(capability_map)
+ )
+
+ def _ensure_available(self, name: str, *, stream: bool) -> None:
+ """确保能力可用且支持指定的调用模式。
+
+ Args:
+ name: 能力名称
+ stream: 是否需要流式支持
+
+ Raises:
+ AstrBotError: 能力不存在或流式不支持
+ """
+ descriptor = self._get_descriptor(name)
+ if descriptor is None:
+ if self._remote_initialized():
+ raise AstrBotError.capability_not_found(name)
+ return
+ if stream and not descriptor.supports_stream:
+ raise AstrBotError.invalid_input(f"{name} 不支持 stream=true")
+
+ def _prepare_payload(self, name: str, payload: dict[str, Any]) -> dict[str, Any]:
+ if (
+ not isinstance(self._request_scope_id, str)
+ or not self._request_scope_id
+ or not name.startswith("system.event.")
+ ):
+ return payload
+ scoped_payload = dict(payload)
+ scoped_payload.setdefault("_request_scope_id", self._request_scope_id)
+ return scoped_payload
+
+ async def call(self, name: str, payload: dict[str, Any]) -> dict[str, Any]:
+ """执行普通能力调用(非流式)。
+
+ Args:
+ name: 能力名称,如 "llm.chat", "db.get"
+ payload: 调用参数字典
+
+ Returns:
+ 调用结果字典
+
+ Raises:
+ AstrBotError: 能力不存在或调用失败
+
+ 示例:
+ result = await proxy.call("llm.chat", {"prompt": "hello"})
+ print(result["text"])
+ """
+ self._ensure_available(name, stream=False)
+ prepared_payload = self._prepare_payload(name, payload)
+ with caller_plugin_scope(self._caller_plugin_id):
+ return await self._peer.invoke(name, prepared_payload, stream=False)
+
+ async def stream(
+ self,
+ name: str,
+ payload: dict[str, Any],
+ ) -> AsyncIterator[dict[str, Any]]:
+ """执行流式能力调用。
+
+ Args:
+ name: 能力名称,如 "llm.stream_chat"
+ payload: 调用参数字典
+
+ Yields:
+ 每个增量数据块(phase="delta" 时的 data 字段)
+
+ Raises:
+ AstrBotError: 能力不存在或不支持流式
+
+ 示例:
+ async for delta in proxy.stream("llm.stream_chat", {"prompt": "hello"}):
+ print(delta["text"], end="")
+ """
+ self._ensure_available(name, stream=True)
+ prepared_payload = self._prepare_payload(name, payload)
+ with caller_plugin_scope(self._caller_plugin_id):
+ event_stream = await self._peer.invoke_stream(name, prepared_payload)
+ async for event in event_stream:
+ if event.phase == "delta":
+ yield event.data
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/db.py b/astrbot-sdk/src/astrbot_sdk/clients/db.py
new file mode 100644
index 0000000000..bf2783490d
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/db.py
@@ -0,0 +1,161 @@
+"""数据库客户端模块。
+
+提供键值存储能力,用于持久化插件数据。
+
+功能说明:
+ - 数据永久存储,除非用户显式删除
+ - 值类型支持任意 JSON 数据
+ - 支持前缀查询键列表
+ - 支持批量读写
+ - 支持订阅变更事件
+"""
+
+from __future__ import annotations
+
+from collections.abc import AsyncIterator, Mapping, Sequence
+from typing import Any
+
+from ._proxy import CapabilityProxy
+
+
+class DBClient:
+ """键值数据库客户端。
+
+ 提供插件数据的持久化存储能力,数据永久保存直到显式删除。
+
+ Attributes:
+ _proxy: CapabilityProxy 实例,用于远程能力调用
+ """
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ """初始化数据库客户端。
+
+ Args:
+ proxy: CapabilityProxy 实例
+ """
+ self._proxy = proxy
+
+ async def get(self, key: str) -> Any | None:
+ """获取指定键的值。
+
+ Args:
+ key: 数据键名
+
+ Returns:
+ 存储的值,若键不存在则返回 None
+
+ 示例:
+ data = await ctx.db.get("user_settings")
+ if data:
+ print(data["theme"])
+ """
+ output = await self._proxy.call("db.get", {"key": key})
+ return output.get("value")
+
+ async def set(self, key: str, value: Any) -> None:
+ """设置键值对。
+
+ Args:
+ key: 数据键名
+ value: 要存储的 JSON 值
+
+ 示例:
+ await ctx.db.set("user_settings", {"theme": "dark", "lang": "zh"})
+ await ctx.db.set("greeted", True)
+ """
+ await self._proxy.call("db.set", {"key": key, "value": value})
+
+ async def delete(self, key: str) -> None:
+ """删除指定键的数据。
+
+ Args:
+ key: 要删除的数据键名
+
+ 示例:
+ await ctx.db.delete("user_settings")
+ """
+ await self._proxy.call("db.delete", {"key": key})
+
+ async def list(self, prefix: str | None = None) -> list[str]:
+ """列出匹配前缀的所有键。
+
+ Args:
+ prefix: 键前缀过滤,None 表示列出所有键
+
+ Returns:
+ 匹配的键名列表
+
+ 示例:
+ # 列出所有用户设置相关的键
+ keys = await ctx.db.list("user_")
+ # ["user_settings", "user_profile", "user_history"]
+ """
+ output = await self._proxy.call("db.list", {"prefix": prefix})
+ keys = output.get("keys")
+ if not isinstance(keys, (list, tuple)):
+ return []
+ return [str(item) for item in keys]
+
+ async def get_many(self, keys: Sequence[str]) -> dict[str, Any | None]:
+ """批量获取多个键的值。
+
+ Args:
+ keys: 要读取的键列表
+
+ Returns:
+ 一个 dict,key 为键名,value 为对应值(不存在则为 None)
+
+ 示例:
+ values = await ctx.db.get_many(["user:1", "user:2"])
+ if values["user:1"] is None:
+ print("user:1 missing")
+ """
+ output = await self._proxy.call("db.get_many", {"keys": list(keys)})
+ items = output.get("items")
+ if not isinstance(items, (list, tuple)):
+ return {}
+ result: dict[str, Any | None] = {}
+ for item in items:
+ if not isinstance(item, dict):
+ continue
+ key = item.get("key")
+ if not isinstance(key, str):
+ continue
+ result[key] = item.get("value")
+ return result
+
+ async def set_many(
+ self, items: Mapping[str, Any] | Sequence[tuple[str, Any]]
+ ) -> None:
+ """批量写入多个键值对。
+
+ Args:
+ items: 键值对集合(dict 或二元组序列)
+
+ 示例:
+ await ctx.db.set_many({"user:1": {"name": "a"}, "user:2": {"name": "b"}})
+ """
+ if isinstance(items, Mapping):
+ pairs = list(items.items())
+ else:
+ pairs = list(items)
+
+ payload_items: list[dict[str, Any]] = [
+ {"key": str(key), "value": value} for key, value in pairs
+ ]
+ await self._proxy.call("db.set_many", {"items": payload_items})
+
+ def watch(self, prefix: str | None = None) -> AsyncIterator[dict[str, Any]]:
+ """订阅 KV 变更事件(流式)。
+
+ Args:
+ prefix: 键前缀过滤;None 表示订阅所有键
+
+ Yields:
+ 变更事件 dict:{"op": "set"|"delete", "key": str, "value": Any|None}
+
+ 示例:
+ async for event in ctx.db.watch("user:"):
+ print(event["op"], event["key"])
+ """
+ return self._proxy.stream("db.watch", {"prefix": prefix})
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/http.py b/astrbot-sdk/src/astrbot_sdk/clients/http.py
new file mode 100644
index 0000000000..84c7417af6
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/http.py
@@ -0,0 +1,187 @@
+"""HTTP 客户端模块。
+
+提供 HTTP API 注册能力。
+
+功能说明:
+ - 注册自定义 Web API 端点
+ - 支持异步请求处理
+ - 与宿主 Web 服务器集成
+
+设计说明:
+ 由于跨进程架构,handler 函数无法直接序列化传递。
+ 插件需要先声明处理 HTTP 请求的 capability,然后注册路由到 capability 的映射。
+ 当前插件身份由运行时在协议层透传,客户端 payload 不暴露 `plugin_id`。
+
+ 调用流程:
+ HTTP 请求 → 宿主 Web 服务器 → 查找 route 映射 → invoke capability → Worker 执行 handler → 返回响应
+
+示例:
+ # 插件声明处理 HTTP 请求的 capability
+ @provide_capability(
+ name="my_plugin.http_handler",
+ description="处理 /my_plugin/api 的 HTTP 请求",
+ input_schema={...},
+ output_schema={...}
+ )
+ async def handle_http_request(request_id: str, payload: dict, cancel_token):
+ return {"status": 200, "body": {"result": "ok"}}
+
+ # 注册路由 → capability 映射
+ await ctx.http.register_api(
+ route="/my_plugin/api",
+ methods=["GET", "POST"],
+ handler_capability="my_plugin.http_handler",
+ description="我的 API"
+ )
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+from ..decorators import get_capability_meta
+from ..errors import AstrBotError
+from ._errors import wrap_client_exception
+from ._proxy import CapabilityProxy
+
+
+def _resolve_handler_capability(
+ handler_capability: str | None,
+ handler: Any | None,
+) -> str:
+ if handler_capability and handler is not None:
+ raise AstrBotError.invalid_input(
+ "register_api 不能同时提供 handler_capability 和 handler",
+ hint="请二选一:传 capability 名称字符串,或传 @provide_capability 标记的方法",
+ )
+ if handler_capability:
+ return handler_capability
+ if handler is None:
+ raise AstrBotError.invalid_input(
+ "register_api 需要提供 handler_capability 或 handler",
+ hint="示例:handler_capability='demo.http_handler' 或 handler=self.http_handler_capability",
+ )
+ target = getattr(handler, "__func__", handler)
+ meta = get_capability_meta(target)
+ if meta is None:
+ raise AstrBotError.invalid_input(
+ "register_api(handler=...) 需要传入使用 @provide_capability 声明的方法",
+ hint="请先用 @provide_capability(name='demo.http_handler', ...) 标记该方法",
+ )
+ return meta.descriptor.name
+
+
+class HTTPClient:
+ """HTTP 能力客户端。
+
+ 提供 Web API 注册能力,允许插件暴露自定义 HTTP 端点。
+
+ Attributes:
+ _proxy: CapabilityProxy 实例,用于远程能力调用
+ """
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ """初始化 HTTP 客户端。
+
+ Args:
+ proxy: CapabilityProxy 实例
+ """
+ self._proxy = proxy
+
+ async def register_api(
+ self,
+ route: str,
+ handler_capability: str | None = None,
+ *,
+ handler: Any | None = None,
+ methods: list[str] | None = None,
+ description: str = "",
+ ) -> None:
+ """注册 Web API 端点。
+
+ Args:
+ route: API 路由路径(必须使用 "/{plugin_id}" 或 "/{plugin_id}/...")
+ handler_capability: 处理此路由的 capability 名称
+ handler: 使用 @provide_capability 标记的方法引用
+ methods: HTTP 方法列表,默认 ["GET"]
+ description: API 描述
+
+ 示例:
+ await ctx.http.register_api(
+ route="/my_plugin/api",
+ handler_capability="my_plugin.http_handler",
+ methods=["GET", "POST"],
+ description="我的 API"
+ )
+ """
+ if methods is None:
+ methods = ["GET"]
+ resolved_handler = _resolve_handler_capability(handler_capability, handler)
+ try:
+ await self._proxy.call(
+ "http.register_api",
+ {
+ "route": route,
+ "methods": methods,
+ "handler_capability": resolved_handler,
+ "description": description,
+ },
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="HTTPClient",
+ method_name="register_api",
+ details=f"route={route!r}, methods={list(methods)!r}",
+ exc=exc,
+ ) from exc
+
+ async def unregister_api(
+ self, route: str, methods: list[str] | None = None
+ ) -> None:
+ """注销 Web API 端点。
+
+ Args:
+ route: API 路由路径
+ methods: HTTP 方法列表,None 表示所有方法
+
+ 示例:
+ await ctx.http.unregister_api("/my_plugin/api")
+ """
+ if methods is None:
+ methods = []
+ try:
+ await self._proxy.call(
+ "http.unregister_api",
+ {"route": route, "methods": methods},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="HTTPClient",
+ method_name="unregister_api",
+ details=f"route={route!r}, methods={list(methods)!r}",
+ exc=exc,
+ ) from exc
+
+ async def list_apis(self) -> list[dict[str, Any]]:
+ """列出当前插件注册的所有 API。
+
+ Returns:
+ API 列表,每项包含 route, methods, description
+
+ 示例:
+ apis = await ctx.http.list_apis()
+ for api in apis:
+ print(f"{api['route']}: {api['methods']}")
+ """
+ try:
+ output = await self._proxy.call(
+ "http.list_apis",
+ {},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="HTTPClient",
+ method_name="list_apis",
+ exc=exc,
+ ) from exc
+ return output.get("apis", [])
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/llm.py b/astrbot-sdk/src/astrbot_sdk/clients/llm.py
new file mode 100644
index 0000000000..62ff86d32c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/llm.py
@@ -0,0 +1,293 @@
+"""大语言模型客户端模块。
+
+提供 astrbot-sdk 原生的 LLM 能力调用接口。
+
+设计边界:
+ - `chat()` 是便捷文本接口,返回最终文本
+ - `chat_raw()` 返回完整结构化响应
+ - `stream_chat()` 返回文本增量
+ - Agent 循环、动态工具注册等更高层 orchestration 不放在客户端内,
+ 由上层运行时或独立迁移入口承接
+"""
+
+from __future__ import annotations
+
+from collections.abc import AsyncGenerator, Mapping, Sequence
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from ._proxy import CapabilityProxy
+
+
+class ChatMessage(BaseModel):
+ """聊天消息模型。
+
+ 用于构建对话历史,传递给 LLM。
+
+ Attributes:
+ role: 消息角色,如 "user", "assistant", "system"
+ content: 消息内容
+
+ 示例:
+ history = [
+ ChatMessage(role="user", content="你好"),
+ ChatMessage(role="assistant", content="你好!有什么可以帮助你的?"),
+ ChatMessage(role="user", content="今天天气怎么样?"),
+ ]
+ """
+
+ role: str
+ content: str
+
+
+ChatHistoryItem = ChatMessage | Mapping[str, Any]
+
+
+def _serialize_history(
+ history: Sequence[ChatHistoryItem] | None,
+) -> list[dict[str, Any]]:
+ if history is None:
+ return []
+
+ serialized: list[dict[str, Any]] = []
+ for item in history:
+ if isinstance(item, ChatMessage):
+ serialized.append(item.model_dump())
+ continue
+ if isinstance(item, Mapping):
+ serialized.append(dict(item))
+ continue
+ raise TypeError("history 项必须是 ChatMessage 或 mapping")
+ return serialized
+
+
+def _normalize_chat_context_payload(
+ *,
+ history: Sequence[ChatHistoryItem] | None = None,
+ contexts: Sequence[ChatHistoryItem] | None = None,
+) -> dict[str, list[dict[str, Any]]]:
+ if contexts is not None:
+ return {"contexts": _serialize_history(contexts)}
+ if history is not None:
+ return {"contexts": _serialize_history(history)}
+ return {}
+
+
+def _build_chat_payload(
+ prompt: str,
+ *,
+ system: str | None = None,
+ history: Sequence[ChatHistoryItem] | None = None,
+ contexts: Sequence[ChatHistoryItem] | None = None,
+ provider_id: str | None = None,
+ tool_calls_result: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ temperature: float | None = None,
+ extra: dict[str, Any] | None = None,
+) -> dict[str, Any]:
+ payload: dict[str, Any] = {"prompt": prompt}
+ if system is not None:
+ payload["system"] = system
+ payload.update(_normalize_chat_context_payload(history=history, contexts=contexts))
+ if provider_id is not None:
+ payload["provider_id"] = provider_id
+ if tool_calls_result is not None:
+ payload["tool_calls_result"] = [dict(item) for item in tool_calls_result]
+ if model is not None:
+ payload["model"] = model
+ if temperature is not None:
+ payload["temperature"] = temperature
+ if extra:
+ payload.update(extra)
+ return payload
+
+
+class LLMResponse(BaseModel):
+ """LLM 响应模型。
+
+ 包含完整的 LLM 响应信息,用于 chat_raw() 方法返回。
+
+ Attributes:
+ text: 生成的文本内容
+ usage: Token 使用统计,如 {"prompt_tokens": 10, "completion_tokens": 20}
+ finish_reason: 结束原因,如 "stop", "length", "tool_calls"
+ tool_calls: 工具调用列表(如果 LLM 决定调用工具)
+ """
+
+ text: str
+ usage: dict[str, Any] | None = None
+ finish_reason: str | None = None
+ tool_calls: list[dict[str, Any]] = Field(default_factory=list)
+ role: str | None = None
+ reasoning_content: str | None = None
+ reasoning_signature: str | None = None
+
+
+class LLMClient:
+ """大语言模型客户端。
+
+ 提供与 LLM 交互的能力,支持普通聊天和流式聊天。
+
+ Attributes:
+ _proxy: CapabilityProxy 实例,用于远程能力调用
+ """
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ """初始化 LLM 客户端。
+
+ Args:
+ proxy: CapabilityProxy 实例
+ """
+ self._proxy = proxy
+
+ async def chat(
+ self,
+ prompt: str,
+ *,
+ system: str | None = None,
+ history: Sequence[ChatHistoryItem] | None = None,
+ contexts: Sequence[ChatHistoryItem] | None = None,
+ provider_id: str | None = None,
+ tool_calls_result: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ temperature: float | None = None,
+ **kwargs: Any,
+ ) -> str:
+ """发送聊天请求并返回文本响应。
+
+ 这是简化的聊天接口,仅返回生成的文本内容。
+ 如需完整响应信息(包括 usage、tool_calls),请使用 chat_raw()。
+
+ Args:
+ prompt: 用户输入的提示文本
+ system: 系统提示词,用于指导 LLM 行为
+ history: 对话历史,用于保持上下文连续性
+ model: 指定使用的模型名称(可选,由核心自动选择)
+ temperature: 生成温度,控制随机性(0-1)
+ **kwargs: 额外透传参数,如 `image_urls`、`tools`
+
+ Returns:
+ LLM 生成的文本内容
+
+ 示例:
+ # 简单对话
+ reply = await ctx.llm.chat("你好,介绍一下自己")
+
+ # 带历史的对话
+ history = [
+ ChatMessage(role="user", content="我叫小明"),
+ ChatMessage(role="assistant", content="你好小明!"),
+ ]
+ reply = await ctx.llm.chat("你记得我的名字吗?", history=history)
+ """
+ output = await self._proxy.call(
+ "llm.chat",
+ _build_chat_payload(
+ prompt,
+ system=system,
+ history=history,
+ contexts=contexts,
+ provider_id=provider_id,
+ tool_calls_result=tool_calls_result,
+ model=model,
+ temperature=temperature,
+ extra=kwargs,
+ ),
+ )
+ return str(output.get("text", ""))
+
+ async def chat_raw(
+ self,
+ prompt: str,
+ *,
+ system: str | None = None,
+ history: Sequence[ChatHistoryItem] | None = None,
+ contexts: Sequence[ChatHistoryItem] | None = None,
+ provider_id: str | None = None,
+ tool_calls_result: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ temperature: float | None = None,
+ **kwargs: Any,
+ ) -> LLMResponse:
+ """发送聊天请求并返回完整响应。
+
+ 与 chat() 不同,此方法返回完整的 LLMResponse 对象,
+ 包含 usage、finish_reason、tool_calls 等信息。
+
+ Args:
+ prompt: 用户输入的提示文本
+ **kwargs: 额外参数,如 system, history, model, temperature 等
+
+ Returns:
+ LLMResponse 对象,包含完整响应信息
+
+ 示例:
+ response = await ctx.llm.chat_raw("写一首诗", temperature=0.8)
+ print(f"生成文本: {response.text}")
+ print(f"Token 使用: {response.usage}")
+ """
+ payload = _build_chat_payload(
+ prompt,
+ system=system,
+ history=history,
+ contexts=contexts,
+ provider_id=provider_id,
+ tool_calls_result=tool_calls_result,
+ model=model,
+ temperature=temperature,
+ extra=kwargs,
+ )
+ output = await self._proxy.call(
+ "llm.chat_raw",
+ payload,
+ )
+ return LLMResponse.model_validate(output)
+
+ async def stream_chat(
+ self,
+ prompt: str,
+ *,
+ system: str | None = None,
+ history: Sequence[ChatHistoryItem] | None = None,
+ contexts: Sequence[ChatHistoryItem] | None = None,
+ provider_id: str | None = None,
+ tool_calls_result: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ temperature: float | None = None,
+ **kwargs: Any,
+ ) -> AsyncGenerator[str, None]:
+ """流式聊天,逐块返回响应文本。
+
+ 适用于需要实时显示生成内容的场景,如聊天界面。
+
+ Args:
+ prompt: 用户输入的提示文本
+ system: 系统提示词
+ history: 对话历史
+ model: 指定模型
+ temperature: 采样温度
+ **kwargs: 额外透传参数,如 `image_urls`、`tools`
+
+ Yields:
+ 每个生成的文本块
+
+ 示例:
+ async for chunk in ctx.llm.stream_chat("讲一个故事"):
+ print(chunk, end="", flush=True)
+ """
+ async for data in self._proxy.stream(
+ "llm.stream_chat",
+ _build_chat_payload(
+ prompt,
+ system=system,
+ history=history,
+ contexts=contexts,
+ provider_id=provider_id,
+ tool_calls_result=tool_calls_result,
+ model=model,
+ temperature=temperature,
+ extra=kwargs,
+ ),
+ ):
+ yield str(data.get("text", ""))
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/managers.py b/astrbot-sdk/src/astrbot_sdk/clients/managers.py
new file mode 100644
index 0000000000..c87b91541a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/managers.py
@@ -0,0 +1,885 @@
+"""Typed SDK manager clients for persona, conversation, and knowledge base."""
+
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import Any
+
+from pydantic import BaseModel, ConfigDict, Field, model_validator
+
+from ..errors import AstrBotError, ErrorCodes
+from ..message.components import (
+ BaseMessageComponent,
+ component_to_payload_sync,
+ payload_to_component,
+)
+from ..message.session import MessageSession
+from ._proxy import CapabilityProxy
+
+
+class _ManagerModel(BaseModel):
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
+
+ def to_payload(self) -> dict[str, Any]:
+ return self.model_dump(exclude_none=True)
+
+ def to_update_payload(self) -> dict[str, Any]:
+ return self.model_dump(exclude_unset=True)
+
+
+def _normalize_session(session: str | MessageSession) -> str:
+ return str(session)
+
+
+def _require_message_history_session(
+ session: MessageSession,
+) -> dict[str, str]:
+ if not isinstance(session, MessageSession):
+ raise TypeError(
+ "message_history requires astrbot_sdk.message.session.MessageSession"
+ )
+ return {
+ "platform_id": str(session.platform_id),
+ "message_type": str(session.message_type),
+ "session_id": str(session.session_id),
+ }
+
+
+def _normalize_message_history_parts(
+ parts: list[BaseMessageComponent],
+) -> list[dict[str, Any]]:
+ normalized: list[dict[str, Any]] = []
+ for part in parts:
+ if not isinstance(part, BaseMessageComponent):
+ raise TypeError(
+ "message_history.append requires BaseMessageComponent items in parts"
+ )
+ normalized.append(component_to_payload_sync(part))
+ return normalized
+
+
+def _normalize_message_history_boundary(value: datetime) -> str:
+ if not isinstance(value, datetime):
+ raise TypeError("message_history boundary requires datetime")
+ normalized = value
+ if normalized.tzinfo is None:
+ normalized = normalized.replace(tzinfo=timezone.utc)
+ else:
+ normalized = normalized.astimezone(timezone.utc)
+ return normalized.isoformat()
+
+
+class PersonaRecord(_ManagerModel):
+ persona_id: str
+ system_prompt: str
+ begin_dialogs: list[str] = Field(default_factory=list)
+ tools: list[str] | None = None
+ skills: list[str] | None = None
+ custom_error_message: str | None = None
+ folder_id: str | None = None
+ sort_order: int = 0
+ created_at: str | None = None
+ updated_at: str | None = None
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any] | None) -> PersonaRecord | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class PersonaCreateParams(_ManagerModel):
+ persona_id: str
+ system_prompt: str
+ begin_dialogs: list[str] = Field(default_factory=list)
+ tools: list[str] | None = None
+ skills: list[str] | None = None
+ custom_error_message: str | None = None
+ folder_id: str | None = None
+ sort_order: int = 0
+
+
+class PersonaUpdateParams(_ManagerModel):
+ system_prompt: str | None = None
+ begin_dialogs: list[str] | None = None
+ tools: list[str] | None = None
+ skills: list[str] | None = None
+ custom_error_message: str | None = None
+
+
+class ConversationRecord(_ManagerModel):
+ conversation_id: str
+ session: str
+ platform_id: str
+ history: list[dict[str, Any]] = Field(default_factory=list)
+ title: str | None = None
+ persona_id: str | None = None
+ created_at: str | None = None
+ updated_at: str | None = None
+ token_usage: int | None = None
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any] | None) -> ConversationRecord | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class ConversationCreateParams(_ManagerModel):
+ platform_id: str | None = None
+ history: list[dict[str, Any]] | None = None
+ title: str | None = None
+ persona_id: str | None = None
+
+
+class ConversationUpdateParams(_ManagerModel):
+ history: list[dict[str, Any]] | None = None
+ title: str | None = None
+ persona_id: str | None = None
+ token_usage: int | None = None
+
+
+class MessageHistorySender(_ManagerModel):
+ sender_id: str | None = None
+ sender_name: str | None = None
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> MessageHistorySender | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class MessageHistoryRecord(_ManagerModel):
+ id: int
+ session: MessageSession
+ sender: MessageHistorySender = Field(default_factory=MessageHistorySender)
+ parts: list[BaseMessageComponent] = Field(default_factory=list)
+ metadata: dict[str, Any] = Field(default_factory=dict)
+ created_at: datetime | None = None
+ updated_at: datetime | None = None
+ idempotency_key: str | None = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def _normalize_payload(cls, value: Any) -> Any:
+ if not isinstance(value, dict):
+ return value
+ normalized = dict(value)
+
+ session_payload = normalized.get("session")
+ if isinstance(session_payload, dict):
+ normalized["session"] = MessageSession(
+ platform_id=str(session_payload.get("platform_id", "")),
+ message_type=str(session_payload.get("message_type", "")),
+ session_id=str(session_payload.get("session_id", "")),
+ )
+
+ sender_payload = normalized.get("sender")
+ if isinstance(sender_payload, dict):
+ normalized["sender"] = MessageHistorySender.model_validate(sender_payload)
+ elif sender_payload is None:
+ normalized["sender"] = MessageHistorySender()
+
+ parts_payload = normalized.get("parts")
+ if isinstance(parts_payload, list):
+ normalized["parts"] = [
+ item
+ if isinstance(item, BaseMessageComponent)
+ else payload_to_component(item)
+ for item in parts_payload
+ ]
+
+ metadata_payload = normalized.get("metadata")
+ if not isinstance(metadata_payload, dict):
+ normalized["metadata"] = {}
+
+ return normalized
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> MessageHistoryRecord | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class MessageHistoryPage(_ManagerModel):
+ records: list[MessageHistoryRecord] = Field(default_factory=list)
+ next_cursor: str | None = None
+ total: int | None = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def _normalize_payload(cls, value: Any) -> Any:
+ if not isinstance(value, dict):
+ return value
+ normalized = dict(value)
+ records_payload = normalized.get("records")
+ if isinstance(records_payload, list):
+ normalized["records"] = [
+ record
+ for record in (
+ item
+ if isinstance(item, MessageHistoryRecord)
+ else MessageHistoryRecord.from_payload(item)
+ for item in records_payload
+ )
+ if record is not None
+ ]
+ return normalized
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> MessageHistoryPage | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class KnowledgeBaseRecord(_ManagerModel):
+ kb_id: str
+ kb_name: str
+ description: str | None = None
+ emoji: str | None = None
+ embedding_provider_id: str
+ rerank_provider_id: str | None = None
+ chunk_size: int | None = None
+ chunk_overlap: int | None = None
+ top_k_dense: int | None = None
+ top_k_sparse: int | None = None
+ top_m_final: int | None = None
+ doc_count: int = 0
+ chunk_count: int = 0
+ created_at: str | None = None
+ updated_at: str | None = None
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any] | None) -> KnowledgeBaseRecord | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class KnowledgeBaseCreateParams(_ManagerModel):
+ kb_name: str
+ embedding_provider_id: str
+ description: str | None = None
+ emoji: str | None = None
+ rerank_provider_id: str | None = None
+ chunk_size: int | None = None
+ chunk_overlap: int | None = None
+ top_k_dense: int | None = None
+ top_k_sparse: int | None = None
+ top_m_final: int | None = None
+
+
+class KnowledgeBaseUpdateParams(_ManagerModel):
+ kb_name: str | None = None
+ embedding_provider_id: str | None = None
+ description: str | None = None
+ emoji: str | None = None
+ rerank_provider_id: str | None = None
+ chunk_size: int | None = None
+ chunk_overlap: int | None = None
+ top_k_dense: int | None = None
+ top_k_sparse: int | None = None
+ top_m_final: int | None = None
+
+
+class KnowledgeBaseDocumentRecord(_ManagerModel):
+ doc_id: str
+ kb_id: str
+ doc_name: str
+ file_type: str
+ file_size: int
+ file_path: str = ""
+ chunk_count: int = 0
+ media_count: int = 0
+ created_at: str | None = None
+ updated_at: str | None = None
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> KnowledgeBaseDocumentRecord | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class KnowledgeBaseRetrieveResultItem(_ManagerModel):
+ chunk_id: str
+ doc_id: str
+ kb_id: str
+ kb_name: str
+ doc_name: str
+ chunk_index: int
+ content: str
+ score: float
+ char_count: int
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> KnowledgeBaseRetrieveResultItem | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class KnowledgeBaseRetrieveResult(_ManagerModel):
+ context_text: str
+ results: list[KnowledgeBaseRetrieveResultItem] = Field(default_factory=list)
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> KnowledgeBaseRetrieveResult | None:
+ if not isinstance(payload, dict):
+ return None
+ items = payload.get("results")
+ normalized_items = (
+ [
+ item.model_dump()
+ for item in (
+ KnowledgeBaseRetrieveResultItem.from_payload(candidate)
+ if isinstance(candidate, dict)
+ else None
+ for candidate in items
+ )
+ if item is not None
+ ]
+ if isinstance(items, list)
+ else []
+ )
+ return cls.model_validate(
+ {
+ "context_text": str(payload.get("context_text", "")),
+ "results": normalized_items,
+ }
+ )
+
+
+class KnowledgeBaseDocumentUploadParams(_ManagerModel):
+ file_token: str | None = None
+ url: str | None = None
+ text: str | None = None
+ file_name: str | None = None
+ file_type: str | None = None
+ chunk_size: int | None = None
+ chunk_overlap: int | None = None
+ batch_size: int | None = None
+ tasks_limit: int | None = None
+ max_retries: int | None = None
+ enable_cleaning: bool | None = None
+ cleaning_provider_id: str | None = None
+
+ @model_validator(mode="after")
+ def _validate_source(self) -> KnowledgeBaseDocumentUploadParams:
+ if any(
+ isinstance(value, str) and value.strip()
+ for value in (self.file_token, self.url, self.text)
+ ):
+ return self
+ raise ValueError(
+ "knowledge base document upload requires file_token, url, or text"
+ )
+
+
+class PersonaManagerClient:
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def get_persona(self, persona_id: str) -> PersonaRecord:
+ try:
+ output = await self._proxy.call(
+ "persona.get",
+ {"persona_id": str(persona_id)},
+ )
+ except AstrBotError as exc:
+ if exc.code == ErrorCodes.INVALID_INPUT:
+ raise ValueError(f"persona not found: {persona_id}") from exc
+ raise
+ persona = PersonaRecord.from_payload(output.get("persona"))
+ if persona is None:
+ raise ValueError(f"persona not found: {persona_id}")
+ return persona
+
+ async def get_all_personas(self) -> list[PersonaRecord]:
+ output = await self._proxy.call("persona.list", {})
+ items = output.get("personas")
+ if not isinstance(items, list):
+ return []
+ return [
+ persona
+ for persona in (
+ PersonaRecord.from_payload(item) if isinstance(item, dict) else None
+ for item in items
+ )
+ if persona is not None
+ ]
+
+ async def create_persona(self, params: PersonaCreateParams) -> PersonaRecord:
+ output = await self._proxy.call(
+ "persona.create",
+ {"persona": params.to_payload()},
+ )
+ persona = PersonaRecord.from_payload(output.get("persona"))
+ if persona is None:
+ raise ValueError("persona.create returned no persona")
+ return persona
+
+ async def update_persona(
+ self,
+ persona_id: str,
+ params: PersonaUpdateParams,
+ ) -> PersonaRecord | None:
+ output = await self._proxy.call(
+ "persona.update",
+ {"persona_id": str(persona_id), "persona": params.to_update_payload()},
+ )
+ return PersonaRecord.from_payload(output.get("persona"))
+
+ async def delete_persona(self, persona_id: str) -> None:
+ await self._proxy.call("persona.delete", {"persona_id": str(persona_id)})
+
+
+class ConversationManagerClient:
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def new_conversation(
+ self,
+ session: str | MessageSession,
+ params: ConversationCreateParams | None = None,
+ ) -> str:
+ output = await self._proxy.call(
+ "conversation.new",
+ {
+ "session": _normalize_session(session),
+ "conversation": (params.to_payload() if params is not None else {}),
+ },
+ )
+ return str(output.get("conversation_id", ""))
+
+ async def switch_conversation(
+ self,
+ session: str | MessageSession,
+ conversation_id: str,
+ ) -> None:
+ await self._proxy.call(
+ "conversation.switch",
+ {
+ "session": _normalize_session(session),
+ "conversation_id": str(conversation_id),
+ },
+ )
+
+ async def delete_conversation(
+ self,
+ session: str | MessageSession,
+ conversation_id: str | None = None,
+ ) -> None:
+ """Delete one conversation for the session.
+
+ When ``conversation_id`` is ``None``, this deletes the current selected
+ conversation for the session only. It does not delete all conversations
+ under the session.
+ """
+
+ await self._proxy.call(
+ "conversation.delete",
+ {
+ "session": _normalize_session(session),
+ "conversation_id": conversation_id,
+ },
+ )
+
+ async def get_conversation(
+ self,
+ session: str | MessageSession,
+ conversation_id: str,
+ *,
+ create_if_not_exists: bool = False,
+ ) -> ConversationRecord | None:
+ output = await self._proxy.call(
+ "conversation.get",
+ {
+ "session": _normalize_session(session),
+ "conversation_id": str(conversation_id),
+ "create_if_not_exists": bool(create_if_not_exists),
+ },
+ )
+ return ConversationRecord.from_payload(output.get("conversation"))
+
+ async def get_current_conversation(
+ self,
+ session: str | MessageSession,
+ *,
+ create_if_not_exists: bool = False,
+ ) -> ConversationRecord | None:
+ output = await self._proxy.call(
+ "conversation.get_current",
+ {
+ "session": _normalize_session(session),
+ "create_if_not_exists": bool(create_if_not_exists),
+ },
+ )
+ return ConversationRecord.from_payload(output.get("conversation"))
+
+ async def get_conversations(
+ self,
+ session: str | MessageSession | None = None,
+ *,
+ platform_id: str | None = None,
+ ) -> list[ConversationRecord]:
+ output = await self._proxy.call(
+ "conversation.list",
+ {
+ "session": (
+ _normalize_session(session) if session is not None else None
+ ),
+ "platform_id": platform_id,
+ },
+ )
+ items = output.get("conversations")
+ if not isinstance(items, list):
+ return []
+ return [
+ conversation
+ for conversation in (
+ ConversationRecord.from_payload(item)
+ if isinstance(item, dict)
+ else None
+ for item in items
+ )
+ if conversation is not None
+ ]
+
+ async def update_conversation(
+ self,
+ session: str | MessageSession,
+ conversation_id: str | None = None,
+ params: ConversationUpdateParams | None = None,
+ ) -> None:
+ await self._proxy.call(
+ "conversation.update",
+ {
+ "session": _normalize_session(session),
+ "conversation_id": conversation_id,
+ "conversation": (
+ params.to_update_payload() if params is not None else {}
+ ),
+ },
+ )
+
+ async def unset_persona(
+ self,
+ session: str | MessageSession,
+ conversation_id: str | None = None,
+ ) -> None:
+ await self._proxy.call(
+ "conversation.unset_persona",
+ {
+ "session": _normalize_session(session),
+ "conversation_id": conversation_id,
+ },
+ )
+
+
+class MessageHistoryManagerClient:
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def list(
+ self,
+ session: MessageSession,
+ *,
+ cursor: str | None = None,
+ limit: int = 50,
+ ) -> MessageHistoryPage:
+ output = await self._proxy.call(
+ "message_history.list",
+ {
+ "session": _require_message_history_session(session),
+ "cursor": str(cursor) if cursor is not None else None,
+ "limit": int(limit),
+ },
+ )
+ page = MessageHistoryPage.from_payload(output.get("page"))
+ if page is None:
+ raise ValueError("message_history.list returned no page")
+ return page
+
+ async def get(
+ self,
+ session: MessageSession,
+ record_id: int,
+ ) -> MessageHistoryRecord | None:
+ output = await self._proxy.call(
+ "message_history.get_by_id",
+ {
+ "session": _require_message_history_session(session),
+ "record_id": int(record_id),
+ },
+ )
+ return MessageHistoryRecord.from_payload(output.get("record"))
+
+ async def get_by_id(
+ self,
+ session: MessageSession,
+ record_id: int,
+ ) -> MessageHistoryRecord | None:
+ return await self.get(session, record_id)
+
+ async def append(
+ self,
+ session: MessageSession,
+ *,
+ parts: list[BaseMessageComponent],
+ sender: MessageHistorySender | dict[str, Any],
+ metadata: dict[str, Any] | None = None,
+ idempotency_key: str | None = None,
+ ) -> MessageHistoryRecord:
+ if isinstance(sender, MessageHistorySender):
+ sender_payload = sender.to_payload()
+ elif isinstance(sender, dict):
+ sender_payload = MessageHistorySender.model_validate(sender).to_payload()
+ else:
+ raise TypeError(
+ "message_history.append requires MessageHistorySender for sender"
+ )
+ output = await self._proxy.call(
+ "message_history.append",
+ {
+ "session": _require_message_history_session(session),
+ "sender": sender_payload,
+ "parts": _normalize_message_history_parts(parts),
+ "metadata": dict(metadata or {}),
+ "idempotency_key": (
+ str(idempotency_key) if idempotency_key is not None else None
+ ),
+ },
+ )
+ record = MessageHistoryRecord.from_payload(output.get("record"))
+ if record is None:
+ raise ValueError("message_history.append returned no record")
+ return record
+
+ async def delete_before(
+ self,
+ session: MessageSession,
+ *,
+ before: datetime,
+ ) -> int:
+ output = await self._proxy.call(
+ "message_history.delete_before",
+ {
+ "session": _require_message_history_session(session),
+ "before": _normalize_message_history_boundary(before),
+ },
+ )
+ return int(output.get("deleted_count", 0) or 0)
+
+ async def delete_after(
+ self,
+ session: MessageSession,
+ *,
+ after: datetime,
+ ) -> int:
+ output = await self._proxy.call(
+ "message_history.delete_after",
+ {
+ "session": _require_message_history_session(session),
+ "after": _normalize_message_history_boundary(after),
+ },
+ )
+ return int(output.get("deleted_count", 0) or 0)
+
+ async def delete_all(self, session: MessageSession) -> int:
+ output = await self._proxy.call(
+ "message_history.delete_all",
+ {"session": _require_message_history_session(session)},
+ )
+ return int(output.get("deleted_count", 0) or 0)
+
+
+class KnowledgeBaseManagerClient:
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def list_kbs(self) -> list[KnowledgeBaseRecord]:
+ output = await self._proxy.call("kb.list", {})
+ items = output.get("kbs")
+ if not isinstance(items, list):
+ return []
+ return [
+ kb
+ for kb in (
+ KnowledgeBaseRecord.from_payload(item)
+ if isinstance(item, dict)
+ else None
+ for item in items
+ )
+ if kb is not None
+ ]
+
+ async def get_kb(self, kb_id: str) -> KnowledgeBaseRecord | None:
+ output = await self._proxy.call("kb.get", {"kb_id": str(kb_id)})
+ return KnowledgeBaseRecord.from_payload(output.get("kb"))
+
+ async def create_kb(
+ self,
+ params: KnowledgeBaseCreateParams,
+ ) -> KnowledgeBaseRecord:
+ output = await self._proxy.call("kb.create", {"kb": params.to_payload()})
+ kb = KnowledgeBaseRecord.from_payload(output.get("kb"))
+ if kb is None:
+ raise ValueError("kb.create returned no knowledge base")
+ return kb
+
+ async def update_kb(
+ self,
+ kb_id: str,
+ params: KnowledgeBaseUpdateParams,
+ ) -> KnowledgeBaseRecord | None:
+ output = await self._proxy.call(
+ "kb.update",
+ {"kb_id": str(kb_id), "kb": params.to_update_payload()},
+ )
+ return KnowledgeBaseRecord.from_payload(output.get("kb"))
+
+ async def delete_kb(self, kb_id: str) -> bool:
+ output = await self._proxy.call("kb.delete", {"kb_id": str(kb_id)})
+ return bool(output.get("deleted", False))
+
+ async def retrieve(
+ self,
+ query: str,
+ *,
+ kb_ids: list[str] | None = None,
+ kb_names: list[str] | None = None,
+ top_k_fusion: int | None = None,
+ top_m_final: int | None = None,
+ ) -> KnowledgeBaseRetrieveResult | None:
+ request_payload: dict[str, Any] = {
+ "query": str(query),
+ "kb_ids": [str(item) for item in (kb_ids or [])],
+ "kb_names": [str(item) for item in (kb_names or [])],
+ }
+ if top_k_fusion is not None:
+ request_payload["top_k_fusion"] = int(top_k_fusion)
+ if top_m_final is not None:
+ request_payload["top_m_final"] = int(top_m_final)
+ output = await self._proxy.call(
+ "kb.retrieve",
+ request_payload,
+ )
+ return KnowledgeBaseRetrieveResult.from_payload(output.get("result"))
+
+ async def upload_document(
+ self,
+ kb_id: str,
+ params: KnowledgeBaseDocumentUploadParams,
+ ) -> KnowledgeBaseDocumentRecord:
+ output = await self._proxy.call(
+ "kb.document.upload",
+ {"kb_id": str(kb_id), "document": params.to_payload()},
+ )
+ document = KnowledgeBaseDocumentRecord.from_payload(output.get("document"))
+ if document is None:
+ raise ValueError("kb.document.upload returned no document")
+ return document
+
+ async def list_documents(
+ self,
+ kb_id: str,
+ *,
+ offset: int = 0,
+ limit: int = 100,
+ ) -> list[KnowledgeBaseDocumentRecord]:
+ output = await self._proxy.call(
+ "kb.document.list",
+ {"kb_id": str(kb_id), "offset": int(offset), "limit": int(limit)},
+ )
+ items = output.get("documents")
+ if not isinstance(items, list):
+ return []
+ return [
+ document
+ for document in (
+ KnowledgeBaseDocumentRecord.from_payload(item)
+ if isinstance(item, dict)
+ else None
+ for item in items
+ )
+ if document is not None
+ ]
+
+ async def get_document(
+ self,
+ kb_id: str,
+ doc_id: str,
+ ) -> KnowledgeBaseDocumentRecord | None:
+ output = await self._proxy.call(
+ "kb.document.get",
+ {"kb_id": str(kb_id), "doc_id": str(doc_id)},
+ )
+ return KnowledgeBaseDocumentRecord.from_payload(output.get("document"))
+
+ async def delete_document(
+ self,
+ kb_id: str,
+ doc_id: str,
+ ) -> bool:
+ output = await self._proxy.call(
+ "kb.document.delete",
+ {"kb_id": str(kb_id), "doc_id": str(doc_id)},
+ )
+ return bool(output.get("deleted", False))
+
+ async def refresh_document(
+ self,
+ kb_id: str,
+ doc_id: str,
+ ) -> KnowledgeBaseDocumentRecord | None:
+ output = await self._proxy.call(
+ "kb.document.refresh",
+ {"kb_id": str(kb_id), "doc_id": str(doc_id)},
+ )
+ return KnowledgeBaseDocumentRecord.from_payload(output.get("document"))
+
+
+__all__ = [
+ "ConversationCreateParams",
+ "ConversationManagerClient",
+ "ConversationRecord",
+ "ConversationUpdateParams",
+ "KnowledgeBaseCreateParams",
+ "KnowledgeBaseDocumentRecord",
+ "KnowledgeBaseDocumentUploadParams",
+ "KnowledgeBaseManagerClient",
+ "KnowledgeBaseRecord",
+ "KnowledgeBaseRetrieveResult",
+ "KnowledgeBaseRetrieveResultItem",
+ "KnowledgeBaseUpdateParams",
+ "MessageHistoryManagerClient",
+ "MessageHistoryPage",
+ "MessageHistoryRecord",
+ "MessageHistorySender",
+ "PersonaCreateParams",
+ "PersonaManagerClient",
+ "PersonaRecord",
+ "PersonaUpdateParams",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/memory.py b/astrbot-sdk/src/astrbot_sdk/clients/memory.py
new file mode 100644
index 0000000000..55d302ca4f
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/memory.py
@@ -0,0 +1,426 @@
+"""记忆客户端模块。
+
+提供 AI 记忆存储能力,用于存储和检索对话记忆、用户偏好等上下文数据。
+
+设计说明:
+ MemoryClient 与 DBClient 的区别:
+ - DBClient: 简单的键值存储,精确匹配
+ - MemoryClient: 支持基于当前 bridge 行为的记忆检索,适合 AI 上下文管理
+
+ 记忆系统可用于:
+ - 存储用户偏好和设置
+ - 记录对话摘要
+ - 缓存 AI 推理结果
+"""
+
+from __future__ import annotations
+
+from typing import Any, Literal
+
+from .._internal.memory_utils import join_memory_namespace
+from ._proxy import CapabilityProxy
+
+
+def _normalize_search_item(item: Any) -> dict[str, Any] | None:
+ if not isinstance(item, dict):
+ return None
+ normalized = dict(item)
+ value = normalized.get("value")
+ if isinstance(value, dict):
+ for key, payload_value in value.items():
+ normalized.setdefault(str(key), payload_value)
+ return normalized
+
+
+class MemoryClient:
+ """记忆客户端。
+
+ 提供 AI 记忆的存储和检索能力。
+
+ Attributes:
+ _proxy: CapabilityProxy 实例,用于远程能力调用
+ """
+
+ def __init__(
+ self,
+ proxy: CapabilityProxy,
+ *,
+ namespace: str | None = None,
+ ) -> None:
+ """初始化记忆客户端。
+
+ Args:
+ proxy: CapabilityProxy 实例
+ """
+ self._proxy = proxy
+ self._namespace = join_memory_namespace(namespace)
+
+ def namespace(self, *parts: Any) -> MemoryClient:
+ """创建一个工作在子命名空间中的派生客户端。"""
+
+ return MemoryClient(
+ self._proxy,
+ namespace=join_memory_namespace(self._namespace, *parts),
+ )
+
+ def _resolve_exact_namespace(self, namespace: str | None) -> str:
+ if namespace is None:
+ return self._namespace
+ return join_memory_namespace(self._namespace, namespace)
+
+ def _resolve_scope_namespace(self, namespace: str | None) -> tuple[bool, str]:
+ if namespace is None:
+ if self._namespace:
+ return True, self._namespace
+ return False, ""
+ return True, join_memory_namespace(self._namespace, namespace)
+
+ async def search(
+ self,
+ query: str,
+ *,
+ mode: Literal["auto", "keyword", "vector", "hybrid"] = "auto",
+ limit: int | None = None,
+ min_score: float | None = None,
+ provider_id: str | None = None,
+ namespace: str | None = None,
+ include_descendants: bool = True,
+ ) -> list[dict[str, Any]]:
+ """搜索记忆项。
+
+ 默认会在有 embedding provider 时执行 hybrid 检索,
+ 否则退化为关键词检索。返回结果包含 `score` 与 `match_type` 字段。
+
+ Args:
+ query: 搜索查询文本
+ mode: 搜索模式,支持 auto/keyword/vector/hybrid
+ limit: 最大返回条数
+ min_score: 最低分数阈值
+ provider_id: 指定 embedding provider,默认使用当前激活的 provider
+
+ Returns:
+ 匹配的记忆项列表,按相关度排序
+
+ 示例:
+ results = await ctx.memory.search(
+ "用户喜欢什么颜色",
+ mode="hybrid",
+ limit=5,
+ )
+ for item in results:
+ print(item["key"], item["score"], item["match_type"])
+ """
+ payload: dict[str, Any] = {"query": query, "mode": mode}
+ if limit is not None:
+ payload["limit"] = limit
+ if min_score is not None:
+ payload["min_score"] = min_score
+ if provider_id is not None:
+ payload["provider_id"] = provider_id
+ has_namespace, resolved_namespace = self._resolve_scope_namespace(namespace)
+ if has_namespace:
+ payload["namespace"] = resolved_namespace
+ payload["include_descendants"] = bool(include_descendants)
+ output = await self._proxy.call("memory.search", payload)
+ items = output.get("items")
+ if not isinstance(items, (list, tuple)):
+ return []
+ normalized_items: list[dict[str, Any]] = []
+ for item in items:
+ normalized = _normalize_search_item(item)
+ if normalized is not None:
+ normalized_items.append(normalized)
+ return normalized_items
+
+ async def save(
+ self,
+ key: str,
+ value: dict[str, Any] | None = None,
+ namespace: str | None = None,
+ **extra: Any,
+ ) -> None:
+ """保存记忆项。
+
+ 将数据存储到记忆系统,可通过 search() 检索或 get() 精确获取。
+
+ Args:
+ key: 记忆项的唯一标识键
+ value: 要存储的数据字典
+ **extra: 额外的键值对,会合并到 value 中
+ Raises:
+ TypeError: 如果 value 不是 dict 类型
+ 示例:
+ 保存用户偏好
+ await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
+
+ 使用关键字参数
+ await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
+
+ 使用 embedding_text 显式指定检索文本
+ await ctx.memory.save(
+ "profile",
+ {"name": "alice", "embedding_text": "Alice 喜欢蓝色和海边"},
+ )
+ """
+ if value is not None and not isinstance(value, dict):
+ raise TypeError("memory.save 的 value 必须是 dict")
+ payload = dict(value or {})
+ if extra:
+ payload.update(extra)
+ request: dict[str, Any] = {"key": key, "value": payload}
+ request["namespace"] = self._resolve_exact_namespace(namespace)
+ await self._proxy.call("memory.save", request)
+
+ async def get(
+ self,
+ key: str,
+ *,
+ namespace: str | None = None,
+ ) -> dict[str, Any] | None:
+ """精确获取单个记忆项。
+
+ 通过唯一键精确获取记忆内容,不经过搜索匹配。
+
+ Args:
+ key: 记忆项的唯一键
+
+ Returns:
+ 记忆项内容字典,若不存在则返回 None
+
+ 示例:
+ pref = await ctx.memory.get("user_pref")
+ if pref:
+ print(f"用户偏好主题: {pref.get('theme')}")
+ """
+ payload: dict[str, Any] = {"key": key}
+ payload["namespace"] = self._resolve_exact_namespace(namespace)
+ output = await self._proxy.call("memory.get", payload)
+ value = output.get("value")
+ return value if isinstance(value, dict) else None
+
+ async def list_keys(
+ self,
+ *,
+ namespace: str | None = None,
+ ) -> list[str]:
+ """列出指定精确命名空间下的全部键。"""
+
+ payload: dict[str, Any] = {
+ "namespace": self._resolve_exact_namespace(namespace)
+ }
+ output = await self._proxy.call("memory.list_keys", payload)
+ keys = output.get("keys")
+ if not isinstance(keys, (list, tuple)):
+ return []
+ return [str(item) for item in keys]
+
+ async def exists(
+ self,
+ key: str,
+ *,
+ namespace: str | None = None,
+ ) -> bool:
+ """检查指定精确命名空间中是否存在某个键。"""
+
+ payload: dict[str, Any] = {"key": key}
+ payload["namespace"] = self._resolve_exact_namespace(namespace)
+ output = await self._proxy.call("memory.exists", payload)
+ return bool(output.get("exists", False))
+
+ async def delete(
+ self,
+ key: str,
+ *,
+ namespace: str | None = None,
+ ) -> None:
+ """删除记忆项。
+
+ Args:
+ key: 要删除的记忆项键名
+
+ 示例:
+ await ctx.memory.delete("old_note")
+ """
+ payload: dict[str, Any] = {"key": key}
+ payload["namespace"] = self._resolve_exact_namespace(namespace)
+ await self._proxy.call("memory.delete", payload)
+
+ async def clear_namespace(
+ self,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = False,
+ ) -> int:
+ """清空命名空间中的记忆项,可选递归清空子命名空间。"""
+
+ payload: dict[str, Any] = {
+ "namespace": self._resolve_exact_namespace(namespace),
+ "include_descendants": bool(include_descendants),
+ }
+ output = await self._proxy.call("memory.clear_namespace", payload)
+ return int(output.get("deleted_count", 0))
+
+ async def save_with_ttl(
+ self,
+ key: str,
+ value: dict[str, Any],
+ ttl_seconds: int,
+ *,
+ namespace: str | None = None,
+ ) -> None:
+ """保存带过期时间的记忆项。
+
+ 与 save() 不同,此方法允许设置记忆项的存活时间(TTL),
+ 过期后记忆项将自动删除。
+
+ Args:
+ key: 记忆项的唯一标识键
+ value: 要存储的数据字典
+ ttl_seconds: 存活时间(秒),必须大于 0
+
+ Raises:
+ TypeError: 如果 value 不是 dict 类型
+ ValueError: 如果 ttl_seconds 小于 1
+
+ 示例:
+ # 保存临时会话状态,1小时后过期
+ await ctx.memory.save_with_ttl(
+ "session_temp",
+ {"state": "waiting"},
+ ttl_seconds=3600,
+ )
+ """
+ if not isinstance(value, dict):
+ raise TypeError("memory.save_with_ttl 的 value 必须是 dict")
+ if ttl_seconds < 1:
+ raise ValueError("ttl_seconds 必须大于 0")
+ payload: dict[str, Any] = {
+ "key": key,
+ "value": value,
+ "ttl_seconds": ttl_seconds,
+ }
+ payload["namespace"] = self._resolve_exact_namespace(namespace)
+ await self._proxy.call("memory.save_with_ttl", payload)
+
+ async def get_many(
+ self,
+ keys: list[str],
+ *,
+ namespace: str | None = None,
+ ) -> list[dict[str, Any]]:
+ """批量获取多个记忆项。
+
+ 一次性获取多个键对应的记忆内容,比多次调用 get() 更高效。
+
+ Args:
+ keys: 记忆项键名列表
+
+ Returns:
+ 记忆项列表,每项包含 key 和 value 字段,
+ 不存在的键返回 value 为 None
+
+ 示例:
+ items = await ctx.memory.get_many(["pref1", "pref2", "pref3"])
+ for item in items:
+ if item["value"]:
+ print(f"{item['key']}: {item['value']}")
+ """
+ payload: dict[str, Any] = {"keys": keys}
+ payload["namespace"] = self._resolve_exact_namespace(namespace)
+ output = await self._proxy.call("memory.get_many", payload)
+ items = output.get("items")
+ if not isinstance(items, (list, tuple)):
+ return []
+ return [dict(item) for item in items if isinstance(item, dict)]
+
+ async def delete_many(
+ self,
+ keys: list[str],
+ *,
+ namespace: str | None = None,
+ ) -> int:
+ """批量删除多个记忆项。
+
+ 一次性删除多个键对应的记忆项,返回实际删除的数量。
+
+ Args:
+ keys: 要删除的记忆项键名列表
+
+ Returns:
+ 实际删除的记忆项数量
+
+ 示例:
+ deleted = await ctx.memory.delete_many(["old1", "old2", "old3"])
+ print(f"删除了 {deleted} 条记忆")
+ """
+ payload: dict[str, Any] = {"keys": keys}
+ payload["namespace"] = self._resolve_exact_namespace(namespace)
+ output = await self._proxy.call("memory.delete_many", payload)
+ return int(output.get("deleted_count", 0))
+
+ async def count(
+ self,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = False,
+ ) -> int:
+ """统计命名空间中的记忆项数量,可选包含子命名空间。"""
+
+ payload: dict[str, Any] = {
+ "namespace": self._resolve_exact_namespace(namespace),
+ "include_descendants": bool(include_descendants),
+ }
+ output = await self._proxy.call("memory.count", payload)
+ return int(output.get("count", 0))
+
+ async def stats(
+ self,
+ *,
+ namespace: str | None = None,
+ include_descendants: bool = True,
+ ) -> dict[str, Any]:
+ """获取记忆系统统计信息。
+
+ 返回记忆系统的当前状态,包括条目数、索引状态和脏索引数量。
+
+ Returns:
+ 统计信息字典,包含:
+ - total_items: 总记忆条目数
+ - total_bytes: 总占用字节数(可选)
+ - ttl_entries: 带过期时间的条目数(可选)
+ - indexed_items: 已建立检索索引的条目数(可选)
+ - embedded_items: 已生成向量的条目数(可选)
+ - dirty_items: 等待重建索引的条目数(可选)
+
+ 示例:
+ stats = await ctx.memory.stats()
+ print(f"记忆库共有 {stats['total_items']} 条记录")
+ if "embedded_items" in stats:
+ print(f"其中 {stats['embedded_items']} 条已经向量化")
+ """
+ payload: dict[str, Any] = {
+ "include_descendants": bool(include_descendants),
+ }
+ has_namespace, resolved_namespace = self._resolve_scope_namespace(namespace)
+ if has_namespace:
+ payload["namespace"] = resolved_namespace
+ output = await self._proxy.call("memory.stats", payload)
+ stats = {
+ "total_items": output.get("total_items", 0),
+ "total_bytes": output.get("total_bytes"),
+ }
+ for key in (
+ "namespace",
+ "namespace_count",
+ "fts_enabled",
+ "vector_backend",
+ "vector_indexes",
+ "plugin_id",
+ "ttl_entries",
+ "indexed_items",
+ "embedded_items",
+ "dirty_items",
+ ):
+ if key in output:
+ stats[key] = output.get(key)
+ return stats
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/metadata.py b/astrbot-sdk/src/astrbot_sdk/clients/metadata.py
new file mode 100644
index 0000000000..9d68314b22
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/metadata.py
@@ -0,0 +1,145 @@
+"""元数据客户端模块。
+
+提供插件元数据查询能力。
+
+功能说明:
+ - 查询已加载插件信息
+ - 获取插件列表
+ - 访问当前插件配置
+
+安全边界:
+ 插件身份由运行时透传到协议层;客户端只暴露业务参数,不接受外部指定调用者。
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Any
+
+from ._errors import wrap_client_exception
+from ._proxy import CapabilityProxy
+
+
+@dataclass
+class StarMetadata:
+ """插件元数据。"""
+
+ name: str
+ display_name: str
+ description: str
+ repo: str
+ author: str
+ version: str
+ enabled: bool = True
+ support_platforms: list[str] = field(default_factory=list)
+ astrbot_version: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> StarMetadata:
+ raw_support_platforms = data.get("support_platforms")
+ support_platforms = (
+ [str(item) for item in raw_support_platforms if isinstance(item, str)]
+ if isinstance(raw_support_platforms, list)
+ else []
+ )
+ return cls(
+ name=str(data.get("name", "")),
+ display_name=str(data.get("display_name", data.get("name", ""))),
+ description=str(data.get("desc", data.get("description", ""))),
+ repo=str(data.get("repo", "")),
+ author=str(data.get("author", "")),
+ version=str(data.get("version", "0.0.0")),
+ enabled=bool(data.get("enabled", True)),
+ support_platforms=support_platforms,
+ astrbot_version=(
+ str(data.get("astrbot_version"))
+ if data.get("astrbot_version") is not None
+ else None
+ ),
+ )
+
+
+PluginMetadata = StarMetadata
+
+
+class MetadataClient:
+ """元数据能力客户端。"""
+
+ def __init__(self, proxy: CapabilityProxy, plugin_id: str) -> None:
+ self._proxy = proxy
+ self._plugin_id = plugin_id
+
+ async def get_plugin(self, name: str) -> StarMetadata | None:
+ try:
+ output = await self._proxy.call(
+ "metadata.get_plugin",
+ {"name": name},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="MetadataClient",
+ method_name="get_plugin",
+ details=f"name={name!r}",
+ exc=exc,
+ ) from exc
+ data = output.get("plugin")
+ if data is None:
+ return None
+ return StarMetadata.from_dict(data)
+
+ async def list_plugins(self) -> list[StarMetadata]:
+ try:
+ output = await self._proxy.call("metadata.list_plugins", {})
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="MetadataClient",
+ method_name="list_plugins",
+ exc=exc,
+ ) from exc
+ items = output.get("plugins", [])
+ return [
+ StarMetadata.from_dict(item) for item in items if isinstance(item, dict)
+ ]
+
+ async def get_current_plugin(self) -> StarMetadata | None:
+ return await self.get_plugin(self._plugin_id)
+
+ async def get_plugin_config(self, name: str | None = None) -> dict[str, Any] | None:
+ target = name or self._plugin_id
+ if target != self._plugin_id:
+ raise PermissionError(
+ "get_plugin_config 只允许访问当前插件自己的配置,"
+ f"请求的插件 '{target}' 不是当前插件 '{self._plugin_id}'"
+ )
+ try:
+ output = await self._proxy.call(
+ "metadata.get_plugin_config",
+ {"name": target},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="MetadataClient",
+ method_name="get_plugin_config",
+ details=f"name={target!r}",
+ exc=exc,
+ ) from exc
+ config = output.get("config")
+ return dict(config) if isinstance(config, dict) else None
+
+ async def save_plugin_config(self, config: dict[str, Any]) -> dict[str, Any]:
+ if not isinstance(config, dict):
+ raise TypeError("save_plugin_config requires a dict payload")
+ try:
+ output = await self._proxy.call(
+ "metadata.save_plugin_config",
+ {"config": dict(config)},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="MetadataClient",
+ method_name="save_plugin_config",
+ details=f"keys={sorted(str(key) for key in config)!r}",
+ exc=exc,
+ ) from exc
+ saved = output.get("config")
+ return dict(saved) if isinstance(saved, dict) else {}
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/permission.py b/astrbot-sdk/src/astrbot_sdk/clients/permission.py
new file mode 100644
index 0000000000..546c8ea589
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/permission.py
@@ -0,0 +1,100 @@
+"""权限能力客户端。"""
+
+from __future__ import annotations
+
+from typing import Any, Literal
+
+from pydantic import BaseModel, ConfigDict
+
+from ._proxy import CapabilityProxy
+
+
+class PermissionCheckResult(BaseModel):
+ """权限检查结果。"""
+
+ model_config = ConfigDict(extra="forbid")
+
+ is_admin: bool
+ role: Literal["member", "admin"]
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> PermissionCheckResult | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class PermissionClient:
+ """权限查询客户端。"""
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def check(
+ self,
+ user_id: str,
+ session_id: str | None = None,
+ ) -> PermissionCheckResult:
+ payload: dict[str, Any] = {"user_id": str(user_id)}
+ if session_id is not None:
+ payload["session_id"] = str(session_id)
+ output = await self._proxy.call("permission.check", payload)
+ result = PermissionCheckResult.from_payload(output)
+ if result is None:
+ return PermissionCheckResult(is_admin=False, role="member")
+ return result
+
+ async def get_admins(self) -> list[str]:
+ output = await self._proxy.call("permission.get_admins", {})
+ admins = output.get("admins")
+ if not isinstance(admins, list):
+ return []
+ return [str(item) for item in admins]
+
+
+class PermissionManagerClient:
+ """权限管理客户端。"""
+
+ def __init__(
+ self,
+ proxy: CapabilityProxy,
+ *,
+ source_event_payload: dict[str, Any] | None = None,
+ ) -> None:
+ self._proxy = proxy
+ self._source_event_payload = (
+ dict(source_event_payload) if isinstance(source_event_payload, dict) else {}
+ )
+
+ def _caller_is_admin(self) -> bool:
+ return bool(self._source_event_payload.get("is_admin", False))
+
+ async def add_admin(self, user_id: str) -> bool:
+ output = await self._proxy.call(
+ "permission.manager.add_admin",
+ {
+ "user_id": str(user_id),
+ "_caller_is_admin": self._caller_is_admin(),
+ },
+ )
+ return bool(output.get("changed", False))
+
+ async def remove_admin(self, user_id: str) -> bool:
+ output = await self._proxy.call(
+ "permission.manager.remove_admin",
+ {
+ "user_id": str(user_id),
+ "_caller_is_admin": self._caller_is_admin(),
+ },
+ )
+ return bool(output.get("changed", False))
+
+
+__all__ = [
+ "PermissionCheckResult",
+ "PermissionClient",
+ "PermissionManagerClient",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/platform.py b/astrbot-sdk/src/astrbot_sdk/clients/platform.py
new file mode 100644
index 0000000000..7a4bcccacf
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/platform.py
@@ -0,0 +1,339 @@
+"""平台客户端模块。
+
+提供 astrbot-sdk 原生的平台能力调用。
+
+设计边界:
+ - `PlatformClient` 只负责直接的平台 capability
+ - 迁移期消息桥接由独立迁移入口承接,不放进原生客户端
+ - 富消息链通过 `platform.send_chain` 发送,链构建能力位于专门的消息模块
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from enum import Enum
+from typing import Any, cast
+
+from pydantic import BaseModel, ConfigDict, Field
+
+from ..message.components import BaseMessageComponent, Plain
+from ..message.result import MessageChain
+from ..message.session import MessageSession
+from ..protocol.descriptors import SessionRef
+from ._errors import wrap_client_exception
+from ._proxy import CapabilityProxy
+
+
+class _PlatformModel(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+
+class PlatformStatus(str, Enum):
+ PENDING = "pending"
+ RUNNING = "running"
+ ERROR = "error"
+ STOPPED = "stopped"
+
+ @classmethod
+ def from_value(cls, value: Any) -> PlatformStatus:
+ if isinstance(value, cls):
+ return value
+ try:
+ return cls(str(value).strip().lower())
+ except ValueError:
+ return cls.PENDING
+
+
+class PlatformError(_PlatformModel):
+ message: str
+ timestamp: str
+ traceback: str | None = None
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any] | None) -> PlatformError | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class PlatformStats(_PlatformModel):
+ id: str
+ type: str
+ display_name: str
+ status: PlatformStatus
+ started_at: str | None = None
+ error_count: int
+ last_error: PlatformError | None = None
+ unified_webhook: bool
+ meta: dict[str, Any] = Field(default_factory=dict)
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any] | None) -> PlatformStats | None:
+ if not isinstance(payload, dict):
+ return None
+ normalized = dict(payload)
+ normalized["status"] = PlatformStatus.from_value(payload.get("status"))
+ normalized["last_error"] = PlatformError.from_payload(payload.get("last_error"))
+ meta = payload.get("meta")
+ normalized["meta"] = dict(meta) if isinstance(meta, dict) else {}
+ return cls.model_validate(normalized)
+
+
+class PlatformClient:
+ """平台消息客户端。
+
+ 提供向聊天平台发送消息和获取信息的能力。
+
+ Attributes:
+ _proxy: CapabilityProxy 实例,用于远程能力调用
+ """
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ """初始化平台客户端。
+
+ Args:
+ proxy: CapabilityProxy 实例
+ """
+ self._proxy = proxy
+
+ def _build_target_payload(
+ self,
+ session: str | SessionRef | MessageSession,
+ ) -> tuple[str, dict[str, Any]]:
+ if isinstance(session, SessionRef):
+ return session.session, {"target": session.to_payload()}
+ if isinstance(session, MessageSession):
+ return str(session), {}
+ return str(session), {}
+
+ async def _coerce_chain_payload(
+ self,
+ content: (
+ str
+ | MessageChain
+ | Sequence[BaseMessageComponent]
+ | Sequence[dict[str, Any]]
+ ),
+ ) -> list[dict[str, Any]]:
+ if isinstance(content, str):
+ return await MessageChain(
+ [Plain(content, convert=False)]
+ ).to_payload_async()
+ if isinstance(content, MessageChain):
+ return await content.to_payload_async()
+ if (
+ isinstance(content, Sequence)
+ and not isinstance(content, (str, bytes))
+ and all(isinstance(item, BaseMessageComponent) for item in content)
+ ):
+ components = cast(Sequence[BaseMessageComponent], content)
+ return await MessageChain(list(components)).to_payload_async()
+ if (
+ isinstance(content, Sequence)
+ and not isinstance(content, (str, bytes))
+ and all(isinstance(item, dict) for item in content)
+ ):
+ payload_items = cast(Sequence[dict[str, Any]], content)
+ return [dict(item) for item in payload_items]
+ raise TypeError(
+ "content must be str, MessageChain, sequence of message components, "
+ "or sequence of platform.send_chain payload dicts"
+ )
+
+ async def send(
+ self,
+ session: str | SessionRef | MessageSession,
+ text: str,
+ ) -> dict[str, Any]:
+ """发送文本消息。
+
+ 向指定的会话(用户或群组)发送文本消息。
+
+ Args:
+ session: 统一消息来源标识 (UMO),格式如 "platform:instance:user_id"
+ text: 要发送的文本内容
+
+ Returns:
+ 发送结果,可能包含消息 ID 等信息
+
+ 示例:
+ # 发送消息到当前会话
+ await ctx.platform.send(event.session_id, "收到您的消息!")
+ """
+ session_id, extra = self._build_target_payload(session)
+ try:
+ return await self._proxy.call(
+ "platform.send",
+ {"session": session_id, "text": text, **extra},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="PlatformClient",
+ method_name="send",
+ details=f"session={session_id!r}",
+ exc=exc,
+ ) from exc
+
+ async def send_image(
+ self,
+ session: str | SessionRef | MessageSession,
+ image_url: str,
+ ) -> dict[str, Any]:
+ """发送图片消息。
+
+ 向指定的会话发送图片,支持 URL 或本地路径。
+
+ Args:
+ session: 统一消息来源标识 (UMO)
+ image_url: 图片 URL 或本地文件路径
+
+ Returns:
+ 发送结果
+
+ 示例:
+ await ctx.platform.send_image(
+ event.session_id,
+ "https://example.com/image.png"
+ )
+ """
+ session_id, extra = self._build_target_payload(session)
+ try:
+ return await self._proxy.call(
+ "platform.send_image",
+ {"session": session_id, "image_url": image_url, **extra},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="PlatformClient",
+ method_name="send_image",
+ details=f"session={session_id!r}",
+ exc=exc,
+ ) from exc
+
+ async def send_chain(
+ self,
+ session: str | SessionRef | MessageSession,
+ chain: MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]],
+ ) -> dict[str, Any]:
+ """发送富消息链。
+
+ Args:
+ session: 统一消息来源标识 (UMO)
+ chain: 序列化后的消息组件数组
+
+ Returns:
+ 发送结果
+ """
+ session_id, extra = self._build_target_payload(session)
+ chain_payload = await self._coerce_chain_payload(chain)
+ try:
+ return await self._proxy.call(
+ "platform.send_chain",
+ {"session": session_id, "chain": chain_payload, **extra},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="PlatformClient",
+ method_name="send_chain",
+ details=f"session={session_id!r}, items={len(chain_payload)!r}",
+ exc=exc,
+ ) from exc
+
+ async def send_by_session(
+ self,
+ session: str | MessageSession,
+ content: (
+ str
+ | MessageChain
+ | Sequence[BaseMessageComponent]
+ | Sequence[dict[str, Any]]
+ ),
+ ) -> dict[str, Any]:
+ """主动向指定会话发送消息链。
+
+ `Sequence[dict]` 的结构与 `platform.send_chain` 完全一致:
+ 每一项都应是 `{"type": "...", "data": {...}}`。
+ """
+ chain_payload = await self._coerce_chain_payload(content)
+ session_id = str(session)
+ try:
+ return await self._proxy.call(
+ "platform.send_by_session",
+ {"session": session_id, "chain": chain_payload},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="PlatformClient",
+ method_name="send_by_session",
+ details=f"session={session_id!r}, items={len(chain_payload)!r}",
+ exc=exc,
+ ) from exc
+
+ async def send_by_id(
+ self,
+ platform_id: str,
+ session_id: str,
+ content: (
+ str
+ | MessageChain
+ | Sequence[BaseMessageComponent]
+ | Sequence[dict[str, Any]]
+ ),
+ *,
+ message_type: str = "private",
+ ) -> dict[str, Any]:
+ """主动向指定平台会话发送消息。"""
+ session = MessageSession(
+ platform_id=str(platform_id),
+ message_type=str(message_type),
+ session_id=str(session_id),
+ )
+ return await self.send_by_session(session, content)
+
+ async def get_members(
+ self,
+ session: str | SessionRef | MessageSession,
+ ) -> list[dict[str, Any]]:
+ """获取群组成员列表。
+
+ 获取指定群组的成员信息列表。注意仅对群组会话有效。
+
+ Args:
+ session: 群组会话的统一消息来源标识 (UMO)
+
+ Returns:
+ 成员信息列表,每个成员是一个字典,可能包含:
+ - user_id: 用户 ID
+ - nickname: 昵称
+ - role: 角色 (owner, admin, member)
+
+ 示例:
+ members = await ctx.platform.get_members(event.session_id)
+ for member in members:
+ print(f"{member['nickname']} ({member['user_id']})")
+ """
+ session_id, extra = self._build_target_payload(session)
+ try:
+ output = await self._proxy.call(
+ "platform.get_members",
+ {"session": session_id, **extra},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="PlatformClient",
+ method_name="get_members",
+ details=f"session={session_id!r}",
+ exc=exc,
+ ) from exc
+ members = output.get("members")
+ if not isinstance(members, (list, tuple)):
+ return []
+ return list(members)
+
+
+__all__ = [
+ "PlatformClient",
+ "PlatformError",
+ "PlatformStats",
+ "PlatformStatus",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/provider.py b/astrbot-sdk/src/astrbot_sdk/clients/provider.py
new file mode 100644
index 0000000000..7142efee0a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/provider.py
@@ -0,0 +1,353 @@
+"""Provider discovery and provider-management clients."""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import inspect
+from collections.abc import AsyncIterator, Awaitable, Callable
+from typing import Any
+
+from pydantic import BaseModel, ConfigDict
+
+from ..llm.entities import ProviderMeta, ProviderType
+from ..llm.providers import (
+ ProviderProxy,
+ STTProvider,
+ TTSProvider,
+ provider_proxy_from_meta,
+)
+from ._proxy import CapabilityProxy
+
+
+class _ProviderModel(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+ def to_payload(self) -> dict[str, Any]:
+ return self.model_dump(exclude_none=True)
+
+
+class ManagedProviderRecord(_ProviderModel):
+ id: str
+ model: str | None = None
+ type: str
+ provider_type: ProviderType
+ loaded: bool
+ enabled: bool
+ provider_source_id: str | None = None
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> ManagedProviderRecord | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class ProviderChangeEvent(_ProviderModel):
+ provider_id: str
+ provider_type: ProviderType
+ umo: str | None = None
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any] | None,
+ ) -> ProviderChangeEvent | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class ProviderClient:
+ """Provider 查询客户端。"""
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ @staticmethod
+ def _provider_meta_list(items: Any) -> list[ProviderMeta]:
+ if not isinstance(items, list):
+ return []
+ providers: list[ProviderMeta] = []
+ for item in items:
+ if not isinstance(item, dict):
+ continue
+ provider = ProviderMeta.from_payload(item)
+ if provider is not None:
+ providers.append(provider)
+ return providers
+
+ async def list_all(self) -> list[ProviderMeta]:
+ output = await self._proxy.call("provider.list_all", {})
+ return self._provider_meta_list(output.get("providers"))
+
+ async def list_tts(self) -> list[ProviderMeta]:
+ output = await self._proxy.call("provider.list_all_tts", {})
+ return self._provider_meta_list(output.get("providers"))
+
+ async def list_stt(self) -> list[ProviderMeta]:
+ output = await self._proxy.call("provider.list_all_stt", {})
+ return self._provider_meta_list(output.get("providers"))
+
+ async def list_embedding(self) -> list[ProviderMeta]:
+ output = await self._proxy.call("provider.list_all_embedding", {})
+ return self._provider_meta_list(output.get("providers"))
+
+ async def list_rerank(self) -> list[ProviderMeta]:
+ output = await self._proxy.call("provider.list_all_rerank", {})
+ return self._provider_meta_list(output.get("providers"))
+
+ async def _get_tts_support_stream(self, provider_id: str) -> bool:
+ output = await self._proxy.call(
+ "provider.tts.support_stream",
+ {"provider_id": str(provider_id)},
+ )
+ return bool(output.get("supported", False))
+
+ async def _build_proxy(self, meta: ProviderMeta | None) -> ProviderProxy | None:
+ if meta is None:
+ return None
+ tts_supports_stream = None
+ if meta.provider_type == ProviderType.TEXT_TO_SPEECH:
+ tts_supports_stream = await self._get_tts_support_stream(meta.id)
+ return provider_proxy_from_meta(
+ self._proxy,
+ meta,
+ tts_supports_stream=tts_supports_stream,
+ )
+
+ async def get(self, provider_id: str) -> ProviderProxy | None:
+ output = await self._proxy.call(
+ "provider.get_by_id",
+ {"provider_id": str(provider_id)},
+ )
+ return await self._build_proxy(
+ ProviderMeta.from_payload(output.get("provider"))
+ )
+
+ async def get_using_chat(self, umo: str | None = None) -> ProviderMeta | None:
+ output = await self._proxy.call("provider.get_using", {"umo": umo})
+ return ProviderMeta.from_payload(output.get("provider"))
+
+ async def get_using_tts(self, umo: str | None = None) -> TTSProvider | None:
+ output = await self._proxy.call("provider.get_using_tts", {"umo": umo})
+ provider = await self._build_proxy(
+ ProviderMeta.from_payload(output.get("provider"))
+ )
+ return provider if isinstance(provider, TTSProvider) else None
+
+ async def get_using_stt(self, umo: str | None = None) -> STTProvider | None:
+ output = await self._proxy.call("provider.get_using_stt", {"umo": umo})
+ provider = await self._build_proxy(
+ ProviderMeta.from_payload(output.get("provider"))
+ )
+ return provider if isinstance(provider, STTProvider) else None
+
+
+class ProviderManagerClient:
+ """Provider 管理客户端。"""
+
+ def __init__(
+ self,
+ proxy: CapabilityProxy,
+ *,
+ plugin_id: str | None = None,
+ logger: Any | None = None,
+ ) -> None:
+ self._proxy = proxy
+ self._plugin_id = plugin_id
+ self._logger = logger
+ self._change_hook_tasks: set[asyncio.Task[None]] = set()
+
+ @staticmethod
+ def _provider_type_value(provider_type: ProviderType | str) -> str:
+ if isinstance(provider_type, ProviderType):
+ return provider_type.value
+ return str(provider_type).strip()
+
+ @staticmethod
+ def _record_from_output(output: dict[str, Any]) -> ManagedProviderRecord | None:
+ return ManagedProviderRecord.from_payload(output.get("provider"))
+
+ async def set_provider(
+ self,
+ provider_id: str,
+ provider_type: ProviderType | str,
+ umo: str | None = None,
+ ) -> None:
+ await self._proxy.call(
+ "provider.manager.set",
+ {
+ "provider_id": str(provider_id),
+ "provider_type": self._provider_type_value(provider_type),
+ "umo": umo,
+ },
+ )
+
+ async def get_provider_by_id(
+ self,
+ provider_id: str,
+ ) -> ManagedProviderRecord | None:
+ output = await self._proxy.call(
+ "provider.manager.get_by_id",
+ {"provider_id": str(provider_id)},
+ )
+ return self._record_from_output(output)
+
+ async def get_merged_provider_config(
+ self,
+ provider_id: str,
+ ) -> dict[str, Any] | None:
+ output = await self._proxy.call(
+ "provider.manager.get_merged_provider_config",
+ {"provider_id": str(provider_id).strip()},
+ )
+ config = output.get("config")
+ return dict(config) if isinstance(config, dict) else None
+
+ async def load_provider(
+ self,
+ provider_config: dict[str, Any],
+ ) -> ManagedProviderRecord | None:
+ output = await self._proxy.call(
+ "provider.manager.load",
+ {"provider_config": dict(provider_config)},
+ )
+ return self._record_from_output(output)
+
+ async def terminate_provider(self, provider_id: str) -> None:
+ await self._proxy.call(
+ "provider.manager.terminate",
+ {"provider_id": str(provider_id)},
+ )
+
+ async def create_provider(
+ self,
+ provider_config: dict[str, Any],
+ ) -> ManagedProviderRecord | None:
+ output = await self._proxy.call(
+ "provider.manager.create",
+ {"provider_config": dict(provider_config)},
+ )
+ return self._record_from_output(output)
+
+ async def update_provider(
+ self,
+ origin_provider_id: str,
+ new_config: dict[str, Any],
+ ) -> ManagedProviderRecord | None:
+ output = await self._proxy.call(
+ "provider.manager.update",
+ {
+ "origin_provider_id": str(origin_provider_id),
+ "new_config": dict(new_config),
+ },
+ )
+ return self._record_from_output(output)
+
+ async def delete_provider(
+ self,
+ provider_id: str | None = None,
+ provider_source_id: str | None = None,
+ ) -> None:
+ await self._proxy.call(
+ "provider.manager.delete",
+ {
+ "provider_id": provider_id,
+ "provider_source_id": provider_source_id,
+ },
+ )
+
+ async def get_insts(self) -> list[ManagedProviderRecord]:
+ output = await self._proxy.call("provider.manager.get_insts", {})
+ items = output.get("providers")
+ if not isinstance(items, list):
+ return []
+ return [
+ record
+ for record in (
+ ManagedProviderRecord.from_payload(item)
+ if isinstance(item, dict)
+ else None
+ for item in items
+ )
+ if record is not None
+ ]
+
+ async def watch_changes(self) -> AsyncIterator[ProviderChangeEvent]:
+ async for chunk in self._proxy.stream("provider.manager.watch_changes", {}):
+ event = ProviderChangeEvent.from_payload(chunk)
+ if event is not None:
+ yield event
+
+ async def register_provider_change_hook(
+ self,
+ callback: Callable[
+ [str, ProviderType, str | None],
+ Awaitable[None] | None,
+ ],
+ ) -> asyncio.Task[None]:
+ async def runner() -> None:
+ async for event in self.watch_changes():
+ result = callback(
+ event.provider_id,
+ event.provider_type,
+ event.umo,
+ )
+ if inspect.isawaitable(result):
+ await result
+
+ task = asyncio.create_task(runner())
+ self._change_hook_tasks.add(task)
+ task.add_done_callback(self._log_change_hook_result)
+ return task
+
+ async def unregister_provider_change_hook(
+ self,
+ task: asyncio.Task[None],
+ ) -> None:
+ if task not in self._change_hook_tasks:
+ return
+ self._change_hook_tasks.discard(task)
+ if not task.done():
+ task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await task
+
+ def _log_change_hook_result(self, task: asyncio.Task[None]) -> None:
+ self._change_hook_tasks.discard(task)
+ if task.cancelled():
+ debug_logger = getattr(self._logger, "debug", None)
+ if callable(debug_logger):
+ debug_logger(
+ "Provider change hook cancelled: plugin_id={}",
+ self._plugin_id,
+ )
+ return
+ try:
+ task.result()
+ except asyncio.CancelledError:
+ debug_logger = getattr(self._logger, "debug", None)
+ if callable(debug_logger):
+ debug_logger(
+ "Provider change hook cancelled: plugin_id={}",
+ self._plugin_id,
+ )
+ except Exception:
+ exception_logger = getattr(self._logger, "exception", None)
+ if callable(exception_logger):
+ exception_logger(
+ "Provider change hook failed: plugin_id={}",
+ self._plugin_id,
+ )
+
+
+__all__ = [
+ "ManagedProviderRecord",
+ "ProviderChangeEvent",
+ "ProviderClient",
+ "ProviderManagerClient",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/registry.py b/astrbot-sdk/src/astrbot_sdk/clients/registry.py
new file mode 100644
index 0000000000..7cb9288b13
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/registry.py
@@ -0,0 +1,167 @@
+"""只读 handler 注册表客户端。"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Any
+
+from ._errors import wrap_client_exception
+from ._proxy import CapabilityProxy
+
+
+def _coerce_int(value: Any, default: int = 0) -> int:
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return default
+
+
+@dataclass(slots=True)
+class HandlerMetadata:
+ plugin_name: str
+ handler_full_name: str
+ trigger_type: str
+ description: str | None = None
+ event_types: list[str] = field(default_factory=list)
+ enabled: bool = True
+ group_path: list[str] = field(default_factory=list)
+ priority: int = 0
+ kind: str = "handler"
+ require_admin: bool = False
+ required_role: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> HandlerMetadata:
+ return cls(
+ plugin_name=str(data.get("plugin_name", "")),
+ handler_full_name=str(data.get("handler_full_name", "")),
+ trigger_type=str(data.get("trigger_type", "")),
+ description=(
+ None
+ if data.get("description") is None
+ else str(data.get("description", "")).strip() or None
+ ),
+ event_types=[
+ str(item)
+ for item in data.get("event_types", [])
+ if isinstance(item, str)
+ ],
+ enabled=bool(data.get("enabled", True)),
+ group_path=[
+ str(item)
+ for item in data.get("group_path", [])
+ if isinstance(item, str)
+ ],
+ priority=_coerce_int(data.get("priority", 0), 0),
+ kind=str(data.get("kind", "handler") or "handler"),
+ require_admin=bool(data.get("require_admin", False)),
+ required_role=(
+ None
+ if data.get("required_role") is None
+ else str(data.get("required_role", "")).strip() or None
+ ),
+ )
+
+
+class RegistryClient:
+ """只读 handler 注册表客户端。"""
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def get_handlers_by_event_type(
+ self,
+ event_type: str,
+ ) -> list[HandlerMetadata]:
+ try:
+ output = await self._proxy.call(
+ "registry.get_handlers_by_event_type",
+ {"event_type": event_type},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="RegistryClient",
+ method_name="get_handlers_by_event_type",
+ details=f"event_type={event_type!r}",
+ exc=exc,
+ ) from exc
+ return [
+ HandlerMetadata.from_dict(item)
+ for item in output.get("handlers", [])
+ if isinstance(item, dict)
+ ]
+
+ async def get_handler_by_full_name(
+ self,
+ full_name: str,
+ ) -> HandlerMetadata | None:
+ try:
+ output = await self._proxy.call(
+ "registry.get_handler_by_full_name",
+ {"full_name": full_name},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="RegistryClient",
+ method_name="get_handler_by_full_name",
+ details=f"full_name={full_name!r}",
+ exc=exc,
+ ) from exc
+ handler = output.get("handler")
+ if not isinstance(handler, dict):
+ return None
+ return HandlerMetadata.from_dict(handler)
+
+ async def set_handler_whitelist(
+ self,
+ plugin_names: list[str] | set[str] | None,
+ ) -> list[str] | None:
+ names = None
+ if plugin_names is not None:
+ names = sorted({str(item) for item in plugin_names if str(item).strip()})
+ try:
+ output = await self._proxy.call(
+ "system.event.handler_whitelist.set",
+ {"plugin_names": names},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="RegistryClient",
+ method_name="set_handler_whitelist",
+ details=f"plugin_names={names!r}",
+ exc=exc,
+ ) from exc
+ result = output.get("plugin_names")
+ if not isinstance(result, list):
+ return None
+ return [str(item) for item in result]
+
+ async def get_handler_whitelist(self) -> list[str] | None:
+ try:
+ output = await self._proxy.call("system.event.handler_whitelist.get", {})
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="RegistryClient",
+ method_name="get_handler_whitelist",
+ exc=exc,
+ ) from exc
+ result = output.get("plugin_names")
+ if not isinstance(result, list):
+ return None
+ return [str(item) for item in result]
+
+ async def clear_handler_whitelist(self) -> None:
+ try:
+ await self._proxy.call(
+ "system.event.handler_whitelist.set",
+ {"plugin_names": None},
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="RegistryClient",
+ method_name="clear_handler_whitelist",
+ exc=exc,
+ ) from exc
+
+
+__all__ = ["HandlerMetadata", "RegistryClient"]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/session.py b/astrbot-sdk/src/astrbot_sdk/clients/session.py
new file mode 100644
index 0000000000..0c8894cc1f
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/session.py
@@ -0,0 +1,133 @@
+"""Session-scoped SDK managers."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from ..events import MessageEvent
+from ..message.session import MessageSession
+from ._proxy import CapabilityProxy
+from .registry import HandlerMetadata
+
+
+def _normalize_session(session: str | MessageSession | MessageEvent) -> str:
+ if isinstance(session, MessageEvent):
+ return str(session.unified_msg_origin)
+ return str(session)
+
+
+def _handler_to_payload(handler: HandlerMetadata) -> dict[str, Any]:
+ return {
+ "plugin_name": handler.plugin_name,
+ "handler_full_name": handler.handler_full_name,
+ "trigger_type": handler.trigger_type,
+ "description": handler.description,
+ "event_types": list(handler.event_types),
+ "enabled": handler.enabled,
+ "group_path": list(handler.group_path),
+ "priority": handler.priority,
+ "kind": handler.kind,
+ "require_admin": handler.require_admin,
+ }
+
+
+class SessionPluginManager:
+ """Session-scoped plugin status manager."""
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def is_plugin_enabled_for_session(
+ self,
+ session: str | MessageSession | MessageEvent,
+ plugin_name: str,
+ ) -> bool:
+ output = await self._proxy.call(
+ "session.plugin.is_enabled",
+ {
+ "session": _normalize_session(session),
+ "plugin_name": str(plugin_name),
+ },
+ )
+ return bool(output.get("enabled", False))
+
+ async def filter_handlers_by_session(
+ self,
+ session: str | MessageSession | MessageEvent,
+ handlers: list[HandlerMetadata],
+ ) -> list[HandlerMetadata]:
+ output = await self._proxy.call(
+ "session.plugin.filter_handlers",
+ {
+ "session": _normalize_session(session),
+ "handlers": [_handler_to_payload(handler) for handler in handlers],
+ },
+ )
+ items = output.get("handlers")
+ if not isinstance(items, list):
+ return []
+ return [
+ HandlerMetadata.from_dict(item) for item in items if isinstance(item, dict)
+ ]
+
+
+class SessionServiceManager:
+ """Session-scoped LLM/TTS service status manager."""
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def is_llm_enabled_for_session(
+ self,
+ session: str | MessageSession | MessageEvent,
+ ) -> bool:
+ output = await self._proxy.call(
+ "session.service.is_llm_enabled",
+ {"session": _normalize_session(session)},
+ )
+ return bool(output.get("enabled", False))
+
+ async def set_llm_status_for_session(
+ self,
+ session: str | MessageSession | MessageEvent,
+ enabled: bool,
+ ) -> None:
+ await self._proxy.call(
+ "session.service.set_llm_status",
+ {"session": _normalize_session(session), "enabled": bool(enabled)},
+ )
+
+ async def should_process_llm_request(
+ self,
+ event_or_session: str | MessageSession | MessageEvent,
+ ) -> bool:
+ return await self.is_llm_enabled_for_session(event_or_session)
+
+ async def is_tts_enabled_for_session(
+ self,
+ session: str | MessageSession | MessageEvent,
+ ) -> bool:
+ output = await self._proxy.call(
+ "session.service.is_tts_enabled",
+ {"session": _normalize_session(session)},
+ )
+ return bool(output.get("enabled", False))
+
+ async def set_tts_status_for_session(
+ self,
+ session: str | MessageSession | MessageEvent,
+ enabled: bool,
+ ) -> None:
+ await self._proxy.call(
+ "session.service.set_tts_status",
+ {"session": _normalize_session(session), "enabled": bool(enabled)},
+ )
+
+ async def should_process_tts_request(
+ self,
+ event_or_session: str | MessageSession | MessageEvent,
+ ) -> bool:
+ return await self.is_tts_enabled_for_session(event_or_session)
+
+
+__all__ = ["SessionPluginManager", "SessionServiceManager"]
diff --git a/astrbot-sdk/src/astrbot_sdk/clients/skills.py b/astrbot-sdk/src/astrbot_sdk/clients/skills.py
new file mode 100644
index 0000000000..54115a2bfb
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/clients/skills.py
@@ -0,0 +1,90 @@
+"""技能注册客户端。"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+from ._errors import wrap_client_exception
+from ._proxy import CapabilityProxy
+
+
+@dataclass(slots=True)
+class SkillRegistration:
+ """已注册技能的元数据。"""
+
+ name: str
+ description: str
+ path: str
+ skill_dir: str
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> SkillRegistration:
+ return cls(
+ name=str(data.get("name", "")),
+ description=str(data.get("description", "") or ""),
+ path=str(data.get("path", "")),
+ skill_dir=str(data.get("skill_dir", "")),
+ )
+
+
+class SkillClient:
+ """技能管理能力客户端。"""
+
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def register(
+ self,
+ *,
+ name: str,
+ path: str,
+ description: str = "",
+ ) -> SkillRegistration:
+ try:
+ output = await self._proxy.call(
+ "skill.register",
+ {
+ "name": name,
+ "path": path,
+ "description": description,
+ },
+ )
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="SkillClient",
+ method_name="register",
+ details=f"name={name!r}, path={path!r}",
+ exc=exc,
+ ) from exc
+ return SkillRegistration.from_dict(output)
+
+ async def unregister(self, name: str) -> bool:
+ try:
+ output = await self._proxy.call("skill.unregister", {"name": name})
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="SkillClient",
+ method_name="unregister",
+ details=f"name={name!r}",
+ exc=exc,
+ ) from exc
+ return bool(output.get("removed", False))
+
+ async def list(self) -> list[SkillRegistration]:
+ try:
+ output = await self._proxy.call("skill.list", {})
+ except Exception as exc:
+ raise wrap_client_exception(
+ client_name="SkillClient",
+ method_name="list",
+ exc=exc,
+ ) from exc
+ return [
+ SkillRegistration.from_dict(item)
+ for item in output.get("skills", [])
+ if isinstance(item, dict)
+ ]
+
+
+__all__ = ["SkillClient", "SkillRegistration"]
diff --git a/astrbot-sdk/src/astrbot_sdk/commands.py b/astrbot-sdk/src/astrbot_sdk/commands.py
new file mode 100644
index 0000000000..1d4f278e1c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/commands.py
@@ -0,0 +1,161 @@
+"""SDK-native command group helpers.
+
+本模块提供命令分组工具,用于组织具有层级关系的命令。
+
+CommandGroup 允许以嵌套方式定义命令树,例如:
+ admin
+ ├── user
+ │ ├── add
+ │ └── remove
+ └── config
+ ├── get
+ └── set
+
+特性:
+- 支持命令别名,自动展开父级路径的所有别名组合
+- 自动生成命令树的可视化输出 (print_cmd_tree)
+- 与 @on_command 装饰器无缝集成
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from itertools import product
+from typing import Any
+
+from .decorators import on_command, set_command_route_meta
+from .protocol.descriptors import CommandRouteSpec
+
+
+@dataclass(slots=True)
+class _CommandNode:
+ name: str
+ aliases: list[str] = field(default_factory=list)
+ description: str | None = None
+ subgroups: list[CommandGroup] = field(default_factory=list)
+ commands: list[tuple[str, str | None]] = field(default_factory=list)
+
+
+class CommandGroup:
+ def __init__(
+ self,
+ name: str,
+ *,
+ aliases: list[str] | None = None,
+ description: str | None = None,
+ parent: CommandGroup | None = None,
+ ) -> None:
+ self.name = name
+ self.aliases = list(aliases or [])
+ self.description = description
+ self.parent = parent
+ self._tree = _CommandNode(
+ name=name, aliases=self.aliases, description=description
+ )
+
+ def group(
+ self,
+ name: str,
+ *,
+ aliases: list[str] | None = None,
+ description: str | None = None,
+ ) -> CommandGroup:
+ child = CommandGroup(
+ name,
+ aliases=aliases,
+ description=description,
+ parent=self,
+ )
+ self._tree.subgroups.append(child)
+ return child
+
+ def command(
+ self,
+ name: str,
+ *,
+ aliases: list[str] | None = None,
+ description: str | None = None,
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
+ full_command = " ".join([*self.path, name])
+ full_aliases = self._expand_aliases(name=name, aliases=aliases or [])
+ display_command = full_command
+ route = CommandRouteSpec(
+ group_path=self.path,
+ display_command=display_command,
+ group_help=self.description,
+ )
+
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
+ decorated = on_command(
+ full_command,
+ aliases=full_aliases,
+ description=description,
+ )(func)
+ self._tree.commands.append((name, description))
+ set_command_route_meta(decorated, route)
+ return decorated
+
+ return decorator
+
+ @property
+ def path(self) -> list[str]:
+ if self.parent is None:
+ return [self.name]
+ return [*self.parent.path, self.name]
+
+ def print_cmd_tree(self) -> str:
+ lines: list[str] = []
+ self._append_tree_lines(lines, indent=0)
+ return "\n".join(lines)
+
+ def _append_tree_lines(self, lines: list[str], *, indent: int) -> None:
+ prefix = " " * indent
+ label = self.name
+ if self.aliases:
+ label += f" ({', '.join(self.aliases)})"
+ lines.append(f"{prefix}{label}")
+ for command_name, description in self._tree.commands:
+ command_label = f"{prefix} - {command_name}"
+ if description:
+ command_label += f": {description}"
+ lines.append(command_label)
+ for subgroup in self._tree.subgroups:
+ subgroup._append_tree_lines(lines, indent=indent + 1)
+
+ def _expand_aliases(self, *, name: str, aliases: list[str]) -> list[str]:
+ group_segments: list[list[str]] = []
+ cursor: CommandGroup | None = self
+ ancestry: list[CommandGroup] = []
+ while cursor is not None:
+ ancestry.append(cursor)
+ cursor = cursor.parent
+ for group in reversed(ancestry):
+ group_segments.append([group.name, *group.aliases])
+ leaf_segments = [name, *aliases]
+ expanded: set[str] = set()
+ for parts in product(*group_segments, leaf_segments):
+ route = " ".join(parts)
+ if route != " ".join([*self.path, name]):
+ expanded.add(route)
+ return sorted(expanded)
+
+
+def command_group(
+ name: str,
+ *,
+ aliases: list[str] | None = None,
+ description: str | None = None,
+) -> CommandGroup:
+ return CommandGroup(
+ name,
+ aliases=aliases,
+ description=description,
+ )
+
+
+def print_cmd_tree(group: CommandGroup) -> str:
+ return group.print_cmd_tree()
+
+
+__all__ = ["CommandGroup", "command_group", "print_cmd_tree"]
diff --git a/astrbot-sdk/src/astrbot_sdk/context.py b/astrbot-sdk/src/astrbot_sdk/context.py
new file mode 100644
index 0000000000..82007d7c02
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/context.py
@@ -0,0 +1,900 @@
+"""astrbot-sdk 原生运行时上下文。
+
+`Context` 是插件与 AstrBot Core 交互的主要入口,
+负责组合所有 capability 客户端并提供统一的访问接口。
+
+每个 handler 调用都会创建一个新的 Context 实例,
+绑定到当前的 Peer、插件 ID 和取消令牌。
+
+Attributes:
+ llm: LLM 能力客户端,用于 AI 对话
+ memory: 记忆能力客户端,用于语义存储
+ db: 数据库客户端,用于 KV 持久化
+ platform: 平台客户端,用于发送消息
+ permission: 权限客户端,用于查询用户角色
+ providers: Provider 客户端,用于查询和调用专用 Provider
+ provider_manager: Provider 管理客户端,用于 reserved/system 级操作
+ permission_manager: 权限管理客户端,用于 reserved/system 级管理员维护
+ personas: 人格管理客户端
+ conversations: 对话管理客户端
+ kbs: 知识库管理客户端
+ message_history: 消息历史管理客户端
+ http: HTTP 客户端,用于注册 API 端点
+ metadata: 元数据客户端,用于查询插件信息
+ skills: Skill 客户端,用于向 AstrBot 注册插件技能
+ plugin_id: 当前插件的唯一标识
+ logger: 绑定了插件 ID 的日志器
+ cancel_token: 取消令牌,用于处理请求取消
+"""
+
+from __future__ import annotations
+
+import asyncio
+from collections.abc import Awaitable, Callable, Sequence
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any
+
+from ._internal.plugin_logger import PluginLogger
+from ._internal.sdk_logger import logger as base_logger
+from ._internal.star_runtime import current_star_instance
+from ._message_types import normalize_message_type
+from .clients import (
+ DBClient,
+ HTTPClient,
+ LLMClient,
+ MemoryClient,
+ MetadataClient,
+ PermissionClient,
+ PermissionManagerClient,
+ PlatformClient,
+ PlatformError,
+ PlatformStats,
+ PlatformStatus,
+ RegistryClient,
+ SkillClient,
+)
+from .clients._proxy import CapabilityProxy
+from .clients.llm import LLMResponse
+from .clients.managers import (
+ ConversationManagerClient,
+ KnowledgeBaseManagerClient,
+ MessageHistoryManagerClient,
+ PersonaManagerClient,
+)
+from .clients.provider import ProviderClient, ProviderManagerClient
+from .clients.session import SessionPluginManager, SessionServiceManager
+from .clients.skills import SkillRegistration
+from .errors import AstrBotError
+from .llm.entities import LLMToolSpec, ProviderMeta, ProviderRequest
+from .llm.tools import LLMToolManager
+from .message.components import BaseMessageComponent
+from .message.result import MessageChain
+from .message.session import MessageSession
+from .session_waiter import (
+ _mark_session_waiter_background_task,
+ _unmark_session_waiter_background_task,
+)
+
+PlatformCompatContent = (
+ str | MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]]
+)
+
+
+def _context_call_label(method_name: str, details: str | None = None) -> str:
+ label = f"Context.{method_name}"
+ if details:
+ return f"{label} ({details})"
+ return label
+
+
+def _wrap_context_exception(
+ *,
+ method_name: str,
+ exc: Exception,
+ details: str | None = None,
+) -> Exception:
+ message = f"{_context_call_label(method_name, details)} failed: {exc}"
+ if isinstance(exc, AstrBotError):
+ return AstrBotError(
+ code=exc.code,
+ message=message,
+ hint=exc.hint,
+ retryable=exc.retryable,
+ docs_url=exc.docs_url,
+ details=exc.details,
+ )
+ return RuntimeError(message)
+
+
+async def _call_proxy_with_context(
+ proxy: CapabilityProxy,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ method_name: str,
+ details: str | None = None,
+) -> dict[str, Any]:
+ try:
+ return await proxy.call(capability, payload)
+ except Exception as exc:
+ raise _wrap_context_exception(
+ method_name=method_name,
+ details=details,
+ exc=exc,
+ ) from exc
+
+
+def _normalize_platform_instance_payload(payload: Any) -> dict[str, Any] | None:
+ if not isinstance(payload, dict):
+ return None
+ platform_id = str(payload.get("id", "")).strip()
+ platform_type = str(payload.get("type", "")).strip()
+ if not platform_id or not platform_type:
+ return None
+ # Normalize platform records once at the runtime boundary so later lookups
+ # do not each need to remember the same string cleanup rules.
+ return {
+ "id": platform_id,
+ "name": str(payload.get("name", platform_id)).strip() or platform_id,
+ "type": platform_type,
+ "status": PlatformStatus.from_value(payload.get("status")),
+ }
+
+
+@dataclass(slots=True)
+class PlatformCompatFacade:
+ """兼容层平台入口,仅暴露安全元信息和主动发送能力。"""
+
+ _ctx: Context
+ id: str
+ name: str
+ type: str
+ status: PlatformStatus = PlatformStatus.PENDING
+ errors: list[PlatformError] = field(default_factory=list)
+ last_error: PlatformError | None = None
+ unified_webhook: bool = False
+ _state_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
+
+ async def send_by_session(
+ self,
+ session: str | MessageSession,
+ content: PlatformCompatContent,
+ ) -> dict[str, Any]:
+ return await self._ctx.platform.send_by_session(session, content)
+
+ async def send_by_id(
+ self,
+ session_id: str,
+ content: PlatformCompatContent,
+ *,
+ message_type: str = "private",
+ ) -> dict[str, Any]:
+ return await self._ctx.platform.send_by_id(
+ self.id,
+ session_id,
+ content,
+ message_type=message_type,
+ )
+
+ async def send(
+ self,
+ session: str | MessageSession,
+ content: PlatformCompatContent,
+ *,
+ message_type: str = "private",
+ ) -> dict[str, Any]:
+ if isinstance(session, MessageSession):
+ return await self.send_by_session(session, content)
+ session_text = str(session).strip()
+ if ":" in session_text:
+ return await self.send_by_session(session_text, content)
+ return await self.send_by_id(
+ session_text,
+ content,
+ message_type=message_type,
+ )
+
+ async def refresh(self) -> None:
+ async with self._state_lock:
+ await self._refresh_locked()
+
+ async def clear_errors(self) -> None:
+ async with self._state_lock:
+ await self._call_platform_manager(
+ "platform.manager.clear_errors",
+ {"platform_id": self.id},
+ method_name="platform.clear_errors",
+ details=f"platform_id={self.id!r}",
+ )
+ await self._refresh_locked()
+
+ async def get_stats(self) -> PlatformStats | None:
+ output = await self._call_platform_manager(
+ "platform.manager.get_stats",
+ {"platform_id": self.id},
+ method_name="platform.get_stats",
+ details=f"platform_id={self.id!r}",
+ )
+ return PlatformStats.from_payload(output.get("stats"))
+
+ def _apply_snapshot(self, payload: Any) -> None:
+ if not isinstance(payload, dict):
+ return
+ self.name = str(payload.get("name", self.name))
+ self.type = str(payload.get("type", self.type))
+ self.status = PlatformStatus.from_value(payload.get("status"))
+ errors_payload = payload.get("errors")
+ if isinstance(errors_payload, list):
+ self.errors = [
+ error
+ for error in (
+ PlatformError.from_payload(item) if isinstance(item, dict) else None
+ for item in errors_payload
+ )
+ if error is not None
+ ]
+ self.last_error = PlatformError.from_payload(payload.get("last_error"))
+ self.unified_webhook = bool(payload.get("unified_webhook", False))
+
+ async def _refresh_locked(self) -> None:
+ output = await self._call_platform_manager(
+ "platform.manager.get_by_id",
+ {"platform_id": self.id},
+ method_name="platform.refresh",
+ details=f"platform_id={self.id!r}",
+ )
+ self._apply_snapshot(output.get("platform"))
+
+ async def _call_platform_manager(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ method_name: str,
+ details: str | None = None,
+ ) -> dict[str, Any]:
+ call_proxy = getattr(self._ctx, "_call_proxy", None)
+ if callable(call_proxy):
+ return await call_proxy(
+ capability,
+ payload,
+ method_name=method_name,
+ details=details,
+ )
+ return await _call_proxy_with_context(
+ self._ctx._proxy,
+ capability,
+ payload,
+ method_name=method_name,
+ details=details,
+ )
+
+
+@dataclass(slots=True)
+class CancelToken:
+ """请求取消令牌。
+
+ 用于协调长时间运行操作的取消。当用户取消请求或
+ 上游超时时,令牌会被触发,允许 handler 及时清理资源。
+
+ Example:
+ async def long_operation(ctx: Context):
+ for item in large_list:
+ ctx.cancel_token.raise_if_cancelled()
+ await process(item)
+ """
+
+ _cancelled: asyncio.Event
+
+ def __init__(self) -> None:
+ self._cancelled = asyncio.Event()
+
+ def cancel(self) -> None:
+ """触发取消信号。"""
+ self._cancelled.set()
+
+ @property
+ def cancelled(self) -> bool:
+ """检查是否已被取消。"""
+ return self._cancelled.is_set()
+
+ async def wait(self) -> None:
+ """等待取消信号。"""
+ await self._cancelled.wait()
+
+ def raise_if_cancelled(self) -> None:
+ """如果已取消则抛出 CancelledError。
+
+ Raises:
+ asyncio.CancelledError: 如果令牌已被取消
+ """
+ if self.cancelled:
+ raise asyncio.CancelledError
+
+
+class Context:
+ """插件运行时上下文。
+
+ 组合所有 capability 客户端,提供统一的访问接口。
+ 每个 handler 调用都会创建新的 Context 实例。
+
+ Attributes:
+ peer: 协议对等端,用于底层通信
+ llm: LLM 客户端
+ memory: 记忆客户端
+ db: 数据库客户端
+ platform: 平台客户端
+ permission: 权限客户端
+ providers: Provider 客户端
+ provider_manager: Provider 管理客户端
+ permission_manager: 权限管理客户端
+ personas: 人格管理客户端
+ conversations: 对话管理客户端
+ kbs: 知识库管理客户端
+ message_history: 消息历史管理客户端
+ http: HTTP 客户端
+ metadata: 元数据客户端
+ registry: 能力注册客户端
+ skills: 技能客户端
+ session_plugins: 会话插件管理器
+ session_services: 会话服务管理器
+ plugin_id: 当前插件 ID
+ logger: 日志器
+ cancel_token: 取消令牌
+ """
+
+ def __init__(
+ self,
+ *,
+ peer,
+ plugin_id: str,
+ request_id: str | None = None,
+ cancel_token: CancelToken | None = None,
+ logger: Any | None = None,
+ source_event_payload: dict[str, Any] | None = None,
+ ) -> None:
+ """初始化上下文。
+
+ Args:
+ peer: 协议对等端实例
+ plugin_id: 当前插件 ID
+ cancel_token: 取消令牌,None 时创建新令牌
+ logger: 日志器,None 时使用默认 logger 并绑定 plugin_id
+ """
+ proxy = CapabilityProxy(
+ peer,
+ caller_plugin_id=plugin_id,
+ request_scope_id=request_id,
+ )
+ if isinstance(logger, PluginLogger):
+ bound_logger = logger
+ else:
+ bound_logger = logger or base_logger.bind(plugin_id=plugin_id)
+ self._proxy = proxy
+ self.peer = peer
+ self.llm = LLMClient(proxy)
+ self.memory = MemoryClient(proxy)
+ self.db = DBClient(proxy)
+ self.platform = PlatformClient(proxy)
+ self.permission = PermissionClient(proxy)
+ self.providers = ProviderClient(proxy)
+ self.provider_manager = ProviderManagerClient(
+ proxy,
+ plugin_id=plugin_id,
+ logger=bound_logger,
+ )
+ self.permission_manager = PermissionManagerClient(
+ proxy,
+ source_event_payload=source_event_payload,
+ )
+ self.personas = PersonaManagerClient(proxy)
+ self.conversations = ConversationManagerClient(proxy)
+ self.kbs = KnowledgeBaseManagerClient(proxy)
+ self.message_history = MessageHistoryManagerClient(proxy)
+ self.http = HTTPClient(proxy)
+ self.metadata = MetadataClient(proxy, plugin_id)
+ self.registry = RegistryClient(proxy)
+ self.skills = SkillClient(proxy)
+ self.session_plugins = SessionPluginManager(proxy)
+ self.session_services = SessionServiceManager(proxy)
+ self.persona_manager = self.personas
+ self.conversation_manager = self.conversations
+ self.kb_manager = self.kbs
+ self.message_history_manager = self.message_history
+ self._llm_tool_manager = LLMToolManager(proxy)
+ self.plugin_id = plugin_id
+ self.logger: PluginLogger = (
+ bound_logger
+ if isinstance(bound_logger, PluginLogger)
+ else PluginLogger(plugin_id=plugin_id, logger=bound_logger)
+ )
+ self.cancel_token = cancel_token or CancelToken()
+ self.request_id = request_id
+ self._source_event_payload = (
+ dict(source_event_payload) if isinstance(source_event_payload, dict) else {}
+ )
+
+ async def _call_proxy(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ method_name: str,
+ details: str | None = None,
+ ) -> dict[str, Any]:
+ return await _call_proxy_with_context(
+ self._proxy,
+ capability,
+ payload,
+ method_name=method_name,
+ details=details,
+ )
+
+ @staticmethod
+ def _platform_lookup_target(value: str) -> tuple[str, str]:
+ normalized_value = str(value).strip()
+ return normalized_value, normalized_value.lower()
+
+ @staticmethod
+ def _match_platform_instance(
+ platform_payload: dict[str, Any],
+ *,
+ platform_id: str | None = None,
+ platform_alias: str | None = None,
+ ) -> bool:
+ if platform_id is not None and platform_payload.get("id") == platform_id:
+ return True
+ if platform_alias is None:
+ return False
+ return (
+ str(platform_payload.get("type", "")).strip().lower() == platform_alias
+ or str(platform_payload.get("name", "")).strip().lower() == platform_alias
+ )
+
+ async def get_data_dir(self) -> Path:
+ """Return the plugin-scoped data directory path."""
+ output = await self._call_proxy(
+ "system.get_data_dir",
+ {},
+ method_name="get_data_dir",
+ )
+ return Path(str(output.get("path", "")))
+
+ async def text_to_image(
+ self,
+ text: str,
+ *,
+ return_url: bool = True,
+ ) -> str:
+ """Render plain text into an image using the host renderer."""
+ output = await self._call_proxy(
+ "system.text_to_image",
+ {"text": text, "return_url": return_url},
+ method_name="text_to_image",
+ details=f"return_url={return_url!r}",
+ )
+ return str(output.get("result", ""))
+
+ async def html_render(
+ self,
+ tmpl: str,
+ data: dict[str, Any],
+ *,
+ return_url: bool = True,
+ options: dict[str, Any] | None = None,
+ ) -> str:
+ """Render an HTML template using the host renderer."""
+ output = await self._call_proxy(
+ "system.html_render",
+ {
+ "tmpl": tmpl,
+ "data": dict(data),
+ "return_url": return_url,
+ "options": options,
+ },
+ method_name="html_render",
+ details=f"tmpl={tmpl!r}, return_url={return_url!r}",
+ )
+ return str(output.get("result", ""))
+
+ async def get_using_provider(self, umo: str | None = None) -> ProviderMeta | None:
+ return await self.providers.get_using_chat(umo)
+
+ async def get_current_chat_provider_id(self, umo: str | None = None) -> str | None:
+ output = await self._call_proxy(
+ "provider.get_current_chat_provider_id",
+ {"umo": umo},
+ method_name="get_current_chat_provider_id",
+ details=f"umo={umo!r}",
+ )
+ value = output.get("provider_id")
+ return str(value) if value else None
+
+ async def get_all_providers(self) -> list[ProviderMeta]:
+ return await self.providers.list_all()
+
+ async def get_all_tts_providers(self) -> list[ProviderMeta]:
+ return await self.providers.list_tts()
+
+ async def get_all_stt_providers(self) -> list[ProviderMeta]:
+ return await self.providers.list_stt()
+
+ async def get_all_embedding_providers(self) -> list[ProviderMeta]:
+ return await self.providers.list_embedding()
+
+ async def get_all_rerank_providers(self) -> list[ProviderMeta]:
+ return await self.providers.list_rerank()
+
+ async def get_using_tts_provider(
+ self, umo: str | None = None
+ ) -> ProviderMeta | None:
+ provider = await self.providers.get_using_tts(umo)
+ return provider.meta() if provider is not None else None
+
+ async def get_using_stt_provider(
+ self, umo: str | None = None
+ ) -> ProviderMeta | None:
+ provider = await self.providers.get_using_stt(umo)
+ return provider.meta() if provider is not None else None
+
+ async def send_message(
+ self,
+ session: str | MessageSession,
+ content: PlatformCompatContent,
+ ) -> dict[str, Any]:
+ return await self.platform.send_by_session(session, content)
+
+ async def send_message_by_id(
+ self,
+ type: str,
+ id: str,
+ content: PlatformCompatContent,
+ *,
+ platform: str,
+ ) -> dict[str, Any]:
+ platform_payload = await self._resolve_platform_target(platform)
+ return await self.platform.send_by_id(
+ str(platform_payload.get("id", "")),
+ str(id),
+ content,
+ message_type=self._normalize_compat_message_type(type),
+ )
+
+ @staticmethod
+ def _normalize_compat_message_type(value: str) -> str:
+ normalized = normalize_message_type(value)
+ if not normalized:
+ raise AstrBotError.invalid_input("send_message_by_id requires type")
+ return normalized
+
+ async def _resolve_platform_target(self, platform: str) -> dict[str, Any]:
+ target, normalized_target = self._platform_lookup_target(platform)
+ if not target:
+ raise AstrBotError.invalid_input(
+ "send_message_by_id requires explicit platform"
+ )
+ instances = await self._list_platform_instances()
+ id_matches = [
+ item
+ for item in instances
+ if self._match_platform_instance(item, platform_id=target)
+ ]
+ if len(id_matches) == 1:
+ return id_matches[0]
+ alias_matches = [
+ item
+ for item in instances
+ if self._match_platform_instance(item, platform_alias=normalized_target)
+ ]
+ if len(alias_matches) == 1:
+ return alias_matches[0]
+ if len(alias_matches) > 1:
+ raise AstrBotError.invalid_input(
+ f"send_message_by_id platform '{target}' is ambiguous"
+ )
+ raise AstrBotError.invalid_input(
+ f"send_message_by_id cannot resolve platform '{target}'"
+ )
+
+ def get_llm_tool_manager(self) -> LLMToolManager:
+ return self._llm_tool_manager
+
+ async def activate_llm_tool(self, name: str) -> bool:
+ return await self._llm_tool_manager.activate(name)
+
+ async def deactivate_llm_tool(self, name: str) -> bool:
+ return await self._llm_tool_manager.deactivate(name)
+
+ async def add_llm_tools(self, *tools: LLMToolSpec) -> list[str]:
+ return await self._llm_tool_manager.add(*tools)
+
+ async def register_llm_tool(
+ self,
+ name: str,
+ parameters_schema: dict[str, Any],
+ desc: str,
+ func_obj: Callable[..., Any] | Callable[..., Awaitable[Any]],
+ *,
+ active: bool = True,
+ ) -> list[str]:
+ if not callable(func_obj):
+ raise TypeError("register_llm_tool requires a callable func_obj")
+ tool_name = str(name).strip()
+ if not tool_name:
+ raise AstrBotError.invalid_input("register_llm_tool requires name")
+ if not isinstance(parameters_schema, dict):
+ raise TypeError("register_llm_tool requires parameters_schema dict")
+
+ handler_ref = f"__dynamic_llm_tool__:{tool_name}"
+ tool_spec = LLMToolSpec.create(
+ name=tool_name,
+ description=str(desc),
+ parameters_schema=dict(parameters_schema),
+ handler_ref=handler_ref,
+ active=bool(active),
+ )
+ owner = getattr(func_obj, "__self__", None) or current_star_instance()
+ dispatcher = getattr(self.peer, "_sdk_capability_dispatcher", None)
+ if dispatcher is not None and hasattr(dispatcher, "add_dynamic_llm_tool"):
+ dispatcher.add_dynamic_llm_tool(
+ plugin_id=self.plugin_id,
+ spec=tool_spec,
+ callable_obj=func_obj,
+ owner=owner,
+ )
+ try:
+ return await self._llm_tool_manager.add(tool_spec)
+ except Exception as exc:
+ if dispatcher is not None and hasattr(dispatcher, "remove_llm_tool"):
+ dispatcher.remove_llm_tool(self.plugin_id, tool_name)
+ raise _wrap_context_exception(
+ method_name="register_llm_tool",
+ details=f"name={tool_name!r}, active={bool(active)!r}",
+ exc=exc,
+ ) from exc
+
+ async def unregister_llm_tool(self, name: str) -> bool:
+ removed = await self._llm_tool_manager.remove(str(name))
+ dispatcher = getattr(self.peer, "_sdk_capability_dispatcher", None)
+ if dispatcher is not None and hasattr(dispatcher, "remove_llm_tool"):
+ dispatcher.remove_llm_tool(self.plugin_id, str(name))
+ return removed
+
+ async def register_skill(
+ self,
+ *,
+ name: str,
+ path: str | Path,
+ description: str = "",
+ ) -> SkillRegistration:
+ try:
+ return await self.skills.register(
+ name=name,
+ path=str(path),
+ description=description,
+ )
+ except Exception as exc:
+ raise _wrap_context_exception(
+ method_name="register_skill",
+ details=f"name={name!r}, path={str(path)!r}",
+ exc=exc,
+ ) from exc
+
+ async def unregister_skill(self, name: str) -> bool:
+ try:
+ return await self.skills.unregister(name)
+ except Exception as exc:
+ raise _wrap_context_exception(
+ method_name="unregister_skill",
+ details=f"name={name!r}",
+ exc=exc,
+ ) from exc
+
+ async def tool_loop_agent(
+ self,
+ request: ProviderRequest | None = None,
+ **kwargs: Any,
+ ) -> LLMResponse:
+ provider_request = request or ProviderRequest()
+ if kwargs:
+ merged = provider_request.model_dump()
+ merged.update(kwargs)
+ provider_request = ProviderRequest.model_validate(merged)
+ payload = provider_request.to_payload()
+ target_payload = self._source_event_payload.get("target")
+ if isinstance(target_payload, dict):
+ # Preserve the original message target so core can recover the
+ # dispatch token for message-bound tool loop execution.
+ payload["target"] = dict(target_payload)
+ output = await self._call_proxy(
+ "agent.tool_loop.run",
+ payload,
+ method_name="tool_loop_agent",
+ details=(
+ f"session_id={provider_request.session_id!r}, "
+ f"contexts={len(provider_request.contexts)!r}"
+ ),
+ )
+ return LLMResponse.model_validate(output)
+
+ def _source_event_type(self) -> str:
+ event_type = self._source_event_payload.get("event_type")
+ if isinstance(event_type, str) and event_type.strip():
+ return event_type.strip()
+ fallback_type = self._source_event_payload.get("type")
+ if isinstance(fallback_type, str) and fallback_type.strip():
+ return fallback_type.strip()
+ raw_payload = self._source_event_payload.get("raw")
+ if isinstance(raw_payload, dict):
+ raw_event_type = raw_payload.get("event_type")
+ if isinstance(raw_event_type, str) and raw_event_type.strip():
+ return raw_event_type.strip()
+ return ""
+
+ async def register_commands(
+ self,
+ command_name: str,
+ handler_full_name: str,
+ *,
+ desc: str = "",
+ priority: int = 0,
+ use_regex: bool = False,
+ ignore_prefix: bool = False,
+ ) -> None:
+ source_event_type = self._source_event_type()
+ if source_event_type not in {"astrbot_loaded", "platform_loaded"}:
+ raise AstrBotError.invalid_input(
+ "register_commands is only available in astrbot_loaded/platform_loaded events"
+ )
+ if ignore_prefix:
+ raise AstrBotError.invalid_input(
+ "register_commands(ignore_prefix=True) is unsupported in SDK runtime"
+ )
+ if isinstance(priority, bool) or not isinstance(priority, int):
+ raise AstrBotError.invalid_input(
+ "register_commands priority must be an integer"
+ )
+ normalized_command_name = str(command_name)
+ normalized_handler_name = str(handler_full_name)
+ await self._call_proxy(
+ "registry.command.register",
+ {
+ "command_name": normalized_command_name,
+ "handler_full_name": normalized_handler_name,
+ "source_event_type": source_event_type,
+ "desc": str(desc),
+ "priority": priority,
+ "use_regex": bool(use_regex),
+ "ignore_prefix": False,
+ },
+ method_name="register_commands",
+ details=(
+ f"command_name={normalized_command_name!r}, "
+ f"handler_full_name={normalized_handler_name!r}"
+ ),
+ )
+
+ async def register_task(
+ self,
+ task: Awaitable[Any],
+ desc: str,
+ ) -> asyncio.Task[Any]:
+ """Register a background task owned by the current SDK context.
+
+ This is the recommended way to launch follow-up work that should outlive
+ the current handler dispatch, including `session_waiter(...)` flows.
+ Directly awaiting a waiter inside the current handler keeps the original
+ dispatch open until the next message arrives.
+
+ Example:
+ await event.reply("请输入用户名:")
+ await ctx.register_task(
+ self.collect_username(event),
+ "waiter:collect_username",
+ )
+ """
+ task_desc = str(desc)
+
+ async def _wrap_future(future: asyncio.Future[Any]) -> Any:
+ return await future
+
+ if isinstance(task, asyncio.Task):
+ background_task = task
+ elif asyncio.isfuture(task):
+ background_task = asyncio.create_task(_wrap_future(task))
+ elif asyncio.iscoroutine(task):
+ background_task = asyncio.create_task(task)
+ else:
+ raise TypeError(
+ "Context.register_task requires an awaitable task object; "
+ f"got {type(task).__name__} for desc={task_desc!r}"
+ )
+
+ _mark_session_waiter_background_task(background_task)
+
+ def _on_done(done_task: asyncio.Task[Any]) -> None:
+ _unmark_session_waiter_background_task(done_task)
+ if done_task.cancelled():
+ debug_logger = getattr(self.logger, "debug", None)
+ if callable(debug_logger):
+ debug_logger(
+ "SDK background task cancelled: plugin_id={} desc={}",
+ self.plugin_id,
+ task_desc,
+ )
+ return
+ try:
+ done_task.result()
+ except Exception:
+ exception_logger = getattr(self.logger, "exception", None)
+ if callable(exception_logger):
+ exception_logger(
+ "SDK background task failed: plugin_id={} desc={}",
+ self.plugin_id,
+ task_desc,
+ )
+
+ background_task.add_done_callback(_on_done)
+ return background_task
+
+ async def _list_platform_instances(self) -> list[dict[str, Any]]:
+ output = await self._call_proxy(
+ "platform.list_instances",
+ {},
+ method_name="list_platforms",
+ )
+ items = output.get("platforms")
+ if not isinstance(items, list):
+ return []
+ normalized: list[dict[str, Any]] = []
+ for item in items:
+ normalized_item = _normalize_platform_instance_payload(item)
+ if normalized_item is not None:
+ normalized.append(normalized_item)
+ return normalized
+
+ def _build_platform_facade(
+ self,
+ platform_payload: dict[str, Any],
+ ) -> PlatformCompatFacade:
+ return PlatformCompatFacade(
+ _ctx=self,
+ id=str(platform_payload.get("id", "")),
+ name=str(platform_payload.get("name", "")),
+ type=str(platform_payload.get("type", "")),
+ status=PlatformStatus.from_value(platform_payload.get("status")),
+ )
+
+ async def list_platforms(self) -> list[PlatformCompatFacade]:
+ """获取所有平台实例的兼容层列表。
+
+ Returns:
+ 所有可见平台实例的兼容层对象列表
+
+ Example:
+ for platform in await ctx.list_platforms():
+ print(platform.id, platform.status)
+ """
+ return [
+ self._build_platform_facade(item)
+ for item in await self._list_platform_instances()
+ ]
+
+ async def get_platform(self, platform_type: str) -> PlatformCompatFacade | None:
+ _, target_type = self._platform_lookup_target(platform_type)
+ if not target_type:
+ return None
+ for item in await self._list_platform_instances():
+ if self._match_platform_instance(item, platform_alias=target_type):
+ return self._build_platform_facade(item)
+ return None
+
+ async def get_platform_inst(self, platform_id: str) -> PlatformCompatFacade | None:
+ target_id, _ = self._platform_lookup_target(platform_id)
+ if not target_id:
+ return None
+ for item in await self._list_platform_instances():
+ if self._match_platform_instance(item, platform_id=target_id):
+ return self._build_platform_facade(item)
+ return None
diff --git a/astrbot-sdk/src/astrbot_sdk/conversation.py b/astrbot-sdk/src/astrbot_sdk/conversation.py
new file mode 100644
index 0000000000..78e3cd9095
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/conversation.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any
+
+from .context import Context
+from .events import MessageEvent
+from .message.components import BaseMessageComponent
+from .message.result import MessageChain
+from .session_waiter import SessionWaiterManager
+
+DEFAULT_BUSY_MESSAGE = "当前会话已有进行中的交互,请先完成后再试。"
+
+
+class ConversationState(str, Enum):
+ ACTIVE = "active"
+ REJECTED_BUSY = "rejected_busy"
+ REPLACED = "replaced"
+ TIMEOUT = "timeout"
+ COMPLETED = "completed"
+ CANCELLED = "cancelled"
+
+
+class ConversationReplaced(RuntimeError):
+ pass
+
+
+class ConversationClosed(RuntimeError):
+ pass
+
+
+@dataclass(slots=True)
+class ConversationSession:
+ ctx: Context
+ event: MessageEvent
+ waiter_manager: SessionWaiterManager
+ timeout: int
+ state: ConversationState = ConversationState.ACTIVE
+ _owner_task: asyncio.Task[Any] | None = None
+
+ def __post_init__(self) -> None:
+ if self.state is None:
+ self.state = ConversationState.ACTIVE
+ return
+ if not isinstance(self.state, ConversationState):
+ self.state = ConversationState(str(self.state))
+
+ def bind_owner_task(self, task: asyncio.Task[Any]) -> None:
+ self._owner_task = task
+
+ @property
+ def session_key(self) -> str:
+ return self.event.unified_msg_origin
+
+ @property
+ def active(self) -> bool:
+ return self.state == ConversationState.ACTIVE
+
+ async def ask(self, prompt: str, timeout: int | None = None) -> MessageEvent:
+ self._ensure_usable("ask")
+ if prompt:
+ await self.reply(prompt)
+ try:
+ return await self.waiter_manager.wait_for_event(
+ event=self.event,
+ timeout=timeout or self.timeout,
+ record_history_chains=False,
+ )
+ except asyncio.TimeoutError:
+ self.close(ConversationState.TIMEOUT)
+ raise
+ except asyncio.CancelledError as exc:
+ if self.state == ConversationState.REPLACED:
+ raise ConversationReplaced(
+ "conversation replaced by a newer session"
+ ) from exc
+ self.close(ConversationState.CANCELLED)
+ raise
+
+ async def reply(self, text: str) -> None:
+ self._ensure_usable("reply")
+ await self.event.reply(text)
+
+ async def reply_chain(
+ self,
+ chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]],
+ ) -> None:
+ self._ensure_usable("reply_chain")
+ await self.event.reply_chain(chain)
+
+ async def send_message(
+ self,
+ content: str | MessageChain | list[BaseMessageComponent] | list[dict[str, Any]],
+ ) -> dict[str, Any]:
+ self._ensure_usable("send_message")
+ return await self.ctx.platform.send_by_session(self.event.session_id, content)
+
+ def end(self) -> None:
+ self.close(ConversationState.COMPLETED)
+
+ def mark_replaced(self) -> None:
+ self.close(ConversationState.REPLACED)
+
+ def close(self, state: ConversationState) -> None:
+ if self.state != ConversationState.ACTIVE and state == self.state:
+ return
+ if (
+ self.state != ConversationState.ACTIVE
+ and state != ConversationState.REPLACED
+ ):
+ return
+ self.state = state
+
+ def _ensure_usable(self, action: str) -> None:
+ if (
+ self._owner_task is not None
+ and asyncio.current_task() is not self._owner_task
+ ):
+ raise ConversationClosed(
+ f"ConversationSession cannot be used outside its owner task during {action}"
+ )
+ if not self.active:
+ raise ConversationClosed(
+ f"ConversationSession is already closed ({self.state.value}) during {action}"
+ )
+
+
+__all__ = [
+ "ConversationClosed",
+ "ConversationReplaced",
+ "ConversationSession",
+ "ConversationState",
+ "DEFAULT_BUSY_MESSAGE",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/decorators.py b/astrbot-sdk/src/astrbot_sdk/decorators.py
new file mode 100644
index 0000000000..49d69985ab
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/decorators.py
@@ -0,0 +1,1332 @@
+"""astrbot-sdk 原生装饰器。
+
+提供声明式的方法来注册 handler 和 capability。
+装饰器会在方法上附加元数据,由 Star.__init_subclass__ 自动收集。
+
+触发器装饰器:
+ - @on_command: 命令触发器
+ - @on_message: 消息触发器(关键词/正则)
+ - @on_event: 事件触发器
+ - @on_schedule: 定时任务触发器
+ - @conversation_command: 带会话生命周期的命令触发器
+
+权限与过滤装饰器:
+ - @require_admin / @admin_only: 管理员权限标记
+ - @require_permission: 通用角色权限标记
+ - @platforms: 限定平台
+ - @group_only / @private_only: 群聊/私聊限定
+ - @message_types: 消息类型过滤
+
+限流装饰器:
+ - @rate_limit: 滑动窗口限流
+ - @cooldown: 冷却时间
+
+优先级装饰器:
+ - @priority: 设置执行优先级
+
+能力导出装饰器:
+ - @provide_capability: 声明对外暴露的能力
+ - @register_llm_tool: 注册 LLM 工具
+ - @register_agent: 注册 Agent
+
+Example:
+ class MyPlugin(Star):
+ @on_command("hello", aliases=["hi"])
+ async def hello(self, event: MessageEvent, ctx: Context):
+ await event.reply("Hello!")
+
+ @on_message(keywords=["help"])
+ async def help(self, event: MessageEvent, ctx: Context):
+ await event.reply("Help info...")
+
+ @provide_capability("my_plugin.calculate", description="计算")
+ async def calculate(self, payload: dict, ctx: Context):
+ return {"result": payload["x"] * 2}
+"""
+
+from __future__ import annotations
+
+import inspect
+import typing
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from typing import Any, Literal, TypeVar, cast
+
+from pydantic import BaseModel
+
+from ._internal.typing_utils import unwrap_optional
+from .llm.agents import AgentSpec, BaseAgentRunner
+from .llm.entities import LLMToolSpec
+from .protocol.descriptors import (
+ RESERVED_CAPABILITY_PREFIXES,
+ CapabilityDescriptor,
+ CommandRouteSpec,
+ CommandTrigger,
+ EventTrigger,
+ FilterSpec,
+ MessageTrigger,
+ MessageTypeFilterSpec,
+ Permissions,
+ PlatformFilterSpec,
+ ScheduleTrigger,
+)
+
+HandlerCallable = Callable[..., Any]
+_HandlerT = TypeVar("_HandlerT", bound=Callable[..., Any])
+HANDLER_META_ATTR = "__astrbot_handler_meta__"
+CAPABILITY_META_ATTR = "__astrbot_capability_meta__"
+LLM_TOOL_META_ATTR = "__astrbot_llm_tool_meta__"
+AGENT_META_ATTR = "__astrbot_agent_meta__"
+HTTP_API_META_ATTR = "__astrbot_http_api_meta__"
+VALIDATE_CONFIG_META_ATTR = "__astrbot_validate_config_meta__"
+PROVIDER_CHANGE_META_ATTR = "__astrbot_provider_change_meta__"
+BACKGROUND_TASK_META_ATTR = "__astrbot_background_task_meta__"
+SKILL_META_ATTR = "__astrbot_skill_meta__"
+
+LimiterScope = Literal["session", "user", "group", "global"]
+LimiterBehavior = Literal["hint", "silent", "error"]
+ConversationMode = Literal["replace", "reject"]
+
+
+@dataclass(slots=True)
+class LimiterMeta:
+ kind: Literal["rate_limit", "cooldown"]
+ limit: int
+ window: float
+ scope: LimiterScope = "session"
+ behavior: LimiterBehavior = "hint"
+ message: str | None = None
+
+
+@dataclass(slots=True)
+class ConversationMeta:
+ timeout: int = 60
+ mode: ConversationMode = "replace"
+ busy_message: str | None = None
+ grace_period: float = 1.0
+
+
+@dataclass(slots=True)
+class HandlerMeta:
+ """Handler 元数据。
+
+ 存储在方法上的 __astrbot_handler_meta__ 属性中。
+
+ Attributes:
+ trigger: 触发器(命令/消息/事件/定时)
+ kind: handler 类型标识
+ contract: 契约类型(可选)
+ priority: 执行优先级(数值越大越先执行)
+ permissions: 权限要求
+ """
+
+ trigger: CommandTrigger | MessageTrigger | EventTrigger | ScheduleTrigger | None = (
+ None
+ )
+ kind: str = "handler"
+ contract: str | None = None
+ description: str | None = None
+ priority: int = 0
+ permissions: Permissions = field(default_factory=Permissions)
+ filters: list[FilterSpec] = field(default_factory=list)
+ local_filters: list[Any] = field(default_factory=list)
+ command_route: CommandRouteSpec | None = None
+ limiter: LimiterMeta | None = None
+ conversation: ConversationMeta | None = None
+ decorator_sources: dict[str, str] = field(default_factory=dict)
+
+
+@dataclass(slots=True)
+class CapabilityMeta:
+ """Capability 元数据。
+
+ 存储在方法上的 __astrbot_capability_meta__ 属性中。
+
+ Attributes:
+ descriptor: 能力描述符
+ """
+
+ descriptor: CapabilityDescriptor
+
+
+@dataclass(slots=True)
+class LLMToolMeta:
+ spec: LLMToolSpec
+
+
+@dataclass(slots=True)
+class AgentMeta:
+ spec: AgentSpec
+
+
+@dataclass(slots=True)
+class HttpApiMeta:
+ route: str
+ methods: list[str] = field(default_factory=lambda: ["GET"])
+ description: str = ""
+ capability_name: str | None = None
+
+
+@dataclass(slots=True)
+class ValidateConfigMeta:
+ model: type[BaseModel] | None = None
+ schema: dict[str, Any] | None = None
+
+
+def _is_valid_validate_config_expected_type(value: Any) -> bool:
+ if isinstance(value, type):
+ return True
+ return (
+ isinstance(value, tuple)
+ and len(value) > 0
+ and all(isinstance(item, type) for item in value)
+ )
+
+
+def _validate_validate_config_schema(schema: dict[str, Any]) -> None:
+ for field_name, field_schema in schema.items():
+ if not isinstance(field_schema, dict):
+ raise TypeError(
+ f"validate_config schema field {field_name!r} must be a dict"
+ )
+ expected_type = field_schema.get("type")
+ if expected_type is not None and not _is_valid_validate_config_expected_type(
+ expected_type
+ ):
+ raise TypeError(
+ "validate_config schema field "
+ f"{field_name!r} has invalid 'type' entry {expected_type!r}; "
+ "expected a type or tuple of types"
+ )
+
+
+@dataclass(slots=True)
+class ProviderChangeMeta:
+ provider_types: list[str] = field(default_factory=list)
+
+
+@dataclass(slots=True)
+class BackgroundTaskMeta:
+ description: str = ""
+ auto_start: bool = True
+ on_error: Literal["log", "restart"] = "log"
+
+
+@dataclass(slots=True)
+class SkillMeta:
+ name: str
+ path: str
+ description: str = ""
+
+
+def _get_or_create_meta(func: HandlerCallable) -> HandlerMeta:
+ """获取或创建 handler 元数据。"""
+ meta = getattr(func, HANDLER_META_ATTR, None)
+ if meta is None:
+ meta = HandlerMeta()
+ setattr(func, HANDLER_META_ATTR, meta)
+ return meta
+
+
+def get_handler_meta(func: HandlerCallable) -> HandlerMeta | None:
+ """获取方法的 handler 元数据。
+
+ Args:
+ func: 要检查的方法
+
+ Returns:
+ HandlerMeta 实例,如果没有则返回 None
+ """
+ return getattr(func, HANDLER_META_ATTR, None)
+
+
+def get_capability_meta(func: HandlerCallable) -> CapabilityMeta | None:
+ """获取方法的 capability 元数据。
+
+ Args:
+ func: 要检查的方法
+
+ Returns:
+ CapabilityMeta 实例,如果没有则返回 None
+ """
+ return getattr(func, CAPABILITY_META_ATTR, None)
+
+
+def get_llm_tool_meta(func: HandlerCallable) -> LLMToolMeta | None:
+ return getattr(func, LLM_TOOL_META_ATTR, None)
+
+
+def get_agent_meta(obj: Any) -> AgentMeta | None:
+ return getattr(obj, AGENT_META_ATTR, None)
+
+
+def get_http_api_meta(func: HandlerCallable) -> HttpApiMeta | None:
+ return getattr(func, HTTP_API_META_ATTR, None)
+
+
+def get_validate_config_meta(func: HandlerCallable) -> ValidateConfigMeta | None:
+ return getattr(func, VALIDATE_CONFIG_META_ATTR, None)
+
+
+def get_provider_change_meta(func: HandlerCallable) -> ProviderChangeMeta | None:
+ return getattr(func, PROVIDER_CHANGE_META_ATTR, None)
+
+
+def get_background_task_meta(func: HandlerCallable) -> BackgroundTaskMeta | None:
+ return getattr(func, BACKGROUND_TASK_META_ATTR, None)
+
+
+def get_skill_meta(obj: Any) -> list[SkillMeta]:
+ values = getattr(obj, SKILL_META_ATTR, None)
+ if not isinstance(values, list):
+ return []
+ return [item for item in values if isinstance(item, SkillMeta)]
+
+
+def _append_list_meta(obj: Any, attr_name: str, value: Any) -> None:
+ values = getattr(obj, attr_name, None)
+ if not isinstance(values, list):
+ values = []
+ setattr(obj, attr_name, values)
+ values.append(value)
+
+
+def _replace_filter(meta: HandlerMeta, spec: FilterSpec) -> None:
+ kind = getattr(spec, "kind", None)
+ meta.filters = [
+ item for item in meta.filters if getattr(item, "kind", None) != kind
+ ]
+ meta.filters.append(spec)
+
+
+def _has_filter_kind(meta: HandlerMeta, kind: str) -> bool:
+ return any(getattr(item, "kind", None) == kind for item in meta.filters)
+
+
+def _set_platform_filter(
+ meta: HandlerMeta,
+ values: list[str],
+ *,
+ source: str,
+) -> None:
+ normalized = [
+ value for value in dict.fromkeys(str(item).strip() for item in values) if value
+ ]
+ if not normalized:
+ return
+ existing = meta.decorator_sources.get("platforms")
+ if existing is not None and existing != source:
+ raise ValueError("platforms(...) 不能与 on_message(platforms=...) 混用")
+ if existing is None and _has_filter_kind(meta, "platform"):
+ raise ValueError("platforms(...) 不能与已有平台过滤器混用")
+ meta.decorator_sources["platforms"] = source
+ _replace_filter(meta, PlatformFilterSpec(platforms=normalized))
+
+
+def _set_message_type_filter(
+ meta: HandlerMeta,
+ values: list[str],
+ *,
+ source: str,
+) -> None:
+ normalized = [
+ value
+ for value in dict.fromkeys(str(item).strip().lower() for item in values)
+ if value
+ ]
+ if not normalized:
+ return
+ existing = meta.decorator_sources.get("message_types")
+ if existing is not None and existing != source:
+ raise ValueError(
+ "group_only()/private_only()/message_types(...) 不能与已有消息类型约束混用"
+ )
+ if existing is None and _has_filter_kind(meta, "message_type"):
+ raise ValueError(
+ "group_only()/private_only()/message_types(...) 不能与已有消息类型过滤器混用"
+ )
+ meta.decorator_sources["message_types"] = source
+ _replace_filter(meta, MessageTypeFilterSpec(message_types=normalized))
+
+
+def _validate_message_trigger_compatibility(meta: HandlerMeta) -> None:
+ if meta.limiter is None or meta.trigger is None:
+ return
+ trigger_type = getattr(meta.trigger, "type", None)
+ if trigger_type not in {"command", "message"}:
+ raise ValueError(
+ "rate_limit(...) 和 cooldown(...) 只适用于 on_command/on_message"
+ )
+
+
+def _set_required_role(
+ meta: HandlerMeta,
+ role: Literal["member", "admin"],
+) -> None:
+ current = meta.permissions.required_role
+ if current is not None and current != role:
+ raise ValueError(
+ f"require_permission({role!r}) 与已有权限要求 {current!r} 冲突"
+ )
+ meta.permissions.required_role = role
+ meta.permissions.require_admin = role == "admin"
+
+
+def _normalize_description(description: str | None) -> str | None:
+ if description is None:
+ return None
+ text = str(description).strip()
+ return text or None
+
+
+def _require_handler_callable(
+ target: Any,
+ *,
+ decorator_name: str,
+) -> None:
+ if not callable(target):
+ raise TypeError(f"{decorator_name} can only decorate callables")
+
+
+def _validate_limiter_args(
+ *,
+ kind: str,
+ limit: int,
+ window: float,
+ scope: LimiterScope,
+ behavior: LimiterBehavior,
+) -> None:
+ if isinstance(limit, bool) or int(limit) <= 0:
+ raise ValueError(f"{kind} requires a positive limit")
+ if float(window) <= 0:
+ raise ValueError(f"{kind} requires a positive window")
+ if scope not in {"session", "user", "group", "global"}:
+ raise ValueError(f"unsupported limiter scope: {scope}")
+ if behavior not in {"hint", "silent", "error"}:
+ raise ValueError(f"unsupported limiter behavior: {behavior}")
+
+
+def _set_limiter(
+ func: _HandlerT,
+ limiter: LimiterMeta,
+) -> _HandlerT:
+ meta = _get_or_create_meta(func)
+ if meta.limiter is not None:
+ raise ValueError("rate_limit(...) 和 cooldown(...) 不能叠加在同一个 handler 上")
+ meta.limiter = limiter
+ _validate_message_trigger_compatibility(meta)
+ return func
+
+
+def _model_to_schema(
+ model: type[BaseModel] | None,
+ *,
+ label: str,
+) -> dict[str, Any] | None:
+ """将 pydantic 模型转换为 JSON Schema。
+
+ Args:
+ model: pydantic BaseModel 子类
+ label: 错误消息中的字段名
+
+ Returns:
+ JSON Schema 字典,如果 model 为 None 则返回 None
+
+ Raises:
+ TypeError: 如果 model 不是 BaseModel 子类
+ """
+ if model is None:
+ return None
+ if not isinstance(model, type) or not issubclass(model, BaseModel):
+ raise TypeError(f"{label} 必须是 pydantic BaseModel 子类")
+ return cast(dict[str, Any], model.model_json_schema())
+
+
+def on_command(
+ command: str | typing.Sequence[str],
+ *,
+ aliases: list[str] | None = None,
+ description: str | None = None,
+ group: str | typing.Sequence[str] | None = None,
+ group_help: str | None = None,
+) -> Callable[[_HandlerT], _HandlerT]:
+ """注册命令处理方法。
+
+ 当用户发送指定命令时触发。命令格式为 `/{command}` 或直接 `{command}`,
+ 取决于平台配置。
+
+ Args:
+ command: 命令名称(不包含前缀符)
+ aliases: 命令别名列表
+ description: 命令描述,用于帮助信息
+ group: 指令组路径。传入 "admin" 表示一级组;传入 ["admin", "user"] 表示多级组
+ 设置后实际命令为 ``"admin command"`` 或 ``"admin user command"``
+ group_help: 指令组描述,用于帮助信息
+
+ Returns:
+ 装饰器函数
+
+ Example:
+ @on_command("echo", aliases=["repeat"], description="重复消息")
+ async def echo(self, event: MessageEvent, ctx: Context):
+ await event.reply(event.text)
+
+ @on_command("ban", group="admin", description="封禁用户")
+ async def admin_ban(self, event: MessageEvent, ctx: Context):
+ await event.reply("已封禁")
+ """
+
+ if aliases is not None and not isinstance(aliases, list):
+ raise TypeError("on_command aliases must be a list of strings")
+
+ commands = (
+ [str(command).strip()]
+ if isinstance(command, str)
+ else [str(item).strip() for item in command]
+ )
+ commands = [item for item in commands if item]
+ if not commands:
+ raise ValueError("on_command requires at least one non-empty command name")
+
+ group_path: list[str] = []
+ if group is not None:
+ group_path = (
+ [str(group).strip()]
+ if isinstance(group, str)
+ else [str(item).strip() for item in group]
+ )
+ group_path = [item for item in group_path if item]
+
+ canonical = commands[0]
+ display_command = " ".join([*group_path, canonical]) if group_path else canonical
+ merged_aliases: list[str] = [
+ item
+ for item in dict.fromkeys([*commands[1:], *(aliases or [])])
+ if isinstance(item, str) and item and item != canonical
+ ]
+ expanded_aliases: list[str] = (
+ [" ".join([*group_path, alias]) for alias in merged_aliases]
+ if group_path
+ else merged_aliases
+ )
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="on_command(...)")
+ meta = _get_or_create_meta(func)
+ normalized_description = _normalize_description(description)
+ trigger_command = display_command if group_path else canonical
+ meta.trigger = CommandTrigger(
+ command=trigger_command,
+ aliases=expanded_aliases if group_path else merged_aliases,
+ description=normalized_description,
+ )
+ meta.description = normalized_description
+ if group_path:
+ meta.command_route = CommandRouteSpec(
+ group_path=group_path,
+ display_command=display_command,
+ group_help=_normalize_description(group_help),
+ )
+ _validate_message_trigger_compatibility(meta)
+ return func
+
+ return decorator
+
+
+def on_message(
+ *,
+ regex: str | None = None,
+ keywords: list[str] | None = None,
+ platforms: list[str] | None = None,
+ message_types: list[str] | None = None,
+ description: str | None = None,
+) -> Callable[[_HandlerT], _HandlerT]:
+ """注册消息处理方法。
+
+ 当消息匹配指定条件时触发。支持正则表达式或关键词匹配。
+
+ Args:
+ regex: 正则表达式模式
+ keywords: 关键词列表(任一匹配即可)
+ platforms: 限定平台列表(如 ["qq", "wechat"])
+
+ Returns:
+ 装饰器函数
+
+ Note:
+ regex 和 keywords 至少提供一个
+
+ Example:
+ @on_message(keywords=["help", "帮助"])
+ async def help(self, event: MessageEvent, ctx: Context):
+ await event.reply("帮助信息")
+
+ @on_message(regex=r"\\d+") # 匹配数字
+ async def number_handler(self, event: MessageEvent, ctx: Context):
+ await event.reply("收到了数字")
+ """
+
+ if keywords is not None and not isinstance(keywords, list):
+ raise TypeError("on_message keywords must be a list of strings")
+ if platforms is not None and not isinstance(platforms, list):
+ raise TypeError("on_message platforms must be a list of strings")
+ if message_types is not None and not isinstance(message_types, list):
+ raise TypeError("on_message message_types must be a list of strings")
+
+ normalized_regex = None if regex is None else str(regex).strip()
+ normalized_keywords = [
+ str(item).strip() for item in (keywords or []) if str(item).strip()
+ ]
+ if not normalized_regex and not normalized_keywords:
+ raise ValueError("on_message(...) requires regex or at least one keyword")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="on_message(...)")
+ meta = _get_or_create_meta(func)
+ meta.trigger = MessageTrigger(
+ regex=normalized_regex,
+ keywords=normalized_keywords,
+ platforms=platforms or [],
+ message_types=message_types or [],
+ )
+ meta.description = _normalize_description(description)
+ if platforms:
+ _set_platform_filter(meta, list(platforms), source="trigger.platforms")
+ if message_types:
+ _set_message_type_filter(
+ meta,
+ list(message_types),
+ source="trigger.message_types",
+ )
+ _validate_message_trigger_compatibility(meta)
+ return func
+
+ return decorator
+
+
+def append_filter_meta(
+ func: _HandlerT,
+ *,
+ specs: list[FilterSpec] | None = None,
+ local_bindings: list[Any] | None = None,
+) -> _HandlerT:
+ """追加过滤器元数据。"""
+ meta = _get_or_create_meta(func)
+ if specs:
+ meta.filters.extend(specs)
+ if local_bindings:
+ meta.local_filters.extend(local_bindings)
+ return func
+
+
+def set_command_route_meta(
+ func: _HandlerT,
+ route: CommandRouteSpec,
+) -> _HandlerT:
+ """设置命令路由元数据。"""
+ meta = _get_or_create_meta(func)
+ meta.command_route = route
+ return func
+
+
+def on_event(
+ event_type: str,
+ *,
+ description: str | None = None,
+) -> Callable[[_HandlerT], _HandlerT]:
+ """注册事件处理方法。
+
+ 当特定类型的事件发生时触发。用于处理非消息类型的事件,
+ 如群成员变动、好友请求等。
+
+ Args:
+ event_type: 事件类型标识
+
+ Returns:
+ 装饰器函数
+
+ Example:
+ @on_event("group_member_join")
+ async def on_join(self, event, ctx):
+ await ctx.platform.send(event.group_id, "欢迎新人!")
+ """
+
+ normalized_event_type = str(event_type).strip()
+ if not normalized_event_type:
+ raise ValueError("on_event(...) requires a non-empty event_type")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="on_event(...)")
+ meta = _get_or_create_meta(func)
+ meta.trigger = EventTrigger(event_type=normalized_event_type)
+ meta.description = _normalize_description(description)
+ _validate_message_trigger_compatibility(meta)
+ return func
+
+ return decorator
+
+
+def on_schedule(
+ *,
+ name: str | None = None,
+ cron: str | None = None,
+ interval_seconds: int | None = None,
+ timezone: str | None = None,
+ description: str | None = None,
+) -> Callable[[_HandlerT], _HandlerT]:
+ """注册定时任务方法。
+
+ 按指定的时间计划定期执行。
+
+ Args:
+ name: 调度任务名称,默认回退为插件 ID 与 handler ID 组合
+ cron: cron 表达式(如 "0 8 * * *" 表示每天 8:00)
+ interval_seconds: 执行间隔(秒)
+ timezone: IANA 时区名称(如 "Asia/Shanghai")
+
+ Returns:
+ 装饰器函数
+
+ Note:
+ cron 和 interval_seconds 至少提供一个
+
+ Example:
+ @on_schedule(cron="0 8 * * *") # 每天 8:00
+ async def morning_greeting(self, ctx):
+ await ctx.platform.send("group_123", "早上好!")
+
+ @on_schedule(interval_seconds=3600) # 每小时
+ async def hourly_check(self, ctx):
+ pass
+ """
+
+ normalized_name = None if name is None else str(name).strip() or None
+ normalized_cron = None if cron is None else str(cron).strip() or None
+ normalized_timezone = None if timezone is None else str(timezone).strip() or None
+ if normalized_cron is None and interval_seconds is None:
+ raise ValueError("on_schedule(...) requires cron or interval_seconds")
+ if interval_seconds is not None and (
+ isinstance(interval_seconds, bool) or int(interval_seconds) <= 0
+ ):
+ raise ValueError("on_schedule(...) interval_seconds must be a positive integer")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="on_schedule(...)")
+ meta = _get_or_create_meta(func)
+ meta.trigger = ScheduleTrigger(
+ name=normalized_name,
+ cron=normalized_cron,
+ interval_seconds=(
+ None if interval_seconds is None else int(interval_seconds)
+ ),
+ timezone=normalized_timezone,
+ )
+ meta.description = _normalize_description(description)
+ _validate_message_trigger_compatibility(meta)
+ return func
+
+ return decorator
+
+
+def http_api(
+ route: str,
+ *,
+ methods: list[str] | None = None,
+ description: str = "",
+ capability_name: str | None = None,
+) -> Callable[[HandlerCallable], HandlerCallable]:
+ normalized_route = str(route).strip()
+ if not normalized_route:
+ raise ValueError("http_api(...) requires a non-empty route")
+ normalized_methods = methods or ["GET"]
+ normalized_methods = [
+ str(item).strip().upper() for item in normalized_methods if str(item).strip()
+ ]
+ if not normalized_methods:
+ raise ValueError("http_api(...) requires at least one HTTP method")
+
+ def decorator(func: HandlerCallable) -> HandlerCallable:
+ _require_handler_callable(func, decorator_name="http_api(...)")
+ setattr(
+ func,
+ HTTP_API_META_ATTR,
+ HttpApiMeta(
+ route=normalized_route,
+ methods=normalized_methods,
+ description=str(description),
+ capability_name=(
+ str(capability_name).strip()
+ if capability_name is not None
+ else None
+ ),
+ ),
+ )
+ return func
+
+ return decorator
+
+
+def validate_config(
+ *,
+ model: type[BaseModel] | None = None,
+ schema: dict[str, Any] | None = None,
+) -> Callable[[HandlerCallable], HandlerCallable]:
+ if model is None and schema is None:
+ raise ValueError("validate_config(...) requires model or schema")
+ if model is not None and schema is not None:
+ raise ValueError("validate_config(...) cannot accept model and schema together")
+ if model is not None and (
+ not isinstance(model, type) or not issubclass(model, BaseModel)
+ ):
+ raise TypeError("validate_config model must be a pydantic BaseModel subclass")
+ if schema is not None and not isinstance(schema, dict):
+ raise TypeError("validate_config schema must be a dict")
+ if isinstance(schema, dict):
+ _validate_validate_config_schema(schema)
+
+ def decorator(func: HandlerCallable) -> HandlerCallable:
+ _require_handler_callable(func, decorator_name="validate_config(...)")
+ setattr(
+ func,
+ VALIDATE_CONFIG_META_ATTR,
+ ValidateConfigMeta(
+ model=model,
+ schema=dict(schema) if isinstance(schema, dict) else None,
+ ),
+ )
+ return func
+
+ return decorator
+
+
+def on_provider_change(
+ *,
+ provider_types: list[str] | tuple[str, ...] | None = None,
+) -> Callable[[HandlerCallable], HandlerCallable]:
+ normalized = [
+ str(item).strip().lower()
+ for item in (provider_types or [])
+ if str(item).strip()
+ ]
+
+ def decorator(func: HandlerCallable) -> HandlerCallable:
+ _require_handler_callable(func, decorator_name="on_provider_change(...)")
+ setattr(
+ func,
+ PROVIDER_CHANGE_META_ATTR,
+ ProviderChangeMeta(provider_types=normalized),
+ )
+ return func
+
+ return decorator
+
+
+def background_task(
+ *,
+ description: str = "",
+ auto_start: bool = True,
+ on_error: Literal["log", "restart"] = "log",
+) -> Callable[[HandlerCallable], HandlerCallable]:
+ if on_error not in {"log", "restart"}:
+ raise ValueError("background_task on_error must be 'log' or 'restart'")
+
+ def decorator(func: HandlerCallable) -> HandlerCallable:
+ _require_handler_callable(func, decorator_name="background_task(...)")
+ setattr(
+ func,
+ BACKGROUND_TASK_META_ATTR,
+ BackgroundTaskMeta(
+ description=str(description),
+ auto_start=bool(auto_start),
+ on_error=on_error,
+ ),
+ )
+ return func
+
+ return decorator
+
+
+def register_skill(
+ *,
+ name: str,
+ path: str,
+ description: str = "",
+):
+ normalized_name = str(name).strip()
+ normalized_path = str(path).strip()
+ if not normalized_name:
+ raise ValueError("register_skill(...) requires a non-empty name")
+ if not normalized_path:
+ raise ValueError("register_skill(...) requires a non-empty path")
+
+ meta = SkillMeta(
+ name=normalized_name,
+ path=normalized_path,
+ description=str(description),
+ )
+
+ def decorator(target):
+ _append_list_meta(target, SKILL_META_ATTR, meta)
+ return target
+
+ return decorator
+
+
+def require_admin(func: _HandlerT) -> _HandlerT:
+ """标记 handler 需要管理员权限。
+
+ 当用户不是管理员时,handler 将不会被调用。
+
+ Args:
+ func: 要标记的方法
+
+ Returns:
+ 标记后的方法
+
+ Example:
+ @on_command("admin")
+ @require_admin
+ async def admin_only(self, event: MessageEvent, ctx: Context):
+ await event.reply("管理员命令执行成功")
+ """
+ _require_handler_callable(func, decorator_name="require_admin")
+ meta = _get_or_create_meta(func)
+ _set_required_role(meta, "admin")
+ return func
+
+
+def admin_only(func: _HandlerT) -> _HandlerT:
+ return require_admin(func)
+
+
+def require_permission(
+ role: Literal["member", "admin"],
+) -> Callable[[_HandlerT], _HandlerT]:
+ normalized_role = str(role).strip().lower()
+ if normalized_role not in {"member", "admin"}:
+ raise ValueError("require_permission(...) 只支持 'member' 或 'admin'")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="require_permission(...)")
+ meta = _get_or_create_meta(func)
+ _set_required_role(
+ meta,
+ cast(Literal["member", "admin"], normalized_role),
+ )
+ return func
+
+ return decorator
+
+
+def platforms(*names: str) -> Callable[[_HandlerT], _HandlerT]:
+ normalized_names = [str(name).strip() for name in names if str(name).strip()]
+ if not normalized_names:
+ raise ValueError("platforms(...) requires at least one non-empty platform name")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="platforms(...)")
+ meta = _get_or_create_meta(func)
+ _set_platform_filter(meta, normalized_names, source="decorator.platforms")
+ return func
+
+ return decorator
+
+
+def message_types(*types: str) -> Callable[[_HandlerT], _HandlerT]:
+ normalized_types = [str(item).strip() for item in types if str(item).strip()]
+ if not normalized_types:
+ raise ValueError("message_types(...) requires at least one non-empty type")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="message_types(...)")
+ meta = _get_or_create_meta(func)
+ _set_message_type_filter(
+ meta,
+ normalized_types,
+ source="decorator.message_types",
+ )
+ return func
+
+ return decorator
+
+
+def group_only() -> Callable[[_HandlerT], _HandlerT]:
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="group_only()")
+ meta = _get_or_create_meta(func)
+ _set_message_type_filter(meta, ["group"], source="decorator.group_only")
+ return func
+
+ return decorator
+
+
+def private_only() -> Callable[[_HandlerT], _HandlerT]:
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="private_only()")
+ meta = _get_or_create_meta(func)
+ _set_message_type_filter(meta, ["private"], source="decorator.private_only")
+ return func
+
+ return decorator
+
+
+def priority(value: int) -> Callable[[_HandlerT], _HandlerT]:
+ if isinstance(value, bool) or not isinstance(value, int):
+ raise ValueError("priority(...) requires an integer")
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="priority(...)")
+ meta = _get_or_create_meta(func)
+ meta.priority = value
+ return func
+
+ return decorator
+
+
+def rate_limit(
+ limit: int,
+ window: float,
+ *,
+ scope: LimiterScope = "session",
+ behavior: LimiterBehavior = "hint",
+ message: str | None = None,
+) -> Callable[[_HandlerT], _HandlerT]:
+ _validate_limiter_args(
+ kind="rate_limit",
+ limit=limit,
+ window=window,
+ scope=scope,
+ behavior=behavior,
+ )
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="rate_limit(...)")
+ return _set_limiter(
+ func,
+ LimiterMeta(
+ kind="rate_limit",
+ limit=int(limit),
+ window=float(window),
+ scope=scope,
+ behavior=behavior,
+ message=message,
+ ),
+ )
+
+ return decorator
+
+
+def cooldown(
+ seconds: float,
+ *,
+ scope: LimiterScope = "session",
+ behavior: LimiterBehavior = "hint",
+ message: str | None = None,
+) -> Callable[[_HandlerT], _HandlerT]:
+ _validate_limiter_args(
+ kind="cooldown",
+ limit=1,
+ window=seconds,
+ scope=scope,
+ behavior=behavior,
+ )
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="cooldown(...)")
+ return _set_limiter(
+ func,
+ LimiterMeta(
+ kind="cooldown",
+ limit=1,
+ window=float(seconds),
+ scope=scope,
+ behavior=behavior,
+ message=message,
+ ),
+ )
+
+ return decorator
+
+
+def conversation_command(
+ command: str | typing.Sequence[str],
+ *,
+ aliases: list[str] | None = None,
+ description: str | None = None,
+ group: str | typing.Sequence[str] | None = None,
+ group_help: str | None = None,
+ timeout: int = 60,
+ mode: ConversationMode = "replace",
+ busy_message: str | None = None,
+ grace_period: float = 1.0,
+) -> Callable[[_HandlerT], _HandlerT]:
+ """注册带会话生命周期的命令处理方法。
+
+ 在 ``on_command`` 基础上附加会话元数据,支持超时、并发策略和宽限期控制。
+
+ Args:
+ command: 命令名称或序列(首项为正式名,其余视为别名)
+ aliases: 额外别名列表
+ description: 命令描述
+ group: 指令组路径,例如 ``"admin"`` 或 ``["admin", "user"]``
+ group_help: 指令组描述,用于帮助信息
+ timeout: 会话超时时间(秒),必须为正整数
+ mode: 会话冲突时的行为:
+ - ``"replace"``: 替换当前会话
+ - ``"reject"``: 拒绝新请求
+ busy_message: 拒绝新请求时的提示消息
+ grace_period: 宽限期(秒),用于会话生命周期处理
+
+ Returns:
+ 装饰器函数
+
+ Raises:
+ ValueError: mode 不合法、timeout 非正整数或 grace_period 非正数
+
+ Example:
+ @conversation_command("chat", timeout=120, mode="reject", busy_message="请稍后再试")
+ async def chat(self, event: MessageEvent, ctx: Context):
+ await event.reply("开始对话...")
+ """
+ if mode not in {"replace", "reject"}:
+ raise ValueError("conversation_command mode must be 'replace' or 'reject'")
+ # bool 是 int 子类,需单独排除
+ if isinstance(timeout, bool) or int(timeout) <= 0:
+ raise ValueError("conversation_command timeout must be a positive integer")
+ if float(grace_period) <= 0:
+ raise ValueError("conversation_command grace_period must be positive")
+
+ command_decorator = on_command(
+ command,
+ aliases=aliases,
+ description=description,
+ group=group,
+ group_help=group_help,
+ )
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ _require_handler_callable(func, decorator_name="conversation_command(...)")
+ decorated = command_decorator(func)
+ meta = _get_or_create_meta(decorated)
+ meta.conversation = ConversationMeta(
+ timeout=int(timeout),
+ mode=mode,
+ busy_message=busy_message,
+ grace_period=float(grace_period),
+ )
+ return decorated
+
+ return decorator
+
+
+def provide_capability(
+ name: str,
+ *,
+ description: str,
+ input_schema: dict[str, Any] | None = None,
+ output_schema: dict[str, Any] | None = None,
+ input_model: type[BaseModel] | None = None,
+ output_model: type[BaseModel] | None = None,
+ supports_stream: bool = False,
+ cancelable: bool = False,
+) -> Callable[[HandlerCallable], HandlerCallable]:
+ """声明插件对外暴露的 capability。
+
+ 允许其他插件或 Core 通过 capability 名称调用此方法。
+ 支持使用 JSON Schema 或 pydantic 模型定义输入输出。
+
+ Args:
+ name: capability 名称(不能使用保留命名空间,且运行时必须以当前 plugin_id 为前缀)
+ description: 能力描述
+ input_schema: 输入 JSON Schema
+ output_schema: 输出 JSON Schema
+ input_model: 输入 pydantic 模型(与 input_schema 二选一)
+ output_model: 输出 pydantic 模型(与 output_schema 二选一)
+ supports_stream: 是否支持流式输出
+ cancelable: 是否可取消
+
+ Returns:
+ 装饰器函数
+
+ Raises:
+ ValueError: 如果使用保留命名空间,或同时提供 schema 和 model
+
+ Example:
+ @provide_capability(
+ "my_plugin.calculate",
+ description="执行计算",
+ input_model=CalculateInput,
+ output_model=CalculateOutput,
+ )
+ async def calculate(self, payload: dict, ctx: Context):
+ return {"result": payload["x"] * 2}
+ """
+
+ normalized_name = str(name).strip()
+ if not normalized_name:
+ raise ValueError("provide_capability(...) requires a non-empty name")
+ normalized_description = _normalize_description(description)
+ if normalized_description is None:
+ raise ValueError("provide_capability(...) requires a non-empty description")
+ if input_schema is not None and not isinstance(input_schema, dict):
+ raise TypeError("input_schema must be a dict")
+ if output_schema is not None and not isinstance(output_schema, dict):
+ raise TypeError("output_schema must be a dict")
+
+ def decorator(func: HandlerCallable) -> HandlerCallable:
+ _require_handler_callable(func, decorator_name="provide_capability(...)")
+ if normalized_name.startswith(RESERVED_CAPABILITY_PREFIXES):
+ raise ValueError(
+ f"保留 capability 命名空间不能用于插件导出:{normalized_name}"
+ )
+ if input_schema is not None and input_model is not None:
+ raise ValueError("input_schema 和 input_model 不能同时提供")
+ if output_schema is not None and output_model is not None:
+ raise ValueError("output_schema 和 output_model 不能同时提供")
+ descriptor = CapabilityDescriptor(
+ name=normalized_name,
+ description=normalized_description,
+ input_schema=(
+ input_schema
+ if input_schema is not None
+ else _model_to_schema(input_model, label="input_model")
+ ),
+ output_schema=(
+ output_schema
+ if output_schema is not None
+ else _model_to_schema(output_model, label="output_model")
+ ),
+ supports_stream=supports_stream,
+ cancelable=cancelable,
+ )
+ setattr(func, CAPABILITY_META_ATTR, CapabilityMeta(descriptor=descriptor))
+ return func
+
+ return decorator
+
+
+def _annotation_to_schema(annotation: Any) -> dict[str, Any]:
+ normalized, _is_optional = unwrap_optional(annotation)
+ origin = typing.get_origin(normalized)
+ if normalized is str:
+ return {"type": "string"}
+ if normalized is int:
+ return {"type": "integer"}
+ if normalized is float:
+ return {"type": "number"}
+ if normalized is bool:
+ return {"type": "boolean"}
+ if normalized is dict or origin is dict:
+ return {"type": "object"}
+ if normalized is list or origin is list:
+ args = typing.get_args(normalized)
+ item_schema = _annotation_to_schema(args[0]) if args else {}
+ return {"type": "array", "items": item_schema}
+ return {"type": "string"}
+
+
+def _callable_parameters_schema(func: HandlerCallable) -> dict[str, Any]:
+ signature = inspect.signature(func)
+ type_hints: dict[str, Any] = {}
+ try:
+ type_hints = typing.get_type_hints(func)
+ except Exception:
+ type_hints = {}
+
+ properties: dict[str, Any] = {}
+ required: list[str] = []
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+ if parameter.name == "self":
+ continue
+ annotation = type_hints.get(parameter.name)
+ normalized, _is_optional = unwrap_optional(annotation)
+ if parameter.name in {"event", "ctx", "context"}:
+ continue
+ properties[parameter.name] = _annotation_to_schema(normalized)
+ if parameter.default is inspect.Parameter.empty and not _is_optional:
+ required.append(parameter.name)
+ schema: dict[str, Any] = {"type": "object", "properties": properties}
+ if required:
+ schema["required"] = required
+ return schema
+
+
+def register_llm_tool(
+ name: str | None = None,
+ *,
+ description: str | None = None,
+ parameters_schema: dict[str, Any] | None = None,
+ active: bool = True,
+) -> Callable[[HandlerCallable], HandlerCallable]:
+ if parameters_schema is not None and not isinstance(parameters_schema, dict):
+ raise TypeError("register_llm_tool parameters_schema must be a dict")
+ if not isinstance(active, bool):
+ raise TypeError("register_llm_tool active must be a bool")
+
+ def decorator(func: HandlerCallable) -> HandlerCallable:
+ _require_handler_callable(func, decorator_name="register_llm_tool(...)")
+ tool_name = str(name or func.__name__).strip()
+ if not tool_name:
+ raise ValueError("LLM tool name must not be empty")
+ setattr(
+ func,
+ LLM_TOOL_META_ATTR,
+ LLMToolMeta(
+ spec=LLMToolSpec.create(
+ name=tool_name,
+ description=description
+ or (inspect.getdoc(func) or "").splitlines()[0]
+ if inspect.getdoc(func)
+ else "",
+ parameters_schema=parameters_schema
+ or _callable_parameters_schema(func),
+ handler_ref=tool_name,
+ active=active,
+ )
+ ),
+ )
+ return func
+
+ return decorator
+
+
+def register_agent(
+ name: str,
+ *,
+ description: str = "",
+ tool_names: list[str] | None = None,
+) -> Callable[[type[BaseAgentRunner]], type[BaseAgentRunner]]:
+ if tool_names is not None and not isinstance(tool_names, list):
+ raise TypeError("register_agent tool_names must be a list of strings")
+ normalized_name = str(name).strip()
+ if not normalized_name:
+ raise ValueError("register_agent(...) requires a non-empty name")
+ normalized_tool_names = [
+ str(tool_name).strip()
+ for tool_name in dict.fromkeys(tool_names or [])
+ if str(tool_name).strip()
+ ]
+
+ def decorator(cls: type[BaseAgentRunner]) -> type[BaseAgentRunner]:
+ if not inspect.isclass(cls) or not issubclass(cls, BaseAgentRunner):
+ raise TypeError("@register_agent() 只接受 BaseAgentRunner 子类")
+ setattr(
+ cls,
+ AGENT_META_ATTR,
+ AgentMeta(
+ spec=AgentSpec(
+ name=normalized_name,
+ description=description,
+ tool_names=normalized_tool_names,
+ runner_class=f"{cls.__module__}.{cls.__qualname__}",
+ )
+ ),
+ )
+ return cls
+
+ return decorator
diff --git a/astrbot-sdk/src/astrbot_sdk/errors.py b/astrbot-sdk/src/astrbot_sdk/errors.py
new file mode 100644
index 0000000000..c33244f387
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/errors.py
@@ -0,0 +1,311 @@
+"""跨运行时边界传递的统一错误模型。
+
+AstrBotError 是 SDK 中所有可预期错误的标准格式,
+支持跨进程传递(通过 to_payload/from_payload 序列化)。
+
+错误处理流程:
+ 1. 运行时抛出 AstrBotError 子类或实例
+ 2. 错误被捕获并序列化为 payload
+ 3. 跨进程传输后反序列化
+ 4. 在 on_error 钩子中统一处理
+
+Example:
+ # 抛出错误
+ raise AstrBotError.invalid_input("参数不能为空")
+
+ # 捕获并处理
+ try:
+ await some_operation()
+ except AstrBotError as e:
+ if e.retryable:
+ # 可重试的错误
+ await retry()
+ else:
+ # 不可重试的错误
+ await event.reply(e.hint or e.message)
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+
+class ErrorCodes:
+ """AstrBot SDK 的稳定错误码常量。
+
+ 这些错误码在协议层稳定,不应随意更改。
+ 新增错误码应放在对应分类的末尾。
+
+ 分类:
+ - 不可重试错误(retryable=False):配置错误、权限错误等
+ - 可重试错误(retryable=True):网络超时、临时故障等
+ """
+
+ UNKNOWN_ERROR = "unknown_error"
+
+ # 不可重试错误 - 配置或使用问题
+ LLM_NOT_CONFIGURED = "llm_not_configured"
+ CAPABILITY_NOT_FOUND = "capability_not_found"
+ PERMISSION_DENIED = "permission_denied"
+ LLM_ERROR = "llm_error"
+ INVALID_INPUT = "invalid_input"
+ CANCELLED = "cancelled"
+ PROTOCOL_VERSION_MISMATCH = "protocol_version_mismatch"
+ PROTOCOL_ERROR = "protocol_error"
+ INTERNAL_ERROR = "internal_error"
+ RATE_LIMITED = "rate_limited"
+ COOLDOWN_ACTIVE = "cooldown_active"
+
+ # 可重试错误 - 临时故障
+ CAPABILITY_TIMEOUT = "capability_timeout"
+ NETWORK_ERROR = "network_error"
+ LLM_TEMPORARY_ERROR = "llm_temporary_error"
+
+
+@dataclass(slots=True)
+class AstrBotError(Exception):
+ """AstrBot SDK 的标准错误类型。
+
+ 所有可预期的错误都应使用此类或其工厂方法创建。
+ 支持跨进程传递,包含用户友好的提示信息。
+
+ Attributes:
+ code: 错误码,来自 ErrorCodes 常量
+ message: 错误消息,面向开发者
+ hint: 用户提示,面向终端用户
+ retryable: 是否可重试
+
+ Example:
+ # 使用工厂方法创建错误
+ raise AstrBotError.invalid_input("参数格式错误", hint="请使用 JSON 格式")
+
+ # 检查错误类型
+ try:
+ await operation()
+ except AstrBotError as e:
+ if e.code == ErrorCodes.CAPABILITY_NOT_FOUND:
+ logger.error(f"能力不存在: {e.message}")
+ """
+
+ code: str
+ message: str
+ hint: str = ""
+ retryable: bool = False
+ docs_url: str = ""
+ details: dict[str, Any] | None = None
+
+ def __str__(self) -> str:
+ return self.message
+
+ @classmethod
+ def cancelled(cls, message: str = "调用被取消") -> AstrBotError:
+ """创建取消错误。
+
+ Args:
+ message: 错误消息
+
+ Returns:
+ AstrBotError 实例
+ """
+ return cls(
+ code=ErrorCodes.CANCELLED,
+ message=message,
+ hint="",
+ retryable=False,
+ )
+
+ @classmethod
+ def capability_not_found(cls, name: str) -> AstrBotError:
+ """创建能力未找到错误。
+
+ Args:
+ name: 未找到的能力名称
+
+ Returns:
+ AstrBotError 实例
+ """
+ return cls(
+ code=ErrorCodes.CAPABILITY_NOT_FOUND,
+ message=f"未找到能力:{name}",
+ hint="请确认 AstrBot Core 是否已注册该 capability",
+ retryable=False,
+ )
+
+ @classmethod
+ def invalid_input(
+ cls,
+ message: str,
+ *,
+ hint: str = "请检查调用参数",
+ docs_url: str = "",
+ details: dict[str, Any] | None = None,
+ ) -> AstrBotError:
+ """创建输入无效错误。
+
+ Args:
+ message: 详细错误消息
+ hint: 用户提示
+
+ Returns:
+ AstrBotError 实例
+ """
+ return cls(
+ code=ErrorCodes.INVALID_INPUT,
+ message=message,
+ hint=hint,
+ retryable=False,
+ docs_url=docs_url,
+ details=details,
+ )
+
+ @classmethod
+ def protocol_version_mismatch(cls, message: str) -> AstrBotError:
+ """创建协议版本不匹配错误。
+
+ Args:
+ message: 详细错误消息
+
+ Returns:
+ AstrBotError 实例
+ """
+ return cls(
+ code=ErrorCodes.PROTOCOL_VERSION_MISMATCH,
+ message=message,
+ hint="请升级 astrbot_sdk 至最新版本",
+ retryable=False,
+ )
+
+ @classmethod
+ def protocol_error(cls, message: str) -> AstrBotError:
+ """创建协议错误。
+
+ Args:
+ message: 详细错误消息
+
+ Returns:
+ AstrBotError 实例
+ """
+ return cls(
+ code=ErrorCodes.PROTOCOL_ERROR,
+ message=message,
+ hint="请检查通信双方的协议实现",
+ retryable=False,
+ )
+
+ @classmethod
+ def internal_error(
+ cls,
+ message: str,
+ *,
+ hint: str = "请联系插件作者",
+ docs_url: str = "",
+ details: dict[str, Any] | None = None,
+ ) -> AstrBotError:
+ """创建内部错误。
+
+ Args:
+ message: 详细错误消息
+ hint: 用户提示
+
+ Returns:
+ AstrBotError 实例
+ """
+ return cls(
+ code=ErrorCodes.INTERNAL_ERROR,
+ message=message,
+ hint=hint,
+ retryable=False,
+ docs_url=docs_url,
+ details=details,
+ )
+
+ @classmethod
+ def network_error(
+ cls,
+ message: str,
+ *,
+ hint: str = "网络请求失败,请稍后重试",
+ docs_url: str = "",
+ details: dict[str, Any] | None = None,
+ ) -> AstrBotError:
+ return cls(
+ code=ErrorCodes.NETWORK_ERROR,
+ message=message,
+ hint=hint,
+ retryable=True,
+ docs_url=docs_url,
+ details=details,
+ )
+
+ @classmethod
+ def rate_limited(
+ cls,
+ *,
+ hint: str = "操作过于频繁,请稍后再试。",
+ details: dict[str, Any] | None = None,
+ ) -> AstrBotError:
+ return cls(
+ code=ErrorCodes.RATE_LIMITED,
+ message="handler invocation is rate limited",
+ hint=hint,
+ retryable=False,
+ details=details,
+ )
+
+ @classmethod
+ def cooldown_active(
+ cls,
+ *,
+ hint: str,
+ details: dict[str, Any] | None = None,
+ ) -> AstrBotError:
+ return cls(
+ code=ErrorCodes.COOLDOWN_ACTIVE,
+ message="handler cooldown is active",
+ hint=hint,
+ retryable=False,
+ details=details,
+ )
+
+ def to_payload(self) -> dict[str, object]:
+ """序列化为可传输的字典格式。
+
+ 用于跨进程传递错误信息。
+
+ Returns:
+ 包含错误信息的字典
+ """
+ return {
+ "code": self.code,
+ "message": self.message,
+ "hint": self.hint,
+ "retryable": self.retryable,
+ "docs_url": self.docs_url,
+ "details": dict(self.details) if isinstance(self.details, dict) else None,
+ }
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, object]) -> AstrBotError:
+ """从字典反序列化错误实例。
+
+ Args:
+ payload: 包含错误信息的字典
+
+ Returns:
+ AstrBotError 实例
+ """
+ details_payload = payload.get("details")
+ details = (
+ {str(key): value for key, value in details_payload.items()}
+ if isinstance(details_payload, dict)
+ else None
+ )
+ return cls(
+ code=str(payload.get("code", ErrorCodes.UNKNOWN_ERROR)),
+ message=str(payload.get("message", "未知错误")),
+ hint=str(payload.get("hint", "")),
+ retryable=bool(payload.get("retryable", False)),
+ docs_url=str(payload.get("docs_url", "")),
+ details=details,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/events.py b/astrbot-sdk/src/astrbot_sdk/events.py
new file mode 100644
index 0000000000..492d000a3d
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/events.py
@@ -0,0 +1,731 @@
+"""astrbot-sdk 原生事件对象。
+
+顶层 ``MessageEvent`` 保持精简,只承载 astrbot-sdk 运行时真正需要的基础能力。
+迁移期扩展事件能力放在独立模块中,而不是继续塞回顶层事件类型。
+
+MessageEvent 是 handler 接收的主要事件类型,封装了:
+ - 消息文本内容
+ - 发送者信息(user_id, group_id)
+ - 平台标识
+ - 回复能力(reply, reply_image, reply_chain)
+"""
+
+from __future__ import annotations
+
+import json
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, TypeVar
+
+from ._message_types import normalize_message_type
+from .message.components import (
+ At,
+ BaseMessageComponent,
+ File,
+ Image,
+ Plain,
+ component_to_payload_sync,
+ payloads_to_components,
+)
+from .message.result import EventResultType, MessageChain, MessageEventResult
+from .protocol.descriptors import SessionRef
+
+if TYPE_CHECKING:
+ from .context import Context
+
+
+@dataclass(slots=True)
+class PlainTextResult:
+ """纯文本结果。
+
+ 用于 handler 返回简单的文本结果。
+ """
+
+ text: str
+
+
+ReplyHandler = Callable[[str], Awaitable[None]]
+_MessageComponentT = TypeVar("_MessageComponentT", bound=BaseMessageComponent)
+
+_JSON_DROP = object()
+
+
+def _coerce_str(value: Any) -> str:
+ if value is None:
+ return ""
+ if isinstance(value, str):
+ return value
+ return str(value)
+
+
+def _coerce_optional_str(value: Any) -> str | None:
+ if value is None:
+ return None
+ text = value if isinstance(value, str) else str(value)
+ return text or None
+
+
+def _json_safe_value(value: Any) -> Any:
+ if value is None or isinstance(value, (str, int, float, bool)):
+ return value
+ if isinstance(value, (list, tuple)):
+ items = []
+ for item in value:
+ normalized = _json_safe_value(item)
+ if normalized is not _JSON_DROP:
+ items.append(normalized)
+ return items
+ if isinstance(value, dict):
+ normalized_dict: dict[str, Any] = {}
+ for key, item in value.items():
+ normalized = _json_safe_value(item)
+ if normalized is not _JSON_DROP:
+ normalized_dict[str(key)] = normalized
+ return normalized_dict
+ model_dump = getattr(value, "model_dump", None)
+ if callable(model_dump):
+ try:
+ return _json_safe_value(model_dump())
+ except Exception:
+ return _JSON_DROP
+ try:
+ json.dumps(value)
+ except (TypeError, ValueError):
+ return _JSON_DROP
+ return value
+
+
+def _json_safe_mapping(value: Any) -> dict[str, Any]:
+ if not isinstance(value, dict):
+ return {}
+ normalized: dict[str, Any] = {}
+ for key, item in value.items():
+ safe_item = _json_safe_value(item)
+ if safe_item is not _JSON_DROP:
+ normalized[str(key)] = safe_item
+ return normalized
+
+
+def _resolve_message_target(
+ payload: dict[str, Any],
+) -> tuple[SessionRef | None, Any, Any]:
+ target_payload = payload.get("target")
+ session_id = payload.get("session_id")
+ platform = payload.get("platform")
+ if not isinstance(target_payload, dict):
+ return None, session_id, platform
+ target = SessionRef.model_validate(target_payload)
+ return target, session_id or target.session, platform or target.platform
+
+
+class MessageEvent:
+ """消息事件对象。
+
+ 封装收到的消息,提供便捷的回复方法。
+ 每个 handler 调用都会创建新的 MessageEvent 实例。
+
+ Attributes:
+ text: 消息文本内容
+ user_id: 发送者用户 ID,缺失时为空字符串
+ group_id: 群组 ID(私聊时为 None)
+ platform: 平台标识(如 "qq", "wechat"),缺失时为空字符串
+ session_id: 会话 ID(通常是 group_id 或 user_id,缺失时为空字符串)
+ raw: 原始消息数据
+
+ Example:
+ @on_command("echo")
+ async def echo(self, event: MessageEvent, ctx: Context):
+ await event.reply(f"你说: {event.text}")
+ """
+
+ text: str
+ user_id: str
+ group_id: str | None
+ platform: str
+ session_id: str
+ self_id: str
+ platform_id: str
+ message_type: str
+ sender_name: str
+ raw: dict[str, Any]
+ _is_admin: bool
+ _stopped: bool
+ _host_extras: dict[str, Any]
+ _host_extras_present: bool
+ _sdk_local_extras: dict[str, Any]
+ _sdk_local_extras_present: bool
+ _sdk_local_extras_dirty: bool
+ _messages: list[BaseMessageComponent]
+ _messages_present: bool
+ _message_outline: str
+ _sent_messages: list[BaseMessageComponent]
+ _sent_messages_present: bool
+ _sent_message_outline: str
+ _sent_message_outline_present: bool
+ _context: Context | None
+ _reply_handler: ReplyHandler | None
+
+ def __init__(
+ self,
+ *,
+ text: str = "",
+ user_id: str | None = None,
+ group_id: str | None = None,
+ platform: str | None = None,
+ session_id: str | None = None,
+ self_id: str | None = None,
+ platform_id: str | None = None,
+ message_type: str | None = None,
+ sender_name: str | None = None,
+ is_admin: bool = False,
+ raw: dict[str, Any] | None = None,
+ context: Context | None = None,
+ reply_handler: ReplyHandler | None = None,
+ ) -> None:
+ """初始化消息事件。
+
+ Args:
+ text: 消息文本
+ user_id: 用户 ID
+ group_id: 群组 ID
+ platform: 平台标识
+ session_id: 会话 ID,None 时自动从 group_id/user_id 推断
+ raw: 原始消息数据
+ context: 运行时上下文
+ reply_handler: 自定义回复处理器
+ """
+ normalized_user_id = _coerce_str(user_id)
+ normalized_group_id = _coerce_optional_str(group_id)
+ normalized_platform = _coerce_str(platform)
+ normalized_session_id = _coerce_str(session_id)
+
+ self.text = text
+ self.user_id = normalized_user_id
+ self.group_id = normalized_group_id
+ self.platform = normalized_platform
+ self.session_id = (
+ normalized_session_id or normalized_group_id or normalized_user_id or ""
+ )
+ self.self_id = _coerce_str(self_id)
+ self.platform_id = _coerce_str(platform_id) or normalized_platform
+ self.message_type = normalize_message_type(
+ message_type,
+ group_id=normalized_group_id,
+ user_id=normalized_user_id,
+ )
+ self.sender_name = _coerce_str(sender_name)
+ self._is_admin = bool(is_admin)
+ self.raw = raw or {}
+ self._stopped = False
+ host_extras = self.raw.get("host_extras")
+ raw_extras = self.raw.get("extras")
+ self._host_extras = _json_safe_mapping(
+ host_extras if isinstance(host_extras, dict) else raw_extras
+ )
+ self._host_extras_present = "host_extras" in self.raw or "extras" in self.raw
+ sdk_local_extras = self.raw.get("sdk_local_extras")
+ self._sdk_local_extras = _json_safe_mapping(sdk_local_extras)
+ self._sdk_local_extras_present = "sdk_local_extras" in self.raw
+ self._sdk_local_extras_dirty = False
+ messages_payload = self.raw.get("messages")
+ self._messages = (
+ payloads_to_components(messages_payload)
+ if isinstance(messages_payload, list)
+ else []
+ )
+ self._messages_present = "messages" in self.raw
+ self._message_outline = str(self.raw.get("message_outline", self.text))
+ sent_messages_payload = self.raw.get("sent_messages")
+ self._sent_messages = (
+ payloads_to_components(sent_messages_payload)
+ if isinstance(sent_messages_payload, list)
+ else []
+ )
+ self._sent_messages_present = "sent_messages" in self.raw
+ self._sent_message_outline = str(self.raw.get("sent_message_outline", ""))
+ self._sent_message_outline_present = "sent_message_outline" in self.raw
+ self._context = context
+ self._reply_handler = reply_handler
+ if self._reply_handler is None and context is not None:
+ self._reply_handler = lambda text: context.platform.send(
+ self.session_ref or self.session_id,
+ text,
+ )
+
+ def _require_runtime_context(self, action: str) -> Context:
+ """获取运行时上下文,不存在则抛出异常。"""
+ if self._context is None:
+ raise RuntimeError(f"MessageEvent 未绑定运行时上下文,无法 {action}")
+ return self._context
+
+ def _reply_target(self) -> SessionRef | str:
+ """获取回复目标。"""
+ return self.session_ref or self.session_id
+
+ @classmethod
+ def from_payload(
+ cls,
+ payload: dict[str, Any],
+ *,
+ context: Context | None = None,
+ reply_handler: ReplyHandler | None = None,
+ ) -> MessageEvent:
+ """从协议载荷创建事件实例。
+
+ Args:
+ payload: 协议层传递的消息数据
+ context: 运行时上下文
+ reply_handler: 自定义回复处理器
+
+ Returns:
+ 新的 MessageEvent 实例
+ """
+ target, session_id, platform = _resolve_message_target(payload)
+ return cls(
+ text=str(payload.get("text", "")),
+ user_id=payload.get("user_id"),
+ group_id=payload.get("group_id"),
+ platform=platform,
+ session_id=session_id,
+ self_id=payload.get("self_id"),
+ platform_id=payload.get("platform_id"),
+ message_type=payload.get("message_type"),
+ sender_name=payload.get("sender_name"),
+ is_admin=bool(payload.get("is_admin", False)),
+ raw=payload,
+ context=context,
+ reply_handler=reply_handler,
+ )
+
+ @staticmethod
+ def session_key_from_payload(payload: dict[str, Any]) -> str:
+ target, session_id, _ = _resolve_message_target(payload)
+ if session_id:
+ return str(session_id)
+ if target is not None and target.conversation_id:
+ return str(target.conversation_id)
+ return ""
+
+ def to_payload(self) -> dict[str, Any]:
+ """转换为协议载荷格式。
+
+ Returns:
+ 可序列化的字典
+ """
+ payload = dict(self.raw)
+ payload.update(
+ {
+ "text": self.text,
+ "user_id": self.user_id,
+ "group_id": self.group_id,
+ "platform": self.platform,
+ "session_id": self.session_id,
+ "self_id": self.self_id,
+ "platform_id": self.platform_id,
+ "message_type": self.message_type,
+ "sender_name": self.sender_name,
+ "is_admin": self._is_admin,
+ }
+ )
+ if self.session_ref is not None:
+ payload["target"] = self.session_ref.to_payload()
+ merged_extras = dict(self._host_extras)
+ merged_extras.update(self._sdk_local_extras_payload())
+ if merged_extras:
+ payload["extras"] = merged_extras
+ elif self._host_extras_present:
+ payload["extras"] = {}
+ else:
+ payload.pop("extras", None)
+ if self._host_extras or self._host_extras_present:
+ payload["host_extras"] = dict(self._host_extras)
+ else:
+ payload.pop("host_extras", None)
+ sdk_local_extras = self._sdk_local_extras_payload()
+ if sdk_local_extras or self._should_serialize_sdk_local_extras():
+ payload["sdk_local_extras"] = sdk_local_extras
+ else:
+ payload.pop("sdk_local_extras", None)
+ if self._messages or self._messages_present:
+ payload["messages"] = [
+ component_to_payload_sync(component) for component in self._messages
+ ]
+ else:
+ payload.pop("messages", None)
+ payload["message_outline"] = self._message_outline
+ if self._sent_messages or self._sent_messages_present:
+ payload["sent_messages"] = [
+ component_to_payload_sync(component)
+ for component in self._sent_messages
+ ]
+ else:
+ payload.pop("sent_messages", None)
+ if self._sent_message_outline or self._sent_message_outline_present:
+ payload["sent_message_outline"] = self._sent_message_outline
+ else:
+ payload.pop("sent_message_outline", None)
+ return payload
+
+ @property
+ def session_ref(self) -> SessionRef | None:
+ """获取会话引用对象。
+
+ Returns:
+ SessionRef 实例,如果没有有效的 session_id 则返回 None
+ """
+ if not self.session_id:
+ return None
+ return SessionRef(
+ conversation_id=self.session_id,
+ platform=self.platform,
+ raw=self.raw or None,
+ )
+
+ @property
+ def target(self) -> SessionRef | None:
+ """session_ref 的别名。"""
+ return self.session_ref
+
+ @property
+ def unified_msg_origin(self) -> str:
+ """Unified message origin string."""
+ return self.session_id
+
+ def is_private_chat(self) -> bool:
+ """Whether the current event belongs to a private chat."""
+ if self.message_type:
+ return self.message_type == "private"
+ return not bool(self.group_id)
+
+ def is_group_chat(self) -> bool:
+ if self.message_type:
+ return self.message_type == "group"
+ return bool(self.group_id)
+
+ def get_platform_id(self) -> str:
+ """Get the platform instance identifier."""
+ return self.platform_id
+
+ def get_message_type(self) -> str:
+ """Get the normalized message type."""
+ return self.message_type
+
+ def get_session_id(self) -> str:
+ """Get the current session identifier."""
+ return self.session_id
+
+ def is_admin(self) -> bool:
+ """Whether the sender has admin permission."""
+ return self._is_admin
+
+ def has_admin_permission(self) -> bool:
+ """Return whether the sender currently has administrator permission."""
+ return self.is_admin()
+
+ def get_messages(self) -> list[BaseMessageComponent]:
+ """Return SDK message components for the current event."""
+ return list(self._messages)
+
+ def get_sent_messages(self) -> list[BaseMessageComponent]:
+ """Return outbound SDK message components for after-send events."""
+ return list(self._sent_messages)
+
+ def has_component(self, type_: type[BaseMessageComponent]) -> bool:
+ return any(isinstance(component, type_) for component in self._messages)
+
+ def get_components(
+ self,
+ type_: type[_MessageComponentT],
+ ) -> list[_MessageComponentT]:
+ return [
+ component for component in self._messages if isinstance(component, type_)
+ ]
+
+ def get_images(self) -> list[Image]:
+ return self.get_components(Image)
+
+ def get_files(self) -> list[File]:
+ return self.get_components(File)
+
+ def extract_plain_text(self) -> str:
+ return " ".join(
+ component.text
+ for component in self._messages
+ if isinstance(component, Plain)
+ )
+
+ def get_at_users(self) -> list[str]:
+ return [
+ str(component.qq)
+ for component in self._messages
+ if isinstance(component, At) and str(component.qq).lower() != "all"
+ ]
+
+ def get_message_outline(self) -> str:
+ """Return the normalized message outline."""
+ return self._message_outline
+
+ def get_sent_message_outline(self) -> str:
+ """Return the outbound message outline for after-send events."""
+ return self._sent_message_outline
+
+ async def get_group(self) -> dict[str, Any] | None:
+ """Get current-group metadata for the bound message request."""
+ context = self._require_runtime_context("get_group")
+ output = await context._proxy.call( # noqa: SLF001
+ "platform.get_group",
+ {
+ "session": self.session_id,
+ "target": (
+ self.session_ref.to_payload()
+ if self.session_ref is not None
+ else None
+ ),
+ },
+ )
+ payload = output.get("group")
+ if not isinstance(payload, dict):
+ return None
+ return dict(payload)
+
+ def set_extra(self, key: str, value: Any) -> None:
+ """Store SDK-local transient event data.
+
+ Values written here are immediately available through ``get_extra()``
+ inside the current handler invocation. If you expect the value to remain
+ available after the event crosses the SDK bridge into a later handler or
+ lifecycle event, store only JSON-serializable data.
+
+ Recommended approach:
+ - Keep values to ``dict`` / ``list`` / ``str`` / ``int`` / ``float`` /
+ ``bool`` / ``None`` and nested combinations of those types.
+ - Convert framework objects into payloads before storing them. For
+ message components, use ``component_to_payload_sync()`` before
+ ``set_extra()`` and ``payload_to_component()`` after ``get_extra()``.
+
+ Non-serializable values may still be readable in the current handler,
+ but they will be dropped when the SDK bridge serializes extras for a
+ later event.
+ """
+ self._sdk_local_extras[key] = value
+ self._sdk_local_extras_dirty = True
+
+ def get_extra(self, key: str | None = None, default: Any = None) -> Any:
+ """Read SDK-local transient event data.
+
+ Extras returned here merge host-provided extras with values previously
+ written via ``set_extra()``. If a key was written with a
+ non-serializable value, it may disappear after the event is serialized
+ across the SDK bridge. In that case, persist a JSON-safe payload
+ instead of the original object.
+ """
+ extras = dict(self._host_extras)
+ extras.update(self._sdk_local_extras)
+ if key is None:
+ return extras
+ return extras.get(key, default)
+
+ def clear_extra(self) -> None:
+ """Clear SDK-local transient event data."""
+ self._sdk_local_extras.clear()
+ self._sdk_local_extras_dirty = True
+
+ def _sdk_local_extras_payload(self) -> dict[str, Any]:
+ return _json_safe_mapping(self._sdk_local_extras)
+
+ def _should_serialize_sdk_local_extras(self) -> bool:
+ return (
+ self._sdk_local_extras_present
+ or self._sdk_local_extras_dirty
+ or bool(self._sdk_local_extras)
+ )
+
+ def stop_event(self) -> None:
+ """Mark the SDK-local event as stopped."""
+ self._stopped = True
+
+ def continue_event(self) -> None:
+ """Clear the SDK-local stop flag."""
+ self._stopped = False
+
+ def is_stopped(self) -> bool:
+ """Return whether the SDK-local event is stopped."""
+ return self._stopped
+
+ async def reply(self, text: str) -> None:
+ """回复文本消息。
+
+ Args:
+ text: 要回复的文本内容
+
+ Raises:
+ RuntimeError: 如果未绑定 reply handler
+ """
+ if self._reply_handler is None:
+ raise RuntimeError("MessageEvent 未绑定 reply handler,无法 reply")
+ await self._reply_handler(text)
+
+ async def reply_image(self, image_url: str) -> None:
+ """回复图片消息。
+
+ Args:
+ image_url: 图片 URL
+
+ Raises:
+ RuntimeError: 如果未绑定运行时上下文
+ """
+ context = self._require_runtime_context("reply_image")
+ await context.platform.send_image(self._reply_target(), image_url)
+
+ async def reply_chain(
+ self,
+ chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]],
+ ) -> None:
+ """回复消息链(多类型消息组合)。
+
+ Args:
+ chain: 消息链组件列表
+
+ Raises:
+ RuntimeError: 如果未绑定运行时上下文
+ """
+ context = self._require_runtime_context("reply_chain")
+ await context.platform.send_chain(self._reply_target(), chain)
+
+ async def react(self, emoji: str) -> bool:
+ """Send a platform reaction when supported."""
+ context = self._require_runtime_context("react")
+ output = await context._proxy.call( # noqa: SLF001
+ "system.event.react",
+ {
+ "target": (
+ self.session_ref.to_payload()
+ if self.session_ref is not None
+ else None
+ ),
+ "emoji": emoji,
+ },
+ )
+ return bool(output.get("supported", False))
+
+ async def send_typing(self) -> bool:
+ """Emit typing state when the host platform supports it."""
+ context = self._require_runtime_context("send_typing")
+ output = await context._proxy.call( # noqa: SLF001
+ "system.event.send_typing",
+ {
+ "target": (
+ self.session_ref.to_payload()
+ if self.session_ref is not None
+ else None
+ ),
+ },
+ )
+ return bool(output.get("supported", False))
+
+ async def send_streaming(
+ self,
+ generator,
+ use_fallback: bool = False,
+ ) -> bool:
+ """Replay normalized chunks through the host streaming pathway."""
+ context = self._require_runtime_context("send_streaming")
+ output = await context._proxy.call( # noqa: SLF001
+ "system.event.send_streaming",
+ {
+ "target": (
+ self.session_ref.to_payload()
+ if self.session_ref is not None
+ else None
+ ),
+ "use_fallback": use_fallback,
+ },
+ )
+ if not bool(output.get("supported", False)):
+ return False
+
+ stream_id = str(output.get("stream_id", ""))
+ if not stream_id:
+ return False
+
+ try:
+ async for item in generator:
+ if isinstance(item, str):
+ chain = MessageChain([Plain(item, convert=False)])
+ else:
+ chain = self._coerce_chain_or_raise(item)
+ await context._proxy.call( # noqa: SLF001
+ "system.event.send_streaming_chunk",
+ {
+ "stream_id": stream_id,
+ "chain": await chain.to_payload_async(),
+ },
+ )
+ finally:
+ output = await context._proxy.call( # noqa: SLF001
+ "system.event.send_streaming_close",
+ {"stream_id": stream_id},
+ )
+ return bool(output.get("supported", False))
+
+ def bind_reply_handler(self, reply_handler: ReplyHandler) -> None:
+ """绑定自定义回复处理器。
+
+ Args:
+ reply_handler: 回复处理函数
+ """
+ self._reply_handler = reply_handler
+
+ def plain_result(self, text: str) -> PlainTextResult:
+ """创建纯文本结果。
+
+ Args:
+ text: 结果文本
+
+ Returns:
+ PlainTextResult 实例
+ """
+ return PlainTextResult(text=text)
+
+ def make_result(self) -> MessageEventResult:
+ """Create an empty SDK-local result wrapper."""
+ return MessageEventResult(type=EventResultType.EMPTY)
+
+ def image_result(self, url_or_path: str) -> MessageEventResult:
+ """Create a chain result that contains one image component."""
+ if url_or_path.startswith(("http://", "https://")):
+ image = Image.fromURL(url_or_path)
+ elif url_or_path.startswith("base64://"):
+ image = Image.fromBase64(url_or_path.removeprefix("base64://"))
+ else:
+ image = Image.fromFileSystem(url_or_path)
+ return MessageEventResult(
+ type=EventResultType.CHAIN,
+ chain=MessageChain([image]),
+ )
+
+ def chain_result(
+ self,
+ chain: MessageChain | list[BaseMessageComponent],
+ ) -> MessageEventResult:
+ """Create a chain result from SDK components."""
+ normalized = (
+ chain if isinstance(chain, MessageChain) else MessageChain(list(chain))
+ )
+ return MessageEventResult(type=EventResultType.CHAIN, chain=normalized)
+
+ @staticmethod
+ def _coerce_chain_or_raise(item: Any) -> MessageChain:
+ if isinstance(item, MessageEventResult):
+ return item.chain
+ if isinstance(item, MessageChain):
+ return item
+ if isinstance(item, BaseMessageComponent):
+ return MessageChain([item])
+ if isinstance(item, list) and all(
+ isinstance(component, BaseMessageComponent) for component in item
+ ):
+ return MessageChain(list(item))
+ raise TypeError(
+ "send_streaming only accepts str, MessageChain, MessageEventResult or SDK message components"
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/filters.py b/astrbot-sdk/src/astrbot_sdk/filters.py
new file mode 100644
index 0000000000..4704f46dd0
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/filters.py
@@ -0,0 +1,234 @@
+"""SDK-native filter declarations.
+
+本模块提供事件过滤器的声明式 API,用于在 handler 执行前进行条件判断。
+
+内置过滤器类型:
+- PlatformFilter: 按平台名称过滤(如 qq、wechat)
+- MessageTypeFilter: 按消息类型过滤(如 group、private)
+- CustomFilter: 用户自定义的同步布尔函数
+
+组合操作:
+- all_of(*filters): 所有过滤器都通过才执行(AND 逻辑)
+- any_of(*filters): 任一过滤器通过即可执行(OR 逻辑)
+- 支持 & 和 | 运算符进行链式组合
+
+例子:
+@custom_filter(
+ all_of(
+ PlatformFilter(["qq"]),
+ MessageTypeFilter(["group"]),
+ CustomFilter(lambda event: "hello" in event.text),
+ )
+)
+
+过滤器在本地(SDK worker 进程内)求值,避免不必要的跨进程调用。
+"""
+
+from __future__ import annotations
+
+import inspect
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from typing import Any, Literal, TypeAlias, TypeVar
+
+from .decorators import append_filter_meta
+from .protocol.descriptors import (
+ CompositeFilterSpec,
+ FilterSpec,
+ LocalFilterRefSpec,
+ MessageTypeFilterSpec,
+ PlatformFilterSpec,
+)
+
+FilterOperator: TypeAlias = Literal["and", "or"]
+_HandlerT = TypeVar("_HandlerT", bound=Callable[..., Any])
+
+
+@dataclass(slots=True)
+class LocalFilterBinding:
+ filter_id: str
+ callable: Callable[..., bool]
+ args: dict[str, Any] = field(default_factory=dict)
+ _accepts_event: bool = field(init=False, repr=False)
+ _accepts_ctx: bool = field(init=False, repr=False)
+
+ def __post_init__(self) -> None:
+ parameters = inspect.signature(self.callable).parameters
+ self._accepts_event = "event" in parameters
+ self._accepts_ctx = "ctx" in parameters
+
+ def evaluate(self, *, event=None, ctx=None) -> bool:
+ kwargs: dict[str, Any] = {}
+ if self._accepts_event:
+ kwargs["event"] = event
+ if self._accepts_ctx:
+ kwargs["ctx"] = ctx
+ result = self.callable(**kwargs)
+ if inspect.isawaitable(result):
+ raise TypeError("CustomFilter must return a synchronous bool")
+ if not isinstance(result, bool):
+ raise TypeError("CustomFilter must return bool")
+ return result
+
+
+class FilterBinding:
+ def __and__(self, other: FilterBinding) -> CompositeFilter:
+ return CompositeFilter("and", [self, other])
+
+ def __or__(self, other: FilterBinding) -> CompositeFilter:
+ return CompositeFilter("or", [self, other])
+
+ def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]:
+ raise NotImplementedError
+
+
+@dataclass(slots=True)
+class PlatformFilter(FilterBinding):
+ platforms: list[str]
+
+ def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]:
+ return PlatformFilterSpec(platforms=list(self.platforms)), []
+
+
+@dataclass(slots=True)
+class MessageTypeFilter(FilterBinding):
+ message_types: list[str]
+
+ def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]:
+ return MessageTypeFilterSpec(message_types=list(self.message_types)), []
+
+
+@dataclass(slots=True)
+class CustomFilter(FilterBinding):
+ callable: Callable[..., bool]
+ filter_id: str | None = None
+
+ def __post_init__(self) -> None:
+ if self.filter_id is None:
+ self.filter_id = f"{self.callable.__module__}.{getattr(self.callable, '__qualname__', self.callable.__name__)}"
+
+ def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]:
+ assert self.filter_id is not None
+ return LocalFilterRefSpec(filter_id=self.filter_id), [
+ LocalFilterBinding(filter_id=self.filter_id, callable=self.callable),
+ ]
+
+
+@dataclass(slots=True)
+class CompositeFilter(FilterBinding):
+ operator: FilterOperator
+ children: list[FilterBinding]
+
+ def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]:
+ compiled_children: list[FilterSpec] = []
+ local_bindings: list[LocalFilterBinding] = []
+ for child in self.children:
+ spec, locals_for_child = child.compile()
+ compiled_children.append(spec)
+ local_bindings.extend(locals_for_child)
+
+ if local_bindings:
+ filter_id = (
+ "composite:"
+ + ":".join(binding.filter_id for binding in local_bindings)
+ + f":{self.operator}"
+ )
+
+ def _evaluate(*, event=None, ctx=None) -> bool:
+ results = [
+ _evaluate_filter_spec_locally(
+ spec, local_bindings, event=event, ctx=ctx
+ )
+ for spec in compiled_children
+ ]
+ if self.operator == "and":
+ return all(results)
+ return any(results)
+
+ return (
+ LocalFilterRefSpec(filter_id=filter_id),
+ [LocalFilterBinding(filter_id=filter_id, callable=_evaluate)],
+ )
+
+ return CompositeFilterSpec(kind=self.operator, children=compiled_children), []
+
+
+def _evaluate_filter_spec_locally(
+ spec: FilterSpec,
+ local_bindings: list[LocalFilterBinding],
+ *,
+ event=None,
+ ctx=None,
+) -> bool:
+ if isinstance(spec, PlatformFilterSpec):
+ if event is None:
+ return True
+ platform = getattr(event, "platform", "") or ""
+ return platform in spec.platforms
+ if isinstance(spec, MessageTypeFilterSpec):
+ if event is None:
+ return True
+ message_type = getattr(event, "message_type", "") or ""
+ return message_type in spec.message_types
+ if isinstance(spec, LocalFilterRefSpec):
+ binding = next(
+ (item for item in local_bindings if item.filter_id == spec.filter_id),
+ None,
+ )
+ if binding is None:
+ # LocalFilterRefSpec 只在当前 worker 持有同名 local binding 时可真正执行。
+ # 缺失 binding 往往意味着描述符来自远端/测试快照,此时保持 fail-open,
+ # 避免因为无法调用进程内函数而把原本可执行的 handler 错误过滤掉。
+ return True
+ return binding.evaluate(event=event, ctx=ctx)
+ if isinstance(spec, CompositeFilterSpec):
+ results = [
+ _evaluate_filter_spec_locally(
+ child,
+ local_bindings,
+ event=event,
+ ctx=ctx,
+ )
+ for child in spec.children
+ ]
+ if spec.kind == "and":
+ return all(results)
+ return any(results)
+ return True
+
+
+def custom_filter(
+ binding: FilterBinding,
+) -> Callable[[_HandlerT], _HandlerT]:
+ """Attach a filter declaration to a handler."""
+
+ def decorator(func: _HandlerT) -> _HandlerT:
+ spec, local_bindings = binding.compile()
+ append_filter_meta(
+ func,
+ specs=[spec],
+ local_bindings=local_bindings,
+ )
+ return func
+
+ return decorator
+
+
+def all_of(*bindings: FilterBinding) -> CompositeFilter:
+ return CompositeFilter("and", list(bindings))
+
+
+def any_of(*bindings: FilterBinding) -> CompositeFilter:
+ return CompositeFilter("or", list(bindings))
+
+
+__all__ = [
+ "CustomFilter",
+ "FilterBinding",
+ "LocalFilterBinding",
+ "MessageTypeFilter",
+ "PlatformFilter",
+ "all_of",
+ "any_of",
+ "custom_filter",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/llm/__init__.py b/astrbot-sdk/src/astrbot_sdk/llm/__init__.py
new file mode 100644
index 0000000000..02e15b9d2f
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/llm/__init__.py
@@ -0,0 +1,105 @@
+"""Canonical SDK LLM/tool/provider entrypoints for P0.5."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from .agents import AgentSpec, BaseAgentRunner
+ from .entities import (
+ LLMToolSpec,
+ ProviderMeta,
+ ProviderRequest,
+ ProviderType,
+ RerankResult,
+ ToolCallsResult,
+ )
+ from .providers import (
+ EmbeddingProvider,
+ ProviderProxy,
+ RerankProvider,
+ STTProvider,
+ TTSAudioChunk,
+ TTSProvider,
+ )
+ from .tools import LLMToolManager
+
+__all__ = [
+ "AgentSpec",
+ "BaseAgentRunner",
+ "EmbeddingProvider",
+ "LLMToolManager",
+ "LLMToolSpec",
+ "ProviderMeta",
+ "ProviderProxy",
+ "ProviderRequest",
+ "ProviderType",
+ "RerankProvider",
+ "RerankResult",
+ "STTProvider",
+ "TTSAudioChunk",
+ "TTSProvider",
+ "ToolCallsResult",
+]
+
+
+def __getattr__(name: str) -> Any:
+ if name in {"AgentSpec", "BaseAgentRunner"}:
+ from .agents import AgentSpec, BaseAgentRunner
+
+ return {"AgentSpec": AgentSpec, "BaseAgentRunner": BaseAgentRunner}[name]
+ if name in {
+ "LLMToolSpec",
+ "ProviderMeta",
+ "ProviderRequest",
+ "ProviderType",
+ "RerankResult",
+ "ToolCallsResult",
+ }:
+ from .entities import (
+ LLMToolSpec,
+ ProviderMeta,
+ ProviderRequest,
+ ProviderType,
+ RerankResult,
+ ToolCallsResult,
+ )
+
+ return {
+ "LLMToolSpec": LLMToolSpec,
+ "ProviderMeta": ProviderMeta,
+ "ProviderRequest": ProviderRequest,
+ "ProviderType": ProviderType,
+ "RerankResult": RerankResult,
+ "ToolCallsResult": ToolCallsResult,
+ }[name]
+ if name in {
+ "EmbeddingProvider",
+ "ProviderProxy",
+ "RerankProvider",
+ "STTProvider",
+ "TTSAudioChunk",
+ "TTSProvider",
+ }:
+ from .providers import (
+ EmbeddingProvider,
+ ProviderProxy,
+ RerankProvider,
+ STTProvider,
+ TTSAudioChunk,
+ TTSProvider,
+ )
+
+ return {
+ "EmbeddingProvider": EmbeddingProvider,
+ "ProviderProxy": ProviderProxy,
+ "RerankProvider": RerankProvider,
+ "STTProvider": STTProvider,
+ "TTSAudioChunk": TTSAudioChunk,
+ "TTSProvider": TTSProvider,
+ }[name]
+ if name == "LLMToolManager":
+ from .tools import LLMToolManager
+
+ return LLMToolManager
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/astrbot-sdk/src/astrbot_sdk/llm/agents.py b/astrbot-sdk/src/astrbot_sdk/llm/agents.py
new file mode 100644
index 0000000000..c2d6b21e62
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/llm/agents.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+from pydantic import BaseModel, ConfigDict, Field
+
+from .entities import ProviderRequest
+
+if TYPE_CHECKING:
+ from ..context import Context
+
+
+class AgentSpec(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+ name: str
+ description: str = ""
+ tool_names: list[str] = Field(default_factory=list)
+ runner_class: str
+
+ def to_payload(self) -> dict[str, Any]:
+ return self.model_dump(exclude_none=True)
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> AgentSpec:
+ return cls.model_validate(payload)
+
+
+class BaseAgentRunner(ABC):
+ """agent registration surface.
+
+ only supports agent registration metadata. Actual execution remains
+ owned by the core tool loop and is not directly callable from SDK plugins.
+ """
+
+ @abstractmethod
+ async def run(self, ctx: Context, request: ProviderRequest) -> Any:
+ raise NotImplementedError
diff --git a/astrbot-sdk/src/astrbot_sdk/llm/entities.py b/astrbot-sdk/src/astrbot_sdk/llm/entities.py
new file mode 100644
index 0000000000..ba252db24b
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/llm/entities.py
@@ -0,0 +1,137 @@
+from __future__ import annotations
+
+import enum
+from typing import Any
+
+from pydantic import BaseModel, ConfigDict, Field
+
+
+class _EntityModel(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+ def to_payload(self) -> dict[str, Any]:
+ return self.model_dump(exclude_none=True)
+
+
+class ProviderType(str, enum.Enum):
+ CHAT_COMPLETION = "chat_completion"
+ SPEECH_TO_TEXT = "speech_to_text"
+ TEXT_TO_SPEECH = "text_to_speech"
+ EMBEDDING = "embedding"
+ RERANK = "rerank"
+
+
+class ProviderMeta(_EntityModel):
+ id: str
+ model: str | None = None
+ type: str
+ provider_type: ProviderType = ProviderType.CHAT_COMPLETION
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any] | None) -> ProviderMeta | None:
+ if not isinstance(payload, dict):
+ return None
+ return cls.model_validate(payload)
+
+
+class ToolCallsResult(_EntityModel):
+ tool_call_id: str | None = None
+ tool_name: str
+ content: str
+ success: bool = True
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> ToolCallsResult:
+ return cls.model_validate(payload)
+
+
+class RerankResult(_EntityModel):
+ index: int
+ score: float
+ document: str
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> RerankResult:
+ return cls.model_validate(payload)
+
+
+class LLMToolSpec(_EntityModel):
+ name: str
+ description: str = ""
+ parameters_schema: dict[str, Any] = Field(
+ default_factory=lambda: {"type": "object", "properties": {}}
+ )
+ handler_ref: str | None = Field(
+ default=None,
+ description="Worker-side handler reference used to resolve the tool callable.",
+ )
+ handler_capability: str | None = Field(
+ default=None,
+ description="Optional capability name override for executing this tool handler.",
+ )
+ active: bool = True
+
+ @classmethod
+ def create(
+ cls,
+ *,
+ name: str,
+ description: str = "",
+ parameters_schema: dict[str, Any] | None = None,
+ handler_ref: str | None = None,
+ handler_capability: str | None = None,
+ active: bool = True,
+ ) -> LLMToolSpec:
+ # Keep an explicit factory signature so static analyzers do not depend on
+ # Pydantic's generated __init__ when SDK call sites construct tool specs.
+ payload: dict[str, Any] = {
+ "name": name,
+ "description": description,
+ "parameters_schema": parameters_schema
+ if parameters_schema is not None
+ else {"type": "object", "properties": {}},
+ "active": active,
+ }
+ if handler_ref is not None:
+ payload["handler_ref"] = handler_ref
+ if handler_capability is not None:
+ payload["handler_capability"] = handler_capability
+ return cls.from_payload(payload)
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> LLMToolSpec:
+ return cls.model_validate(payload)
+
+
+class ProviderRequest(_EntityModel):
+ prompt: str | None = None
+ system_prompt: str | None = None
+ session_id: str | None = None
+ contexts: list[dict[str, Any]] = Field(default_factory=list)
+ image_urls: list[str] = Field(default_factory=list)
+ tool_names: list[str] | None = None
+ tool_calls_result: list[ToolCallsResult] = Field(default_factory=list)
+ provider_id: str | None = None
+ model: str | None = None
+ temperature: float | None = None
+ max_steps: int | None = None
+ tool_call_timeout: int | None = None
+
+ def to_payload(self) -> dict[str, Any]:
+ payload = super().to_payload()
+ payload["tool_calls_result"] = [
+ item.to_payload() for item in self.tool_calls_result
+ ]
+ return payload
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> ProviderRequest:
+ normalized = dict(payload)
+ raw_results = normalized.get("tool_calls_result")
+ if isinstance(raw_results, list):
+ normalized["tool_calls_result"] = [
+ ToolCallsResult.from_payload(item)
+ for item in raw_results
+ if isinstance(item, dict)
+ ]
+ return cls.model_validate(normalized)
diff --git a/astrbot-sdk/src/astrbot_sdk/llm/providers.py b/astrbot-sdk/src/astrbot_sdk/llm/providers.py
new file mode 100644
index 0000000000..591e1d57d5
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/llm/providers.py
@@ -0,0 +1,199 @@
+"""Provider-facing SDK entities and typed proxy helpers."""
+
+from __future__ import annotations
+
+import base64
+from collections.abc import AsyncIterable, AsyncIterator
+from dataclasses import dataclass
+
+from ..clients._proxy import CapabilityProxy
+from .entities import ProviderMeta, ProviderType, RerankResult
+
+
+@dataclass(slots=True)
+class TTSAudioChunk:
+ audio: bytes
+ text: str | None = None
+
+
+class _BaseProviderProxy:
+ def __init__(self, proxy: CapabilityProxy, meta: ProviderMeta) -> None:
+ self._proxy = proxy
+ self._meta = meta
+
+ @property
+ def id(self) -> str:
+ return self._meta.id
+
+ @property
+ def model(self) -> str | None:
+ return self._meta.model
+
+ @property
+ def type(self) -> str:
+ return self._meta.type
+
+ @property
+ def provider_type(self) -> ProviderType:
+ return self._meta.provider_type
+
+ def meta(self) -> ProviderMeta:
+ return self._meta
+
+
+class STTProvider(_BaseProviderProxy):
+ async def get_text(self, audio_url: str) -> str:
+ output = await self._proxy.call(
+ "provider.stt.get_text",
+ {"provider_id": self.id, "audio_url": str(audio_url)},
+ )
+ return str(output.get("text", ""))
+
+
+class TTSProvider(_BaseProviderProxy):
+ def __init__(
+ self,
+ proxy: CapabilityProxy,
+ meta: ProviderMeta,
+ *,
+ supports_stream: bool = False,
+ ) -> None:
+ super().__init__(proxy, meta)
+ self._supports_stream = supports_stream
+
+ async def get_audio(self, text: str) -> str:
+ output = await self._proxy.call(
+ "provider.tts.get_audio",
+ {"provider_id": self.id, "text": str(text)},
+ )
+ return str(output.get("audio_path", ""))
+
+ def support_stream(self) -> bool:
+ return self._supports_stream
+
+ async def get_audio_stream(
+ self,
+ text: str | AsyncIterable[str],
+ ) -> AsyncIterator[TTSAudioChunk]:
+ payload = await self._build_stream_payload(text)
+ async for chunk in self._proxy.stream("provider.tts.get_audio_stream", payload):
+ audio_base64 = str(chunk.get("audio_base64", ""))
+ yield TTSAudioChunk(
+ audio=base64.b64decode(audio_base64) if audio_base64 else b"",
+ text=(
+ str(chunk.get("text")) if chunk.get("text") is not None else None
+ ),
+ )
+
+ async def _build_stream_payload(
+ self,
+ text: str | AsyncIterable[str],
+ ) -> dict[str, object]:
+ payload: dict[str, object] = {"provider_id": self.id}
+ if isinstance(text, str):
+ payload["text"] = text
+ return payload
+ payload["text_chunks"] = [str(item) async for item in text]
+ return payload
+
+
+class EmbeddingProvider(_BaseProviderProxy):
+ async def get_embedding(self, text: str) -> list[float]:
+ output = await self._proxy.call(
+ "provider.embedding.get_embedding",
+ {"provider_id": self.id, "text": str(text)},
+ )
+ embedding = output.get("embedding")
+ if not isinstance(embedding, list):
+ return []
+ return [float(item) for item in embedding]
+
+ async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
+ output = await self._proxy.call(
+ "provider.embedding.get_embeddings",
+ {
+ "provider_id": self.id,
+ "texts": [str(item) for item in texts],
+ },
+ )
+ embeddings = output.get("embeddings")
+ if not isinstance(embeddings, list):
+ return []
+ return [
+ [float(value) for value in item]
+ for item in embeddings
+ if isinstance(item, list)
+ ]
+
+ async def get_dim(self) -> int:
+ output = await self._proxy.call(
+ "provider.embedding.get_dim",
+ {"provider_id": self.id},
+ )
+ return int(output.get("dim", 0))
+
+
+class RerankProvider(_BaseProviderProxy):
+ async def rerank(
+ self,
+ query: str,
+ documents: list[str],
+ top_n: int | None = None,
+ ) -> list[RerankResult]:
+ output = await self._proxy.call(
+ "provider.rerank.rerank",
+ {
+ "provider_id": self.id,
+ "query": str(query),
+ "documents": [str(item) for item in documents],
+ "top_n": top_n,
+ },
+ )
+ results = output.get("results")
+ if not isinstance(results, list):
+ return []
+ return [
+ RerankResult.from_payload(item)
+ for item in results
+ if isinstance(item, dict)
+ ]
+
+
+ProviderProxy = STTProvider | TTSProvider | EmbeddingProvider | RerankProvider
+
+
+def provider_proxy_from_meta(
+ proxy: CapabilityProxy,
+ meta: ProviderMeta | None,
+ *,
+ tts_supports_stream: bool | None = None,
+) -> ProviderProxy | None:
+ if meta is None:
+ return None
+ if meta.provider_type == ProviderType.SPEECH_TO_TEXT:
+ return STTProvider(proxy, meta)
+ if meta.provider_type == ProviderType.TEXT_TO_SPEECH:
+ return TTSProvider(
+ proxy,
+ meta,
+ supports_stream=bool(tts_supports_stream),
+ )
+ if meta.provider_type == ProviderType.EMBEDDING:
+ return EmbeddingProvider(proxy, meta)
+ if meta.provider_type == ProviderType.RERANK:
+ return RerankProvider(proxy, meta)
+ return None
+
+
+__all__ = [
+ "EmbeddingProvider",
+ "ProviderMeta",
+ "ProviderProxy",
+ "ProviderType",
+ "RerankProvider",
+ "RerankResult",
+ "STTProvider",
+ "TTSAudioChunk",
+ "TTSProvider",
+ "provider_proxy_from_meta",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/llm/tools.py b/astrbot-sdk/src/astrbot_sdk/llm/tools.py
new file mode 100644
index 0000000000..d1a67b30c7
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/llm/tools.py
@@ -0,0 +1,59 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .entities import LLMToolSpec
+
+if TYPE_CHECKING:
+ from ..clients._proxy import CapabilityProxy
+
+
+class LLMToolManager:
+ def __init__(self, proxy: CapabilityProxy) -> None:
+ self._proxy = proxy
+
+ async def list_registered(self) -> list[LLMToolSpec]:
+ output = await self._proxy.call("llm_tool.manager.get", {})
+ items = output.get("registered")
+ if not isinstance(items, list):
+ return []
+ return [
+ LLMToolSpec.from_payload(item) for item in items if isinstance(item, dict)
+ ]
+
+ async def list_active(self) -> list[LLMToolSpec]:
+ output = await self._proxy.call("llm_tool.manager.get", {})
+ items = output.get("active")
+ if not isinstance(items, list):
+ return []
+ return [
+ LLMToolSpec.from_payload(item) for item in items if isinstance(item, dict)
+ ]
+
+ async def activate(self, name: str) -> bool:
+ output = await self._proxy.call("llm_tool.manager.activate", {"name": name})
+ return bool(output.get("activated", False))
+
+ async def deactivate(self, name: str) -> bool:
+ output = await self._proxy.call("llm_tool.manager.deactivate", {"name": name})
+ return bool(output.get("deactivated", False))
+
+ async def add(self, *tools: LLMToolSpec) -> list[str]:
+ output = await self._proxy.call(
+ "llm_tool.manager.add",
+ {"tools": [tool.to_payload() for tool in tools]},
+ )
+ result = output.get("names")
+ if not isinstance(result, list):
+ return []
+ return [str(item) for item in result]
+
+ async def remove(self, name: str) -> bool:
+ output = await self._proxy.call("llm_tool.manager.remove", {"name": name})
+ return bool(output.get("removed", False))
+
+ async def get(self, name: str) -> LLMToolSpec | None:
+ for tool in await self.list_registered():
+ if tool.name == name:
+ return tool
+ return None
diff --git a/astrbot-sdk/src/astrbot_sdk/message/__init__.py b/astrbot-sdk/src/astrbot_sdk/message/__init__.py
new file mode 100644
index 0000000000..4125a0db12
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message/__init__.py
@@ -0,0 +1,103 @@
+"""Message component, result, and session subpackage."""
+
+from .components import (
+ At as At,
+)
+from .components import (
+ AtAll as AtAll,
+)
+from .components import (
+ BaseMessageComponent as BaseMessageComponent,
+)
+from .components import (
+ File as File,
+)
+from .components import (
+ Forward as Forward,
+)
+from .components import (
+ Image as Image,
+)
+from .components import (
+ MediaHelper as MediaHelper,
+)
+from .components import (
+ Plain as Plain,
+)
+from .components import (
+ Poke as Poke,
+)
+from .components import (
+ Record as Record,
+)
+from .components import (
+ Reply as Reply,
+)
+from .components import (
+ UnknownComponent as UnknownComponent,
+)
+from .components import (
+ Video as Video,
+)
+from .components import (
+ build_media_component_from_url as build_media_component_from_url,
+)
+from .components import (
+ component_to_payload as component_to_payload,
+)
+from .components import (
+ component_to_payload_sync as component_to_payload_sync,
+)
+from .components import (
+ is_message_component as is_message_component,
+)
+from .components import (
+ payload_to_component as payload_to_component,
+)
+from .components import (
+ payloads_to_components as payloads_to_components,
+)
+from .result import (
+ EventResultType as EventResultType,
+)
+from .result import (
+ MessageBuilder as MessageBuilder,
+)
+from .result import (
+ MessageChain as MessageChain,
+)
+from .result import (
+ MessageEventResult as MessageEventResult,
+)
+from .result import (
+ coerce_message_chain as coerce_message_chain,
+)
+from .session import MessageSession as MessageSession
+
+__all__ = [
+ "At",
+ "AtAll",
+ "BaseMessageComponent",
+ "EventResultType",
+ "File",
+ "Forward",
+ "Image",
+ "MediaHelper",
+ "MessageBuilder",
+ "MessageChain",
+ "MessageEventResult",
+ "MessageSession",
+ "Plain",
+ "Poke",
+ "Record",
+ "Reply",
+ "UnknownComponent",
+ "Video",
+ "build_media_component_from_url",
+ "coerce_message_chain",
+ "component_to_payload",
+ "component_to_payload_sync",
+ "is_message_component",
+ "payload_to_component",
+ "payloads_to_components",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/message/components.py b/astrbot-sdk/src/astrbot_sdk/message/components.py
new file mode 100644
index 0000000000..bd00708ac2
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message/components.py
@@ -0,0 +1,513 @@
+"""SDK message component compatibility layer.
+
+该模块有意避免在导入时导入遗留核心组件模块。
+SDK工作线程应该保持轻量级并且不能依赖于主机核心引导程序
+仅用于构造消息对象的路径。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import os
+from collections.abc import Mapping
+from pathlib import Path
+from typing import Any
+from urllib.parse import urlparse
+from urllib.request import urlretrieve
+
+from ..errors import AstrBotError
+
+_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
+_RECORD_SUFFIXES = {".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a"}
+_VIDEO_SUFFIXES = {".mp4", ".webm", ".mov", ".mkv", ".avi"}
+
+
+def _stringify_mapping(mapping: Mapping[Any, Any]) -> dict[str, Any]:
+ return {str(key): value for key, value in mapping.items()}
+
+
+def _reply_chain_payloads_sync(value: Any) -> list[dict[str, Any]]:
+ if not isinstance(value, list):
+ return []
+ return [component_to_payload_sync(item) for item in value]
+
+
+async def _reply_chain_payloads(value: Any) -> list[dict[str, Any]]:
+ if not isinstance(value, list):
+ return []
+ return [await component_to_payload(item) for item in value]
+
+
+def _coerce_reply_chain(value: Any) -> list[BaseMessageComponent]:
+ if not isinstance(value, list):
+ return []
+ if value and all(isinstance(item, BaseMessageComponent) for item in value):
+ return list(value)
+ return payloads_to_components(value)
+
+
+def _component_type_name(component: Any) -> str:
+ raw_type = getattr(component, "type", "unknown")
+ normalized = getattr(raw_type, "value", raw_type)
+ return str(normalized or "unknown").lower()
+
+
+def _plain_payload(text: Any) -> dict[str, Any]:
+ return {"type": "text", "data": {"text": str(text)}}
+
+
+def _reply_payload_data(
+ component: Any,
+ *,
+ chain_payloads: list[dict[str, Any]],
+) -> dict[str, Any]:
+ return {
+ "id": getattr(component, "id", ""),
+ "chain": chain_payloads,
+ "sender_id": getattr(component, "sender_id", 0),
+ "sender_nickname": getattr(component, "sender_nickname", ""),
+ "time": getattr(component, "time", 0),
+ "message_str": getattr(component, "message_str", ""),
+ "text": getattr(component, "text", ""),
+ "qq": getattr(component, "qq", 0),
+ "seq": getattr(component, "seq", 0),
+ }
+
+
+def _resolve_media_kind(url: str, kind: str = "auto") -> str:
+ normalized_kind = str(kind).strip().lower() or "auto"
+ if normalized_kind != "auto":
+ return normalized_kind
+ suffix = Path(urlparse(url).path).suffix.lower()
+ if suffix in _IMAGE_SUFFIXES:
+ return "image"
+ if suffix in _RECORD_SUFFIXES:
+ return "record"
+ if suffix in _VIDEO_SUFFIXES:
+ return "video"
+ return "file"
+
+
+def build_media_component_from_url(
+ url: str,
+ *,
+ kind: str = "auto",
+) -> BaseMessageComponent:
+ url_text = str(url).strip()
+ if not url_text:
+ raise AstrBotError.invalid_input(
+ "MediaHelper.from_url requires a non-empty url"
+ )
+ resolved_kind = _resolve_media_kind(url_text, kind=kind)
+ if resolved_kind == "image":
+ return Image.fromURL(url_text)
+ if resolved_kind in {"record", "audio"}:
+ return Record.fromURL(url_text)
+ if resolved_kind == "video":
+ return Video.fromURL(url_text)
+ if resolved_kind == "file":
+ return File(name=_filename_from_url(url_text), url=url_text)
+ raise AstrBotError.invalid_input(
+ f"Unsupported media kind: {kind}",
+ details={"kind": kind, "url": url_text},
+ )
+
+
+def _filename_from_url(url: str) -> str:
+ name = Path(urlparse(url).path).name
+ return name or "download"
+
+
+class BaseMessageComponent:
+ type: str = "unknown"
+
+ def toDict(self) -> dict[str, Any]:
+ data: dict[str, Any] = {}
+ for key, value in self.__dict__.items():
+ if key == "type" or value is None:
+ continue
+ data["type" if key == "_type" else key] = value
+ return {"type": str(self.type).lower(), "data": data}
+
+ async def to_dict(self) -> dict[str, Any]:
+ return self.toDict()
+
+
+class Plain(BaseMessageComponent):
+ type = "plain"
+
+ def __init__(self, text: str, convert: bool = True, **_: Any) -> None:
+ self.text = text
+ self.convert = convert
+
+ def toDict(self) -> dict[str, Any]:
+ return _plain_payload(self.text)
+
+ async def to_dict(self) -> dict[str, Any]:
+ return _plain_payload(self.text)
+
+
+class At(BaseMessageComponent):
+ type = "at"
+
+ def __init__(self, qq: int | str, name: str | None = "", **_: Any) -> None:
+ self.qq = qq
+ self.name = name or ""
+
+ def toDict(self) -> dict[str, Any]:
+ return {"type": "at", "data": {"qq": str(self.qq)}}
+
+
+class AtAll(At):
+ def __init__(self, **_: Any) -> None:
+ super().__init__(qq="all")
+
+
+class Reply(BaseMessageComponent):
+ type = "reply"
+
+ def __init__(self, **kwargs: Any) -> None:
+ self.id = kwargs.get("id", "")
+ self.chain = _coerce_reply_chain(kwargs.get("chain", []))
+ self.sender_id = kwargs.get("sender_id", 0)
+ self.sender_nickname = kwargs.get("sender_nickname", "")
+ self.time = kwargs.get("time", 0)
+ self.message_str = kwargs.get("message_str", "")
+ self.text = kwargs.get("text", "")
+ self.qq = kwargs.get("qq", 0)
+ self.seq = kwargs.get("seq", 0)
+
+ def toDict(self) -> dict[str, Any]:
+ return {
+ "type": "reply",
+ "data": _reply_payload_data(
+ self,
+ chain_payloads=_reply_chain_payloads_sync(self.chain),
+ ),
+ }
+
+ async def to_dict(self) -> dict[str, Any]:
+ return {
+ "type": "reply",
+ "data": _reply_payload_data(
+ self,
+ chain_payloads=await _reply_chain_payloads(self.chain),
+ ),
+ }
+
+
+class Image(BaseMessageComponent):
+ type = "image"
+
+ def __init__(self, file: str | None, **kwargs: Any) -> None:
+ self.file = file or ""
+ self._type = kwargs.get("_type", "")
+ self.subType = kwargs.get("subType", 0)
+ self.url = kwargs.get("url", "")
+ self.cache = kwargs.get("cache", True)
+ self.id = kwargs.get("id", 40000)
+ self.c = kwargs.get("c", 2)
+ self.path = kwargs.get("path", "")
+ self.file_unique = kwargs.get("file_unique", "")
+
+ @staticmethod
+ def fromURL(url: str, **kwargs: Any) -> Image:
+ return Image(url, **kwargs)
+
+ @staticmethod
+ def fromFileSystem(path: str, **kwargs: Any) -> Image:
+ return Image(f"file:///{os.path.abspath(path)}", path=path, **kwargs)
+
+ @staticmethod
+ def fromBase64(base64_data: str, **kwargs: Any) -> Image:
+ return Image(f"base64://{base64_data}", **kwargs)
+
+
+class Record(BaseMessageComponent):
+ type = "record"
+
+ def __init__(self, file: str | None, **kwargs: Any) -> None:
+ self.file = file or ""
+ self.magic = kwargs.get("magic", False)
+ self.url = kwargs.get("url", "")
+ self.cache = kwargs.get("cache", True)
+ self.proxy = kwargs.get("proxy", True)
+ self.timeout = kwargs.get("timeout", 0)
+ self.text = kwargs.get("text")
+ self.path = kwargs.get("path")
+
+ @staticmethod
+ def fromFileSystem(path: str, **kwargs: Any) -> Record:
+ return Record(f"file:///{os.path.abspath(path)}", path=path, **kwargs)
+
+ @staticmethod
+ def fromURL(url: str, **kwargs: Any) -> Record:
+ return Record(url, **kwargs)
+
+
+class Video(BaseMessageComponent):
+ type = "video"
+
+ def __init__(self, file: str, **kwargs: Any) -> None:
+ self.file = file
+ self.cover = kwargs.get("cover", "")
+ self.c = kwargs.get("c", 2)
+ self.path = kwargs.get("path", "")
+
+ @staticmethod
+ def fromFileSystem(path: str, **kwargs: Any) -> Video:
+ return Video(f"file:///{os.path.abspath(path)}", path=path, **kwargs)
+
+ @staticmethod
+ def fromURL(url: str, **kwargs: Any) -> Video:
+ return Video(url, **kwargs)
+
+
+class File(BaseMessageComponent):
+ type = "file"
+
+ def __init__(self, name: str, file: str = "", url: str = "") -> None:
+ self.name = name
+ self.file_ = file
+ self.url = url
+
+ @property
+ def file(self) -> str:
+ return self.file_
+
+ @file.setter
+ def file(self, value: str) -> None:
+ if value.startswith(("http://", "https://")):
+ self.url = value
+ else:
+ self.file_ = value
+
+ def toDict(self) -> dict[str, Any]:
+ payload_file = self.url or self.file_
+ return {
+ "type": "file",
+ "data": {
+ "name": self.name,
+ "file": payload_file,
+ },
+ }
+
+ async def to_dict(self) -> dict[str, Any]:
+ payload_file = self.url or self.file_
+ return {
+ "type": "file",
+ "data": {
+ "name": self.name,
+ "file": payload_file,
+ },
+ }
+
+
+class Poke(BaseMessageComponent):
+ type = "poke"
+
+ def __init__(self, poke_type: str | int | None = None, **kwargs: Any) -> None:
+ legacy_type = kwargs.pop("type", None)
+ if poke_type is None:
+ poke_type = legacy_type
+ if poke_type in (None, "", "poke", "Poke"):
+ poke_type = "126"
+ self._type = str(poke_type)
+ self.id = kwargs.get("id")
+ self.qq = kwargs.get("qq", 0)
+
+ def target_id(self) -> str | None:
+ for value in (self.id, self.qq):
+ if value is None:
+ continue
+ text = str(value).strip()
+ if text and text != "0":
+ return text
+ return None
+
+ def toDict(self) -> dict[str, Any]:
+ data = {"type": str(self._type or "126")}
+ target_id = self.target_id()
+ if target_id:
+ data["id"] = target_id
+ return {"type": "poke", "data": data}
+
+
+class Forward(BaseMessageComponent):
+ type = "forward"
+
+ def __init__(self, id: str, **_: Any) -> None:
+ self.id = id
+
+
+class UnknownComponent(BaseMessageComponent):
+ type = "unknown"
+
+ def __init__(
+ self,
+ *,
+ raw_type: str = "unknown",
+ raw_data: dict[str, Any] | None = None,
+ ) -> None:
+ self.raw_type = raw_type
+ self.raw_data = raw_data or {}
+
+ def toDict(self) -> dict[str, Any]:
+ return {
+ "type": self.raw_type or "unknown",
+ "data": dict(self.raw_data),
+ }
+
+
+def is_message_component(value: Any) -> bool:
+ return isinstance(value, BaseMessageComponent)
+
+
+def payload_to_component(payload: Any) -> BaseMessageComponent:
+ if not isinstance(payload, dict):
+ return UnknownComponent(raw_data={"value": payload})
+
+ raw_type = str(payload.get("type", "unknown") or "unknown").lower()
+ data = payload.get("data")
+ if not isinstance(data, dict):
+ data = {}
+
+ if raw_type in {"text", "plain"}:
+ return Plain(str(data.get("text", "")), convert=False)
+ if raw_type == "image":
+ return Image(str(data.get("file") or data.get("url") or ""))
+ if raw_type == "at":
+ qq_value = data.get("qq")
+ if str(qq_value).lower() == "all":
+ return AtAll()
+ qq = "" if qq_value is None else str(qq_value)
+ return At(qq=qq, name=str(data.get("name", "")))
+ if raw_type == "reply":
+ return Reply(**data)
+ if raw_type == "record":
+ return Record(str(data.get("file") or data.get("url") or ""), **data)
+ if raw_type == "video":
+ return Video(str(data.get("file") or ""), **data)
+ if raw_type == "file":
+ file_value = str(data.get("file") or data.get("file_") or "")
+ if not file_value:
+ file_value = str(data.get("url") or "")
+ return File(
+ str(data.get("name", "")),
+ file="" if file_value.startswith(("http://", "https://")) else file_value,
+ url=file_value if file_value.startswith(("http://", "https://")) else "",
+ )
+ if raw_type == "poke":
+ return Poke(
+ poke_type=data.get("type"),
+ id=data.get("id"),
+ qq=data.get("qq"),
+ )
+ if raw_type == "forward":
+ return Forward(id=str(data.get("id", "")))
+
+ return UnknownComponent(raw_type=raw_type, raw_data=_stringify_mapping(data))
+
+
+def payloads_to_components(payloads: list[Any]) -> list[BaseMessageComponent]:
+ return [payload_to_component(item) for item in payloads]
+
+
+def component_to_payload_sync(component: Any) -> dict[str, Any]:
+ if isinstance(component, UnknownComponent):
+ return component.toDict()
+ if isinstance(component, Plain):
+ return _plain_payload(component.text)
+ if _component_type_name(component) == "reply":
+ return {
+ "type": "reply",
+ "data": _reply_payload_data(
+ component,
+ chain_payloads=_reply_chain_payloads_sync(
+ getattr(component, "chain", [])
+ ),
+ ),
+ }
+ to_dict = getattr(component, "toDict", None)
+ if callable(to_dict):
+ result = to_dict()
+ if isinstance(result, Mapping):
+ return _stringify_mapping(result)
+ return {"type": "unknown", "data": {"value": str(component)}}
+
+
+async def component_to_payload(component: Any) -> dict[str, Any]:
+ if isinstance(component, (UnknownComponent, Plain)):
+ return component_to_payload_sync(component)
+ async_method = getattr(component, "to_dict", None)
+ if callable(async_method):
+ payload = async_method()
+ if inspect.isawaitable(payload):
+ result = await payload
+ if isinstance(result, dict):
+ return result
+ return component_to_payload_sync(component)
+
+
+class MediaHelper:
+ @staticmethod
+ async def from_url(
+ url: str,
+ *,
+ kind: str = "auto",
+ ) -> BaseMessageComponent:
+ return build_media_component_from_url(url, kind=kind)
+
+ @staticmethod
+ async def download(url: str, save_dir: Path) -> Path:
+ url_text = str(url).strip()
+ if not url_text:
+ raise AstrBotError.invalid_input(
+ "MediaHelper.download requires a non-empty url"
+ )
+ parsed = urlparse(url_text)
+ if parsed.scheme not in {"http", "https"}:
+ raise AstrBotError.invalid_input(
+ "MediaHelper.download only supports http/https urls",
+ details={"url": url_text},
+ )
+ target_dir = Path(save_dir)
+ try:
+ target_dir.mkdir(parents=True, exist_ok=True)
+ except OSError as exc:
+ raise AstrBotError.internal_error(
+ f"Failed to prepare download directory: {target_dir}",
+ details={"save_dir": str(target_dir)},
+ ) from exc
+ target_path = target_dir / _filename_from_url(url_text)
+ try:
+ await asyncio.to_thread(urlretrieve, url_text, target_path)
+ except Exception as exc:
+ raise AstrBotError.network_error(
+ f"Failed to download media from '{url_text}'",
+ details={"url": url_text},
+ ) from exc
+ return target_path.resolve()
+
+
+__all__ = [
+ "At",
+ "AtAll",
+ "BaseMessageComponent",
+ "File",
+ "Forward",
+ "Image",
+ "MediaHelper",
+ "Plain",
+ "Poke",
+ "Record",
+ "Reply",
+ "UnknownComponent",
+ "Video",
+ "component_to_payload",
+ "component_to_payload_sync",
+ "is_message_component",
+ "payload_to_component",
+ "payloads_to_components",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/message/result.py b/astrbot-sdk/src/astrbot_sdk/message/result.py
new file mode 100644
index 0000000000..a38c207099
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message/result.py
@@ -0,0 +1,174 @@
+"""SDK-local rich message result objects.
+
+本模块定义消息事件的结果对象,用于构建和返回富文本/多媒体消息。
+
+核心类:
+- MessageChain: 消息组件列表,支持同步/异步序列化为协议 payload
+- MessageEventResult: 事件处理结果,包含类型标记和消息链
+- EventResultType: 结果类型枚举(EMPTY / CHAIN)
+
+辅助函数:
+- coerce_message_chain: 将多种输入格式统一转换为 MessageChain,
+ 支持 MessageEventResult、MessageChain、单个组件或组件列表
+"""
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any
+
+from .components import (
+ At,
+ AtAll,
+ BaseMessageComponent,
+ File,
+ Plain,
+ Reply,
+ build_media_component_from_url,
+ component_to_payload,
+ component_to_payload_sync,
+ is_message_component,
+ payloads_to_components,
+)
+
+
+class EventResultType(str, Enum):
+ EMPTY = "empty"
+ CHAIN = "chain"
+
+
+@dataclass(slots=True)
+class MessageChain:
+ components: list[BaseMessageComponent] = field(default_factory=list)
+
+ def append(self, component: BaseMessageComponent) -> MessageChain:
+ self.components.append(component)
+ return self
+
+ def extend(self, components: list[BaseMessageComponent]) -> MessageChain:
+ self.components.extend(components)
+ return self
+
+ def __iter__(self) -> Iterator[BaseMessageComponent]:
+ return iter(self.components)
+
+ def __len__(self) -> int:
+ return len(self.components)
+
+ def to_payload(self) -> list[dict[str, Any]]:
+ return [component_to_payload_sync(component) for component in self.components]
+
+ async def to_payload_async(self) -> list[dict[str, Any]]:
+ return [await component_to_payload(component) for component in self.components]
+
+ def get_plain_text(self, with_other_comps_mark: bool = False) -> str:
+ texts: list[str] = []
+ for component in self.components:
+ if isinstance(component, Plain):
+ texts.append(component.text)
+ elif with_other_comps_mark:
+ texts.append(f"[{component.__class__.__name__}]")
+ return " ".join(texts)
+
+ def plain_text(self, with_other_comps_mark: bool = False) -> str:
+ return self.get_plain_text(with_other_comps_mark=with_other_comps_mark)
+
+
+@dataclass(slots=True)
+class MessageEventResult:
+ type: EventResultType = EventResultType.EMPTY
+ chain: MessageChain = field(default_factory=MessageChain)
+
+ def to_payload(self) -> dict[str, Any]:
+ return {
+ "type": self.type.value,
+ "chain": self.chain.to_payload(),
+ }
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> MessageEventResult:
+ result_type_raw = str(payload.get("type", EventResultType.EMPTY.value))
+ try:
+ result_type = EventResultType(result_type_raw)
+ except ValueError:
+ result_type = EventResultType.EMPTY
+ chain_payload = payload.get("chain")
+ components = (
+ payloads_to_components(chain_payload)
+ if isinstance(chain_payload, list)
+ else []
+ )
+ return cls(type=result_type, chain=MessageChain(components))
+
+
+@dataclass(slots=True)
+class MessageBuilder:
+ components: list[BaseMessageComponent] = field(default_factory=list)
+
+ def text(self, content: str) -> MessageBuilder:
+ self.components.append(Plain(content, convert=False))
+ return self
+
+ def at(self, user_id: str) -> MessageBuilder:
+ self.components.append(At(user_id))
+ return self
+
+ def at_all(self) -> MessageBuilder:
+ self.components.append(AtAll())
+ return self
+
+ def image(self, url: str) -> MessageBuilder:
+ self.components.append(build_media_component_from_url(url, kind="image"))
+ return self
+
+ def record(self, url: str) -> MessageBuilder:
+ self.components.append(build_media_component_from_url(url, kind="record"))
+ return self
+
+ def video(self, url: str) -> MessageBuilder:
+ self.components.append(build_media_component_from_url(url, kind="video"))
+ return self
+
+ def file(self, name: str, *, file: str = "", url: str = "") -> MessageBuilder:
+ self.components.append(File(name=name, file=file, url=url))
+ return self
+
+ def reply(self, **kwargs: Any) -> MessageBuilder:
+ self.components.append(Reply(**kwargs))
+ return self
+
+ def append(self, component: BaseMessageComponent) -> MessageBuilder:
+ self.components.append(component)
+ return self
+
+ def extend(self, components: list[BaseMessageComponent]) -> MessageBuilder:
+ self.components.extend(components)
+ return self
+
+ def build(self) -> MessageChain:
+ return MessageChain(list(self.components))
+
+
+def coerce_message_chain(value: Any) -> MessageChain | None:
+ if isinstance(value, MessageEventResult):
+ return value.chain
+ if isinstance(value, MessageChain):
+ return value
+ if is_message_component(value):
+ return MessageChain([value])
+ if isinstance(value, (list, tuple)) and all(
+ is_message_component(item) for item in value
+ ):
+ return MessageChain(list(value))
+ return None
+
+
+__all__ = [
+ "EventResultType",
+ "MessageChain",
+ "MessageBuilder",
+ "MessageEventResult",
+ "coerce_message_chain",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/message/session.py b/astrbot-sdk/src/astrbot_sdk/message/session.py
new file mode 100644
index 0000000000..951e34d25c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message/session.py
@@ -0,0 +1,55 @@
+"""SDK-visible message session identifier.
+
+本模块定义 MessageSession 类,用于统一表示消息会话标识符。
+会话标识符格式为:platform_id:message_type:session_id
+
+例如:
+- qq:group:123456 表示 QQ 群 123456
+- wechat:private:user789 表示微信私聊用户 user789
+
+该格式与 AstrBot 核心的 unified_msg_origin 保持兼容,
+确保 SDK 与核心之间的会话信息能够正确传递。
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from .._message_types import normalize_message_type
+
+
+@dataclass(slots=True)
+class MessageSession:
+ """SDK-visible message session identifier.
+
+ The string form stays compatible with AstrBot's unified message origin:
+ ``platform_id:message_type:session_id``.
+ """
+
+ platform_id: str
+ message_type: str
+ session_id: str
+
+ def __post_init__(self) -> None:
+ self.platform_id = str(self.platform_id)
+ self.message_type = normalize_message_type(self.message_type)
+ self.session_id = str(self.session_id)
+
+ def __str__(self) -> str:
+ return f"{self.platform_id}:{self.message_type}:{self.session_id}"
+
+ @classmethod
+ def from_str(cls, session: str) -> MessageSession:
+ raw_session = str(session)
+ parts = raw_session.split(":", 2)
+ if len(parts) != 3 or any(part == "" for part in parts):
+ raise ValueError(
+ "invalid message session format, expected "
+ "'platform_id:message_type:session_id'"
+ )
+ platform_id, message_type, session_id = parts
+ return cls(
+ platform_id=platform_id,
+ message_type=message_type,
+ session_id=session_id,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/message_components.py b/astrbot-sdk/src/astrbot_sdk/message_components.py
new file mode 100644
index 0000000000..372bd54a67
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message_components.py
@@ -0,0 +1,13 @@
+"""Backward-compatible alias for ``astrbot_sdk.message.components``.
+
+This module intentionally aliases the implementation module instead of re-exporting
+names one by one so private helpers keep working with existing monkeypatch sites.
+"""
+
+from __future__ import annotations
+
+import sys
+
+from .message import components as _components_module
+
+sys.modules[__name__] = _components_module
diff --git a/astrbot-sdk/src/astrbot_sdk/message_result.py b/astrbot-sdk/src/astrbot_sdk/message_result.py
new file mode 100644
index 0000000000..0b575aad5c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message_result.py
@@ -0,0 +1,13 @@
+"""Backward-compatible alias for ``astrbot_sdk.message.result``.
+
+Use a module alias so callers patching helper functions on the legacy module path
+still affect ``MessageBuilder`` and other implementation globals.
+"""
+
+from __future__ import annotations
+
+import sys
+
+from .message import result as _result_module
+
+sys.modules[__name__] = _result_module
diff --git a/astrbot-sdk/src/astrbot_sdk/message_session.py b/astrbot-sdk/src/astrbot_sdk/message_session.py
new file mode 100644
index 0000000000..ec87255555
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/message_session.py
@@ -0,0 +1,9 @@
+"""Backward-compatible message session exports.
+
+The canonical implementation moved to ``astrbot_sdk.message.session``. Preserve the
+legacy import path to avoid breaking existing plugins.
+"""
+
+from .message.session import MessageSession
+
+__all__ = ["MessageSession"]
diff --git a/astrbot-sdk/src/astrbot_sdk/plugin_kv.py b/astrbot-sdk/src/astrbot_sdk/plugin_kv.py
new file mode 100644
index 0000000000..de1922b60b
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/plugin_kv.py
@@ -0,0 +1,38 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
+
+if TYPE_CHECKING:
+ from .context import Context
+
+_VT = TypeVar("_VT")
+
+
+class _HasRuntimeContext(Protocol):
+ def _require_runtime_context(self) -> Context: ...
+
+
+class PluginKVStoreMixin:
+ """Plugin-scoped KV helpers backed by the runtime db client."""
+
+ def _runtime_context(self) -> Context:
+ owner = cast(_HasRuntimeContext, self)
+ return owner._require_runtime_context()
+
+ @property
+ def plugin_id(self) -> str:
+ ctx = self._runtime_context()
+ return ctx.plugin_id
+
+ async def put_kv_data(self, key: str, value: Any) -> None:
+ ctx = self._runtime_context()
+ await ctx.db.set(str(key), value)
+
+ async def get_kv_data(self, key: str, default: _VT) -> _VT:
+ ctx = self._runtime_context()
+ value = await ctx.db.get(str(key))
+ return default if value is None else value
+
+ async def delete_kv_data(self, key: str) -> None:
+ ctx = self._runtime_context()
+ await ctx.db.delete(str(key))
diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py b/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py
new file mode 100644
index 0000000000..501b393074
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py
@@ -0,0 +1,164 @@
+"""AstrBot s5r 协议公共入口。
+
+这里暴露 s5r 原生协议的消息模型、描述符和解析函数。
+
+握手阶段由 `InitializeMessage` 发起,返回值不是另一条 initialize 消息,而是
+`ResultMessage(kind="initialize_result")`,其 `output` 负载可解析为
+`InitializeOutput`。
+
+## 插件作者指南:什么时候用什么?
+
+### CapabilityDescriptor vs BUILTIN_CAPABILITY_SCHEMAS
+
+**CapabilityDescriptor** 用于**声明**能力:
+- 当你的插件想**暴露**一个可被其他插件或核心调用的能力时
+- 例如:你的插件提供了一个翻译功能,想让其他插件调用
+
+ ```python
+ from astrbot_sdk.protocol import CapabilityDescriptor
+
+ descriptor = CapabilityDescriptor(
+ name="my_plugin.translate", # 格式: 插件名.能力名
+ description="翻译文本到指定语言",
+ input_schema={
+ "type": "object",
+ "properties": {
+ "text": {"type": "string", "description": "要翻译的文本"},
+ "target_lang": {"type": "string", "description": "目标语言"},
+ },
+ "required": ["text", "target_lang"],
+ },
+ output_schema={
+ "type": "object",
+ "properties": {
+ "translated": {"type": "string"},
+ },
+ },
+ )
+ ```
+
+**BUILTIN_CAPABILITY_SCHEMAS** 用于**查询**内置能力的参数格式:
+- 当你想**调用**核心提供的内置能力时,用它了解参数结构
+- 例如:你想调用 `llm.chat`,但不确定参数格式
+
+ ```python
+ from astrbot_sdk.protocol import BUILTIN_CAPABILITY_SCHEMAS
+
+ # 查看 llm.chat 的输入参数格式
+ schema = BUILTIN_CAPABILITY_SCHEMAS["llm.chat"]
+ print(schema["input"]) # 输入参数的 JSON Schema
+ print(schema["output"]) # 输出结果的 JSON Schema
+ ```
+
+### 命名规范
+
+能力名称必须遵循 `{namespace}.{action}` 或 `{namespace}.{sub_namespace}.{action}` 格式:
+- `llm.chat` - LLM 对话
+- `db.set` - 数据库写入
+- `llm_tool.manager.activate` - LLM 工具管理
+
+**保留命名空间**(插件不可使用):
+- `handler.` - 处理器相关
+- `system.` - 系统内部能力
+- `internal.` - 内部实现细节
+
+### 常用内置能力速查
+
+| 能力名 | 用途 |
+|-------|------|
+| `llm.chat` | 同步 LLM 对话 |
+| `llm.stream_chat` | 流式 LLM 对话 |
+| `memory.save` / `memory.get` | 短期记忆存储 |
+| `db.set` / `db.get` | 持久化键值存储 |
+| `platform.send` | 发送消息 |
+| `provider.get_using` | 获取当前 Provider |
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+from . import _builtin_schemas as builtin_schemas
+from .codec import JsonProtocolCodec, MsgpackProtocolCodec, ProtocolCodec # noqa: F401
+from .descriptors import ( # noqa: F401
+ BUILTIN_CAPABILITY_SCHEMAS,
+ CapabilityDescriptor,
+ CommandRouteSpec,
+ CommandTrigger,
+ CompositeFilterSpec,
+ EventTrigger,
+ FilterSpec,
+ HandlerDescriptor,
+ LocalFilterRefSpec,
+ MessageTrigger,
+ MessageTypeFilterSpec,
+ ParamSpec,
+ Permissions,
+ PlatformFilterSpec,
+ ScheduleTrigger,
+ SessionRef,
+ Trigger,
+)
+from .messages import ( # noqa: F401
+ CancelMessage,
+ ErrorPayload,
+ EventMessage,
+ InitializeMessage,
+ InitializeOutput,
+ InvokeMessage,
+ PeerInfo,
+ ProtocolMessage,
+ ResultMessage,
+ parse_message,
+)
+
+_DIRECT_EXPORTS = [
+ "BUILTIN_CAPABILITY_SCHEMAS",
+ "CapabilityDescriptor",
+ "CommandRouteSpec",
+ "CommandTrigger",
+ "CancelMessage",
+ "builtin_schemas",
+ "CompositeFilterSpec",
+ "ErrorPayload",
+ "EventTrigger",
+ "EventMessage",
+ "FilterSpec",
+ "HandlerDescriptor",
+ "JsonProtocolCodec",
+ "InitializeMessage",
+ "InitializeOutput",
+ "InvokeMessage",
+ "LocalFilterRefSpec",
+ "MessageTrigger",
+ "MessageTypeFilterSpec",
+ "MsgpackProtocolCodec",
+ "ParamSpec",
+ "PeerInfo",
+ "PlatformFilterSpec",
+ "Permissions",
+ "ProtocolCodec",
+ "ProtocolMessage",
+ "ResultMessage",
+ "ScheduleTrigger",
+ "SessionRef",
+ "Trigger",
+ "parse_message",
+]
+
+_BUILTIN_SCHEMA_EXPORTS = tuple(
+ name for name in builtin_schemas.__all__ if name != "BUILTIN_CAPABILITY_SCHEMAS"
+)
+
+
+def __getattr__(name: str) -> Any:
+ if name in _BUILTIN_SCHEMA_EXPORTS:
+ return getattr(builtin_schemas, name)
+ raise AttributeError(name)
+
+
+def __dir__() -> list[str]:
+ return sorted(set(globals()) | set(_BUILTIN_SCHEMA_EXPORTS))
+
+
+__all__ = list(dict.fromkeys([*_DIRECT_EXPORTS, *_BUILTIN_SCHEMA_EXPORTS])) # pyright: ignore[reportUnsupportedDunderAll]
diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py b/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py
new file mode 100644
index 0000000000..0aac1d90cc
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py
@@ -0,0 +1,2396 @@
+"""Builtin protocol schema constants.
+
+本模块定义了 AstrBot SDK s5r 协议中所有内置能力的 JSON Schema。
+这些 Schema 用于:
+1. 验证能力调用的输入参数是否符合预期格式
+2. 生成能力描述文档,供插件开发者参考
+3. 确保跨进程/跨语言调用时的类型安全
+
+所有 Schema 遵循 JSON Schema 规范,支持基本类型检查、必填字段、数组元素约束等。
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+JSONSchema = dict[str, Any]
+
+
+def _object_schema(
+ *,
+ required: tuple[str, ...] = (),
+ **properties: Any,
+) -> JSONSchema:
+ return {
+ "type": "object",
+ "properties": properties,
+ "required": list(required),
+ }
+
+
+def _nullable(schema: JSONSchema) -> JSONSchema:
+ return {"anyOf": [schema, {"type": "null"}]}
+
+
+_OPTIONAL_CHAT_PROPERTIES: dict[str, Any] = {
+ "system": {"type": "string"},
+ "history": {"type": "array", "items": {"type": "object"}},
+ "contexts": {"type": "array", "items": {"type": "object"}},
+ "provider_id": {"type": "string"},
+ "tool_calls_result": {"type": "array", "items": {"type": "object"}},
+ "model": {"type": "string"},
+ "temperature": {"type": "number"},
+ "image_urls": {"type": "array", "items": {"type": "string"}},
+ "tools": {"type": "array"},
+ "max_steps": {"type": "integer"},
+}
+
+LLM_CHAT_INPUT_SCHEMA = _object_schema(
+ required=("prompt",),
+ prompt={"type": "string"},
+ **_OPTIONAL_CHAT_PROPERTIES,
+)
+LLM_CHAT_OUTPUT_SCHEMA = _object_schema(required=("text",), text={"type": "string"})
+LLM_CHAT_RAW_INPUT_SCHEMA = _object_schema(
+ required=("prompt",),
+ prompt={"type": "string"},
+ **_OPTIONAL_CHAT_PROPERTIES,
+)
+LLM_CHAT_RAW_OUTPUT_SCHEMA = _object_schema(
+ required=("text",),
+ text={"type": "string"},
+ usage=_nullable({"type": "object"}),
+ finish_reason=_nullable({"type": "string"}),
+ tool_calls={"type": "array", "items": {"type": "object"}},
+ role=_nullable({"type": "string"}),
+ reasoning_content=_nullable({"type": "string"}),
+ reasoning_signature=_nullable({"type": "string"}),
+)
+LLM_STREAM_CHAT_INPUT_SCHEMA = _object_schema(
+ required=("prompt",),
+ prompt={"type": "string"},
+ **_OPTIONAL_CHAT_PROPERTIES,
+)
+LLM_STREAM_CHAT_OUTPUT_SCHEMA = _object_schema(
+ required=("text",), text={"type": "string"}
+)
+MEMORY_SEARCH_INPUT_SCHEMA = _object_schema(
+ required=("query",),
+ query={"type": "string"},
+ mode={"type": "string", "enum": ["auto", "keyword", "vector", "hybrid"]},
+ limit={"type": "integer", "minimum": 1},
+ min_score={"type": "number"},
+ provider_id={"type": "string"},
+ namespace={"type": "string"},
+ include_descendants={"type": "boolean"},
+)
+MEMORY_SEARCH_OUTPUT_SCHEMA = _object_schema(
+ required=("items",),
+ items={
+ "type": "array",
+ "items": _object_schema(
+ required=("key", "value", "score", "match_type"),
+ key={"type": "string"},
+ namespace=_nullable({"type": "string"}),
+ value=_nullable({"type": "object"}),
+ score={"type": "number"},
+ match_type={
+ "type": "string",
+ "enum": ["keyword", "vector", "hybrid"],
+ },
+ ),
+ },
+)
+MEMORY_SAVE_INPUT_SCHEMA = _object_schema(
+ required=("key", "value"),
+ key={"type": "string"},
+ value={"type": "object"},
+ namespace={"type": "string"},
+)
+MEMORY_SAVE_OUTPUT_SCHEMA = _object_schema()
+MEMORY_GET_INPUT_SCHEMA = _object_schema(
+ required=("key",),
+ key={"type": "string"},
+ namespace={"type": "string"},
+)
+MEMORY_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("value",),
+ value=_nullable({"type": "object"}),
+)
+MEMORY_LIST_KEYS_INPUT_SCHEMA = _object_schema(namespace={"type": "string"})
+MEMORY_LIST_KEYS_OUTPUT_SCHEMA = _object_schema(
+ required=("keys",),
+ keys={"type": "array", "items": {"type": "string"}},
+)
+MEMORY_EXISTS_INPUT_SCHEMA = _object_schema(
+ required=("key",),
+ key={"type": "string"},
+ namespace={"type": "string"},
+)
+MEMORY_EXISTS_OUTPUT_SCHEMA = _object_schema(
+ required=("exists",),
+ exists={"type": "boolean"},
+)
+MEMORY_DELETE_INPUT_SCHEMA = _object_schema(
+ required=("key",),
+ key={"type": "string"},
+ namespace={"type": "string"},
+)
+MEMORY_DELETE_OUTPUT_SCHEMA = _object_schema()
+MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA = _object_schema(
+ namespace={"type": "string"},
+ include_descendants={"type": "boolean"},
+)
+MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted_count",),
+ deleted_count={"type": "integer"},
+)
+MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA = _object_schema(
+ required=("key", "value", "ttl_seconds"),
+ key={"type": "string"},
+ value={"type": "object"},
+ ttl_seconds={"type": "integer", "minimum": 1},
+ namespace={"type": "string"},
+)
+MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA = _object_schema()
+MEMORY_GET_MANY_INPUT_SCHEMA = _object_schema(
+ required=("keys",),
+ keys={"type": "array", "items": {"type": "string"}},
+ namespace={"type": "string"},
+)
+MEMORY_GET_MANY_OUTPUT_SCHEMA = _object_schema(
+ required=("items",),
+ items={
+ "type": "array",
+ "items": _object_schema(
+ required=("key", "value"),
+ key={"type": "string"},
+ value=_nullable({"type": "object"}),
+ ),
+ },
+)
+MEMORY_DELETE_MANY_INPUT_SCHEMA = _object_schema(
+ required=("keys",),
+ keys={"type": "array", "items": {"type": "string"}},
+ namespace={"type": "string"},
+)
+MEMORY_DELETE_MANY_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted_count",),
+ deleted_count={"type": "integer"},
+)
+MEMORY_COUNT_INPUT_SCHEMA = _object_schema(
+ namespace={"type": "string"},
+ include_descendants={"type": "boolean"},
+)
+MEMORY_COUNT_OUTPUT_SCHEMA = _object_schema(
+ required=("count",),
+ count={"type": "integer"},
+)
+MEMORY_STATS_INPUT_SCHEMA = _object_schema(
+ namespace={"type": "string"},
+ include_descendants={"type": "boolean"},
+)
+MEMORY_STATS_OUTPUT_SCHEMA = _object_schema(
+ total_items={"type": "integer"},
+ total_bytes=_nullable({"type": "integer"}),
+ plugin_id=_nullable({"type": "string"}),
+ ttl_entries=_nullable({"type": "integer"}),
+ namespace=_nullable({"type": "string"}),
+ namespace_count=_nullable({"type": "integer"}),
+ indexed_items=_nullable({"type": "integer"}),
+ embedded_items=_nullable({"type": "integer"}),
+ dirty_items=_nullable({"type": "integer"}),
+ fts_enabled={"type": "boolean"},
+ vector_backend=_nullable({"type": "string"}),
+ vector_indexes={"type": "array", "items": {"type": "object"}},
+)
+SYSTEM_GET_DATA_DIR_INPUT_SCHEMA = _object_schema()
+SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA = _object_schema(
+ required=("path",),
+ path={"type": "string"},
+)
+SYSTEM_TEXT_TO_IMAGE_INPUT_SCHEMA = _object_schema(
+ required=("text",),
+ text={"type": "string"},
+ return_url={"type": "boolean"},
+)
+SYSTEM_TEXT_TO_IMAGE_OUTPUT_SCHEMA = _object_schema(
+ required=("result",),
+ result={"type": "string"},
+)
+SYSTEM_HTML_RENDER_INPUT_SCHEMA = _object_schema(
+ required=("tmpl", "data"),
+ tmpl={"type": "string"},
+ data={"type": "object"},
+ return_url={"type": "boolean"},
+ options=_nullable({"type": "object"}),
+)
+SYSTEM_HTML_RENDER_OUTPUT_SCHEMA = _object_schema(
+ required=("result",),
+ result={"type": "string"},
+)
+SYSTEM_SESSION_WAITER_REGISTER_INPUT_SCHEMA = _object_schema(
+ required=("session_key",),
+ session_key={"type": "string"},
+)
+SYSTEM_SESSION_WAITER_REGISTER_OUTPUT_SCHEMA = _object_schema()
+SYSTEM_SESSION_WAITER_UNREGISTER_INPUT_SCHEMA = _object_schema(
+ required=("session_key",),
+ session_key={"type": "string"},
+)
+SYSTEM_SESSION_WAITER_UNREGISTER_OUTPUT_SCHEMA = _object_schema()
+DB_GET_INPUT_SCHEMA = _object_schema(required=("key",), key={"type": "string"})
+DB_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("value",),
+ value=_nullable({}),
+)
+DB_SET_INPUT_SCHEMA = _object_schema(
+ required=("key", "value"),
+ key={"type": "string"},
+ value={},
+)
+DB_SET_OUTPUT_SCHEMA = _object_schema()
+DB_DELETE_INPUT_SCHEMA = _object_schema(required=("key",), key={"type": "string"})
+DB_DELETE_OUTPUT_SCHEMA = _object_schema()
+DB_LIST_INPUT_SCHEMA = _object_schema(prefix=_nullable({"type": "string"}))
+DB_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("keys",),
+ keys={"type": "array", "items": {"type": "string"}},
+)
+DB_GET_MANY_INPUT_SCHEMA = _object_schema(
+ required=("keys",),
+ keys={"type": "array", "items": {"type": "string"}},
+)
+DB_GET_MANY_OUTPUT_SCHEMA = _object_schema(
+ required=("items",),
+ items={
+ "type": "array",
+ "items": _object_schema(
+ required=("key", "value"),
+ key={"type": "string"},
+ value=_nullable({}),
+ ),
+ },
+)
+DB_SET_MANY_INPUT_SCHEMA = _object_schema(
+ required=("items",),
+ items={
+ "type": "array",
+ "items": _object_schema(
+ required=("key", "value"),
+ key={"type": "string"},
+ value={},
+ ),
+ },
+)
+DB_SET_MANY_OUTPUT_SCHEMA = _object_schema()
+DB_WATCH_INPUT_SCHEMA = _object_schema(prefix=_nullable({"type": "string"}))
+DB_WATCH_OUTPUT_SCHEMA = _object_schema()
+SESSION_REF_SCHEMA = _object_schema(
+ required=("conversation_id",),
+ conversation_id={"type": "string"},
+ platform=_nullable({"type": "string"}),
+ raw=_nullable({"type": "object"}),
+)
+SYSTEM_EVENT_REACT_INPUT_SCHEMA = _object_schema(
+ required=("emoji",),
+ target=_nullable(SESSION_REF_SCHEMA),
+ emoji={"type": "string"},
+)
+SYSTEM_EVENT_REACT_OUTPUT_SCHEMA = _object_schema(
+ required=("supported",),
+ supported={"type": "boolean"},
+)
+SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA = _object_schema(
+ required=("supported",),
+ supported={"type": "boolean"},
+)
+SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+ use_fallback={"type": "boolean"},
+)
+SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA = _object_schema(
+ required=("supported",),
+ supported={"type": "boolean"},
+ stream_id=_nullable({"type": "string"}),
+)
+SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA = _object_schema(
+ required=("stream_id", "chain"),
+ stream_id={"type": "string"},
+ chain={"type": "array", "items": {"type": "object"}},
+)
+SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA = _object_schema()
+SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA = _object_schema(
+ required=("stream_id",),
+ stream_id={"type": "string"},
+)
+SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA = _object_schema(
+ required=("supported",),
+ supported={"type": "boolean"},
+)
+SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA = _object_schema(
+ required=("should_call_llm", "requested_llm"),
+ should_call_llm={"type": "boolean"},
+ requested_llm={"type": "boolean"},
+)
+SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA = _object_schema(
+ required=("should_call_llm", "requested_llm"),
+ should_call_llm={"type": "boolean"},
+ requested_llm={"type": "boolean"},
+)
+SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("result",),
+ result=_nullable({"type": "object"}),
+)
+SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA = _object_schema(
+ required=("result",),
+ target=_nullable(SESSION_REF_SCHEMA),
+ result={"type": "object"},
+)
+SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA = _object_schema(
+ required=("result",),
+ result={"type": "object"},
+)
+SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA = _object_schema()
+SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("plugin_names",),
+ plugin_names=_nullable({"type": "array", "items": {"type": "string"}}),
+)
+SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA = _object_schema(
+ target=_nullable(SESSION_REF_SCHEMA),
+ plugin_names=_nullable({"type": "array", "items": {"type": "string"}}),
+)
+SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA = _object_schema(
+ required=("plugin_names",),
+ plugin_names=_nullable({"type": "array", "items": {"type": "string"}}),
+)
+PLATFORM_SEND_INPUT_SCHEMA = _object_schema(
+ required=("session", "text"),
+ session={"type": "string"},
+ target=_nullable(SESSION_REF_SCHEMA),
+ text={"type": "string"},
+)
+PLATFORM_SEND_OUTPUT_SCHEMA = _object_schema(
+ required=("message_id",),
+ message_id={"type": "string"},
+)
+PLATFORM_SEND_IMAGE_INPUT_SCHEMA = _object_schema(
+ required=("session", "image_url"),
+ session={"type": "string"},
+ target=_nullable(SESSION_REF_SCHEMA),
+ image_url={"type": "string"},
+)
+PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA = _object_schema(
+ required=("message_id",),
+ message_id={"type": "string"},
+)
+PLATFORM_SEND_CHAIN_INPUT_SCHEMA = _object_schema(
+ required=("session", "chain"),
+ session={"type": "string"},
+ target=_nullable(SESSION_REF_SCHEMA),
+ chain={"type": "array", "items": {"type": "object"}},
+)
+PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA = _object_schema(
+ required=("message_id",),
+ message_id={"type": "string"},
+)
+PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA = _object_schema(
+ required=("session", "chain"),
+ session={"type": "string"},
+ chain={"type": "array", "items": {"type": "object"}},
+)
+PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA = _object_schema(
+ required=("message_id",),
+ message_id={"type": "string"},
+)
+PLATFORM_GET_GROUP_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+PLATFORM_GET_GROUP_OUTPUT_SCHEMA = _object_schema(
+ required=("group",),
+ group=_nullable({"type": "object"}),
+)
+PLATFORM_GET_MEMBERS_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ target=_nullable(SESSION_REF_SCHEMA),
+)
+PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA = _object_schema(
+ required=("members",),
+ members={"type": "array", "items": {"type": "object"}},
+)
+PLATFORM_INSTANCE_SCHEMA = _object_schema(
+ required=("id", "name", "type", "status"),
+ id={"type": "string"},
+ name={"type": "string"},
+ type={"type": "string"},
+ status={"type": "string"},
+)
+PLATFORM_LIST_INSTANCES_INPUT_SCHEMA = _object_schema()
+PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA = _object_schema(
+ required=("platforms",),
+ platforms={"type": "array", "items": PLATFORM_INSTANCE_SCHEMA},
+)
+PLATFORM_ERROR_SCHEMA = _object_schema(
+ required=("message", "timestamp"),
+ message={"type": "string"},
+ timestamp={"type": "string"},
+ traceback=_nullable({"type": "string"}),
+)
+PLATFORM_MANAGER_STATE_SCHEMA = _object_schema(
+ required=("id", "name", "type", "status", "errors", "unified_webhook"),
+ id={"type": "string"},
+ name={"type": "string"},
+ type={"type": "string"},
+ status={"type": "string"},
+ errors={"type": "array", "items": PLATFORM_ERROR_SCHEMA},
+ last_error=_nullable(PLATFORM_ERROR_SCHEMA),
+ unified_webhook={"type": "boolean"},
+)
+PLATFORM_STATS_SCHEMA = _object_schema(
+ required=(
+ "id",
+ "type",
+ "display_name",
+ "status",
+ "error_count",
+ "unified_webhook",
+ ),
+ id={"type": "string"},
+ type={"type": "string"},
+ display_name={"type": "string"},
+ status={"type": "string"},
+ started_at=_nullable({"type": "string"}),
+ error_count={"type": "integer"},
+ last_error=_nullable(PLATFORM_ERROR_SCHEMA),
+ unified_webhook={"type": "boolean"},
+ meta={"type": "object"},
+)
+PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA = _object_schema(
+ required=("platform_id",),
+ platform_id={"type": "string"},
+)
+PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema(
+ required=("platform",),
+ platform=_nullable(PLATFORM_MANAGER_STATE_SCHEMA),
+)
+PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA = _object_schema(
+ required=("platform_id",),
+ platform_id={"type": "string"},
+)
+PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA = _object_schema()
+PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA = _object_schema(
+ required=("platform_id",),
+ platform_id={"type": "string"},
+)
+PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA = _object_schema(
+ required=("stats",),
+ stats=_nullable(PLATFORM_STATS_SCHEMA),
+)
+PERMISSION_ROLE_SCHEMA = {"type": "string", "enum": ["member", "admin"]}
+PERMISSION_CHECK_INPUT_SCHEMA = _object_schema(
+ required=("user_id",),
+ user_id={"type": "string"},
+ session_id=_nullable({"type": "string"}),
+)
+PERMISSION_CHECK_RESULT_SCHEMA = _object_schema(
+ required=("is_admin", "role"),
+ is_admin={"type": "boolean"},
+ role=PERMISSION_ROLE_SCHEMA,
+)
+PERMISSION_CHECK_OUTPUT_SCHEMA = PERMISSION_CHECK_RESULT_SCHEMA
+PERMISSION_GET_ADMINS_INPUT_SCHEMA = _object_schema()
+PERMISSION_GET_ADMINS_OUTPUT_SCHEMA = _object_schema(
+ required=("admins",),
+ admins={"type": "array", "items": {"type": "string"}},
+)
+PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA = _object_schema(
+ required=("user_id",),
+ user_id={"type": "string"},
+)
+PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA = _object_schema(
+ required=("changed",),
+ changed={"type": "boolean"},
+)
+PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA = _object_schema(
+ required=("user_id",),
+ user_id={"type": "string"},
+)
+PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA = _object_schema(
+ required=("changed",),
+ changed={"type": "boolean"},
+)
+SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA = _object_schema(
+ required=("session", "plugin_name"),
+ session={"type": "string"},
+ plugin_name={"type": "string"},
+)
+SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA = _object_schema(
+ required=("enabled",),
+ enabled={"type": "boolean"},
+)
+SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA = _object_schema(
+ required=("session", "handlers"),
+ session={"type": "string"},
+ handlers={"type": "array", "items": {"type": "object"}},
+)
+SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA = _object_schema(
+ required=("handlers",),
+ handlers={"type": "array", "items": {"type": "object"}},
+)
+SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+)
+SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA = _object_schema(
+ required=("enabled",),
+ enabled={"type": "boolean"},
+)
+SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA = _object_schema(
+ required=("session", "enabled"),
+ session={"type": "string"},
+ enabled={"type": "boolean"},
+)
+SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA = _object_schema()
+SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+)
+SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA = _object_schema(
+ required=("enabled",),
+ enabled={"type": "boolean"},
+)
+SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA = _object_schema(
+ required=("session", "enabled"),
+ session={"type": "string"},
+ enabled={"type": "boolean"},
+)
+SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA = _object_schema()
+PERSONA_RECORD_SCHEMA = _object_schema(
+ required=("persona_id", "system_prompt", "begin_dialogs", "sort_order"),
+ persona_id={"type": "string"},
+ system_prompt={"type": "string"},
+ begin_dialogs={"type": "array", "items": {"type": "string"}},
+ tools=_nullable({"type": "array", "items": {"type": "string"}}),
+ skills=_nullable({"type": "array", "items": {"type": "string"}}),
+ custom_error_message=_nullable({"type": "string"}),
+ folder_id=_nullable({"type": "string"}),
+ sort_order={"type": "integer"},
+ created_at=_nullable({"type": "string"}),
+ updated_at=_nullable({"type": "string"}),
+)
+PERSONA_CREATE_SCHEMA = _object_schema(
+ required=("persona_id", "system_prompt"),
+ persona_id={"type": "string"},
+ system_prompt={"type": "string"},
+ begin_dialogs={"type": "array", "items": {"type": "string"}},
+ tools=_nullable({"type": "array", "items": {"type": "string"}}),
+ skills=_nullable({"type": "array", "items": {"type": "string"}}),
+ custom_error_message=_nullable({"type": "string"}),
+ folder_id=_nullable({"type": "string"}),
+ sort_order={"type": "integer"},
+)
+PERSONA_UPDATE_SCHEMA = _object_schema(
+ system_prompt=_nullable({"type": "string"}),
+ begin_dialogs=_nullable({"type": "array", "items": {"type": "string"}}),
+ tools=_nullable({"type": "array", "items": {"type": "string"}}),
+ skills=_nullable({"type": "array", "items": {"type": "string"}}),
+ custom_error_message=_nullable({"type": "string"}),
+)
+PERSONA_GET_INPUT_SCHEMA = _object_schema(
+ required=("persona_id",),
+ persona_id={"type": "string"},
+)
+PERSONA_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("persona",),
+ persona=PERSONA_RECORD_SCHEMA,
+)
+PERSONA_LIST_INPUT_SCHEMA = _object_schema()
+PERSONA_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("personas",),
+ personas={"type": "array", "items": PERSONA_RECORD_SCHEMA},
+)
+PERSONA_CREATE_INPUT_SCHEMA = _object_schema(
+ required=("persona",),
+ persona=PERSONA_CREATE_SCHEMA,
+)
+PERSONA_CREATE_OUTPUT_SCHEMA = _object_schema(
+ required=("persona",),
+ persona=PERSONA_RECORD_SCHEMA,
+)
+PERSONA_UPDATE_INPUT_SCHEMA = _object_schema(
+ required=("persona_id", "persona"),
+ persona_id={"type": "string"},
+ persona=PERSONA_UPDATE_SCHEMA,
+)
+PERSONA_UPDATE_OUTPUT_SCHEMA = _object_schema(
+ required=("persona",),
+ persona=_nullable(PERSONA_RECORD_SCHEMA),
+)
+PERSONA_DELETE_INPUT_SCHEMA = _object_schema(
+ required=("persona_id",),
+ persona_id={"type": "string"},
+)
+PERSONA_DELETE_OUTPUT_SCHEMA = _object_schema()
+CONVERSATION_RECORD_SCHEMA = _object_schema(
+ required=("conversation_id", "session", "platform_id", "history"),
+ conversation_id={"type": "string"},
+ session={"type": "string"},
+ platform_id={"type": "string"},
+ history={"type": "array", "items": {"type": "object"}},
+ title=_nullable({"type": "string"}),
+ persona_id=_nullable({"type": "string"}),
+ created_at=_nullable({"type": "string"}),
+ updated_at=_nullable({"type": "string"}),
+ token_usage=_nullable({"type": "integer"}),
+)
+CONVERSATION_CREATE_SCHEMA = _object_schema(
+ platform_id=_nullable({"type": "string"}),
+ history=_nullable({"type": "array", "items": {"type": "object"}}),
+ title=_nullable({"type": "string"}),
+ persona_id=_nullable({"type": "string"}),
+)
+CONVERSATION_UPDATE_SCHEMA = _object_schema(
+ history=_nullable({"type": "array", "items": {"type": "object"}}),
+ title=_nullable({"type": "string"}),
+ persona_id=_nullable({"type": "string"}),
+ token_usage=_nullable({"type": "integer"}),
+)
+CONVERSATION_NEW_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ conversation=_nullable(CONVERSATION_CREATE_SCHEMA),
+)
+CONVERSATION_NEW_OUTPUT_SCHEMA = _object_schema(
+ required=("conversation_id",),
+ conversation_id={"type": "string"},
+)
+CONVERSATION_SWITCH_INPUT_SCHEMA = _object_schema(
+ required=("session", "conversation_id"),
+ session={"type": "string"},
+ conversation_id={"type": "string"},
+)
+CONVERSATION_SWITCH_OUTPUT_SCHEMA = _object_schema()
+CONVERSATION_DELETE_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ conversation_id=_nullable({"type": "string"}),
+)
+CONVERSATION_DELETE_OUTPUT_SCHEMA = _object_schema()
+CONVERSATION_GET_INPUT_SCHEMA = _object_schema(
+ required=("session", "conversation_id"),
+ session={"type": "string"},
+ conversation_id={"type": "string"},
+ create_if_not_exists={"type": "boolean"},
+)
+CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("conversation",),
+ conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
+)
+CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ create_if_not_exists={"type": "boolean"},
+)
+CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema(
+ required=("conversation",),
+ conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
+)
+CONVERSATION_LIST_INPUT_SCHEMA = _object_schema(
+ session=_nullable({"type": "string"}),
+ platform_id=_nullable({"type": "string"}),
+)
+CONVERSATION_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("conversations",),
+ conversations={"type": "array", "items": CONVERSATION_RECORD_SCHEMA},
+)
+CONVERSATION_UPDATE_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ conversation_id=_nullable({"type": "string"}),
+ conversation=_nullable(CONVERSATION_UPDATE_SCHEMA),
+)
+CONVERSATION_UPDATE_OUTPUT_SCHEMA = _object_schema()
+CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session={"type": "string"},
+ conversation_id=_nullable({"type": "string"}),
+)
+CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA = _object_schema()
+MESSAGE_HISTORY_SESSION_SCHEMA = _object_schema(
+ required=("platform_id", "message_type", "session_id"),
+ platform_id={"type": "string"},
+ message_type={"type": "string", "enum": ["group", "private", "other"]},
+ session_id={"type": "string"},
+)
+MESSAGE_HISTORY_SENDER_SCHEMA = _object_schema(
+ sender_id=_nullable({"type": "string"}),
+ sender_name=_nullable({"type": "string"}),
+)
+MESSAGE_HISTORY_RECORD_SCHEMA = _object_schema(
+ required=("id", "session", "sender", "parts", "metadata"),
+ id={"type": "integer"},
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+ sender=MESSAGE_HISTORY_SENDER_SCHEMA,
+ parts={"type": "array", "items": {"type": "object"}},
+ metadata={"type": "object"},
+ created_at=_nullable({"type": "string"}),
+ updated_at=_nullable({"type": "string"}),
+ idempotency_key=_nullable({"type": "string"}),
+)
+MESSAGE_HISTORY_PAGE_SCHEMA = _object_schema(
+ required=("records",),
+ records={"type": "array", "items": MESSAGE_HISTORY_RECORD_SCHEMA},
+ next_cursor=_nullable({"type": "string"}),
+ total=_nullable({"type": "integer"}),
+)
+MESSAGE_HISTORY_LIST_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+ cursor=_nullable({"type": "string", "pattern": "^(|[1-9][0-9]*)$"}),
+ limit={"type": "integer", "minimum": 1},
+)
+MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("page",),
+ page=MESSAGE_HISTORY_PAGE_SCHEMA,
+)
+MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA = _object_schema(
+ required=("session", "record_id"),
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+ record_id={"type": "integer", "minimum": 1},
+)
+MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA = _object_schema(
+ required=("record",),
+ record=_nullable(MESSAGE_HISTORY_RECORD_SCHEMA),
+)
+MESSAGE_HISTORY_APPEND_INPUT_SCHEMA = _object_schema(
+ required=("session", "sender", "parts"),
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+ sender=MESSAGE_HISTORY_SENDER_SCHEMA,
+ parts={"type": "array", "items": {"type": "object"}},
+ metadata=_nullable({"type": "object"}),
+ idempotency_key=_nullable({"type": "string"}),
+)
+MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA = _object_schema(
+ required=("record",),
+ record=MESSAGE_HISTORY_RECORD_SCHEMA,
+)
+MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA = _object_schema(
+ required=("session", "before"),
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+ before={"type": "string"},
+)
+MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted_count",),
+ deleted_count={"type": "integer"},
+)
+MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA = _object_schema(
+ required=("session", "after"),
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+ after={"type": "string"},
+)
+MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted_count",),
+ deleted_count={"type": "integer"},
+)
+MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA = _object_schema(
+ required=("session",),
+ session=MESSAGE_HISTORY_SESSION_SCHEMA,
+)
+MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted_count",),
+ deleted_count={"type": "integer"},
+)
+MCP_SERVER_SCOPE_SCHEMA = {"type": "string", "enum": ["local", "global"]}
+MCP_SERVER_RECORD_SCHEMA = _object_schema(
+ required=("name", "scope", "active", "running", "config", "tools", "errlogs"),
+ name={"type": "string"},
+ scope=MCP_SERVER_SCOPE_SCHEMA,
+ active={"type": "boolean"},
+ running={"type": "boolean"},
+ config={"type": "object"},
+ tools={"type": "array", "items": {"type": "string"}},
+ errlogs={"type": "array", "items": {"type": "string"}},
+ last_error=_nullable({"type": "string"}),
+)
+MCP_LOCAL_GET_INPUT_SCHEMA = _object_schema(required=("name",), name={"type": "string"})
+MCP_LOCAL_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=_nullable(MCP_SERVER_RECORD_SCHEMA),
+)
+MCP_LOCAL_LIST_INPUT_SCHEMA = _object_schema()
+MCP_LOCAL_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("servers",),
+ servers={"type": "array", "items": MCP_SERVER_RECORD_SCHEMA},
+)
+MCP_LOCAL_ENABLE_INPUT_SCHEMA = _object_schema(
+ required=("name",), name={"type": "string"}
+)
+MCP_LOCAL_ENABLE_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+MCP_LOCAL_DISABLE_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+MCP_LOCAL_DISABLE_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+MCP_LOCAL_WAIT_UNTIL_READY_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+ timeout={"type": "number"},
+)
+MCP_LOCAL_WAIT_UNTIL_READY_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+MCP_SESSION_OPEN_INPUT_SCHEMA = _object_schema(
+ required=("name", "config"),
+ name={"type": "string"},
+ config={"type": "object"},
+ timeout={"type": "number"},
+)
+MCP_SESSION_OPEN_OUTPUT_SCHEMA = _object_schema(
+ required=("session_id", "tools"),
+ session_id={"type": "string"},
+ tools={"type": "array", "items": {"type": "string"}},
+)
+MCP_SESSION_LIST_TOOLS_INPUT_SCHEMA = _object_schema(
+ required=("session_id",),
+ session_id={"type": "string"},
+)
+MCP_SESSION_LIST_TOOLS_OUTPUT_SCHEMA = _object_schema(
+ required=("tools",),
+ tools={"type": "array", "items": {"type": "string"}},
+)
+MCP_SESSION_CALL_TOOL_INPUT_SCHEMA = _object_schema(
+ required=("session_id", "tool_name", "args"),
+ session_id={"type": "string"},
+ tool_name={"type": "string"},
+ args={"type": "object"},
+)
+MCP_SESSION_CALL_TOOL_OUTPUT_SCHEMA = _object_schema(
+ required=("result",),
+ result={"type": "object"},
+)
+MCP_SESSION_CLOSE_INPUT_SCHEMA = _object_schema(
+ required=("session_id",),
+ session_id={"type": "string"},
+)
+MCP_SESSION_CLOSE_OUTPUT_SCHEMA = _object_schema()
+MCP_GLOBAL_REGISTER_INPUT_SCHEMA = _object_schema(
+ required=("name", "config"),
+ name={"type": "string"},
+ config={"type": "object"},
+ timeout={"type": "number"},
+)
+MCP_GLOBAL_REGISTER_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+MCP_GLOBAL_GET_INPUT_SCHEMA = _object_schema(
+ required=("name",), name={"type": "string"}
+)
+MCP_GLOBAL_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=_nullable(MCP_SERVER_RECORD_SCHEMA),
+)
+MCP_GLOBAL_LIST_INPUT_SCHEMA = _object_schema()
+MCP_GLOBAL_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("servers",),
+ servers={"type": "array", "items": MCP_SERVER_RECORD_SCHEMA},
+)
+MCP_GLOBAL_ENABLE_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+ timeout={"type": "number"},
+)
+MCP_GLOBAL_ENABLE_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+MCP_GLOBAL_DISABLE_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+MCP_GLOBAL_DISABLE_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+MCP_GLOBAL_UNREGISTER_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+MCP_GLOBAL_UNREGISTER_OUTPUT_SCHEMA = _object_schema(
+ required=("server",),
+ server=MCP_SERVER_RECORD_SCHEMA,
+)
+INTERNAL_MCP_LOCAL_EXECUTE_INPUT_SCHEMA = _object_schema(
+ required=("plugin_id", "server_name", "tool_name", "tool_args"),
+ plugin_id={"type": "string"},
+ server_name={"type": "string"},
+ tool_name={"type": "string"},
+ tool_args={"type": "object"},
+)
+INTERNAL_MCP_LOCAL_EXECUTE_OUTPUT_SCHEMA = _object_schema(
+ required=("content", "success"),
+ content=_nullable({"type": "string"}),
+ success={"type": "boolean"},
+)
+KNOWLEDGE_BASE_RECORD_SCHEMA = _object_schema(
+ required=("kb_id", "kb_name", "embedding_provider_id", "doc_count", "chunk_count"),
+ kb_id={"type": "string"},
+ kb_name={"type": "string"},
+ description=_nullable({"type": "string"}),
+ emoji=_nullable({"type": "string"}),
+ embedding_provider_id={"type": "string"},
+ rerank_provider_id=_nullable({"type": "string"}),
+ chunk_size=_nullable({"type": "integer"}),
+ chunk_overlap=_nullable({"type": "integer"}),
+ top_k_dense=_nullable({"type": "integer"}),
+ top_k_sparse=_nullable({"type": "integer"}),
+ top_m_final=_nullable({"type": "integer"}),
+ doc_count={"type": "integer"},
+ chunk_count={"type": "integer"},
+ created_at=_nullable({"type": "string"}),
+ updated_at=_nullable({"type": "string"}),
+)
+KNOWLEDGE_BASE_CREATE_SCHEMA = _object_schema(
+ required=("kb_name", "embedding_provider_id"),
+ kb_name={"type": "string"},
+ embedding_provider_id={"type": "string"},
+ description=_nullable({"type": "string"}),
+ emoji=_nullable({"type": "string"}),
+ rerank_provider_id=_nullable({"type": "string"}),
+ chunk_size=_nullable({"type": "integer"}),
+ chunk_overlap=_nullable({"type": "integer"}),
+ top_k_dense=_nullable({"type": "integer"}),
+ top_k_sparse=_nullable({"type": "integer"}),
+ top_m_final=_nullable({"type": "integer"}),
+)
+KNOWLEDGE_BASE_UPDATE_SCHEMA = _object_schema(
+ kb_name=_nullable({"type": "string"}),
+ description=_nullable({"type": "string"}),
+ emoji=_nullable({"type": "string"}),
+ embedding_provider_id=_nullable({"type": "string"}),
+ rerank_provider_id=_nullable({"type": "string"}),
+ chunk_size=_nullable({"type": "integer"}),
+ chunk_overlap=_nullable({"type": "integer"}),
+ top_k_dense=_nullable({"type": "integer"}),
+ top_k_sparse=_nullable({"type": "integer"}),
+ top_m_final=_nullable({"type": "integer"}),
+)
+KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA = _object_schema(
+ required=(
+ "doc_id",
+ "kb_id",
+ "doc_name",
+ "file_type",
+ "file_size",
+ "chunk_count",
+ "media_count",
+ ),
+ doc_id={"type": "string"},
+ kb_id={"type": "string"},
+ doc_name={"type": "string"},
+ file_type={"type": "string"},
+ file_size={"type": "integer"},
+ file_path={"type": "string"},
+ chunk_count={"type": "integer"},
+ media_count={"type": "integer"},
+ created_at=_nullable({"type": "string"}),
+ updated_at=_nullable({"type": "string"}),
+)
+KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA = _object_schema(
+ required=(
+ "chunk_id",
+ "doc_id",
+ "kb_id",
+ "kb_name",
+ "doc_name",
+ "chunk_index",
+ "content",
+ "score",
+ "char_count",
+ ),
+ chunk_id={"type": "string"},
+ doc_id={"type": "string"},
+ kb_id={"type": "string"},
+ kb_name={"type": "string"},
+ doc_name={"type": "string"},
+ chunk_index={"type": "integer"},
+ content={"type": "string"},
+ score={"type": "number"},
+ char_count={"type": "integer"},
+)
+KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA = _object_schema(
+ file_token=_nullable({"type": "string"}),
+ url=_nullable({"type": "string"}),
+ text=_nullable({"type": "string"}),
+ file_name=_nullable({"type": "string"}),
+ file_type=_nullable({"type": "string"}),
+ chunk_size=_nullable({"type": "integer"}),
+ chunk_overlap=_nullable({"type": "integer"}),
+ batch_size=_nullable({"type": "integer"}),
+ tasks_limit=_nullable({"type": "integer"}),
+ max_retries=_nullable({"type": "integer"}),
+ enable_cleaning=_nullable({"type": "boolean"}),
+ cleaning_provider_id=_nullable({"type": "string"}),
+)
+KB_LIST_INPUT_SCHEMA = _object_schema()
+KB_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("kbs",),
+ kbs={"type": "array", "items": KNOWLEDGE_BASE_RECORD_SCHEMA},
+)
+KB_GET_INPUT_SCHEMA = _object_schema(
+ required=("kb_id",),
+ kb_id={"type": "string"},
+)
+KB_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("kb",),
+ kb=_nullable(KNOWLEDGE_BASE_RECORD_SCHEMA),
+)
+KB_CREATE_INPUT_SCHEMA = _object_schema(
+ required=("kb",),
+ kb=KNOWLEDGE_BASE_CREATE_SCHEMA,
+)
+KB_CREATE_OUTPUT_SCHEMA = _object_schema(
+ required=("kb",),
+ kb=KNOWLEDGE_BASE_RECORD_SCHEMA,
+)
+KB_UPDATE_INPUT_SCHEMA = _object_schema(
+ required=("kb_id", "kb"),
+ kb_id={"type": "string"},
+ kb=KNOWLEDGE_BASE_UPDATE_SCHEMA,
+)
+KB_UPDATE_OUTPUT_SCHEMA = _object_schema(
+ required=("kb",),
+ kb=_nullable(KNOWLEDGE_BASE_RECORD_SCHEMA),
+)
+KB_DELETE_INPUT_SCHEMA = _object_schema(
+ required=("kb_id",),
+ kb_id={"type": "string"},
+)
+KB_DELETE_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted",),
+ deleted={"type": "boolean"},
+)
+KB_RETRIEVE_INPUT_SCHEMA = _object_schema(
+ required=("query",),
+ query={"type": "string"},
+ kb_ids={"type": "array", "items": {"type": "string"}},
+ kb_names={"type": "array", "items": {"type": "string"}},
+ top_k_fusion={"type": "integer"},
+ top_m_final={"type": "integer"},
+)
+KB_RETRIEVE_OUTPUT_SCHEMA = _object_schema(
+ required=("result",),
+ result=_nullable(
+ _object_schema(
+ required=("context_text", "results"),
+ context_text={"type": "string"},
+ results={
+ "type": "array",
+ "items": KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA,
+ },
+ )
+ ),
+)
+KB_DOCUMENT_UPLOAD_INPUT_SCHEMA = _object_schema(
+ required=("kb_id", "document"),
+ kb_id={"type": "string"},
+ document=KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA,
+)
+KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA = _object_schema(
+ required=("document",),
+ document=KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA,
+)
+KB_DOCUMENT_LIST_INPUT_SCHEMA = _object_schema(
+ required=("kb_id",),
+ kb_id={"type": "string"},
+ offset={"type": "integer"},
+ limit={"type": "integer"},
+)
+KB_DOCUMENT_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("documents",),
+ documents={"type": "array", "items": KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA},
+)
+KB_DOCUMENT_GET_INPUT_SCHEMA = _object_schema(
+ required=("kb_id", "doc_id"),
+ kb_id={"type": "string"},
+ doc_id={"type": "string"},
+)
+KB_DOCUMENT_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("document",),
+ document=_nullable(KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA),
+)
+KB_DOCUMENT_DELETE_INPUT_SCHEMA = _object_schema(
+ required=("kb_id", "doc_id"),
+ kb_id={"type": "string"},
+ doc_id={"type": "string"},
+)
+KB_DOCUMENT_DELETE_OUTPUT_SCHEMA = _object_schema(
+ required=("deleted",),
+ deleted={"type": "boolean"},
+)
+KB_DOCUMENT_REFRESH_INPUT_SCHEMA = _object_schema(
+ required=("kb_id", "doc_id"),
+ kb_id={"type": "string"},
+ doc_id={"type": "string"},
+)
+KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA = _object_schema(
+ required=("document",),
+ document=_nullable(KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA),
+)
+REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA = _object_schema(
+ required=("command_name", "handler_full_name"),
+ command_name={"type": "string"},
+ handler_full_name={"type": "string"},
+ source_event_type={"type": "string"},
+ desc={"type": "string"},
+ priority={"type": "integer"},
+ use_regex={"type": "boolean"},
+ ignore_prefix={"type": "boolean"},
+)
+REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA = _object_schema()
+SKILL_REGISTER_INPUT_SCHEMA = _object_schema(
+ required=("name", "path"),
+ name={"type": "string"},
+ path={"type": "string"},
+ description={"type": "string"},
+)
+SKILL_REGISTER_OUTPUT_SCHEMA = _object_schema(
+ required=("name", "description", "path", "skill_dir"),
+ name={"type": "string"},
+ description={"type": "string"},
+ path={"type": "string"},
+ skill_dir={"type": "string"},
+)
+SKILL_UNREGISTER_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+SKILL_UNREGISTER_OUTPUT_SCHEMA = _object_schema(
+ required=("removed",),
+ removed={"type": "boolean"},
+)
+SKILL_LIST_INPUT_SCHEMA = _object_schema()
+SKILL_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("skills",),
+ skills={
+ "type": "array",
+ "items": SKILL_REGISTER_OUTPUT_SCHEMA,
+ },
+)
+HTTP_REGISTER_API_INPUT_SCHEMA = _object_schema(
+ required=("route", "methods", "handler_capability"),
+ route={"type": "string"},
+ methods={"type": "array", "items": {"type": "string"}},
+ handler_capability={"type": "string"},
+ description={"type": "string"},
+)
+HTTP_REGISTER_API_OUTPUT_SCHEMA = _object_schema()
+HTTP_UNREGISTER_API_INPUT_SCHEMA = _object_schema(
+ required=("route", "methods"),
+ route={"type": "string"},
+ methods={"type": "array", "items": {"type": "string"}},
+)
+HTTP_UNREGISTER_API_OUTPUT_SCHEMA = _object_schema()
+HTTP_LIST_APIS_INPUT_SCHEMA = _object_schema()
+HTTP_LIST_APIS_OUTPUT_SCHEMA = _object_schema(
+ required=("apis",),
+ apis={"type": "array", "items": {"type": "object"}},
+)
+METADATA_GET_PLUGIN_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+METADATA_GET_PLUGIN_OUTPUT_SCHEMA = _object_schema(
+ required=("plugin",),
+ plugin=_nullable({"type": "object"}),
+)
+METADATA_LIST_PLUGINS_INPUT_SCHEMA = _object_schema()
+METADATA_LIST_PLUGINS_OUTPUT_SCHEMA = _object_schema(
+ required=("plugins",),
+ plugins={"type": "array", "items": {"type": "object"}},
+)
+METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA = _object_schema(
+ required=("config",),
+ config=_nullable({"type": "object"}),
+)
+METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA = _object_schema(
+ required=("config",),
+ config={"type": "object"},
+)
+METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA = _object_schema(
+ required=("config",),
+ config=_nullable({"type": "object"}),
+)
+REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA = _object_schema(
+ required=("event_type",),
+ event_type={"type": "string"},
+)
+REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA = _object_schema(
+ required=("handlers",),
+ handlers={"type": "array", "items": {"type": "object"}},
+)
+REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA = _object_schema(
+ required=("full_name",),
+ full_name={"type": "string"},
+)
+REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA = _object_schema(
+ required=("handler",),
+ handler=_nullable({"type": "object"}),
+)
+PROVIDER_META_SCHEMA = _object_schema(
+ required=("id", "type", "provider_type"),
+ id={"type": "string"},
+ model=_nullable({"type": "string"}),
+ type={"type": "string"},
+ provider_type={"type": "string"},
+)
+MANAGED_PROVIDER_RECORD_SCHEMA = _object_schema(
+ required=("id", "type", "provider_type", "loaded", "enabled"),
+ id={"type": "string"},
+ model=_nullable({"type": "string"}),
+ type={"type": "string"},
+ provider_type={"type": "string"},
+ loaded={"type": "boolean"},
+ enabled={"type": "boolean"},
+ provider_source_id=_nullable({"type": "string"}),
+)
+PROVIDER_CHANGE_EVENT_SCHEMA = _object_schema(
+ required=("provider_id", "provider_type"),
+ provider_id={"type": "string"},
+ provider_type={"type": "string"},
+ umo=_nullable({"type": "string"}),
+)
+LLM_TOOL_SPEC_SCHEMA = _object_schema(
+ required=("name", "description", "parameters_schema", "active"),
+ name={"type": "string"},
+ description={"type": "string"},
+ parameters_schema={"type": "object"},
+ handler_ref=_nullable({"type": "string"}),
+ handler_capability=_nullable({"type": "string"}),
+ active={"type": "boolean"},
+)
+AGENT_SPEC_SCHEMA = _object_schema(
+ required=("name", "description", "tool_names", "runner_class"),
+ name={"type": "string"},
+ description={"type": "string"},
+ tool_names={"type": "array", "items": {"type": "string"}},
+ runner_class={"type": "string"},
+)
+PROVIDER_GET_USING_INPUT_SCHEMA = _object_schema(umo=_nullable({"type": "string"}))
+PROVIDER_GET_USING_OUTPUT_SCHEMA = _object_schema(
+ required=("provider",),
+ provider=_nullable(PROVIDER_META_SCHEMA),
+)
+PROVIDER_GET_BY_ID_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+)
+PROVIDER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema(
+ required=("provider",),
+ provider=_nullable(PROVIDER_META_SCHEMA),
+)
+PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA = _object_schema(
+ umo=_nullable({"type": "string"}),
+)
+PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id=_nullable({"type": "string"}),
+)
+PROVIDER_LIST_ALL_INPUT_SCHEMA = _object_schema()
+PROVIDER_LIST_ALL_OUTPUT_SCHEMA = _object_schema(
+ required=("providers",),
+ providers={"type": "array", "items": PROVIDER_META_SCHEMA},
+)
+PROVIDER_STT_GET_TEXT_INPUT_SCHEMA = _object_schema(
+ required=("provider_id", "audio_url"),
+ provider_id={"type": "string"},
+ audio_url={"type": "string"},
+)
+PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA = _object_schema(
+ required=("text",),
+ text={"type": "string"},
+)
+PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA = _object_schema(
+ required=("provider_id", "text"),
+ provider_id={"type": "string"},
+ text={"type": "string"},
+)
+PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA = _object_schema(
+ required=("audio_path",),
+ audio_path={"type": "string"},
+)
+PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+)
+PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA = _object_schema(
+ required=("supported",),
+ supported={"type": "boolean"},
+)
+PROVIDER_TTS_AUDIO_CHUNK_SCHEMA = _object_schema(
+ required=("audio_base64",),
+ audio_base64={"type": "string"},
+ text=_nullable({"type": "string"}),
+)
+PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+ text=_nullable({"type": "string"}),
+ text_chunks={"type": "array", "items": {"type": "string"}},
+)
+PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA = PROVIDER_TTS_AUDIO_CHUNK_SCHEMA
+PROVIDER_EMBEDDING_GET_INPUT_SCHEMA = _object_schema(
+ required=("provider_id", "text"),
+ provider_id={"type": "string"},
+ text={"type": "string"},
+)
+PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("embedding",),
+ embedding={"type": "array", "items": {"type": "number"}},
+)
+PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA = _object_schema(
+ required=("provider_id", "texts"),
+ provider_id={"type": "string"},
+ texts={"type": "array", "items": {"type": "string"}},
+)
+PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA = _object_schema(
+ required=("embeddings",),
+ embeddings={
+ "type": "array",
+ "items": {"type": "array", "items": {"type": "number"}},
+ },
+)
+PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+)
+PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA = _object_schema(
+ required=("dim",),
+ dim={"type": "integer"},
+)
+PROVIDER_RERANK_RESULT_SCHEMA = _object_schema(
+ required=("index", "score", "document"),
+ index={"type": "integer"},
+ score={"type": "number"},
+ document={"type": "string"},
+)
+PROVIDER_RERANK_INPUT_SCHEMA = _object_schema(
+ required=("provider_id", "query", "documents"),
+ provider_id={"type": "string"},
+ query={"type": "string"},
+ documents={"type": "array", "items": {"type": "string"}},
+ top_n=_nullable({"type": "integer"}),
+)
+PROVIDER_RERANK_OUTPUT_SCHEMA = _object_schema(
+ required=("results",),
+ results={"type": "array", "items": PROVIDER_RERANK_RESULT_SCHEMA},
+)
+PROVIDER_MANAGER_SET_INPUT_SCHEMA = _object_schema(
+ required=("provider_id", "provider_type"),
+ provider_id={"type": "string"},
+ provider_type={"type": "string"},
+ umo=_nullable({"type": "string"}),
+)
+PROVIDER_MANAGER_SET_OUTPUT_SCHEMA = _object_schema()
+PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+)
+PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema(
+ required=("provider",),
+ provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA),
+)
+PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+)
+PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA = _object_schema(
+ required=("config",),
+ config=_nullable({"type": "object"}),
+)
+PROVIDER_MANAGER_LOAD_INPUT_SCHEMA = _object_schema(
+ required=("provider_config",),
+ provider_config={"type": "object"},
+)
+PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA = _object_schema(
+ required=("provider",),
+ provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA),
+)
+PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA = _object_schema(
+ required=("provider_id",),
+ provider_id={"type": "string"},
+)
+PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA = _object_schema()
+PROVIDER_MANAGER_CREATE_INPUT_SCHEMA = _object_schema(
+ required=("provider_config",),
+ provider_config={"type": "object"},
+)
+PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA = _object_schema(
+ required=("provider",),
+ provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA),
+)
+PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA = _object_schema(
+ required=("origin_provider_id", "new_config"),
+ origin_provider_id={"type": "string"},
+ new_config={"type": "object"},
+)
+PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA = _object_schema(
+ required=("provider",),
+ provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA),
+)
+PROVIDER_MANAGER_DELETE_INPUT_SCHEMA = _object_schema(
+ provider_id=_nullable({"type": "string"}),
+ provider_source_id=_nullable({"type": "string"}),
+)
+PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA = _object_schema()
+PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA = _object_schema()
+PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA = _object_schema(
+ required=("providers",),
+ providers={"type": "array", "items": MANAGED_PROVIDER_RECORD_SCHEMA},
+)
+PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA = _object_schema()
+PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA = _object_schema(
+ required=("provider_id", "provider_type"),
+ provider_id={"type": "string"},
+ provider_type={"type": "string"},
+ umo=_nullable({"type": "string"}),
+)
+LLM_TOOL_MANAGER_GET_INPUT_SCHEMA = _object_schema()
+LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("registered", "active"),
+ registered={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA},
+ active={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA},
+)
+LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA = _object_schema(
+ required=("activated",),
+ activated={"type": "boolean"},
+)
+LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA = _object_schema(
+ required=("deactivated",),
+ deactivated={"type": "boolean"},
+)
+LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA = _object_schema(
+ required=("tools",),
+ tools={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA},
+)
+LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA = _object_schema(
+ required=("names",),
+ names={"type": "array", "items": {"type": "string"}},
+)
+LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA = _object_schema(
+ required=("removed",),
+ removed={"type": "boolean"},
+)
+AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA = _object_schema(
+ prompt=_nullable({"type": "string"}),
+ system_prompt=_nullable({"type": "string"}),
+ session_id=_nullable({"type": "string"}),
+ contexts={"type": "array", "items": {"type": "object"}},
+ image_urls={"type": "array", "items": {"type": "string"}},
+ tool_names=_nullable({"type": "array", "items": {"type": "string"}}),
+ tool_calls_result={"type": "array", "items": {"type": "object"}},
+ provider_id=_nullable({"type": "string"}),
+ model=_nullable({"type": "string"}),
+ temperature={"type": "number"},
+ max_steps={"type": "integer"},
+ tool_call_timeout={"type": "integer"},
+)
+AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA = LLM_CHAT_RAW_OUTPUT_SCHEMA
+AGENT_REGISTRY_LIST_INPUT_SCHEMA = _object_schema()
+AGENT_REGISTRY_LIST_OUTPUT_SCHEMA = _object_schema(
+ required=("agents",),
+ agents={"type": "array", "items": AGENT_SPEC_SCHEMA},
+)
+AGENT_REGISTRY_GET_INPUT_SCHEMA = _object_schema(
+ required=("name",),
+ name={"type": "string"},
+)
+AGENT_REGISTRY_GET_OUTPUT_SCHEMA = _object_schema(
+ required=("agent",),
+ agent=_nullable(AGENT_SPEC_SCHEMA),
+)
+
+BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
+ "llm.chat": {"input": LLM_CHAT_INPUT_SCHEMA, "output": LLM_CHAT_OUTPUT_SCHEMA},
+ "llm.chat_raw": {
+ "input": LLM_CHAT_RAW_INPUT_SCHEMA,
+ "output": LLM_CHAT_RAW_OUTPUT_SCHEMA,
+ },
+ "llm.stream_chat": {
+ "input": LLM_STREAM_CHAT_INPUT_SCHEMA,
+ "output": LLM_STREAM_CHAT_OUTPUT_SCHEMA,
+ },
+ "memory.search": {
+ "input": MEMORY_SEARCH_INPUT_SCHEMA,
+ "output": MEMORY_SEARCH_OUTPUT_SCHEMA,
+ },
+ "memory.save": {
+ "input": MEMORY_SAVE_INPUT_SCHEMA,
+ "output": MEMORY_SAVE_OUTPUT_SCHEMA,
+ },
+ "memory.get": {
+ "input": MEMORY_GET_INPUT_SCHEMA,
+ "output": MEMORY_GET_OUTPUT_SCHEMA,
+ },
+ "memory.list_keys": {
+ "input": MEMORY_LIST_KEYS_INPUT_SCHEMA,
+ "output": MEMORY_LIST_KEYS_OUTPUT_SCHEMA,
+ },
+ "memory.exists": {
+ "input": MEMORY_EXISTS_INPUT_SCHEMA,
+ "output": MEMORY_EXISTS_OUTPUT_SCHEMA,
+ },
+ "memory.delete": {
+ "input": MEMORY_DELETE_INPUT_SCHEMA,
+ "output": MEMORY_DELETE_OUTPUT_SCHEMA,
+ },
+ "memory.clear_namespace": {
+ "input": MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA,
+ "output": MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA,
+ },
+ "memory.save_with_ttl": {
+ "input": MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA,
+ "output": MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA,
+ },
+ "memory.get_many": {
+ "input": MEMORY_GET_MANY_INPUT_SCHEMA,
+ "output": MEMORY_GET_MANY_OUTPUT_SCHEMA,
+ },
+ "memory.delete_many": {
+ "input": MEMORY_DELETE_MANY_INPUT_SCHEMA,
+ "output": MEMORY_DELETE_MANY_OUTPUT_SCHEMA,
+ },
+ "memory.count": {
+ "input": MEMORY_COUNT_INPUT_SCHEMA,
+ "output": MEMORY_COUNT_OUTPUT_SCHEMA,
+ },
+ "memory.stats": {
+ "input": MEMORY_STATS_INPUT_SCHEMA,
+ "output": MEMORY_STATS_OUTPUT_SCHEMA,
+ },
+ "db.get": {"input": DB_GET_INPUT_SCHEMA, "output": DB_GET_OUTPUT_SCHEMA},
+ "db.set": {"input": DB_SET_INPUT_SCHEMA, "output": DB_SET_OUTPUT_SCHEMA},
+ "db.delete": {"input": DB_DELETE_INPUT_SCHEMA, "output": DB_DELETE_OUTPUT_SCHEMA},
+ "db.list": {"input": DB_LIST_INPUT_SCHEMA, "output": DB_LIST_OUTPUT_SCHEMA},
+ "db.get_many": {
+ "input": DB_GET_MANY_INPUT_SCHEMA,
+ "output": DB_GET_MANY_OUTPUT_SCHEMA,
+ },
+ "db.set_many": {
+ "input": DB_SET_MANY_INPUT_SCHEMA,
+ "output": DB_SET_MANY_OUTPUT_SCHEMA,
+ },
+ "db.watch": {"input": DB_WATCH_INPUT_SCHEMA, "output": DB_WATCH_OUTPUT_SCHEMA},
+ "platform.send": {
+ "input": PLATFORM_SEND_INPUT_SCHEMA,
+ "output": PLATFORM_SEND_OUTPUT_SCHEMA,
+ },
+ "platform.send_image": {
+ "input": PLATFORM_SEND_IMAGE_INPUT_SCHEMA,
+ "output": PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA,
+ },
+ "platform.send_chain": {
+ "input": PLATFORM_SEND_CHAIN_INPUT_SCHEMA,
+ "output": PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA,
+ },
+ "platform.send_by_session": {
+ "input": PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA,
+ "output": PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA,
+ },
+ "platform.get_group": {
+ "input": PLATFORM_GET_GROUP_INPUT_SCHEMA,
+ "output": PLATFORM_GET_GROUP_OUTPUT_SCHEMA,
+ },
+ "platform.get_members": {
+ "input": PLATFORM_GET_MEMBERS_INPUT_SCHEMA,
+ "output": PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA,
+ },
+ "platform.list_instances": {
+ "input": PLATFORM_LIST_INSTANCES_INPUT_SCHEMA,
+ "output": PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA,
+ },
+ "session.plugin.is_enabled": {
+ "input": SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA,
+ "output": SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA,
+ },
+ "session.plugin.filter_handlers": {
+ "input": SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA,
+ "output": SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA,
+ },
+ "session.service.is_llm_enabled": {
+ "input": SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA,
+ "output": SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA,
+ },
+ "session.service.set_llm_status": {
+ "input": SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA,
+ "output": SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA,
+ },
+ "session.service.is_tts_enabled": {
+ "input": SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA,
+ "output": SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA,
+ },
+ "session.service.set_tts_status": {
+ "input": SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA,
+ "output": SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA,
+ },
+ "persona.get": {
+ "input": PERSONA_GET_INPUT_SCHEMA,
+ "output": PERSONA_GET_OUTPUT_SCHEMA,
+ },
+ "persona.list": {
+ "input": PERSONA_LIST_INPUT_SCHEMA,
+ "output": PERSONA_LIST_OUTPUT_SCHEMA,
+ },
+ "persona.create": {
+ "input": PERSONA_CREATE_INPUT_SCHEMA,
+ "output": PERSONA_CREATE_OUTPUT_SCHEMA,
+ },
+ "persona.update": {
+ "input": PERSONA_UPDATE_INPUT_SCHEMA,
+ "output": PERSONA_UPDATE_OUTPUT_SCHEMA,
+ },
+ "persona.delete": {
+ "input": PERSONA_DELETE_INPUT_SCHEMA,
+ "output": PERSONA_DELETE_OUTPUT_SCHEMA,
+ },
+ "conversation.new": {
+ "input": CONVERSATION_NEW_INPUT_SCHEMA,
+ "output": CONVERSATION_NEW_OUTPUT_SCHEMA,
+ },
+ "conversation.switch": {
+ "input": CONVERSATION_SWITCH_INPUT_SCHEMA,
+ "output": CONVERSATION_SWITCH_OUTPUT_SCHEMA,
+ },
+ "conversation.delete": {
+ "input": CONVERSATION_DELETE_INPUT_SCHEMA,
+ "output": CONVERSATION_DELETE_OUTPUT_SCHEMA,
+ },
+ "conversation.get": {
+ "input": CONVERSATION_GET_INPUT_SCHEMA,
+ "output": CONVERSATION_GET_OUTPUT_SCHEMA,
+ },
+ "conversation.get_current": {
+ "input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA,
+ "output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA,
+ },
+ "conversation.list": {
+ "input": CONVERSATION_LIST_INPUT_SCHEMA,
+ "output": CONVERSATION_LIST_OUTPUT_SCHEMA,
+ },
+ "conversation.update": {
+ "input": CONVERSATION_UPDATE_INPUT_SCHEMA,
+ "output": CONVERSATION_UPDATE_OUTPUT_SCHEMA,
+ },
+ "conversation.unset_persona": {
+ "input": CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA,
+ "output": CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA,
+ },
+ "message_history.list": {
+ "input": MESSAGE_HISTORY_LIST_INPUT_SCHEMA,
+ "output": MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA,
+ },
+ "message_history.get_by_id": {
+ "input": MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA,
+ "output": MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA,
+ },
+ "message_history.append": {
+ "input": MESSAGE_HISTORY_APPEND_INPUT_SCHEMA,
+ "output": MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA,
+ },
+ "message_history.delete_before": {
+ "input": MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA,
+ "output": MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA,
+ },
+ "message_history.delete_after": {
+ "input": MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA,
+ "output": MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA,
+ },
+ "message_history.delete_all": {
+ "input": MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA,
+ "output": MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA,
+ },
+ "mcp.local.get": {
+ "input": MCP_LOCAL_GET_INPUT_SCHEMA,
+ "output": MCP_LOCAL_GET_OUTPUT_SCHEMA,
+ },
+ "mcp.local.list": {
+ "input": MCP_LOCAL_LIST_INPUT_SCHEMA,
+ "output": MCP_LOCAL_LIST_OUTPUT_SCHEMA,
+ },
+ "mcp.local.enable": {
+ "input": MCP_LOCAL_ENABLE_INPUT_SCHEMA,
+ "output": MCP_LOCAL_ENABLE_OUTPUT_SCHEMA,
+ },
+ "mcp.local.disable": {
+ "input": MCP_LOCAL_DISABLE_INPUT_SCHEMA,
+ "output": MCP_LOCAL_DISABLE_OUTPUT_SCHEMA,
+ },
+ "mcp.local.wait_until_ready": {
+ "input": MCP_LOCAL_WAIT_UNTIL_READY_INPUT_SCHEMA,
+ "output": MCP_LOCAL_WAIT_UNTIL_READY_OUTPUT_SCHEMA,
+ },
+ "mcp.session.open": {
+ "input": MCP_SESSION_OPEN_INPUT_SCHEMA,
+ "output": MCP_SESSION_OPEN_OUTPUT_SCHEMA,
+ },
+ "mcp.session.list_tools": {
+ "input": MCP_SESSION_LIST_TOOLS_INPUT_SCHEMA,
+ "output": MCP_SESSION_LIST_TOOLS_OUTPUT_SCHEMA,
+ },
+ "mcp.session.call_tool": {
+ "input": MCP_SESSION_CALL_TOOL_INPUT_SCHEMA,
+ "output": MCP_SESSION_CALL_TOOL_OUTPUT_SCHEMA,
+ },
+ "mcp.session.close": {
+ "input": MCP_SESSION_CLOSE_INPUT_SCHEMA,
+ "output": MCP_SESSION_CLOSE_OUTPUT_SCHEMA,
+ },
+ "internal.mcp.local.execute": {
+ "input": INTERNAL_MCP_LOCAL_EXECUTE_INPUT_SCHEMA,
+ "output": INTERNAL_MCP_LOCAL_EXECUTE_OUTPUT_SCHEMA,
+ },
+ "kb.list": {"input": KB_LIST_INPUT_SCHEMA, "output": KB_LIST_OUTPUT_SCHEMA},
+ "kb.get": {"input": KB_GET_INPUT_SCHEMA, "output": KB_GET_OUTPUT_SCHEMA},
+ "kb.create": {
+ "input": KB_CREATE_INPUT_SCHEMA,
+ "output": KB_CREATE_OUTPUT_SCHEMA,
+ },
+ "kb.update": {
+ "input": KB_UPDATE_INPUT_SCHEMA,
+ "output": KB_UPDATE_OUTPUT_SCHEMA,
+ },
+ "kb.delete": {
+ "input": KB_DELETE_INPUT_SCHEMA,
+ "output": KB_DELETE_OUTPUT_SCHEMA,
+ },
+ "kb.retrieve": {
+ "input": KB_RETRIEVE_INPUT_SCHEMA,
+ "output": KB_RETRIEVE_OUTPUT_SCHEMA,
+ },
+ "kb.document.upload": {
+ "input": KB_DOCUMENT_UPLOAD_INPUT_SCHEMA,
+ "output": KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA,
+ },
+ "kb.document.list": {
+ "input": KB_DOCUMENT_LIST_INPUT_SCHEMA,
+ "output": KB_DOCUMENT_LIST_OUTPUT_SCHEMA,
+ },
+ "kb.document.get": {
+ "input": KB_DOCUMENT_GET_INPUT_SCHEMA,
+ "output": KB_DOCUMENT_GET_OUTPUT_SCHEMA,
+ },
+ "kb.document.delete": {
+ "input": KB_DOCUMENT_DELETE_INPUT_SCHEMA,
+ "output": KB_DOCUMENT_DELETE_OUTPUT_SCHEMA,
+ },
+ "kb.document.refresh": {
+ "input": KB_DOCUMENT_REFRESH_INPUT_SCHEMA,
+ "output": KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA,
+ },
+ "registry.command.register": {
+ "input": REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA,
+ "output": REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA,
+ },
+ "skill.register": {
+ "input": SKILL_REGISTER_INPUT_SCHEMA,
+ "output": SKILL_REGISTER_OUTPUT_SCHEMA,
+ },
+ "skill.unregister": {
+ "input": SKILL_UNREGISTER_INPUT_SCHEMA,
+ "output": SKILL_UNREGISTER_OUTPUT_SCHEMA,
+ },
+ "skill.list": {
+ "input": SKILL_LIST_INPUT_SCHEMA,
+ "output": SKILL_LIST_OUTPUT_SCHEMA,
+ },
+ "http.register_api": {
+ "input": HTTP_REGISTER_API_INPUT_SCHEMA,
+ "output": HTTP_REGISTER_API_OUTPUT_SCHEMA,
+ },
+ "http.unregister_api": {
+ "input": HTTP_UNREGISTER_API_INPUT_SCHEMA,
+ "output": HTTP_UNREGISTER_API_OUTPUT_SCHEMA,
+ },
+ "http.list_apis": {
+ "input": HTTP_LIST_APIS_INPUT_SCHEMA,
+ "output": HTTP_LIST_APIS_OUTPUT_SCHEMA,
+ },
+ "metadata.get_plugin": {
+ "input": METADATA_GET_PLUGIN_INPUT_SCHEMA,
+ "output": METADATA_GET_PLUGIN_OUTPUT_SCHEMA,
+ },
+ "metadata.list_plugins": {
+ "input": METADATA_LIST_PLUGINS_INPUT_SCHEMA,
+ "output": METADATA_LIST_PLUGINS_OUTPUT_SCHEMA,
+ },
+ "metadata.get_plugin_config": {
+ "input": METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA,
+ "output": METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA,
+ },
+ "metadata.save_plugin_config": {
+ "input": METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA,
+ "output": METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA,
+ },
+ "registry.get_handlers_by_event_type": {
+ "input": REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA,
+ "output": REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA,
+ },
+ "registry.get_handler_by_full_name": {
+ "input": REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA,
+ "output": REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA,
+ },
+ "provider.get_using": {
+ "input": PROVIDER_GET_USING_INPUT_SCHEMA,
+ "output": PROVIDER_GET_USING_OUTPUT_SCHEMA,
+ },
+ "provider.get_by_id": {
+ "input": PROVIDER_GET_BY_ID_INPUT_SCHEMA,
+ "output": PROVIDER_GET_BY_ID_OUTPUT_SCHEMA,
+ },
+ "provider.get_current_chat_provider_id": {
+ "input": PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA,
+ "output": PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA,
+ },
+ "provider.list_all": {
+ "input": PROVIDER_LIST_ALL_INPUT_SCHEMA,
+ "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA,
+ },
+ "provider.list_all_tts": {
+ "input": PROVIDER_LIST_ALL_INPUT_SCHEMA,
+ "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA,
+ },
+ "provider.list_all_stt": {
+ "input": PROVIDER_LIST_ALL_INPUT_SCHEMA,
+ "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA,
+ },
+ "provider.list_all_embedding": {
+ "input": PROVIDER_LIST_ALL_INPUT_SCHEMA,
+ "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA,
+ },
+ "provider.list_all_rerank": {
+ "input": PROVIDER_LIST_ALL_INPUT_SCHEMA,
+ "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA,
+ },
+ "provider.get_using_tts": {
+ "input": PROVIDER_GET_USING_INPUT_SCHEMA,
+ "output": PROVIDER_GET_USING_OUTPUT_SCHEMA,
+ },
+ "provider.get_using_stt": {
+ "input": PROVIDER_GET_USING_INPUT_SCHEMA,
+ "output": PROVIDER_GET_USING_OUTPUT_SCHEMA,
+ },
+ "provider.stt.get_text": {
+ "input": PROVIDER_STT_GET_TEXT_INPUT_SCHEMA,
+ "output": PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA,
+ },
+ "provider.tts.get_audio": {
+ "input": PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA,
+ "output": PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA,
+ },
+ "provider.tts.support_stream": {
+ "input": PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA,
+ "output": PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA,
+ },
+ "provider.tts.get_audio_stream": {
+ "input": PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA,
+ "output": PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA,
+ },
+ "provider.embedding.get_embedding": {
+ "input": PROVIDER_EMBEDDING_GET_INPUT_SCHEMA,
+ "output": PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA,
+ },
+ "provider.embedding.get_embeddings": {
+ "input": PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA,
+ "output": PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA,
+ },
+ "provider.embedding.get_dim": {
+ "input": PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA,
+ "output": PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA,
+ },
+ "provider.rerank.rerank": {
+ "input": PROVIDER_RERANK_INPUT_SCHEMA,
+ "output": PROVIDER_RERANK_OUTPUT_SCHEMA,
+ },
+ "provider.manager.set": {
+ "input": PROVIDER_MANAGER_SET_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_SET_OUTPUT_SCHEMA,
+ },
+ "provider.manager.get_by_id": {
+ "input": PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA,
+ },
+ "provider.manager.get_merged_provider_config": {
+ "input": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA,
+ },
+ "provider.manager.load": {
+ "input": PROVIDER_MANAGER_LOAD_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA,
+ },
+ "provider.manager.terminate": {
+ "input": PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA,
+ },
+ "provider.manager.create": {
+ "input": PROVIDER_MANAGER_CREATE_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA,
+ },
+ "provider.manager.update": {
+ "input": PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA,
+ },
+ "provider.manager.delete": {
+ "input": PROVIDER_MANAGER_DELETE_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA,
+ },
+ "provider.manager.get_insts": {
+ "input": PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA,
+ },
+ "provider.manager.watch_changes": {
+ "input": PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA,
+ "output": PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA,
+ },
+ "platform.manager.get_by_id": {
+ "input": PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA,
+ "output": PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA,
+ },
+ "platform.manager.clear_errors": {
+ "input": PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA,
+ "output": PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA,
+ },
+ "platform.manager.get_stats": {
+ "input": PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA,
+ "output": PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA,
+ },
+ "permission.check": {
+ "input": PERMISSION_CHECK_INPUT_SCHEMA,
+ "output": PERMISSION_CHECK_OUTPUT_SCHEMA,
+ },
+ "permission.get_admins": {
+ "input": PERMISSION_GET_ADMINS_INPUT_SCHEMA,
+ "output": PERMISSION_GET_ADMINS_OUTPUT_SCHEMA,
+ },
+ "permission.manager.add_admin": {
+ "input": PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA,
+ "output": PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA,
+ },
+ "permission.manager.remove_admin": {
+ "input": PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA,
+ "output": PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA,
+ },
+ "llm_tool.manager.get": {
+ "input": LLM_TOOL_MANAGER_GET_INPUT_SCHEMA,
+ "output": LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA,
+ },
+ "llm_tool.manager.activate": {
+ "input": LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA,
+ "output": LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA,
+ },
+ "llm_tool.manager.deactivate": {
+ "input": LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA,
+ "output": LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA,
+ },
+ "llm_tool.manager.add": {
+ "input": LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA,
+ "output": LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA,
+ },
+ "llm_tool.manager.remove": {
+ "input": LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA,
+ "output": LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA,
+ },
+ "agent.tool_loop.run": {
+ "input": AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA,
+ "output": AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA,
+ },
+ "agent.registry.list": {
+ "input": AGENT_REGISTRY_LIST_INPUT_SCHEMA,
+ "output": AGENT_REGISTRY_LIST_OUTPUT_SCHEMA,
+ },
+ "agent.registry.get": {
+ "input": AGENT_REGISTRY_GET_INPUT_SCHEMA,
+ "output": AGENT_REGISTRY_GET_OUTPUT_SCHEMA,
+ },
+ "system.get_data_dir": {
+ "input": SYSTEM_GET_DATA_DIR_INPUT_SCHEMA,
+ "output": SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA,
+ },
+ "system.text_to_image": {
+ "input": SYSTEM_TEXT_TO_IMAGE_INPUT_SCHEMA,
+ "output": SYSTEM_TEXT_TO_IMAGE_OUTPUT_SCHEMA,
+ },
+ "system.html_render": {
+ "input": SYSTEM_HTML_RENDER_INPUT_SCHEMA,
+ "output": SYSTEM_HTML_RENDER_OUTPUT_SCHEMA,
+ },
+ "system.session_waiter.register": {
+ "input": SYSTEM_SESSION_WAITER_REGISTER_INPUT_SCHEMA,
+ "output": SYSTEM_SESSION_WAITER_REGISTER_OUTPUT_SCHEMA,
+ },
+ "system.session_waiter.unregister": {
+ "input": SYSTEM_SESSION_WAITER_UNREGISTER_INPUT_SCHEMA,
+ "output": SYSTEM_SESSION_WAITER_UNREGISTER_OUTPUT_SCHEMA,
+ },
+ "system.event.react": {
+ "input": SYSTEM_EVENT_REACT_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_REACT_OUTPUT_SCHEMA,
+ },
+ "system.event.send_typing": {
+ "input": SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA,
+ },
+ "system.event.send_streaming": {
+ "input": SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA,
+ },
+ "system.event.send_streaming_chunk": {
+ "input": SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA,
+ },
+ "system.event.send_streaming_close": {
+ "input": SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA,
+ },
+ "system.event.handler_whitelist.get": {
+ "input": SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA,
+ },
+ "system.event.handler_whitelist.set": {
+ "input": SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA,
+ "output": SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA,
+ },
+}
+
+
+__all__ = [
+ "BUILTIN_CAPABILITY_SCHEMAS",
+ "DB_DELETE_INPUT_SCHEMA",
+ "DB_DELETE_OUTPUT_SCHEMA",
+ "DB_GET_INPUT_SCHEMA",
+ "DB_GET_MANY_INPUT_SCHEMA",
+ "DB_GET_MANY_OUTPUT_SCHEMA",
+ "DB_GET_OUTPUT_SCHEMA",
+ "DB_LIST_INPUT_SCHEMA",
+ "DB_LIST_OUTPUT_SCHEMA",
+ "DB_SET_INPUT_SCHEMA",
+ "DB_SET_MANY_INPUT_SCHEMA",
+ "DB_SET_MANY_OUTPUT_SCHEMA",
+ "DB_SET_OUTPUT_SCHEMA",
+ "DB_WATCH_INPUT_SCHEMA",
+ "DB_WATCH_OUTPUT_SCHEMA",
+ "HTTP_LIST_APIS_INPUT_SCHEMA",
+ "HTTP_LIST_APIS_OUTPUT_SCHEMA",
+ "HTTP_REGISTER_API_INPUT_SCHEMA",
+ "HTTP_REGISTER_API_OUTPUT_SCHEMA",
+ "HTTP_UNREGISTER_API_INPUT_SCHEMA",
+ "HTTP_UNREGISTER_API_OUTPUT_SCHEMA",
+ "JSONSchema",
+ "LLM_CHAT_INPUT_SCHEMA",
+ "LLM_CHAT_OUTPUT_SCHEMA",
+ "LLM_CHAT_RAW_INPUT_SCHEMA",
+ "LLM_CHAT_RAW_OUTPUT_SCHEMA",
+ "LLM_STREAM_CHAT_INPUT_SCHEMA",
+ "LLM_STREAM_CHAT_OUTPUT_SCHEMA",
+ "MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA",
+ "MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA",
+ "MEMORY_COUNT_INPUT_SCHEMA",
+ "MEMORY_COUNT_OUTPUT_SCHEMA",
+ "MEMORY_DELETE_INPUT_SCHEMA",
+ "MEMORY_DELETE_MANY_INPUT_SCHEMA",
+ "MEMORY_DELETE_MANY_OUTPUT_SCHEMA",
+ "MEMORY_DELETE_OUTPUT_SCHEMA",
+ "MEMORY_EXISTS_INPUT_SCHEMA",
+ "MEMORY_EXISTS_OUTPUT_SCHEMA",
+ "MEMORY_GET_INPUT_SCHEMA",
+ "MEMORY_GET_MANY_INPUT_SCHEMA",
+ "MEMORY_GET_MANY_OUTPUT_SCHEMA",
+ "MEMORY_GET_OUTPUT_SCHEMA",
+ "MEMORY_LIST_KEYS_INPUT_SCHEMA",
+ "MEMORY_LIST_KEYS_OUTPUT_SCHEMA",
+ "MEMORY_SAVE_INPUT_SCHEMA",
+ "MEMORY_SAVE_OUTPUT_SCHEMA",
+ "MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA",
+ "MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA",
+ "MEMORY_SEARCH_INPUT_SCHEMA",
+ "MEMORY_SEARCH_OUTPUT_SCHEMA",
+ "MEMORY_STATS_INPUT_SCHEMA",
+ "MEMORY_STATS_OUTPUT_SCHEMA",
+ "METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA",
+ "METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA",
+ "METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA",
+ "METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA",
+ "METADATA_GET_PLUGIN_INPUT_SCHEMA",
+ "METADATA_GET_PLUGIN_OUTPUT_SCHEMA",
+ "METADATA_LIST_PLUGINS_INPUT_SCHEMA",
+ "METADATA_LIST_PLUGINS_OUTPUT_SCHEMA",
+ "PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA",
+ "PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA",
+ "PROVIDER_GET_BY_ID_INPUT_SCHEMA",
+ "PROVIDER_GET_BY_ID_OUTPUT_SCHEMA",
+ "PROVIDER_GET_USING_INPUT_SCHEMA",
+ "PROVIDER_GET_USING_OUTPUT_SCHEMA",
+ "PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA",
+ "PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA",
+ "PROVIDER_EMBEDDING_GET_INPUT_SCHEMA",
+ "PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA",
+ "PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA",
+ "PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA",
+ "PROVIDER_CHANGE_EVENT_SCHEMA",
+ "PROVIDER_LIST_ALL_INPUT_SCHEMA",
+ "PROVIDER_LIST_ALL_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_CREATE_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_DELETE_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_LOAD_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_SET_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_SET_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA",
+ "PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA",
+ "PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA",
+ "PROVIDER_META_SCHEMA",
+ "PROVIDER_RERANK_INPUT_SCHEMA",
+ "PROVIDER_RERANK_OUTPUT_SCHEMA",
+ "PROVIDER_RERANK_RESULT_SCHEMA",
+ "PROVIDER_STT_GET_TEXT_INPUT_SCHEMA",
+ "PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA",
+ "PROVIDER_TTS_AUDIO_CHUNK_SCHEMA",
+ "PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA",
+ "PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA",
+ "PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA",
+ "PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA",
+ "PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA",
+ "PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_GET_INPUT_SCHEMA",
+ "LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA",
+ "LLM_TOOL_SPEC_SCHEMA",
+ "AGENT_REGISTRY_GET_INPUT_SCHEMA",
+ "AGENT_REGISTRY_GET_OUTPUT_SCHEMA",
+ "AGENT_REGISTRY_LIST_INPUT_SCHEMA",
+ "AGENT_REGISTRY_LIST_OUTPUT_SCHEMA",
+ "AGENT_SPEC_SCHEMA",
+ "AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA",
+ "AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA",
+ "MANAGED_PROVIDER_RECORD_SCHEMA",
+ "PLATFORM_ERROR_SCHEMA",
+ "PLATFORM_GET_MEMBERS_INPUT_SCHEMA",
+ "PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA",
+ "PLATFORM_GET_GROUP_INPUT_SCHEMA",
+ "PLATFORM_GET_GROUP_OUTPUT_SCHEMA",
+ "PLATFORM_INSTANCE_SCHEMA",
+ "PLATFORM_LIST_INSTANCES_INPUT_SCHEMA",
+ "PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA",
+ "PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA",
+ "PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA",
+ "PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA",
+ "PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA",
+ "PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA",
+ "PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA",
+ "PLATFORM_MANAGER_STATE_SCHEMA",
+ "PERMISSION_CHECK_INPUT_SCHEMA",
+ "PERMISSION_CHECK_OUTPUT_SCHEMA",
+ "PERMISSION_CHECK_RESULT_SCHEMA",
+ "PERMISSION_GET_ADMINS_INPUT_SCHEMA",
+ "PERMISSION_GET_ADMINS_OUTPUT_SCHEMA",
+ "PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA",
+ "PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA",
+ "PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA",
+ "PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA",
+ "PERMISSION_ROLE_SCHEMA",
+ "PLATFORM_SEND_CHAIN_INPUT_SCHEMA",
+ "PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA",
+ "PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA",
+ "PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA",
+ "PLATFORM_SEND_IMAGE_INPUT_SCHEMA",
+ "PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA",
+ "PLATFORM_SEND_INPUT_SCHEMA",
+ "PLATFORM_SEND_OUTPUT_SCHEMA",
+ "PLATFORM_STATS_SCHEMA",
+ "PERSONA_CREATE_INPUT_SCHEMA",
+ "PERSONA_CREATE_OUTPUT_SCHEMA",
+ "PERSONA_CREATE_SCHEMA",
+ "PERSONA_DELETE_INPUT_SCHEMA",
+ "PERSONA_DELETE_OUTPUT_SCHEMA",
+ "PERSONA_GET_INPUT_SCHEMA",
+ "PERSONA_GET_OUTPUT_SCHEMA",
+ "PERSONA_LIST_INPUT_SCHEMA",
+ "PERSONA_LIST_OUTPUT_SCHEMA",
+ "PERSONA_RECORD_SCHEMA",
+ "PERSONA_UPDATE_INPUT_SCHEMA",
+ "PERSONA_UPDATE_OUTPUT_SCHEMA",
+ "PERSONA_UPDATE_SCHEMA",
+ "CONVERSATION_CREATE_SCHEMA",
+ "CONVERSATION_DELETE_INPUT_SCHEMA",
+ "CONVERSATION_DELETE_OUTPUT_SCHEMA",
+ "CONVERSATION_GET_CURRENT_INPUT_SCHEMA",
+ "CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA",
+ "CONVERSATION_GET_INPUT_SCHEMA",
+ "CONVERSATION_GET_OUTPUT_SCHEMA",
+ "CONVERSATION_LIST_INPUT_SCHEMA",
+ "CONVERSATION_LIST_OUTPUT_SCHEMA",
+ "CONVERSATION_NEW_INPUT_SCHEMA",
+ "CONVERSATION_NEW_OUTPUT_SCHEMA",
+ "CONVERSATION_RECORD_SCHEMA",
+ "CONVERSATION_SWITCH_INPUT_SCHEMA",
+ "CONVERSATION_SWITCH_OUTPUT_SCHEMA",
+ "CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA",
+ "CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA",
+ "CONVERSATION_UPDATE_INPUT_SCHEMA",
+ "CONVERSATION_UPDATE_OUTPUT_SCHEMA",
+ "CONVERSATION_UPDATE_SCHEMA",
+ "MESSAGE_HISTORY_APPEND_INPUT_SCHEMA",
+ "MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA",
+ "MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA",
+ "MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA",
+ "MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA",
+ "MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA",
+ "MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA",
+ "MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA",
+ "MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA",
+ "MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA",
+ "MESSAGE_HISTORY_LIST_INPUT_SCHEMA",
+ "MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA",
+ "MESSAGE_HISTORY_PAGE_SCHEMA",
+ "MESSAGE_HISTORY_RECORD_SCHEMA",
+ "MESSAGE_HISTORY_SENDER_SCHEMA",
+ "MESSAGE_HISTORY_SESSION_SCHEMA",
+ "KB_CREATE_INPUT_SCHEMA",
+ "KB_CREATE_OUTPUT_SCHEMA",
+ "KB_DOCUMENT_DELETE_INPUT_SCHEMA",
+ "KB_DOCUMENT_DELETE_OUTPUT_SCHEMA",
+ "KB_DOCUMENT_GET_INPUT_SCHEMA",
+ "KB_DOCUMENT_GET_OUTPUT_SCHEMA",
+ "KB_DOCUMENT_LIST_INPUT_SCHEMA",
+ "KB_DOCUMENT_LIST_OUTPUT_SCHEMA",
+ "KB_DOCUMENT_REFRESH_INPUT_SCHEMA",
+ "KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA",
+ "KB_DOCUMENT_UPLOAD_INPUT_SCHEMA",
+ "KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA",
+ "KB_DELETE_INPUT_SCHEMA",
+ "KB_DELETE_OUTPUT_SCHEMA",
+ "KB_GET_INPUT_SCHEMA",
+ "KB_GET_OUTPUT_SCHEMA",
+ "KB_LIST_INPUT_SCHEMA",
+ "KB_LIST_OUTPUT_SCHEMA",
+ "KB_RETRIEVE_INPUT_SCHEMA",
+ "KB_RETRIEVE_OUTPUT_SCHEMA",
+ "KB_UPDATE_INPUT_SCHEMA",
+ "KB_UPDATE_OUTPUT_SCHEMA",
+ "KNOWLEDGE_BASE_CREATE_SCHEMA",
+ "KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA",
+ "KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA",
+ "KNOWLEDGE_BASE_RECORD_SCHEMA",
+ "KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA",
+ "KNOWLEDGE_BASE_UPDATE_SCHEMA",
+ "REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA",
+ "REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA",
+ "SKILL_REGISTER_INPUT_SCHEMA",
+ "SKILL_REGISTER_OUTPUT_SCHEMA",
+ "SKILL_UNREGISTER_INPUT_SCHEMA",
+ "SKILL_UNREGISTER_OUTPUT_SCHEMA",
+ "SKILL_LIST_INPUT_SCHEMA",
+ "SKILL_LIST_OUTPUT_SCHEMA",
+ "REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA",
+ "REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA",
+ "REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA",
+ "REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA",
+ "SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA",
+ "SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA",
+ "SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA",
+ "SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA",
+ "SESSION_REF_SCHEMA",
+ "SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA",
+ "SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA",
+ "SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA",
+ "SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA",
+ "SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA",
+ "SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA",
+ "SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA",
+ "SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_REACT_INPUT_SCHEMA",
+ "SYSTEM_EVENT_REACT_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA",
+ "SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA",
+ "SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA",
+ "SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA",
+ "SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA",
+ "SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA",
+ "SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA",
+ "SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA",
+ "SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/codec.py b/astrbot-sdk/src/astrbot_sdk/protocol/codec.py
new file mode 100644
index 0000000000..852648b010
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/protocol/codec.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any, cast
+
+import msgpack
+
+from .messages import ProtocolMessage, parse_message
+
+
+class ProtocolCodec(ABC):
+ @abstractmethod
+ def encode_message(self, message: ProtocolMessage) -> bytes:
+ raise NotImplementedError
+
+ @abstractmethod
+ def decode_message(
+ self,
+ payload: ProtocolMessage | bytes | str | dict[str, Any],
+ ) -> ProtocolMessage:
+ raise NotImplementedError
+
+
+class JsonProtocolCodec(ProtocolCodec):
+ def encode_message(self, message: ProtocolMessage) -> bytes:
+ return message.model_dump_json(exclude_none=True).encode("utf-8")
+
+ def decode_message(
+ self,
+ payload: ProtocolMessage | bytes | str | dict[str, Any],
+ ) -> ProtocolMessage:
+ return parse_message(payload)
+
+
+class MsgpackProtocolCodec(ProtocolCodec):
+ def encode_message(self, message: ProtocolMessage) -> bytes:
+ payload = msgpack.packb(
+ message.model_dump(exclude_none=True), use_bin_type=True
+ )
+ return cast(bytes, payload)
+
+ def decode_message(
+ self,
+ payload: ProtocolMessage | bytes | str | dict[str, Any],
+ ) -> ProtocolMessage:
+ if not isinstance(payload, bytes):
+ return parse_message(payload)
+ try:
+ unpacked = msgpack.unpackb(payload, raw=False, strict_map_key=True)
+ except (
+ msgpack.ExtraData,
+ msgpack.FormatError,
+ msgpack.StackError,
+ ValueError,
+ ) as exc:
+ raise ValueError(str(exc)) from exc
+ return parse_message(unpacked)
+
+
+__all__ = [
+ "JsonProtocolCodec",
+ "MsgpackProtocolCodec",
+ "ProtocolCodec",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py b/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py
new file mode 100644
index 0000000000..abe8b92b2d
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py
@@ -0,0 +1,413 @@
+"""s5r 协议描述符模型。
+
+`protocol` 是 s5r 新引入的协议层抽象,不对应旧树(圣诞树)中的一个同名目录。这里
+定义的是跨进程握手和调度时使用的声明式元数据,而不是运行时的具体处理器/
+能力实现。
+"""
+
+from __future__ import annotations
+
+from typing import Annotated, Any, Literal
+
+from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
+
+from . import _builtin_schemas
+from ._builtin_schemas import * # noqa: F403
+
+JSONSchema = _builtin_schemas.JSONSchema
+RESERVED_CAPABILITY_NAMESPACES = ("handler", "system", "internal")
+RESERVED_CAPABILITY_PREFIXES = tuple(
+ f"{namespace}." for namespace in RESERVED_CAPABILITY_NAMESPACES
+)
+BUILTIN_CAPABILITY_SCHEMAS = _builtin_schemas.BUILTIN_CAPABILITY_SCHEMAS
+_BUILTIN_SCHEMA_EXPORTS = frozenset(_builtin_schemas.__all__)
+
+
+def __getattr__(name: str) -> Any:
+ if name in _BUILTIN_SCHEMA_EXPORTS:
+ return getattr(_builtin_schemas, name)
+ raise AttributeError(name)
+
+
+def __dir__() -> list[str]:
+ return sorted(set(globals()) | _BUILTIN_SCHEMA_EXPORTS)
+
+
+class _DescriptorBase(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+
+class Permissions(_DescriptorBase):
+ """权限配置,控制处理器的访问权限。
+
+ Attributes:
+ require_admin: 是否需要管理员权限
+ required_role: 处理器要求的最小角色,v1 支持 member/admin
+ level: 权限等级,数值越高权限越大
+ """
+
+ require_admin: bool = False
+ required_role: Literal["member", "admin"] | None = None
+ level: int = 0
+
+ @model_validator(mode="after")
+ def normalize_required_role(self) -> Permissions:
+ if self.require_admin:
+ if self.required_role not in {None, "admin"}:
+ raise ValueError(
+ "permissions.require_admin=True conflicts with required_role="
+ f"{self.required_role!r}"
+ )
+ self.required_role = "admin"
+ return self
+ if self.required_role == "admin":
+ self.require_admin = True
+ return self
+
+
+class SessionRef(_DescriptorBase):
+ """结构化会话目标。
+
+ s5r 运行时内部仍然保留 legacy `session` 字符串作为最低兼容层,
+ 但对外模型允许同时携带平台与原始寻址信息,避免平台发送接口长期
+ 只依赖一个不透明字符串。
+ """
+
+ conversation_id: str = Field(
+ validation_alias=AliasChoices("conversation_id", "session"),
+ )
+ platform: str | None = None
+ raw: dict[str, Any] | None = None
+
+ @property
+ def session(self) -> str:
+ return self.conversation_id
+
+ def to_payload(self) -> dict[str, Any]:
+ return self.model_dump(exclude_none=True)
+
+
+class CommandTrigger(_DescriptorBase):
+ """命令触发器,响应特定命令。
+
+ Attributes:
+ type: 触发器类型,固定为 "command"
+ command: 命令名称(不含前缀,如 "help")
+ aliases: 命令别名列表
+ description: 命令描述,用于帮助文档
+ platforms: 允许的平台列表,为空表示所有平台
+ message_types: 限定的消息类型列表,为空表示不限
+ """
+
+ type: Literal["command"] = "command"
+ command: str
+ aliases: list[str] = Field(default_factory=list)
+ description: str | None = None
+ platforms: list[str] = Field(default_factory=list)
+ message_types: list[str] = Field(default_factory=list)
+
+
+class MessageTrigger(_DescriptorBase):
+ """消息触发器,描述消息类处理器的订阅条件。
+
+ Attributes:
+ type: 触发器类型,固定为 "message"
+ regex: 正则表达式模式,匹配消息文本
+ keywords: 关键词列表,消息包含任一关键词即触发
+ platforms: 目标平台列表,为空表示所有平台
+ message_types: 限定的消息类型列表,为空表示不限
+
+ Note:
+ `regex` 和 `keywords` 可以同时为空,此时表示 "任意消息均可触发",
+ 仅由平台过滤或上层运行时进一步筛选。
+ """
+
+ type: Literal["message"] = "message"
+ regex: str | None = None
+ keywords: list[str] = Field(default_factory=list)
+ platforms: list[str] = Field(default_factory=list)
+ message_types: list[str] = Field(default_factory=list)
+
+
+class EventTrigger(_DescriptorBase):
+ """事件触发器,响应特定类型的事件。
+
+ Attributes:
+ type: 触发器类型,固定为 "event"
+ event_type: 事件类型,字符串形式(如 "message"、"notice")
+ """
+
+ type: Literal["event"] = "event"
+ event_type: str
+
+
+class ScheduleTrigger(_DescriptorBase):
+ """定时触发器,按 cron 表达式或固定间隔执行。
+
+ Attributes:
+ type: 触发器类型,固定为 "schedule"
+ name: 调度任务名称,默认回退为插件 ID 与 handler ID 组合
+ cron: cron 表达式(如 "0 9 * * *" 表示每天 9 点)
+ interval_seconds: 执行间隔(秒)
+ timezone: IANA 时区名称(如 "Asia/Shanghai")
+
+ Note:
+ cron 和 interval_seconds 必须且只能有一个非空。
+ """
+
+ type: Literal["schedule"] = "schedule"
+ name: str | None = None
+ cron: str | None = Field(
+ default=None,
+ validation_alias=AliasChoices("cron", "schedule"),
+ )
+ interval_seconds: int | None = None
+ timezone: str | None = None
+
+ @property
+ def schedule(self) -> str | None:
+ return self.cron
+
+ @model_validator(mode="after")
+ def validate_schedule(self) -> ScheduleTrigger:
+ has_cron = self.cron is not None
+ has_interval = self.interval_seconds is not None
+ if has_cron == has_interval:
+ raise ValueError("cron 和 interval_seconds 必须且只能有一个非 null")
+ return self
+
+
+class PlatformFilterSpec(_DescriptorBase):
+ kind: Literal["platform"] = "platform"
+ platforms: list[str] = Field(default_factory=list)
+
+
+class MessageTypeFilterSpec(_DescriptorBase):
+ kind: Literal["message_type"] = "message_type"
+ message_types: list[str] = Field(default_factory=list)
+
+
+class LocalFilterRefSpec(_DescriptorBase):
+ kind: Literal["local"] = "local"
+ filter_id: str
+ args: dict[str, Any] = Field(default_factory=dict)
+
+
+class CompositeFilterSpec(_DescriptorBase):
+ kind: Literal["and", "or"]
+ children: list[FilterSpec] = Field(default_factory=list)
+
+
+FilterSpec = Annotated[
+ PlatformFilterSpec
+ | MessageTypeFilterSpec
+ | LocalFilterRefSpec
+ | CompositeFilterSpec,
+ Field(discriminator="kind"),
+]
+
+
+class ParamSpec(_DescriptorBase):
+ name: str
+ type: Literal["str", "int", "float", "bool", "optional", "greedy_str"]
+ required: bool = True
+ inner_type: Literal["str", "int", "float", "bool"] | None = None
+
+
+class CommandRouteSpec(_DescriptorBase):
+ group_path: list[str] = Field(default_factory=list)
+ display_command: str
+ group_help: str | None = None
+
+
+CompositeFilterSpec.model_rebuild()
+
+
+Trigger = Annotated[
+ CommandTrigger | MessageTrigger | EventTrigger | ScheduleTrigger,
+ Field(discriminator="type"),
+]
+"""触发器联合类型,使用 type 字段作为判别器自动解析具体类型。"""
+
+
+class HandlerDescriptor(_DescriptorBase):
+ """处理器描述符,描述一个事件处理函数的元信息。
+
+ Attributes:
+ id: 处理器唯一标识,通常是 "模块.函数名" 格式
+ trigger: 触发器配置,决定何时执行该处理器
+ kind: 处理器类别,默认普通 handler
+ contract: 运行时契约名,描述入参/执行语义
+ priority: 优先级,数值越大越先执行
+ permissions: 权限配置,控制谁可以触发该处理器
+
+ 使用场景:
+ HandlerDescriptor 通常由 `@on_command`、`@on_message` 等装饰器自动创建,
+ 插件作者一般不需要手动实例化。但了解其结构有助于理解插件注册机制。
+
+ 触发器类型:
+ - CommandTrigger: 响应特定命令,如 `/help`
+ - MessageTrigger: 响应消息(正则/关键词匹配)
+ - EventTrigger: 响应特定事件类型
+ - ScheduleTrigger: 定时触发
+
+ 示例:
+ 插件作者通常通过装饰器声明处理器,框架会自动生成 HandlerDescriptor:
+
+ ```python
+ from astrbot_sdk.decorators import on_command, on_message
+
+ # 命令处理器
+ @on_command("hello")
+ async def hello_handler(ctx: Context):
+ await ctx.reply("Hello!")
+
+ # 消息处理器(正则匹配)
+ @on_message(regex=r"^test\\s+(.+)$")
+ async def test_handler(ctx: Context):
+ await ctx.reply(f"收到: {ctx.match.group(1)}")
+ ```
+
+ See Also:
+ Trigger: 触发器联合类型
+ Permissions: 权限配置
+ """
+
+ id: str
+ trigger: Trigger
+ kind: Literal["handler", "hook", "tool", "session"] = "handler"
+ contract: str | None = None
+ description: str | None = None
+ priority: int = 0
+ permissions: Permissions = Field(default_factory=Permissions)
+ filters: list[FilterSpec] = Field(default_factory=list)
+ param_specs: list[ParamSpec] = Field(default_factory=list)
+ command_route: CommandRouteSpec | None = None
+
+ @model_validator(mode="after")
+ def validate_contract_defaults(self) -> HandlerDescriptor:
+ if self.contract is None:
+ if isinstance(self.trigger, ScheduleTrigger):
+ self.contract = "schedule"
+ else:
+ self.contract = "message_event"
+ return self
+
+
+class CapabilityDescriptor(_DescriptorBase):
+ """能力描述符,描述一个可调用的远程能力。
+
+ 能力命名规范:
+ - 使用 "namespace.action" 格式,如 "llm.chat"、"db.set"
+ - 支持多级命名空间,如 "llm_tool.manager.activate"
+ - 内置能力以 "internal." 开头,如 "internal.legacy.call_context_function"
+
+ 保留命名空间(插件不可使用):
+ - `handler.` - 处理器相关
+ - `system.` - 系统内部能力
+ - `internal.` - 内部实现细节
+
+ Attributes:
+ name: 能力名称,格式为 "namespace.action"
+ description: 能力描述,用于文档和调试
+ input_schema: 输入参数的 JSON Schema,用于验证
+ output_schema: 输出结果的 JSON Schema,用于验证
+ supports_stream: 是否支持流式响应
+ cancelable: 是否支持取消
+
+ 使用场景:
+ 当你的插件需要**暴露**一个可被其他插件调用的能力时,使用此类声明。
+
+ 示例:
+ ```python
+ from astrbot_sdk.protocol import CapabilityDescriptor
+
+ # 声明一个翻译能力
+ translate_desc = CapabilityDescriptor(
+ name="my_plugin.translate",
+ description="翻译文本到指定语言",
+ input_schema={
+ "type": "object",
+ "properties": {
+ "text": {"type": "string", "description": "要翻译的文本"},
+ "target_lang": {"type": "string", "description": "目标语言"},
+ },
+ "required": ["text", "target_lang"],
+ },
+ output_schema={
+ "type": "object",
+ "properties": {
+ "translated": {"type": "string"},
+ },
+ },
+ )
+
+ # 声明一个流式数据能力
+ stream_desc = CapabilityDescriptor(
+ name="my_plugin.stream_data",
+ description="流式返回数据",
+ supports_stream=True,
+ cancelable=True,
+ input_schema={"type": "object", "properties": {"count": {"type": "integer"}}},
+ output_schema={"type": "object", "properties": {"items": {"type": "array"}}},
+ )
+ ```
+
+ 注意:
+ 如果你要调用**内置能力**(如 `llm.chat`、`db.set`),不需要手动创建
+ CapabilityDescriptor,而是直接通过 `Context.invoke()` 调用,或查阅
+ `BUILTIN_CAPABILITY_SCHEMAS` 了解参数格式。
+
+ See Also:
+ BUILTIN_CAPABILITY_SCHEMAS: 内置能力的 schema 定义,用于查询参数格式
+ """
+
+ name: str
+ description: str
+ input_schema: JSONSchema | None = None
+ output_schema: JSONSchema | None = None
+ supports_stream: bool = False
+ cancelable: bool = False
+
+ @model_validator(mode="after")
+ def validate_builtin_schema_governance(self) -> CapabilityDescriptor:
+ builtin_schema = BUILTIN_CAPABILITY_SCHEMAS.get(self.name)
+ if builtin_schema is None:
+ return self
+ if self.input_schema is None or self.output_schema is None:
+ raise ValueError(
+ f"内建 capability {self.name} 必须同时提供 input_schema 和 output_schema"
+ )
+ if (
+ self.input_schema != builtin_schema["input"]
+ or self.output_schema != builtin_schema["output"]
+ ):
+ raise ValueError(
+ f"内建 capability {self.name} 的 schema 必须与协议注册表保持一致"
+ )
+ return self
+
+
+__all__ = [
+ "Trigger",
+ "BUILTIN_CAPABILITY_SCHEMAS",
+ "CapabilityDescriptor",
+ "CommandRouteSpec",
+ "CommandTrigger",
+ "CompositeFilterSpec",
+ "EventTrigger",
+ "FilterSpec",
+ "HandlerDescriptor",
+ "JSONSchema",
+ "LocalFilterRefSpec",
+ "MessageTrigger",
+ "MessageTypeFilterSpec",
+ "ParamSpec",
+ "Permissions",
+ "PlatformFilterSpec",
+ "RESERVED_CAPABILITY_NAMESPACES",
+ "RESERVED_CAPABILITY_PREFIXES",
+ "ScheduleTrigger",
+ "SessionRef",
+]
+__all__ += list(_BUILTIN_SCHEMA_EXPORTS)
diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/messages.py b/astrbot-sdk/src/astrbot_sdk/protocol/messages.py
new file mode 100644
index 0000000000..c249bf16bd
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/protocol/messages.py
@@ -0,0 +1,323 @@
+"""s5r 协议消息模型。
+
+这些模型描述的是 `Peer` 与 `Peer` 之间的线协议。握手阶段通过
+`InitializeMessage` 发起,再由 `ResultMessage(kind="initialize_result")`
+返回 `InitializeOutput`;能力调用阶段则使用 `InvokeMessage` / `ResultMessage`
+或 `EventMessage` 序列。
+
+TODO: Batch Invoke(协议 v1.1 候选特性)
+==========================================
+
+设计概要:
+ 新增 BatchInvokeMessage / BatchResultMessage,将多个独立非流式调用
+ 打包为单次 IPC 传输,减少序列化和 I/O syscall 开销。
+
+约束:
+ - 只支持非流式子调用(stream=false)
+ - 结果保序返回,但服务端内部可 asyncio.gather 并发处理
+ - 单个子调用失败不拖垮整个 batch,各自返回独立的 success/error
+ - 仅协议级错误(空 calls、重复 id、子项带 stream=true)整体失败
+ - 取消只到 batch 粒度:取消 batch ID → 取消全部未完成子调用
+
+改动范围:
+ - messages.py : 加 BatchInvokeMessage / BatchResultMessage
+ - peer.py : 加 invoke_batch() 和 _handle_batch_invoke()
+ - clients/_proxy.py : 加 call_batch()
+ - transport.py : 不动(batch 仍然是一行 JSON)
+
+暂不实现的原因(2026-03-28):
+ 1. SDK 集成(feat/sdk-integration)尚在主干开发期,协议层应保持简单稳定
+ 2. 现有 pipelining(asyncio.gather + 多行 InvokeMessage)已覆盖并发场景,
+ 单次 stdio IPC 延迟在微秒级,实测中不构成瓶颈
+ 3. peer.py 已 776 行,是协议栈核心文件,batch 会引入子调用生命周期管理、
+ 超时聚合等额外复杂度
+ 4. 目前无真实插件在单次 handler 中发出 10+ 独立 capability 调用,
+ 缺乏可测量的性能收益数据
+
+触发条件(何时重新评估):
+ - 有插件在单次 handler 中 gather 10+ 独立 capability 调用
+ - IPC 序列化/解析耗时经 profile 确认占总延迟 >5%
+ - 需要 WebSocket 传输场景下的带宽优化
+"""
+
+from __future__ import annotations
+
+import json
+from typing import Any, Literal
+
+from pydantic import BaseModel, ConfigDict, Field, model_validator
+
+from .descriptors import CapabilityDescriptor, HandlerDescriptor
+
+
+class _MessageBase(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+
+class ErrorPayload(_MessageBase):
+ """错误载荷,用于 ResultMessage 和 EventMessage 中传递错误信息。
+
+ Attributes:
+ code: 错误码,字符串类型,便于语义化错误分类
+ message: 错误消息,人类可读的错误描述
+ hint: 错误提示,可选的解决方案或建议
+ retryable: 是否可重试,标识该错误是否可通过重试解决
+ docs_url: 可选的文档链接,帮助调用方定位更多说明
+ details: 可选的结构化细节,便于调试和日志展示
+ """
+
+ code: str
+ message: str
+ hint: str = ""
+ retryable: bool = False
+ docs_url: str = ""
+ details: dict[str, Any] | None = None
+
+
+class PeerInfo(_MessageBase):
+ """对等节点信息,标识消息发送方的身份。
+
+ Attributes:
+ name: 节点名称,通常是插件 ID 或核心标识
+ role: 节点角色,"plugin" 或 "core"
+ version: 节点版本号,可选
+ """
+
+ name: str
+ role: Literal["plugin", "core"]
+ version: str | None = None
+
+
+class InitializeMessage(_MessageBase):
+ """初始化消息,用于建立连接时交换信息。
+
+ Attributes:
+ type: 消息类型,固定为 "initialize"
+ id: 消息 ID,用于关联响应
+ protocol_version: 协议版本号
+ peer: 发送方节点信息
+ handlers: 注册的处理器描述符列表
+ provided_capabilities: 发送方对外暴露的能力描述符列表
+ metadata: 扩展元数据,可存储插件配置等信息
+ """
+
+ type: Literal["initialize"] = "initialize"
+ id: str
+ protocol_version: str
+ peer: PeerInfo
+ handlers: list[HandlerDescriptor] = Field(default_factory=list)
+ provided_capabilities: list[CapabilityDescriptor] = Field(default_factory=list)
+ metadata: dict[str, Any] = Field(default_factory=dict)
+
+
+class InitializeOutput(_MessageBase):
+ """初始化输出,作为 InitializeMessage 的响应数据。
+
+ Attributes:
+ peer: 接收方(核心)节点信息
+ protocol_version: 协商后的协议版本;未协商时可为空
+ capabilities: 核心提供的能力描述符列表
+ metadata: 扩展元数据
+ """
+
+ peer: PeerInfo
+ protocol_version: str | None = None
+ capabilities: list[CapabilityDescriptor] = Field(default_factory=list)
+ metadata: dict[str, Any] = Field(default_factory=dict)
+
+
+class ResultMessage(_MessageBase):
+ """结果消息,用于返回能力调用的结果。
+
+ Attributes:
+ type: 消息类型,固定为 "result"
+ id: 关联的请求 ID
+ kind: 结果类型,可选,如 "initialize_result" 标识初始化结果
+ success: 是否成功
+ output: 成功时的输出数据
+ error: 失败时的错误信息
+ """
+
+ type: Literal["result"] = "result"
+ id: str
+ kind: str | None = None
+ success: bool
+ output: dict[str, Any] = Field(default_factory=dict)
+ error: ErrorPayload | None = None
+
+ @model_validator(mode="after")
+ def validate_result_state(self) -> ResultMessage:
+ """约束 success / output / error 的组合状态。"""
+ if self.success:
+ if self.error is not None:
+ raise ValueError("success=true 时 error 必须为空")
+ return self
+ if self.error is None:
+ raise ValueError("success=false 时必须提供 error")
+ if self.output:
+ raise ValueError("success=false 时 output 必须为空")
+ return self
+
+
+class InvokeMessage(_MessageBase):
+ """调用消息,用于请求执行远程能力。
+
+ Attributes:
+ type: 消息类型,固定为 "invoke"
+ id: 请求 ID,用于关联响应
+ capability: 目标能力名称,格式为 "namespace.action"
+ input: 调用输入参数
+ stream: 是否期望流式响应,若为 True 将收到 EventMessage 序列
+ caller_plugin_id: 运行时透传的调用方插件 ID,不属于业务 payload
+ """
+
+ type: Literal["invoke"] = "invoke"
+ id: str
+ capability: str
+ input: dict[str, Any] = Field(default_factory=dict)
+ stream: bool = False
+ caller_plugin_id: str | None = None
+
+
+class EventMessage(_MessageBase):
+ """事件消息,用于流式调用的状态通知。
+
+ 流式调用生命周期:
+ 1. started: 调用开始,所有字段为空
+ 2. delta: 数据增量更新,包含 data 字段
+ 3. completed: 调用完成,包含 output 字段
+ 4. failed: 调用失败,包含 error 字段
+
+ Attributes:
+ type: 消息类型,固定为 "event"
+ id: 关联的请求 ID
+ phase: 事件阶段,started/delta/completed/failed
+ data: 增量数据,仅 delta 阶段有效
+ output: 最终输出,仅 completed 阶段有效
+ error: 错误信息,仅 failed 阶段有效
+ """
+
+ type: Literal["event"] = "event"
+ id: str
+ phase: Literal["started", "delta", "completed", "failed"]
+ data: dict[str, Any] = Field(default_factory=dict)
+ output: dict[str, Any] = Field(default_factory=dict)
+ error: ErrorPayload | None = None
+
+ @model_validator(mode="after")
+ def validate_phase_constraints(self) -> EventMessage:
+ """验证各 phase 的字段约束。
+
+ - started: 所有字段必须为空
+ - delta: 必须有 data,output/error 必须为空
+ - completed: 必须有 output,data/error 必须为空
+ - failed: 必须有 error,data/output 必须为空
+ """
+ phase = self.phase
+ if phase == "started":
+ if self.data or self.output or self.error:
+ raise ValueError("started phase 必须所有字段为空")
+ elif phase == "delta":
+ if not self.data:
+ raise ValueError("delta phase 需要 data")
+ if self.output or self.error:
+ raise ValueError("delta phase 的 output/error 必须为空")
+ elif phase == "completed":
+ if not self.output:
+ raise ValueError("completed phase 需要 output")
+ if self.data or self.error:
+ raise ValueError("completed phase 的 data/error 必须为空")
+ elif phase == "failed":
+ if self.error is None:
+ raise ValueError("failed phase 需要 error")
+ if self.data or self.output:
+ raise ValueError("failed phase 的 data/output 必须为空")
+ return self
+
+
+class CancelMessage(_MessageBase):
+ """取消消息,用于取消正在进行的调用。
+
+ Attributes:
+ type: 消息类型,固定为 "cancel"
+ id: 要取消的请求 ID
+ reason: 取消原因,默认为 "user_cancelled"
+ """
+
+ type: Literal["cancel"] = "cancel"
+ id: str
+ reason: str = "user_cancelled"
+
+
+ProtocolMessage = (
+ InitializeMessage | ResultMessage | InvokeMessage | EventMessage | CancelMessage
+)
+"""协议消息联合类型,所有有效消息类型的联合。"""
+
+_PROTOCOL_MESSAGE_MODELS = {
+ "initialize": InitializeMessage,
+ "result": ResultMessage,
+ "invoke": InvokeMessage,
+ "event": EventMessage,
+ "cancel": CancelMessage,
+}
+
+
+def parse_message(
+ payload: ProtocolMessage | str | bytes | dict[str, Any],
+) -> ProtocolMessage:
+ """解析协议消息。
+
+ 从原始载荷(字符串、字节或字典)解析为对应的 ProtocolMessage 类型。
+ 根据 "type" 字段自动识别消息类型并验证。
+
+ Args:
+ payload: 原始消息载荷,支持已解析模型、JSON 字符串、字节或字典
+
+ Returns:
+ 解析后的协议消息对象
+
+ Raises:
+ ValueError: 未知的消息类型
+
+ Example:
+ >>> msg = parse_message('{"type": "invoke", "id": "1", "capability": "test"}')
+ >>> isinstance(msg, InvokeMessage)
+ True
+ """
+ if isinstance(
+ payload,
+ (
+ InitializeMessage,
+ ResultMessage,
+ InvokeMessage,
+ EventMessage,
+ CancelMessage,
+ ),
+ ):
+ return payload
+ if isinstance(payload, bytes):
+ payload = payload.decode("utf-8")
+ if isinstance(payload, str):
+ payload = json.loads(payload)
+ if not isinstance(payload, dict):
+ raise ValueError("协议消息必须是 JSON object")
+ message_type = payload.get("type")
+ model = _PROTOCOL_MESSAGE_MODELS.get(str(message_type))
+ if model is not None:
+ return model.model_validate(payload)
+ raise ValueError(f"未知消息类型:{message_type}")
+
+
+__all__ = [
+ "CancelMessage",
+ "ErrorPayload",
+ "EventMessage",
+ "InitializeMessage",
+ "InitializeOutput",
+ "InvokeMessage",
+ "PeerInfo",
+ "ProtocolMessage",
+ "ResultMessage",
+ "parse_message",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py
new file mode 100644
index 0000000000..7601f745c2
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py
@@ -0,0 +1,63 @@
+"""AstrBot SDK runtime public exports.
+
+本模块提供运行时核心组件的公共导出,包括:
+- CapabilityRouter: 能力路由器,处理能力调用的分发和路由
+- HandlerDispatcher: 事件处理器分发器,将事件分发到注册的 handler
+- Peer: 与 AstrBot 核心通信的对等端抽象
+- Transport 系列: 进程间通信传输层实现(stdio/websocket)
+
+延迟加载策略:
+为避免导入时触发 websocket/aiohttp 等重型依赖,采用 __getattr__ 实现按需加载。
+这样轻量级导入(如仅使用类型提示)不会产生不必要的依赖开销。
+"""
+
+from __future__ import annotations
+
+from importlib import import_module
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from .capability_router import CapabilityRouter, StreamExecution
+ from .handler_dispatcher import HandlerDispatcher
+ from .peer import Peer
+ from .transport import (
+ MessageHandler,
+ StdioTransport,
+ Transport,
+ WebSocketClientTransport,
+ WebSocketServerTransport,
+ )
+
+__all__ = [
+ "CapabilityRouter",
+ "HandlerDispatcher",
+ "MessageHandler",
+ "Peer",
+ "StdioTransport",
+ "StreamExecution",
+ "Transport",
+ "WebSocketClientTransport",
+ "WebSocketServerTransport",
+]
+
+
+def __getattr__(name: str) -> Any:
+ if name in {"CapabilityRouter", "StreamExecution"}:
+ module = import_module(".capability_router", __name__)
+ return getattr(module, name)
+ if name == "HandlerDispatcher":
+ module = import_module(".handler_dispatcher", __name__)
+ return getattr(module, name)
+ if name == "Peer":
+ module = import_module(".peer", __name__)
+ return getattr(module, name)
+ if name in {
+ "MessageHandler",
+ "StdioTransport",
+ "Transport",
+ "WebSocketClientTransport",
+ "WebSocketServerTransport",
+ }:
+ module = import_module(".transport", __name__)
+ return getattr(module, name)
+ raise AttributeError(name)
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py
new file mode 100644
index 0000000000..ce168e2883
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py
@@ -0,0 +1,62 @@
+from __future__ import annotations
+
+from .bridge_base import CapabilityRouterBridgeBase
+from .capabilities import (
+ ConversationCapabilityMixin,
+ DBCapabilityMixin,
+ HttpCapabilityMixin,
+ KnowledgeBaseCapabilityMixin,
+ LLMCapabilityMixin,
+ MemoryCapabilityMixin,
+ MessageHistoryCapabilityMixin,
+ MetadataCapabilityMixin,
+ PermissionCapabilityMixin,
+ PersonaCapabilityMixin,
+ PlatformCapabilityMixin,
+ ProviderCapabilityMixin,
+ SessionCapabilityMixin,
+ SkillCapabilityMixin,
+ SystemCapabilityMixin,
+)
+
+
+class BuiltinCapabilityRouterMixin(
+ LLMCapabilityMixin,
+ MemoryCapabilityMixin,
+ DBCapabilityMixin,
+ PlatformCapabilityMixin,
+ HttpCapabilityMixin,
+ MetadataCapabilityMixin,
+ PermissionCapabilityMixin,
+ ProviderCapabilityMixin,
+ SessionCapabilityMixin,
+ SkillCapabilityMixin,
+ PersonaCapabilityMixin,
+ ConversationCapabilityMixin,
+ MessageHistoryCapabilityMixin,
+ KnowledgeBaseCapabilityMixin,
+ SystemCapabilityMixin,
+ CapabilityRouterBridgeBase,
+):
+ def _register_builtin_capabilities(self) -> None:
+ self._register_llm_capabilities()
+ self._register_memory_capabilities()
+ self._register_db_capabilities()
+ self._register_platform_capabilities()
+ self._register_http_capabilities()
+ self._register_metadata_capabilities()
+ self._register_permission_capabilities()
+ self._register_provider_capabilities()
+ self._register_agent_tool_capabilities()
+ self._register_session_capabilities()
+ self._register_skill_capabilities()
+ self._register_persona_capabilities()
+ self._register_conversation_capabilities()
+ self._register_message_history_capabilities()
+ self._register_kb_capabilities()
+ self._register_provider_manager_capabilities()
+ self._register_platform_manager_capabilities()
+ self._register_system_capabilities()
+
+
+__all__ = ["BuiltinCapabilityRouterMixin"]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py
new file mode 100644
index 0000000000..6d31ba6f2c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py
@@ -0,0 +1,126 @@
+from __future__ import annotations
+
+import asyncio
+from datetime import datetime
+from pathlib import Path
+from typing import Any
+
+from ...protocol.descriptors import CapabilityDescriptor
+
+
+class CapabilityRouterHost:
+ memory_store: dict[str, dict[str, Any]]
+ _memory_backends: dict[str, Any]
+ _memory_index: dict[str, dict[str, Any]]
+ _memory_dirty_keys: set[str]
+ _memory_expires_at: dict[str, datetime | None]
+ db_store: dict[str, Any]
+ sent_messages: list[dict[str, Any]]
+ event_actions: list[dict[str, Any]]
+ http_api_store: list[dict[str, Any]]
+ _event_streams: dict[str, dict[str, Any]]
+ _plugins: dict[str, Any]
+ _request_overlays: dict[str, dict[str, Any]]
+ _provider_catalog: dict[str, list[dict[str, Any]]]
+ _provider_configs: dict[str, dict[str, Any]]
+ _active_provider_ids: dict[str, str | None]
+ _provider_change_subscriptions: dict[str, asyncio.Queue[dict[str, Any]]]
+ _system_data_root: Path
+ _session_waiters: dict[str, set[str]]
+ _session_plugin_configs: dict[str, dict[str, Any]]
+ _session_service_configs: dict[str, dict[str, Any]]
+ _db_watch_subscriptions: dict[str, tuple[str | None, asyncio.Queue[dict[str, Any]]]]
+ _dynamic_command_routes: dict[str, list[dict[str, Any]]]
+ _file_token_store: dict[str, str]
+ _platform_instances: list[dict[str, Any]]
+ _persona_store: dict[str, dict[str, Any]]
+ _conversation_store: dict[str, dict[str, Any]]
+ _session_current_conversation_ids: dict[str, str]
+ _kb_store: dict[str, dict[str, Any]]
+ _kb_document_store: dict[str, dict[str, dict[str, Any]]]
+ _kb_document_content_store: dict[str, str]
+
+ def register(
+ self,
+ descriptor: CapabilityDescriptor,
+ *,
+ call_handler=None,
+ stream_handler=None,
+ finalize=None,
+ exposed: bool = True,
+ ) -> None:
+ raise NotImplementedError
+
+ def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None:
+ raise NotImplementedError
+
+ @staticmethod
+ def _require_caller_plugin_id(capability_name: str) -> str:
+ raise NotImplementedError
+
+ @staticmethod
+ def _validated_plugin_id(plugin_id: str, *, capability_name: str) -> str:
+ raise NotImplementedError
+
+ def _plugin_data_dir(self, plugin_id: str, *, capability_name: str) -> Path:
+ raise NotImplementedError
+
+ def register_dynamic_command_route(
+ self,
+ *,
+ plugin_id: str,
+ command_name: str,
+ handler_full_name: str,
+ desc: str = "",
+ priority: int = 0,
+ use_regex: bool = False,
+ ) -> None:
+ raise NotImplementedError
+
+ def get_platform_instances(self) -> list[dict[str, Any]]:
+ raise NotImplementedError
+
+ @staticmethod
+ def _normalize_platform_name(value: Any) -> str:
+ raise NotImplementedError
+
+ @classmethod
+ def _normalized_platform_names(cls, values: Any) -> set[str]:
+ raise NotImplementedError
+
+ def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool:
+ raise NotImplementedError
+
+ def _platform_name_from_id(self, platform_id: str) -> str:
+ raise NotImplementedError
+
+ def _session_platform_name(self, session: str) -> str:
+ raise NotImplementedError
+
+ def _require_platform_support_for_session(
+ self,
+ capability_name: str,
+ session: str,
+ ) -> str:
+ raise NotImplementedError
+
+ def _register_agent_tool_capabilities(self) -> None:
+ raise NotImplementedError
+
+ def _provider_entry(
+ self,
+ payload: dict[str, Any],
+ capability_name: str,
+ expected_kind: str | None = None,
+ ) -> dict[str, Any]:
+ raise NotImplementedError
+
+ async def _provider_embedding_get_embedding(
+ self, request_id: str, payload: dict[str, Any], token
+ ) -> dict[str, Any]:
+ raise NotImplementedError
+
+ async def _provider_embedding_get_embeddings(
+ self, request_id: str, payload: dict[str, Any], token
+ ) -> dict[str, Any]:
+ raise NotImplementedError
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py
new file mode 100644
index 0000000000..f1e36516fe
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py
@@ -0,0 +1,246 @@
+from __future__ import annotations
+
+import copy
+import hashlib
+import math
+import re
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+from ..._internal.plugin_ids import resolve_plugin_data_dir, validate_plugin_id
+from ...errors import AstrBotError
+from ...protocol.descriptors import (
+ BUILTIN_CAPABILITY_SCHEMAS,
+ CapabilityDescriptor,
+ SessionRef,
+)
+from ._host import CapabilityRouterHost
+
+
+def _clone_target_payload(value: Any) -> dict[str, Any] | None:
+ if not isinstance(value, dict):
+ return None
+ return {str(key): item for key, item in value.items()}
+
+
+def _clone_chain_payload(value: Any) -> list[dict[str, Any]]:
+ if not isinstance(value, list):
+ return []
+ return [
+ {str(key): item for key, item in chunk.items()}
+ for chunk in value
+ if isinstance(chunk, dict)
+ ]
+
+
+_MOCK_EMBEDDING_DIM = 24
+
+
+def _embedding_terms(text: str) -> list[str]:
+ """Build stable tokens for the mock embedding implementation."""
+ normalized = re.sub(r"\s+", " ", str(text).strip().casefold())
+ compact = normalized.replace(" ", "")
+ if not normalized:
+ return []
+
+ terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word]
+ if compact:
+ if len(compact) == 1:
+ terms.append(compact)
+ else:
+ terms.extend(
+ compact[index : index + 2] for index in range(len(compact) - 1)
+ )
+ terms.append(compact)
+ return terms or [normalized]
+
+
+def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
+ """Generate a deterministic normalized mock embedding vector."""
+ values = [0.0] * _MOCK_EMBEDDING_DIM
+ for term in _embedding_terms(text):
+ digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest()
+ index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
+ values[index] += 1.0 + min(len(term), 8) * 0.05
+ norm = math.sqrt(sum(value * value for value in values))
+ if norm <= 0:
+ return values
+ return [value / norm for value in values]
+
+
+class CapabilityRouterBridgeBase(CapabilityRouterHost):
+ _memory_backends: dict[str, Any]
+
+ @staticmethod
+ def _normalize_platform_name(value: Any) -> str:
+ return str(value or "").strip().lower()
+
+ @classmethod
+ def _normalized_platform_names(cls, values: Any) -> set[str]:
+ if not isinstance(values, list):
+ return set()
+ return {
+ cls._normalize_platform_name(item)
+ for item in values
+ if cls._normalize_platform_name(item)
+ }
+
+ @staticmethod
+ def _validated_plugin_id(plugin_id: str, *, capability_name: str) -> str:
+ try:
+ return validate_plugin_id(plugin_id)
+ except ValueError as exc:
+ raise AstrBotError.invalid_input(
+ f"{capability_name} requires a safe plugin_id: {exc}"
+ ) from exc
+
+ def _plugin_data_dir(self, plugin_id: str, *, capability_name: str) -> Path:
+ try:
+ return resolve_plugin_data_dir(self._system_data_root, plugin_id)
+ except ValueError as exc:
+ raise AstrBotError.invalid_input(
+ f"{capability_name} requires a safe plugin_id: {exc}"
+ ) from exc
+
+ def _builtin_descriptor(
+ self,
+ name: str,
+ description: str,
+ *,
+ supports_stream: bool = False,
+ cancelable: bool = False,
+ ) -> CapabilityDescriptor:
+ schema = BUILTIN_CAPABILITY_SCHEMAS[name]
+ return CapabilityDescriptor(
+ name=name,
+ description=description,
+ input_schema=copy.deepcopy(schema["input"]),
+ output_schema=copy.deepcopy(schema["output"]),
+ supports_stream=supports_stream,
+ cancelable=cancelable,
+ )
+
+ def _resolve_target(
+ self, payload: dict[str, Any]
+ ) -> tuple[str, dict[str, Any] | None]:
+ target_payload = payload.get("target")
+ if isinstance(target_payload, dict):
+ target = SessionRef.model_validate(target_payload)
+ return target.session, target.to_payload()
+ return str(payload.get("session", "")), None
+
+ @staticmethod
+ def _is_group_session(session: str) -> bool:
+ normalized = str(session).lower()
+ return ":group:" in normalized or ":groupmessage:" in normalized
+
+ @staticmethod
+ def _mock_group_payload(session: str) -> dict[str, Any] | None:
+ if not CapabilityRouterBridgeBase._is_group_session(session):
+ return None
+ members = [
+ {
+ "user_id": f"{session}:member-1",
+ "nickname": "Member 1",
+ "role": "member",
+ },
+ {
+ "user_id": f"{session}:member-2",
+ "nickname": "Member 2",
+ "role": "admin",
+ },
+ ]
+ return {
+ "group_id": session.rsplit(":", maxsplit=1)[-1],
+ "group_name": f"Mock Group {session.rsplit(':', maxsplit=1)[-1]}",
+ "group_avatar": "",
+ "group_owner": members[0]["user_id"],
+ "group_admins": [members[1]["user_id"]],
+ "members": members,
+ }
+
+ def _session_plugin_config(self, session: str) -> dict[str, Any]:
+ config = self._session_plugin_configs.get(str(session), {})
+ return dict(config) if isinstance(config, dict) else {}
+
+ def _session_service_config(self, session: str) -> dict[str, Any]:
+ config = self._session_service_configs.get(str(session), {})
+ return dict(config) if isinstance(config, dict) else {}
+
+ @staticmethod
+ def _now_iso() -> str:
+ return datetime.now(timezone.utc).isoformat()
+
+ @staticmethod
+ def _session_platform_id(session: str) -> str:
+ parts = str(session).split(":", maxsplit=1)
+ if parts and parts[0].strip():
+ return parts[0].strip()
+ return "unknown"
+
+ def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool:
+ normalized_platform = self._normalize_platform_name(platform_name)
+ if not normalized_platform:
+ return True
+ plugin = self._plugins.get(str(plugin_id))
+ if plugin is None:
+ return True
+ metadata = getattr(plugin, "metadata", None)
+ if not isinstance(metadata, dict):
+ return True
+ supported = self._normalized_platform_names(metadata.get("support_platforms"))
+ if not supported:
+ return True
+ return normalized_platform in supported
+
+ def _platform_name_from_id(self, platform_id: str) -> str:
+ normalized_platform_id = str(platform_id).strip()
+ if not normalized_platform_id:
+ return ""
+ for item in self.get_platform_instances():
+ if not isinstance(item, dict):
+ continue
+ if str(item.get("id", "")).strip() != normalized_platform_id:
+ continue
+ return self._normalize_platform_name(item.get("type"))
+ return ""
+
+ def _session_platform_name(self, session: str) -> str:
+ return self._platform_name_from_id(self._session_platform_id(session))
+
+ def _require_platform_support_for_session(
+ self,
+ capability_name: str,
+ session: str,
+ ) -> str:
+ plugin_id = self._require_caller_plugin_id(capability_name)
+ platform_name = self._session_platform_name(session)
+ if not platform_name or self._plugin_supports_platform(
+ plugin_id, platform_name
+ ):
+ return plugin_id
+ raise AstrBotError.invalid_input(
+ f"{capability_name} does not support platform '{platform_name}' for plugin '{plugin_id}'"
+ )
+
+ @staticmethod
+ def _normalize_history_payload(value: Any) -> list[dict[str, Any]]:
+ if not isinstance(value, list):
+ return []
+ return [dict(item) for item in value if isinstance(item, dict)]
+
+ @staticmethod
+ def _normalize_persona_dialogs_payload(value: Any) -> list[str]:
+ if not isinstance(value, list):
+ return []
+ return [str(item) for item in value if isinstance(item, str)]
+
+ @staticmethod
+ def _optional_int(value: Any) -> int | None:
+ if value is None:
+ return None
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return None
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py
new file mode 100644
index 0000000000..0c8b01c741
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py
@@ -0,0 +1,33 @@
+from .conversation import ConversationCapabilityMixin
+from .db import DBCapabilityMixin
+from .http import HttpCapabilityMixin
+from .kb import KnowledgeBaseCapabilityMixin
+from .llm import LLMCapabilityMixin
+from .memory import MemoryCapabilityMixin
+from .message_history import MessageHistoryCapabilityMixin
+from .metadata import MetadataCapabilityMixin
+from .permission import PermissionCapabilityMixin
+from .persona import PersonaCapabilityMixin
+from .platform import PlatformCapabilityMixin
+from .provider import ProviderCapabilityMixin
+from .session import SessionCapabilityMixin
+from .skill import SkillCapabilityMixin
+from .system import SystemCapabilityMixin
+
+__all__ = [
+ "ConversationCapabilityMixin",
+ "DBCapabilityMixin",
+ "HttpCapabilityMixin",
+ "KnowledgeBaseCapabilityMixin",
+ "LLMCapabilityMixin",
+ "MemoryCapabilityMixin",
+ "MessageHistoryCapabilityMixin",
+ "MetadataCapabilityMixin",
+ "PermissionCapabilityMixin",
+ "PersonaCapabilityMixin",
+ "PlatformCapabilityMixin",
+ "ProviderCapabilityMixin",
+ "SessionCapabilityMixin",
+ "SkillCapabilityMixin",
+ "SystemCapabilityMixin",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py
new file mode 100644
index 0000000000..a250f43e5a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py
@@ -0,0 +1,261 @@
+from __future__ import annotations
+
+import uuid
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class ConversationCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _conversation_new(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ if not session:
+ raise AstrBotError.invalid_input("conversation.new requires session")
+ raw_conversation = payload.get("conversation")
+ if raw_conversation is None:
+ raw_conversation = {}
+ if not isinstance(raw_conversation, dict):
+ raise AstrBotError.invalid_input(
+ "conversation.new requires conversation object"
+ )
+ conversation_id = uuid.uuid4().hex
+ now = self._now_iso()
+ record = {
+ "conversation_id": conversation_id,
+ "session": session,
+ "platform_id": (
+ str(raw_conversation.get("platform_id"))
+ if raw_conversation.get("platform_id") is not None
+ else self._session_platform_id(session)
+ ),
+ "history": self._normalize_history_payload(raw_conversation.get("history")),
+ "title": (
+ str(raw_conversation.get("title"))
+ if raw_conversation.get("title") is not None
+ else None
+ ),
+ "persona_id": (
+ str(raw_conversation.get("persona_id"))
+ if raw_conversation.get("persona_id") is not None
+ else None
+ ),
+ "created_at": now,
+ "updated_at": now,
+ "token_usage": None,
+ }
+ self._conversation_store[conversation_id] = record
+ self._session_current_conversation_ids[session] = conversation_id
+ return {"conversation_id": conversation_id}
+
+ async def _conversation_switch(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ conversation_id = str(payload.get("conversation_id", "")).strip()
+ record = self._conversation_store.get(conversation_id)
+ if record is None or str(record.get("session", "")) != session:
+ raise AstrBotError.invalid_input(
+ "conversation.switch requires a conversation in the same session"
+ )
+ self._session_current_conversation_ids[session] = conversation_id
+ return {}
+
+ async def _conversation_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ conversation_id = payload.get("conversation_id")
+ normalized_conversation_id = (
+ str(conversation_id).strip() if conversation_id is not None else ""
+ )
+ if not normalized_conversation_id:
+ normalized_conversation_id = self._session_current_conversation_ids.get(
+ session, ""
+ )
+ if not normalized_conversation_id:
+ return {}
+ record = self._conversation_store.get(normalized_conversation_id)
+ if record is None:
+ return {}
+ if str(record.get("session", "")) != session:
+ raise AstrBotError.invalid_input(
+ "conversation.delete requires a conversation in the same session"
+ )
+ del self._conversation_store[normalized_conversation_id]
+ current_conversation_id = self._session_current_conversation_ids.get(session)
+ if current_conversation_id == normalized_conversation_id:
+ replacement = next(
+ (
+ conversation_id
+ for conversation_id, item in self._conversation_store.items()
+ if str(item.get("session", "")) == session
+ ),
+ None,
+ )
+ if replacement is None:
+ self._session_current_conversation_ids.pop(session, None)
+ else:
+ self._session_current_conversation_ids[session] = replacement
+ return {}
+
+ async def _conversation_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ conversation_id = str(payload.get("conversation_id", "")).strip()
+ record = self._conversation_store.get(conversation_id)
+ if record is None and bool(payload.get("create_if_not_exists", False)):
+ created = await self._conversation_new(
+ _request_id,
+ {"session": session, "conversation": {}},
+ _token,
+ )
+ record = self._conversation_store.get(
+ str(created.get("conversation_id", "")).strip()
+ )
+ if record is None:
+ return {"conversation": None}
+ if str(record.get("session", "")) != session:
+ return {"conversation": None}
+ return {"conversation": dict(record)}
+
+ async def _conversation_get_current(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ conversation_id = self._session_current_conversation_ids.get(session, "")
+ if not conversation_id and bool(payload.get("create_if_not_exists", False)):
+ created = await self._conversation_new(
+ _request_id,
+ {"session": session, "conversation": {}},
+ _token,
+ )
+ conversation_id = str(created.get("conversation_id", "")).strip()
+ if not conversation_id:
+ return {"conversation": None}
+ record = self._conversation_store.get(conversation_id)
+ if record is None or str(record.get("session", "")) != session:
+ return {"conversation": None}
+ return {"conversation": dict(record)}
+
+ async def _conversation_list(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = payload.get("session")
+ platform_id = payload.get("platform_id")
+ conversations = []
+ for conversation_id in sorted(self._conversation_store.keys()):
+ item = self._conversation_store[conversation_id]
+ if session is not None and str(item.get("session", "")) != str(session):
+ continue
+ if platform_id is not None and str(item.get("platform_id", "")) != str(
+ platform_id
+ ):
+ continue
+ conversations.append(dict(item))
+ return {"conversations": conversations}
+
+ async def _conversation_update(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ conversation_id = payload.get("conversation_id")
+ normalized_conversation_id = (
+ str(conversation_id).strip() if conversation_id is not None else ""
+ )
+ if not normalized_conversation_id:
+ normalized_conversation_id = self._session_current_conversation_ids.get(
+ session, ""
+ )
+ if not normalized_conversation_id:
+ return {}
+ record = self._conversation_store.get(normalized_conversation_id)
+ if record is None:
+ return {}
+ if str(record.get("session", "")) != session:
+ raise AstrBotError.invalid_input(
+ "conversation.update requires a conversation in the same session"
+ )
+ raw_conversation = payload.get("conversation")
+ if not isinstance(raw_conversation, dict):
+ raw_conversation = {}
+ if "history" in raw_conversation:
+ history = raw_conversation.get("history")
+ record["history"] = (
+ self._normalize_history_payload(history) if history is not None else []
+ )
+ if "title" in raw_conversation:
+ title = raw_conversation.get("title")
+ record["title"] = str(title) if title is not None else None
+ if "persona_id" in raw_conversation:
+ persona_id = raw_conversation.get("persona_id")
+ record["persona_id"] = str(persona_id) if persona_id is not None else None
+ if "token_usage" in raw_conversation:
+ token_usage = raw_conversation.get("token_usage")
+ record["token_usage"] = (
+ int(token_usage) if token_usage is not None else None
+ )
+ record["updated_at"] = self._now_iso()
+ return {}
+
+ async def _conversation_unset_persona(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", "")).strip()
+ conversation_id = payload.get("conversation_id")
+ normalized_conversation_id = (
+ str(conversation_id).strip() if conversation_id is not None else ""
+ )
+ if not normalized_conversation_id:
+ normalized_conversation_id = self._session_current_conversation_ids.get(
+ session, ""
+ )
+ if not normalized_conversation_id:
+ return {}
+ record = self._conversation_store.get(normalized_conversation_id)
+ if record is None:
+ return {}
+ if str(record.get("session", "")) != session:
+ raise AstrBotError.invalid_input(
+ "conversation.unset_persona requires a conversation in the same session"
+ )
+ record["persona_id"] = None
+ record["updated_at"] = self._now_iso()
+ return {}
+
+ def _register_conversation_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("conversation.new", "新建对话"),
+ call_handler=self._conversation_new,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.switch", "切换对话"),
+ call_handler=self._conversation_switch,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.delete", "删除对话"),
+ call_handler=self._conversation_delete,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.get", "获取对话"),
+ call_handler=self._conversation_get,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.get_current", "获取当前对话"),
+ call_handler=self._conversation_get_current,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.list", "列出对话"),
+ call_handler=self._conversation_list,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.update", "更新对话"),
+ call_handler=self._conversation_update,
+ )
+ self.register(
+ self._builtin_descriptor("conversation.unset_persona", "清空对话人格"),
+ call_handler=self._conversation_unset_persona,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py
new file mode 100644
index 0000000000..f8bdfedf9a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py
@@ -0,0 +1,170 @@
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from ....errors import AstrBotError
+from ..._streaming import StreamExecution
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class DBCapabilityMixin(CapabilityRouterBridgeBase):
+ def _db_scoped_key(self, plugin_id: str, key: str) -> str:
+ """将用户提供的 key 加上插件命名空间前缀,防止跨插件越权访问。"""
+ return f"{plugin_id}:{key}"
+
+ def _db_strip_scope(self, plugin_id: str, scoped_key: str) -> str:
+ """去掉命名空间前缀,返回插件视角的原始 key。"""
+ prefix = f"{plugin_id}:"
+ return (
+ scoped_key[len(prefix) :] if scoped_key.startswith(prefix) else scoped_key
+ )
+
+ def _db_public_event(
+ self, plugin_id: str, raw_event: dict[str, Any]
+ ) -> dict[str, Any]:
+ """将内部事件转换回插件可见的 key 视图。"""
+ event = dict(raw_event)
+ key = event.get("key")
+ if isinstance(key, str):
+ event["key"] = self._db_strip_scope(plugin_id, key)
+ return event
+
+ async def _db_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("db.get")
+ key = self._db_scoped_key(plugin_id, str(payload.get("key", "")))
+ return {"value": self.db_store.get(key)}
+
+ async def _db_set(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("db.set")
+ key = self._db_scoped_key(plugin_id, str(payload.get("key", "")))
+ value = payload.get("value")
+ self.db_store[key] = value
+ self._emit_db_change(op="set", key=key, value=value)
+ return {}
+
+ async def _db_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("db.delete")
+ key = self._db_scoped_key(plugin_id, str(payload.get("key", "")))
+ self.db_store.pop(key, None)
+ self._emit_db_change(op="delete", key=key, value=None)
+ return {}
+
+ async def _db_list(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("db.list")
+ ns_prefix = f"{plugin_id}:"
+ # 只列出属于当前插件命名空间的 key,并去掉命名空间前缀返回给插件
+ user_prefix = payload.get("prefix")
+ all_keys = sorted(
+ key for key in self.db_store.keys() if key.startswith(ns_prefix)
+ )
+ stripped = [self._db_strip_scope(plugin_id, k) for k in all_keys]
+ if isinstance(user_prefix, str):
+ stripped = [k for k in stripped if k.startswith(user_prefix)]
+ return {"keys": stripped}
+
+ async def _db_get_many(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("db.get_many")
+ keys_payload = payload.get("keys")
+ if not isinstance(keys_payload, (list, tuple)):
+ raise AstrBotError.invalid_input("db.get_many 的 keys 必须是数组")
+ items = [
+ {
+ "key": str(k),
+ "value": self.db_store.get(self._db_scoped_key(plugin_id, str(k))),
+ }
+ for k in keys_payload
+ ]
+ return {"items": items}
+
+ async def _db_set_many(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("db.set_many")
+ items_payload = payload.get("items")
+ if not isinstance(items_payload, (list, tuple)):
+ raise AstrBotError.invalid_input("db.set_many 的 items 必须是数组")
+ for entry in items_payload:
+ if not isinstance(entry, dict):
+ raise AstrBotError.invalid_input(
+ "db.set_many 的 items 必须是 object 数组"
+ )
+ key = self._db_scoped_key(plugin_id, str(entry.get("key", "")))
+ value = entry.get("value")
+ self.db_store[key] = value
+ self._emit_db_change(op="set", key=key, value=value)
+ return {}
+
+ async def _db_watch(
+ self, request_id: str, payload: dict[str, Any], _token
+ ) -> StreamExecution:
+ plugin_id = self._require_caller_plugin_id("db.watch")
+ prefix = payload.get("prefix")
+ prefix_value: str | None
+ if isinstance(prefix, str):
+ # 将用户传入的前缀也加上命名空间,只监听本插件的 key 变更
+ prefix_value = self._db_scoped_key(plugin_id, prefix)
+ elif prefix is None:
+ # 无前缀时默认监听整个命名空间
+ prefix_value = f"{plugin_id}:"
+ else:
+ raise AstrBotError.invalid_input("db.watch 的 prefix 必须是 string 或 null")
+
+ queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
+ self._db_watch_subscriptions[request_id] = (prefix_value, queue)
+
+ async def iterator() -> AsyncIterator[dict[str, Any]]:
+ try:
+ while True:
+ yield self._db_public_event(plugin_id, await queue.get())
+ finally:
+ self._db_watch_subscriptions.pop(request_id, None)
+
+ return StreamExecution(
+ iterator=iterator(),
+ finalize=lambda _chunks: {},
+ collect_chunks=False,
+ )
+
+ def _register_db_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("db.get", "读取 KV"), call_handler=self._db_get
+ )
+ self.register(
+ self._builtin_descriptor("db.set", "写入 KV"), call_handler=self._db_set
+ )
+ self.register(
+ self._builtin_descriptor("db.delete", "删除 KV"),
+ call_handler=self._db_delete,
+ )
+ self.register(
+ self._builtin_descriptor("db.list", "列出 KV"), call_handler=self._db_list
+ )
+ self.register(
+ self._builtin_descriptor("db.get_many", "批量读取 KV"),
+ call_handler=self._db_get_many,
+ )
+ self.register(
+ self._builtin_descriptor("db.set_many", "批量写入 KV"),
+ call_handler=self._db_set_many,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "db.watch",
+ "订阅 KV 变更",
+ supports_stream=True,
+ cancelable=True,
+ ),
+ stream_handler=self._db_watch,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py
new file mode 100644
index 0000000000..c0e6e59bbf
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py
@@ -0,0 +1,169 @@
+from __future__ import annotations
+
+import re
+from typing import Any
+
+from ...._internal.plugin_ids import (
+ capability_belongs_to_plugin,
+ http_route_belongs_to_plugin,
+ plugin_capability_prefix,
+ plugin_http_route_root,
+)
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+# 路由只允许字母、数字、/, -, _, . 以及路径参数 {param},且必须以 / 开头。
+# 参数段必须完整地形如 {param},同时禁止空段(例如连续斜杠)。
+_ROUTE_SEGMENT_RE = re.compile(r"^(?:[\w\-._]+|\{[\w\-._]+\})$")
+
+
+def _validate_route(route: str, capability_name: str) -> None:
+ """校验 HTTP 路由路径格式,阻止路径遍历和非法字符。"""
+ if ".." in route:
+ raise AstrBotError.invalid_input(f"{capability_name}: 路由路径不允许包含 '..'")
+ if not route.startswith("/"):
+ raise AstrBotError.invalid_input(
+ f"{capability_name}: 路由路径格式非法,只允许字母/数字/-/_/./{{param}} 段,"
+ "且必须以 / 开头,如 /foo/bar"
+ )
+ if route == "/":
+ return
+ segments = route.split("/")[1:]
+ if any(
+ not segment or not _ROUTE_SEGMENT_RE.fullmatch(segment) for segment in segments
+ ):
+ raise AstrBotError.invalid_input(
+ f"{capability_name}: 路由路径格式非法,只允许字母/数字/-/_/./{{param}} 段,"
+ "禁止连续斜杠,且必须以 / 开头,如 /foo/bar"
+ )
+
+
+def _validate_plugin_route_namespace(route: str, plugin_id: str) -> None:
+ if http_route_belongs_to_plugin(route, plugin_id):
+ return
+ route_root = plugin_http_route_root(plugin_id)
+ raise AstrBotError.invalid_input(
+ "http.register_api 要求 route 使用当前插件的公开命名空间前缀:"
+ f" route={route!r}, plugin_id={plugin_id!r}, expected={route_root!r} "
+ f"或 {route_root + '/...'}"
+ )
+
+
+def _validate_handler_capability_namespace(
+ handler_capability: str,
+ plugin_id: str,
+) -> None:
+ if capability_belongs_to_plugin(handler_capability, plugin_id):
+ return
+ expected_prefix = plugin_capability_prefix(plugin_id)
+ raise AstrBotError.invalid_input(
+ "http.register_api 要求 handler_capability 属于当前插件:"
+ f" capability={handler_capability!r}, plugin_id={plugin_id!r}, "
+ f"expected_prefix={expected_prefix!r}"
+ )
+
+
+class HttpCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _http_register_api(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ methods_payload = payload.get("methods")
+ if not isinstance(methods_payload, list) or not all(
+ isinstance(item, str) for item in methods_payload
+ ):
+ raise AstrBotError.invalid_input(
+ "http.register_api 的 methods 必须是 string 数组"
+ )
+ route = str(payload.get("route", "")).strip()
+ handler_capability = str(payload.get("handler_capability", "")).strip()
+ if not route or not handler_capability:
+ raise AstrBotError.invalid_input(
+ "http.register_api 需要 route 和 handler_capability"
+ )
+ _validate_route(route, "http.register_api")
+ plugin_name = self._require_caller_plugin_id("http.register_api")
+ _validate_plugin_route_namespace(route, plugin_name)
+ _validate_handler_capability_namespace(handler_capability, plugin_name)
+ methods = sorted(
+ {method.strip().upper() for method in methods_payload if method.strip()}
+ )
+ if not methods:
+ raise AstrBotError.invalid_input(
+ "http.register_api 的 methods 至少需要一个非空 HTTP 方法"
+ )
+ entry: dict[str, Any] = {
+ "route": route,
+ "methods": methods,
+ "handler_capability": handler_capability,
+ "description": str(payload.get("description", "")),
+ "plugin_id": plugin_name,
+ }
+ self.http_api_store = [
+ item
+ for item in self.http_api_store
+ if not (
+ item.get("route") == route
+ and item.get("plugin_id") == entry["plugin_id"]
+ and item.get("methods") == methods
+ )
+ ]
+ self.http_api_store.append(entry)
+ return {}
+
+ async def _http_unregister_api(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ route = str(payload.get("route", "")).strip()
+ methods_payload = payload.get("methods")
+ if not isinstance(methods_payload, list) or not all(
+ isinstance(item, str) for item in methods_payload
+ ):
+ raise AstrBotError.invalid_input(
+ "http.unregister_api 的 methods 必须是 string 数组"
+ )
+ plugin_name = self._require_caller_plugin_id("http.unregister_api")
+ methods = {method.upper() for method in methods_payload if method}
+ updated: list[dict[str, Any]] = []
+ for entry in self.http_api_store:
+ if entry.get("route") != route:
+ updated.append(entry)
+ continue
+ if entry.get("plugin_id") != plugin_name:
+ updated.append(entry)
+ continue
+ if not methods:
+ # `HTTPClient.unregister_api(methods=None)` 会归一化为空列表,
+ # 公开语义就是“移除当前插件在该 route 下注册的全部方法”。
+ continue
+ remaining_methods = [
+ method for method in entry.get("methods", []) if method not in methods
+ ]
+ if remaining_methods:
+ updated.append({**entry, "methods": remaining_methods})
+ self.http_api_store = updated
+ return {}
+
+ async def _http_list_apis(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_name = self._require_caller_plugin_id("http.list_apis")
+ apis = [
+ dict(entry)
+ for entry in self.http_api_store
+ if entry.get("plugin_id") == plugin_name
+ ]
+ return {"apis": apis}
+
+ def _register_http_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("http.register_api", "注册 HTTP 路由"),
+ call_handler=self._http_register_api,
+ )
+ self.register(
+ self._builtin_descriptor("http.unregister_api", "注销 HTTP 路由"),
+ call_handler=self._http_unregister_api,
+ )
+ self.register(
+ self._builtin_descriptor("http.list_apis", "列出 HTTP 路由"),
+ call_handler=self._http_list_apis,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py
new file mode 100644
index 0000000000..77a03d86c7
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py
@@ -0,0 +1,427 @@
+from __future__ import annotations
+
+import math
+import uuid
+from pathlib import Path
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+def _term_set(text: str) -> set[str]:
+ normalized = " ".join(str(text).strip().casefold().split())
+ compact = normalized.replace(" ", "")
+ if not normalized:
+ return set()
+ terms = {item for item in normalized.split(" ") if item}
+ if compact:
+ terms.add(compact)
+ if len(compact) > 1:
+ terms.update(
+ compact[index : index + 2] for index in range(len(compact) - 1)
+ )
+ return terms
+
+
+class KnowledgeBaseCapabilityMixin(CapabilityRouterBridgeBase):
+ def _kb_documents(self, kb_id: str) -> dict[str, dict[str, Any]]:
+ return self._kb_document_store.setdefault(kb_id, {})
+
+ def _refresh_mock_kb_stats(self, kb_id: str) -> None:
+ kb = self._kb_store.get(kb_id)
+ if not isinstance(kb, dict):
+ return
+ documents = self._kb_documents(kb_id)
+ kb["doc_count"] = len(documents)
+ kb["chunk_count"] = sum(
+ int(document.get("chunk_count", 0) or 0) for document in documents.values()
+ )
+ kb["updated_at"] = self._now_iso()
+
+ def _resolve_mock_kb_ids(self, payload: dict[str, Any]) -> list[str]:
+ kb_ids = [
+ str(item).strip() for item in payload.get("kb_ids", []) if str(item).strip()
+ ]
+ if kb_ids:
+ return [kb_id for kb_id in kb_ids if kb_id in self._kb_store]
+
+ kb_names = [
+ str(item).strip()
+ for item in payload.get("kb_names", [])
+ if str(item).strip()
+ ]
+ if not kb_names:
+ return []
+ name_set = set(kb_names)
+ return [
+ kb_id
+ for kb_id, kb in self._kb_store.items()
+ if str(kb.get("kb_name", "")).strip() in name_set
+ ]
+
+ @staticmethod
+ def _score_mock_document(query: str, content: str) -> float:
+ query_terms = _term_set(query)
+ content_terms = _term_set(content)
+ if not query_terms or not content_terms:
+ return 0.0
+ overlap = len(query_terms & content_terms)
+ if overlap <= 0:
+ return 0.0
+ score = overlap / len(query_terms)
+ if query.strip().casefold() in str(content).casefold():
+ score += 0.25
+ return min(score, 1.0)
+
+ @staticmethod
+ def _build_mock_context_text(results: list[dict[str, Any]]) -> str:
+ lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
+ for index, item in enumerate(results, start=1):
+ lines.append(f"【知识 {index}】")
+ lines.append(f"来源: {item['kb_name']} / {item['doc_name']}")
+ lines.append(f"内容: {item['content']}")
+ lines.append(f"相关度: {float(item['score']):.2f}")
+ lines.append("")
+ return "\n".join(lines)
+
+ async def _kb_list(
+ self,
+ _request_id: str,
+ _payload: dict[str, Any],
+ _token,
+ ) -> dict[str, Any]:
+ return {
+ "kbs": [
+ dict(record)
+ for record in sorted(
+ self._kb_store.values(),
+ key=lambda item: str(item.get("created_at", "")),
+ )
+ ]
+ }
+
+ async def _kb_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ record = self._kb_store.get(kb_id)
+ return {"kb": dict(record) if isinstance(record, dict) else None}
+
+ async def _kb_create(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ raw_kb = payload.get("kb")
+ if not isinstance(raw_kb, dict):
+ raise AstrBotError.invalid_input("kb.create requires kb object")
+ embedding_provider_id = str(raw_kb.get("embedding_provider_id", "")).strip()
+ if not embedding_provider_id:
+ raise AstrBotError.invalid_input("kb.create requires embedding_provider_id")
+ kb_id = uuid.uuid4().hex
+ now = self._now_iso()
+ record = {
+ "kb_id": kb_id,
+ "kb_name": str(raw_kb.get("kb_name", "")),
+ "description": (
+ str(raw_kb.get("description"))
+ if raw_kb.get("description") is not None
+ else None
+ ),
+ "emoji": (
+ str(raw_kb.get("emoji")) if raw_kb.get("emoji") is not None else None
+ ),
+ "embedding_provider_id": embedding_provider_id,
+ "rerank_provider_id": (
+ str(raw_kb.get("rerank_provider_id"))
+ if raw_kb.get("rerank_provider_id") is not None
+ else None
+ ),
+ "chunk_size": self._optional_int(raw_kb.get("chunk_size")),
+ "chunk_overlap": self._optional_int(raw_kb.get("chunk_overlap")),
+ "top_k_dense": self._optional_int(raw_kb.get("top_k_dense")),
+ "top_k_sparse": self._optional_int(raw_kb.get("top_k_sparse")),
+ "top_m_final": self._optional_int(raw_kb.get("top_m_final")),
+ "doc_count": 0,
+ "chunk_count": 0,
+ "created_at": now,
+ "updated_at": now,
+ }
+ self._kb_store[kb_id] = record
+ self._kb_document_store[kb_id] = {}
+ return {"kb": dict(record)}
+
+ async def _kb_update(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ raw_kb = payload.get("kb")
+ if not isinstance(raw_kb, dict):
+ raise AstrBotError.invalid_input("kb.update requires kb object")
+ record = self._kb_store.get(kb_id)
+ if not isinstance(record, dict):
+ return {"kb": None}
+
+ for field_name in (
+ "kb_name",
+ "description",
+ "emoji",
+ "embedding_provider_id",
+ "rerank_provider_id",
+ ):
+ if field_name in raw_kb:
+ value = raw_kb.get(field_name)
+ record[field_name] = str(value) if value is not None else None
+ for field_name in (
+ "chunk_size",
+ "chunk_overlap",
+ "top_k_dense",
+ "top_k_sparse",
+ "top_m_final",
+ ):
+ if field_name in raw_kb:
+ record[field_name] = self._optional_int(raw_kb.get(field_name))
+ record["updated_at"] = self._now_iso()
+ return {"kb": dict(record)}
+
+ async def _kb_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ documents = self._kb_document_store.pop(kb_id, {})
+ for document in documents.values():
+ doc_id = str(document.get("doc_id", "")).strip()
+ if doc_id:
+ self._kb_document_content_store.pop(doc_id, None)
+ deleted = self._kb_store.pop(kb_id, None) is not None
+ return {"deleted": deleted}
+
+ async def _kb_retrieve(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ query = str(payload.get("query", "")).strip()
+ if not query:
+ raise AstrBotError.invalid_input("kb.retrieve requires query")
+ kb_ids = self._resolve_mock_kb_ids(payload)
+ if not kb_ids:
+ raise AstrBotError.invalid_input("kb.retrieve requires kb_ids or kb_names")
+
+ top_m_final = self._optional_int(payload.get("top_m_final")) or 5
+ results: list[dict[str, Any]] = []
+ for kb_id in kb_ids:
+ kb = self._kb_store.get(kb_id)
+ if not isinstance(kb, dict):
+ continue
+ for document in self._kb_documents(kb_id).values():
+ doc_id = str(document.get("doc_id", "")).strip()
+ if not doc_id:
+ continue
+ content = self._kb_document_content_store.get(doc_id, "")
+ score = self._score_mock_document(query, content)
+ if score <= 0:
+ continue
+ results.append(
+ {
+ "chunk_id": f"{doc_id}:0",
+ "doc_id": doc_id,
+ "kb_id": kb_id,
+ "kb_name": str(kb.get("kb_name", "")),
+ "doc_name": str(document.get("doc_name", "")),
+ "chunk_index": 0,
+ "content": content,
+ "score": score,
+ "char_count": len(content),
+ }
+ )
+ results.sort(key=lambda item: float(item["score"]), reverse=True)
+ results = results[:top_m_final]
+ if not results:
+ return {"result": None}
+ return {
+ "result": {
+ "context_text": self._build_mock_context_text(results),
+ "results": results,
+ }
+ }
+
+ async def _kb_document_upload(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ kb = self._kb_store.get(kb_id)
+ if not isinstance(kb, dict):
+ raise AstrBotError.invalid_input(f"Unknown knowledge base: {kb_id}")
+ raw_document = payload.get("document")
+ if not isinstance(raw_document, dict):
+ raise AstrBotError.invalid_input(
+ "kb.document.upload requires document object"
+ )
+
+ file_name = str(raw_document.get("file_name", "")).strip()
+ file_type = str(raw_document.get("file_type", "")).strip()
+ file_path = ""
+ content_text = ""
+ file_size = 0
+
+ text_value = raw_document.get("text")
+ url_value = raw_document.get("url")
+ file_token = str(raw_document.get("file_token", "")).strip()
+
+ if isinstance(text_value, str) and text_value.strip():
+ content_text = text_value
+ if not file_name:
+ file_name = "document.txt"
+ if not file_type:
+ file_type = "txt"
+ file_size = len(content_text.encode("utf-8"))
+ elif isinstance(url_value, str) and url_value.strip():
+ url_text = url_value.strip()
+ content_text = f"Imported from {url_text}"
+ if not file_name:
+ file_name = (
+ Path(url_text.split("?", maxsplit=1)[0]).name or "document.url"
+ )
+ if not file_type:
+ suffix = Path(file_name).suffix.lstrip(".")
+ file_type = suffix or "url"
+ file_path = url_text
+ file_size = len(content_text.encode("utf-8"))
+ elif file_token:
+ file_path = self._file_token_store.pop(file_token, "")
+ if not file_path:
+ raise AstrBotError.invalid_input(f"Unknown file token: {file_token}")
+ path = Path(file_path)
+ if not path.exists():
+ raise AstrBotError.invalid_input(f"File does not exist: {file_path}")
+ raw_bytes = path.read_bytes()
+ content_text = raw_bytes.decode("utf-8", errors="ignore")
+ if not file_name:
+ file_name = path.name
+ if not file_type:
+ file_type = path.suffix.lstrip(".")
+ if not file_type:
+ raise AstrBotError.invalid_input(
+ "kb.document.upload requires file_type when the file has no suffix"
+ )
+ file_size = len(raw_bytes)
+ else:
+ raise AstrBotError.invalid_input(
+ "kb.document.upload requires file_token, url, or text"
+ )
+
+ chunk_size = self._optional_int(raw_document.get("chunk_size"))
+ if chunk_size is None or chunk_size <= 0:
+ chunk_size = self._optional_int(kb.get("chunk_size")) or 512
+ chunk_count = max(1, math.ceil(max(len(content_text), 1) / chunk_size))
+ doc_id = uuid.uuid4().hex
+ now = self._now_iso()
+ document = {
+ "doc_id": doc_id,
+ "kb_id": kb_id,
+ "doc_name": file_name,
+ "file_type": file_type,
+ "file_size": file_size,
+ "file_path": file_path,
+ "chunk_count": chunk_count,
+ "media_count": 0,
+ "created_at": now,
+ "updated_at": now,
+ }
+ self._kb_documents(kb_id)[doc_id] = document
+ self._kb_document_content_store[doc_id] = content_text
+ self._refresh_mock_kb_stats(kb_id)
+ return {"document": dict(document)}
+
+ async def _kb_document_list(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ offset = max(self._optional_int(payload.get("offset")) or 0, 0)
+ limit = max(self._optional_int(payload.get("limit")) or 100, 0)
+ documents = list(self._kb_documents(kb_id).values())
+ documents.sort(key=lambda item: str(item.get("created_at", "")))
+ return {
+ "documents": [dict(item) for item in documents[offset : offset + limit]]
+ }
+
+ async def _kb_document_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ doc_id = str(payload.get("doc_id", "")).strip()
+ document = self._kb_documents(kb_id).get(doc_id)
+ return {"document": dict(document) if isinstance(document, dict) else None}
+
+ async def _kb_document_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ doc_id = str(payload.get("doc_id", "")).strip()
+ deleted = self._kb_documents(kb_id).pop(doc_id, None) is not None
+ if deleted:
+ self._kb_document_content_store.pop(doc_id, None)
+ self._refresh_mock_kb_stats(kb_id)
+ return {"deleted": deleted}
+
+ async def _kb_document_refresh(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ kb_id = str(payload.get("kb_id", "")).strip()
+ doc_id = str(payload.get("doc_id", "")).strip()
+ document = self._kb_documents(kb_id).get(doc_id)
+ if not isinstance(document, dict):
+ return {"document": None}
+ kb = self._kb_store.get(kb_id, {})
+ chunk_size = self._optional_int(kb.get("chunk_size")) or 512
+ content_text = self._kb_document_content_store.get(doc_id, "")
+ document["chunk_count"] = max(
+ 1, math.ceil(max(len(content_text), 1) / chunk_size)
+ )
+ document["updated_at"] = self._now_iso()
+ self._refresh_mock_kb_stats(kb_id)
+ return {"document": dict(document)}
+
+ def _register_kb_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("kb.list", "列出知识库"),
+ call_handler=self._kb_list,
+ )
+ self.register(
+ self._builtin_descriptor("kb.get", "获取知识库"),
+ call_handler=self._kb_get,
+ )
+ self.register(
+ self._builtin_descriptor("kb.create", "创建知识库"),
+ call_handler=self._kb_create,
+ )
+ self.register(
+ self._builtin_descriptor("kb.update", "更新知识库"),
+ call_handler=self._kb_update,
+ )
+ self.register(
+ self._builtin_descriptor("kb.delete", "删除知识库"),
+ call_handler=self._kb_delete,
+ )
+ self.register(
+ self._builtin_descriptor("kb.retrieve", "检索知识库"),
+ call_handler=self._kb_retrieve,
+ )
+ self.register(
+ self._builtin_descriptor("kb.document.upload", "上传知识库文档"),
+ call_handler=self._kb_document_upload,
+ )
+ self.register(
+ self._builtin_descriptor("kb.document.list", "列出知识库文档"),
+ call_handler=self._kb_document_list,
+ )
+ self.register(
+ self._builtin_descriptor("kb.document.get", "获取知识库文档"),
+ call_handler=self._kb_document_get,
+ )
+ self.register(
+ self._builtin_descriptor("kb.document.delete", "删除知识库文档"),
+ call_handler=self._kb_document_delete,
+ )
+ self.register(
+ self._builtin_descriptor("kb.document.refresh", "刷新知识库文档"),
+ call_handler=self._kb_document_refresh,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py
new file mode 100644
index 0000000000..daf1621128
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+import asyncio
+from collections.abc import AsyncIterator
+from typing import Any
+
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class LLMCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _llm_chat(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ prompt = str(payload.get("prompt", ""))
+ return {"text": f"Echo: {prompt}"}
+
+ async def _llm_chat_raw(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ prompt = str(payload.get("prompt", ""))
+ text = f"Echo: {prompt}"
+ return {
+ "text": text,
+ "usage": {
+ "input_tokens": len(prompt),
+ "output_tokens": len(text),
+ },
+ "finish_reason": "stop",
+ "tool_calls": [],
+ }
+
+ async def _llm_stream(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ token,
+ ) -> AsyncIterator[dict[str, Any]]:
+ text = f"Echo: {str(payload.get('prompt', ''))}"
+ for char in text:
+ token.raise_if_cancelled()
+ await asyncio.sleep(0)
+ yield {"text": char}
+
+ def _register_llm_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("llm.chat", "发送对话请求,返回文本"),
+ call_handler=self._llm_chat,
+ )
+ self.register(
+ self._builtin_descriptor("llm.chat_raw", "发送对话请求,返回完整响应"),
+ call_handler=self._llm_chat_raw,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "llm.stream_chat",
+ "流式对话",
+ supports_stream=True,
+ cancelable=True,
+ ),
+ stream_handler=self._llm_stream,
+ finalize=lambda chunks: {
+ "text": "".join(item.get("text", "") for item in chunks)
+ },
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py
new file mode 100644
index 0000000000..f55ef7ccf0
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py
@@ -0,0 +1,655 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import Any
+
+from ...._internal.invocation_context import current_caller_plugin_id
+from ...._internal.memory_utils import (
+ cosine_similarity,
+ extract_memory_text,
+ is_ttl_memory_entry,
+ memory_expiration_from_ttl,
+ memory_index_entry,
+ memory_keyword_score,
+ memory_value_for_search,
+)
+from ...._memory_backend import PluginMemoryBackend
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
+ def _memory_plugin_id(self) -> str:
+ plugin_id = current_caller_plugin_id()
+ return self._validated_plugin_id(
+ str(plugin_id).strip() or "__anonymous__",
+ capability_name="memory.*",
+ )
+
+ def _memory_backend_for_plugin(self, plugin_id: str) -> PluginMemoryBackend:
+ backend = self._memory_backends.get(plugin_id)
+ if backend is None:
+ backend = PluginMemoryBackend(
+ self._plugin_data_dir(plugin_id, capability_name="memory.*")
+ )
+ self._memory_backends[plugin_id] = backend
+ return backend
+
+ @staticmethod
+ def _is_ttl_memory_entry(value: Any) -> bool:
+ """判断存储值是否使用了 TTL 包装结构。
+
+ Args:
+ value: 待检查的存储值。
+
+ Returns:
+ bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
+ """
+ return is_ttl_memory_entry(value)
+
+ @classmethod
+ def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
+ """提取用于检索的原始 memory payload。
+
+ Args:
+ stored: memory_store 中保存的原始值。
+
+ Returns:
+ dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
+ """
+ return memory_value_for_search(stored)
+
+ @classmethod
+ def _extract_memory_text(cls, stored: Any) -> str:
+ """提取用于检索索引的首选文本。
+
+ Args:
+ stored: memory_store 中保存的原始值。
+
+ Returns:
+ str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
+ """
+ return extract_memory_text(stored)
+
+ @staticmethod
+ def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
+ """将 TTL 秒数转换为 UTC 过期时间。
+
+ Args:
+ ttl_seconds: TTL 秒数。
+
+ Returns:
+ datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
+ """
+ return memory_expiration_from_ttl(ttl_seconds)
+
+ @staticmethod
+ def _memory_keyword_score(query: str, key: str, text: str) -> float:
+ """计算关键词匹配分数。
+
+ Args:
+ query: 查询文本。
+ key: memory 条目的键。
+ text: 已索引的检索文本。
+
+ Returns:
+ float: 基于键名和文本命中的粗粒度关键词分数。
+ """
+ return memory_keyword_score(query, key, text)
+
+ @staticmethod
+ def _cosine_similarity(left: list[float], right: list[float]) -> float:
+ """计算两个向量之间的余弦相似度。
+
+ Args:
+ left: 左侧向量。
+ right: 右侧向量。
+
+ Returns:
+ float: 余弦相似度;输入不合法时返回 ``0.0``。
+ """
+ return cosine_similarity(left, right)
+
+ def _resolve_memory_embedding_provider_id(
+ self,
+ provider_id: Any,
+ *,
+ required: bool,
+ ) -> str | None:
+ """解析 memory.search 要使用的 embedding provider。
+
+ Args:
+ provider_id: 调用方显式传入的 provider 标识。
+ required: 当前检索模式是否强制要求 embedding provider。
+
+ Returns:
+ str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。
+ """
+ normalized = str(provider_id).strip() if provider_id is not None else ""
+ if normalized:
+ self._provider_entry(
+ {"provider_id": normalized},
+ "memory.search",
+ "embedding",
+ )
+ return normalized
+ active_id = self._active_provider_ids.get("embedding")
+ if active_id is not None:
+ normalized_active = str(active_id).strip()
+ if normalized_active:
+ self._provider_entry(
+ {"provider_id": normalized_active},
+ "memory.search",
+ "embedding",
+ )
+ return normalized_active
+ if required:
+ raise AstrBotError.invalid_input(
+ "memory.search requires an embedding provider",
+ )
+ return None
+
+ @staticmethod
+ def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
+ """将原始索引项规范化为内部统一结构。
+
+ Args:
+ entry: 当前索引表中的原始项。
+ text: 当前条目的索引文本。
+
+ Returns:
+ dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
+ """
+ return memory_index_entry(entry, text=text)
+
+ def _clear_memory_sidecars(self, key: str) -> None:
+ """清理指定 memory 键对应的所有 sidecar 状态。
+
+ Args:
+ key: memory 条目的键。
+
+ Returns:
+ None
+ """
+ self._memory_index.pop(key, None)
+ self._memory_expires_at.pop(key, None)
+ self._memory_dirty_keys.discard(key)
+
+ def _delete_memory_entry(self, key: str) -> bool:
+ """删除 memory 条目并同步清理 sidecar 状态。
+
+ Args:
+ key: memory 条目的键。
+
+ Returns:
+ bool: 条目存在并删除成功时返回 ``True``。
+ """
+ deleted = self.memory_store.pop(key, None) is not None
+ self._clear_memory_sidecars(key)
+ return deleted
+
+ def _upsert_memory_sidecars(
+ self,
+ key: str,
+ stored: dict[str, Any],
+ *,
+ expires_at: datetime | None = None,
+ ) -> None:
+ """创建或更新单条 memory 的 sidecar 索引状态。
+
+ Args:
+ key: memory 条目的键。
+ stored: 需要建立索引的原始存储值。
+ expires_at: 可选的绝对过期时间。
+
+ Returns:
+ None
+ """
+ self._memory_index[key] = {
+ "text": self._extract_memory_text(stored),
+ "embedding": None,
+ "provider_id": None,
+ }
+ if expires_at is None:
+ self._memory_expires_at.pop(key, None)
+ else:
+ self._memory_expires_at[key] = expires_at
+ self._memory_dirty_keys.add(key)
+
+ def _ensure_memory_sidecars(self, key: str, stored: Any) -> None:
+ """确保 sidecar 状态与当前存储值保持一致。
+
+ Args:
+ key: memory 条目的键。
+ stored: memory_store 中的当前存储值。
+
+ Returns:
+ None
+ """
+ if not isinstance(stored, dict):
+ return
+ text = self._extract_memory_text(stored)
+ existed = key in self._memory_index
+ entry = self._memory_index_entry(self._memory_index.get(key), text=text)
+ if entry["text"] != text:
+ entry["text"] = text
+ entry["embedding"] = None
+ entry["provider_id"] = None
+ self._memory_dirty_keys.add(key)
+ self._memory_index[key] = entry
+ if not existed:
+ self._memory_dirty_keys.add(key)
+
+ def _is_memory_expired(self, key: str) -> bool:
+ """判断 memory 条目是否已过期。
+
+ Args:
+ key: memory 条目的键。
+
+ Returns:
+ bool: 如果当前时间已超过记录的过期时间则返回 ``True``。
+ """
+ expires_at = self._memory_expires_at.get(key)
+ return expires_at is not None and expires_at <= datetime.now(timezone.utc)
+
+ def _purge_expired_memory_entry(self, key: str) -> bool:
+ """在单条 memory 已过期时立即清理它。
+
+ Args:
+ key: memory 条目的键。
+
+ Returns:
+ bool: 如果条目已过期并被成功清理则返回 ``True``。
+ """
+ if not self._is_memory_expired(key):
+ return False
+ self._delete_memory_entry(key)
+ return True
+
+ def _purge_expired_memory_entries(self) -> None:
+ """批量清理所有已跟踪的过期 TTL 条目。
+
+ Returns:
+ None
+ """
+ for key in list(self._memory_expires_at):
+ self._purge_expired_memory_entry(key)
+
+ async def _embedding_for_text(
+ self,
+ *,
+ provider_id: str,
+ text: str,
+ ) -> list[float]:
+ """通过 embedding capability 获取单条文本向量。
+
+ Args:
+ provider_id: 使用的 embedding provider 标识。
+ text: 待向量化的文本。
+
+ Returns:
+ list[float]: provider 返回的向量;异常场景下返回空列表。
+ """
+ output = await self._provider_embedding_get_embedding(
+ "",
+ {"provider_id": provider_id, "text": text},
+ None,
+ )
+ embedding = output.get("embedding")
+ if not isinstance(embedding, list):
+ return []
+ return [float(item) for item in embedding]
+
+ async def _embeddings_for_texts(
+ self,
+ *,
+ provider_id: str,
+ texts: list[str],
+ ) -> list[list[float]]:
+ """批量获取多条文本的 embedding 向量。
+
+ Args:
+ provider_id: 使用的 embedding provider 标识。
+ texts: 待向量化的文本列表。
+
+ Returns:
+ list[list[float]]: 与输入顺序对应的向量列表。
+ """
+ if not texts:
+ return []
+ output = await self._provider_embedding_get_embeddings(
+ "",
+ {"provider_id": provider_id, "texts": texts},
+ None,
+ )
+ embeddings = output.get("embeddings")
+ if not isinstance(embeddings, list):
+ return []
+ return [
+ [float(value) for value in item]
+ for item in embeddings
+ if isinstance(item, list)
+ ]
+
+ async def _refresh_memory_embeddings(self, *, provider_id: str) -> None:
+ """刷新当前 provider 下脏或过期的 memory 向量索引。
+
+ Args:
+ provider_id: 当前使用的 embedding provider 标识。
+
+ Returns:
+ None
+ """
+ keys_to_refresh: list[str] = []
+ texts_to_refresh: list[str] = []
+ for key, stored in self.memory_store.items():
+ self._ensure_memory_sidecars(key, stored)
+ entry = self._memory_index_entry(
+ self._memory_index.get(key),
+ text=self._extract_memory_text(stored),
+ )
+ should_refresh = (
+ key in self._memory_dirty_keys
+ or entry["embedding"] is None
+ or entry["provider_id"] != provider_id
+ )
+ self._memory_index[key] = entry
+ if should_refresh:
+ keys_to_refresh.append(key)
+ texts_to_refresh.append(str(entry["text"]))
+ # 分批请求,避免单次 payload 过大导致 OOM 或 413
+ _BATCH_SIZE = 64
+ embeddings: list[list[float]] = []
+ for batch_start in range(0, len(texts_to_refresh), _BATCH_SIZE):
+ batch = texts_to_refresh[batch_start : batch_start + _BATCH_SIZE]
+ embeddings.extend(
+ await self._embeddings_for_texts(
+ provider_id=provider_id,
+ texts=batch,
+ )
+ )
+ for index, key in enumerate(keys_to_refresh):
+ entry = self._memory_index_entry(
+ self._memory_index.get(key),
+ text=str(texts_to_refresh[index]),
+ )
+ entry["embedding"] = embeddings[index] if index < len(embeddings) else []
+ entry["provider_id"] = provider_id
+ self._memory_index[key] = entry
+ self._memory_dirty_keys.discard(key)
+
+ async def _memory_search(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ query = str(payload.get("query", ""))
+ mode = str(payload.get("mode", "auto")).strip().lower() or "auto"
+ limit = self._optional_int(payload.get("limit"))
+ raw_min_score = payload.get("min_score")
+ min_score = float(raw_min_score) if raw_min_score is not None else None
+ namespace = payload.get("namespace")
+ include_descendants = bool(payload.get("include_descendants", True))
+ provider_id = self._resolve_memory_embedding_provider_id(
+ payload.get("provider_id"),
+ required=mode in {"vector", "hybrid"},
+ )
+ effective_mode = mode
+ if effective_mode == "auto":
+ effective_mode = "hybrid" if provider_id is not None else "keyword"
+ backend = self._memory_backend_for_plugin(plugin_id)
+ items = await backend.search(
+ query,
+ namespace=str(namespace) if namespace is not None else None,
+ include_descendants=include_descendants,
+ mode=effective_mode,
+ limit=limit,
+ min_score=min_score,
+ provider_id=provider_id,
+ embed_one=(
+ (
+ lambda text: self._embedding_for_text(
+ provider_id=provider_id, text=text
+ )
+ )
+ if provider_id is not None and effective_mode in {"vector", "hybrid"}
+ else None
+ ),
+ embed_many=(
+ (
+ lambda texts: self._embeddings_for_texts(
+ provider_id=provider_id,
+ texts=texts,
+ )
+ )
+ if provider_id is not None and effective_mode in {"vector", "hybrid"}
+ else None
+ ),
+ )
+ return {"items": items}
+
+ async def _memory_save(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ key = str(payload.get("key", ""))
+ value = payload.get("value")
+ if not isinstance(value, dict):
+ raise AstrBotError.invalid_input("memory.save 的 value 必须是 object")
+ await self._memory_backend_for_plugin(plugin_id).save(
+ key,
+ value,
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {}
+
+ async def _memory_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ key = str(payload.get("key", ""))
+ value = await self._memory_backend_for_plugin(plugin_id).get(
+ key,
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {"value": value}
+
+ async def _memory_list_keys(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ keys = await self._memory_backend_for_plugin(plugin_id).list_keys(
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {"keys": keys}
+
+ async def _memory_exists(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ exists = await self._memory_backend_for_plugin(plugin_id).exists(
+ str(payload.get("key", "")),
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {"exists": exists}
+
+ async def _memory_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ await self._memory_backend_for_plugin(plugin_id).delete(
+ str(payload.get("key", "")),
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {}
+
+ async def _memory_clear_namespace(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ deleted_count = await self._memory_backend_for_plugin(
+ plugin_id
+ ).clear_namespace(
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ include_descendants=bool(payload.get("include_descendants", False)),
+ )
+ return {"deleted_count": deleted_count}
+
+ async def _memory_save_with_ttl(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ key = str(payload.get("key", ""))
+ value = payload.get("value")
+ ttl_seconds = payload.get("ttl_seconds", 0)
+ if not isinstance(value, dict):
+ raise AstrBotError.invalid_input(
+ "memory.save_with_ttl 的 value 必须是 object"
+ )
+ await self._memory_backend_for_plugin(plugin_id).save_with_ttl(
+ key,
+ value,
+ int(ttl_seconds),
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {}
+
+ async def _memory_get_many(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ keys_payload = payload.get("keys")
+ if not isinstance(keys_payload, (list, tuple)):
+ raise AstrBotError.invalid_input("memory.get_many 的 keys 必须是数组")
+ items = await self._memory_backend_for_plugin(plugin_id).get_many(
+ [str(item) for item in keys_payload],
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {"items": items}
+
+ async def _memory_delete_many(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ keys_payload = payload.get("keys")
+ if not isinstance(keys_payload, (list, tuple)):
+ raise AstrBotError.invalid_input("memory.delete_many 的 keys 必须是数组")
+ deleted_count = await self._memory_backend_for_plugin(plugin_id).delete_many(
+ [str(item) for item in keys_payload],
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ )
+ return {"deleted_count": deleted_count}
+
+ async def _memory_count(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ count = await self._memory_backend_for_plugin(plugin_id).count(
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ include_descendants=bool(payload.get("include_descendants", False)),
+ )
+ return {"count": count}
+
+ async def _memory_stats(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._memory_plugin_id()
+ stats = await self._memory_backend_for_plugin(plugin_id).stats(
+ namespace=(
+ str(payload.get("namespace"))
+ if payload.get("namespace") is not None
+ else None
+ ),
+ include_descendants=bool(payload.get("include_descendants", True)),
+ )
+ stats["plugin_id"] = plugin_id
+ return stats
+
+ def _register_memory_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("memory.search", "搜索记忆"),
+ call_handler=self._memory_search,
+ )
+ self.register(
+ self._builtin_descriptor("memory.save", "保存记忆"),
+ call_handler=self._memory_save,
+ )
+ self.register(
+ self._builtin_descriptor("memory.get", "读取单条记忆"),
+ call_handler=self._memory_get,
+ )
+ self.register(
+ self._builtin_descriptor("memory.list_keys", "列出命名空间内的记忆键"),
+ call_handler=self._memory_list_keys,
+ )
+ self.register(
+ self._builtin_descriptor("memory.exists", "检查记忆键是否存在"),
+ call_handler=self._memory_exists,
+ )
+ self.register(
+ self._builtin_descriptor("memory.delete", "删除记忆"),
+ call_handler=self._memory_delete,
+ )
+ self.register(
+ self._builtin_descriptor("memory.clear_namespace", "清理记忆命名空间"),
+ call_handler=self._memory_clear_namespace,
+ )
+ self.register(
+ self._builtin_descriptor("memory.save_with_ttl", "保存带过期时间的记忆"),
+ call_handler=self._memory_save_with_ttl,
+ )
+ self.register(
+ self._builtin_descriptor("memory.get_many", "批量获取记忆"),
+ call_handler=self._memory_get_many,
+ )
+ self.register(
+ self._builtin_descriptor("memory.delete_many", "批量删除记忆"),
+ call_handler=self._memory_delete_many,
+ )
+ self.register(
+ self._builtin_descriptor("memory.count", "统计命名空间内的记忆数量"),
+ call_handler=self._memory_count,
+ )
+ self.register(
+ self._builtin_descriptor("memory.stats", "获取记忆统计信息"),
+ call_handler=self._memory_stats,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py
new file mode 100644
index 0000000000..3e2b6666bc
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py
@@ -0,0 +1,338 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import Any
+
+from ....errors import AstrBotError
+from ....message.session import MessageSession
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+def _session_payload(session: MessageSession) -> dict[str, str]:
+ return {
+ "platform_id": str(session.platform_id),
+ "message_type": str(session.message_type),
+ "session_id": str(session.session_id),
+ }
+
+
+class MessageHistoryCapabilityMixin(CapabilityRouterBridgeBase):
+ @staticmethod
+ def _normalize_timestamp(raw_value: Any) -> datetime:
+ normalized = str(raw_value or "").strip()
+ if normalized.endswith("Z"):
+ normalized = f"{normalized[:-1]}+00:00"
+ parsed = datetime.fromisoformat(normalized)
+ if parsed.tzinfo is None:
+ parsed = parsed.replace(tzinfo=timezone.utc)
+ return parsed.astimezone(timezone.utc)
+
+ @staticmethod
+ def _typed_session_from_payload(payload: Any) -> MessageSession:
+ if not isinstance(payload, dict):
+ raise AstrBotError.invalid_input(
+ "message_history capabilities require a session object"
+ )
+ platform_id = str(payload.get("platform_id", "")).strip()
+ message_type = str(payload.get("message_type", "")).strip()
+ session_id = str(payload.get("session_id", "")).strip()
+ if not platform_id or not message_type or not session_id:
+ raise AstrBotError.invalid_input(
+ "message_history session requires platform_id, message_type, and session_id"
+ )
+ return MessageSession(
+ platform_id=platform_id,
+ message_type=message_type,
+ session_id=session_id,
+ )
+
+ @staticmethod
+ def _typed_key(session: MessageSession) -> str:
+ return (
+ f"{str(session.platform_id)}:{str(session.message_type).lower()}:"
+ f"{str(session.session_id)}"
+ )
+
+ def _message_history_records(self, session: MessageSession) -> list[dict[str, Any]]:
+ key = self._typed_key(session)
+ records = self._message_history_store.get(key)
+ if records is None:
+ records = []
+ self._message_history_store[key] = records
+ return records
+
+ def _next_message_history_id(self) -> int:
+ next_id = int(self._message_history_next_id)
+ self._message_history_next_id += 1
+ return next_id
+
+ def _create_message_history_record(
+ self,
+ *,
+ session: MessageSession,
+ sender_payload: dict[str, Any],
+ parts_payload: list[dict[str, Any]],
+ metadata: dict[str, Any],
+ idempotency_key: str | None,
+ ) -> dict[str, Any]:
+ now = self._now_iso()
+ return {
+ "id": self._next_message_history_id(),
+ "session": _session_payload(session),
+ "sender": {
+ "sender_id": (
+ str(sender_payload.get("sender_id"))
+ if sender_payload.get("sender_id") is not None
+ else None
+ ),
+ "sender_name": (
+ str(sender_payload.get("sender_name"))
+ if sender_payload.get("sender_name") is not None
+ else None
+ ),
+ },
+ "parts": [dict(item) for item in parts_payload if isinstance(item, dict)],
+ "metadata": dict(metadata),
+ "created_at": now,
+ "updated_at": now,
+ "idempotency_key": idempotency_key,
+ }
+
+ @staticmethod
+ def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
+ return {
+ "id": int(record.get("id", 0) or 0),
+ "session": (
+ dict(record.get("session"))
+ if isinstance(record.get("session"), dict)
+ else {}
+ ),
+ "sender": (
+ dict(record.get("sender"))
+ if isinstance(record.get("sender"), dict)
+ else {}
+ ),
+ "parts": (
+ [
+ dict(item)
+ for item in record.get("parts", [])
+ if isinstance(item, dict)
+ ]
+ if isinstance(record.get("parts"), list)
+ else []
+ ),
+ "metadata": (
+ dict(record.get("metadata"))
+ if isinstance(record.get("metadata"), dict)
+ else {}
+ ),
+ "created_at": record.get("created_at"),
+ "updated_at": record.get("updated_at"),
+ "idempotency_key": (
+ str(record.get("idempotency_key"))
+ if record.get("idempotency_key") is not None
+ else None
+ ),
+ }
+
+ @staticmethod
+ def _parse_boundary(raw_value: Any, field_name: str) -> datetime:
+ text = str(raw_value or "").strip()
+ if not text:
+ raise AstrBotError.invalid_input(
+ f"message_history.{field_name} requires {field_name}"
+ )
+ try:
+ return MessageHistoryCapabilityMixin._normalize_timestamp(text)
+ except ValueError as exc:
+ raise AstrBotError.invalid_input(
+ f"message_history.{field_name} requires an ISO datetime string"
+ ) from exc
+
+ async def _message_history_list(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = self._typed_session_from_payload(payload.get("session"))
+ raw_limit = self._optional_int(payload.get("limit"))
+ limit = 50 if raw_limit is None else raw_limit
+ if limit < 1:
+ raise AstrBotError.invalid_input("message_history.list requires limit >= 1")
+ raw_cursor = payload.get("cursor")
+ cursor_id = (
+ self._optional_int(raw_cursor) if raw_cursor not in (None, "") else None
+ )
+ if raw_cursor not in (None, "") and (cursor_id is None or cursor_id < 1):
+ raise AstrBotError.invalid_input(
+ "message_history.list requires cursor to be a positive integer string"
+ )
+ records = list(reversed(self._message_history_records(session)))
+ total = len(records)
+ if cursor_id is not None:
+ records = [
+ record for record in records if int(record.get("id", 0)) < cursor_id
+ ]
+ page_records = records[:limit]
+ next_cursor = (
+ str(page_records[-1]["id"])
+ if len(records) > limit and page_records
+ else None
+ )
+ return {
+ "page": {
+ "records": [self._serialize_record(record) for record in page_records],
+ "next_cursor": next_cursor,
+ "total": total,
+ }
+ }
+
+ async def _message_history_get_by_id(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = self._typed_session_from_payload(payload.get("session"))
+ record_id = self._optional_int(payload.get("record_id"))
+ if record_id is None or record_id < 1:
+ raise AstrBotError.invalid_input(
+ "message_history.get_by_id requires record_id >= 1"
+ )
+ record = next(
+ (
+ item
+ for item in self._message_history_records(session)
+ if int(item.get("id", 0) or 0) == record_id
+ ),
+ None,
+ )
+ return {
+ "record": self._serialize_record(record) if record is not None else None
+ }
+
+ async def _message_history_append(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = self._typed_session_from_payload(payload.get("session"))
+ sender_payload = payload.get("sender")
+ if not isinstance(sender_payload, dict):
+ raise AstrBotError.invalid_input(
+ "message_history.append requires sender object"
+ )
+ parts_payload = payload.get("parts")
+ if not isinstance(parts_payload, list) or any(
+ not isinstance(item, dict) for item in parts_payload
+ ):
+ raise AstrBotError.invalid_input(
+ "message_history.append requires parts array"
+ )
+ metadata = payload.get("metadata")
+ if metadata is not None and not isinstance(metadata, dict):
+ raise AstrBotError.invalid_input(
+ "message_history.append requires metadata object when provided"
+ )
+ idempotency_key = (
+ str(payload.get("idempotency_key"))
+ if payload.get("idempotency_key") is not None
+ else None
+ )
+ records = self._message_history_records(session)
+ if idempotency_key:
+ existing = next(
+ (
+ record
+ for record in records
+ if str(record.get("idempotency_key") or "") == idempotency_key
+ ),
+ None,
+ )
+ if existing is not None:
+ return {"record": self._serialize_record(existing)}
+ record = self._create_message_history_record(
+ session=session,
+ sender_payload=sender_payload,
+ parts_payload=parts_payload,
+ metadata=dict(metadata or {}),
+ idempotency_key=idempotency_key,
+ )
+ records.append(record)
+ return {"record": self._serialize_record(record)}
+
+ async def _message_history_delete_before(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = self._typed_session_from_payload(payload.get("session"))
+ before = self._parse_boundary(payload.get("before"), "delete_before")
+ records = self._message_history_records(session)
+ retained: list[dict[str, Any]] = []
+ deleted_count = 0
+ for record in records:
+ created_at = self._normalize_timestamp(record.get("created_at"))
+ if created_at < before:
+ deleted_count += 1
+ continue
+ retained.append(record)
+ self._message_history_store[self._typed_key(session)] = retained
+ return {"deleted_count": deleted_count}
+
+ async def _message_history_delete_after(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = self._typed_session_from_payload(payload.get("session"))
+ after = self._parse_boundary(payload.get("after"), "delete_after")
+ records = self._message_history_records(session)
+ retained: list[dict[str, Any]] = []
+ deleted_count = 0
+ for record in records:
+ created_at = self._normalize_timestamp(record.get("created_at"))
+ if created_at > after:
+ deleted_count += 1
+ continue
+ retained.append(record)
+ self._message_history_store[self._typed_key(session)] = retained
+ return {"deleted_count": deleted_count}
+
+ async def _message_history_delete_all(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = self._typed_session_from_payload(payload.get("session"))
+ key = self._typed_key(session)
+ deleted_count = len(self._message_history_store.get(key, []))
+ self._message_history_store[key] = []
+ return {"deleted_count": deleted_count}
+
+ def _register_message_history_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("message_history.list", "List message history"),
+ call_handler=self._message_history_list,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "message_history.get_by_id",
+ "Get message history by id",
+ ),
+ call_handler=self._message_history_get_by_id,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "message_history.append", "Append message history"
+ ),
+ call_handler=self._message_history_append,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "message_history.delete_before",
+ "Delete message history before timestamp",
+ ),
+ call_handler=self._message_history_delete_before,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "message_history.delete_after",
+ "Delete message history after timestamp",
+ ),
+ call_handler=self._message_history_delete_after,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "message_history.delete_all",
+ "Delete all message history in session",
+ ),
+ call_handler=self._message_history_delete_all,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py
new file mode 100644
index 0000000000..787f63369b
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+from typing import Any
+
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class MetadataCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _metadata_get_plugin(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ name = str(payload.get("name", "")).strip()
+ plugin = self._plugins.get(name)
+ if plugin is None:
+ return {"plugin": None}
+ return {"plugin": dict(plugin.metadata)}
+
+ async def _metadata_list_plugins(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugins = [
+ dict(self._plugins[name].metadata) for name in sorted(self._plugins.keys())
+ ]
+ return {"plugins": plugins}
+
+ async def _metadata_get_plugin_config(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ name = str(payload.get("name", "")).strip()
+ caller_plugin_id = self._require_caller_plugin_id("metadata.get_plugin_config")
+ if name != caller_plugin_id:
+ return {"config": None}
+ plugin = self._plugins.get(name)
+ if plugin is None:
+ return {"config": None}
+ return {"config": dict(plugin.config)}
+
+ async def _metadata_save_plugin_config(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ caller_plugin_id = self._require_caller_plugin_id("metadata.save_plugin_config")
+ plugin = self._plugins.get(caller_plugin_id)
+ if plugin is None:
+ return {"config": None}
+ config = payload.get("config")
+ if not isinstance(config, dict):
+ return {"config": dict(plugin.config)}
+ plugin.config = dict(config)
+ return {"config": dict(plugin.config)}
+
+ def _register_metadata_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("metadata.get_plugin", "获取单个插件元数据"),
+ call_handler=self._metadata_get_plugin,
+ )
+ self.register(
+ self._builtin_descriptor("metadata.list_plugins", "列出插件元数据"),
+ call_handler=self._metadata_list_plugins,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "metadata.get_plugin_config",
+ "获取插件配置",
+ ),
+ call_handler=self._metadata_get_plugin_config,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "metadata.save_plugin_config",
+ "保存当前插件配置",
+ ),
+ call_handler=self._metadata_save_plugin_config,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py
new file mode 100644
index 0000000000..063ab840c9
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py
@@ -0,0 +1,133 @@
+from __future__ import annotations
+
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class PermissionCapabilityMixin(CapabilityRouterBridgeBase):
+ def _register_permission_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("permission.check", "查询用户权限角色"),
+ call_handler=self._permission_check,
+ )
+ self.register(
+ self._builtin_descriptor("permission.get_admins", "列出管理员 ID"),
+ call_handler=self._permission_get_admins,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "permission.manager.add_admin",
+ "添加管理员 ID",
+ ),
+ call_handler=self._permission_manager_add_admin,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "permission.manager.remove_admin",
+ "移除管理员 ID",
+ ),
+ call_handler=self._permission_manager_remove_admin,
+ )
+
+ @staticmethod
+ def _normalize_admin_ids(values: Any) -> list[str]:
+ if not isinstance(values, list):
+ return []
+ normalized: list[str] = []
+ for item in values:
+ user_id = str(item).strip()
+ if user_id:
+ normalized.append(user_id)
+ return normalized
+
+ def _admin_ids_snapshot(self) -> list[str]:
+ normalized = self._normalize_admin_ids(
+ getattr(self, "_permission_admin_ids", [])
+ )
+ self._permission_admin_ids = list(normalized)
+ return normalized
+
+ @staticmethod
+ def _required_user_id(payload: dict[str, Any], capability_name: str) -> str:
+ user_id = str(payload.get("user_id", "")).strip()
+ if not user_id:
+ raise AstrBotError.invalid_input(f"{capability_name} requires user_id")
+ return user_id
+
+ def _require_reserved_plugin(self, capability_name: str) -> str:
+ plugin_id = self._require_caller_plugin_id(capability_name)
+ plugin = self._plugins.get(plugin_id)
+ if plugin is not None and bool(plugin.metadata.get("reserved", False)):
+ return plugin_id
+ if plugin_id in {"system", "__system__"}:
+ return plugin_id
+ raise AstrBotError.invalid_input(
+ f"{capability_name} is restricted to reserved/system plugins"
+ )
+
+ @staticmethod
+ def _require_admin_event_context(
+ payload: dict[str, Any],
+ capability_name: str,
+ ) -> None:
+ if bool(payload.get("_caller_is_admin", False)):
+ return
+ raise AstrBotError.invalid_input(
+ f"{capability_name} requires an active admin event context"
+ )
+
+ async def _permission_check(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ _token,
+ ) -> dict[str, Any]:
+ user_id = self._required_user_id(payload, "permission.check")
+ admins = self._admin_ids_snapshot()
+ is_admin = user_id in admins
+ return {
+ "is_admin": is_admin,
+ "role": "admin" if is_admin else "member",
+ }
+
+ async def _permission_get_admins(
+ self,
+ _request_id: str,
+ _payload: dict[str, Any],
+ _token,
+ ) -> dict[str, Any]:
+ return {"admins": self._admin_ids_snapshot()}
+
+ async def _permission_manager_add_admin(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ _token,
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("permission.manager.add_admin")
+ self._require_admin_event_context(payload, "permission.manager.add_admin")
+ user_id = self._required_user_id(payload, "permission.manager.add_admin")
+ admins = self._admin_ids_snapshot()
+ if user_id in admins:
+ return {"changed": False}
+ admins.append(user_id)
+ self._permission_admin_ids = admins
+ return {"changed": True}
+
+ async def _permission_manager_remove_admin(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ _token,
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("permission.manager.remove_admin")
+ self._require_admin_event_context(payload, "permission.manager.remove_admin")
+ user_id = self._required_user_id(payload, "permission.manager.remove_admin")
+ admins = self._admin_ids_snapshot()
+ if user_id not in admins:
+ return {"changed": False}
+ admins.remove(user_id)
+ self._permission_admin_ids = admins
+ return {"changed": True}
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py
new file mode 100644
index 0000000000..6d7b3b3531
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py
@@ -0,0 +1,142 @@
+from __future__ import annotations
+
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class PersonaCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _persona_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ persona_id = str(payload.get("persona_id", "")).strip()
+ record = self._persona_store.get(persona_id)
+ if record is None:
+ raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
+ return {"persona": dict(record)}
+
+ async def _persona_list(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ personas = [
+ dict(self._persona_store[persona_id])
+ for persona_id in sorted(self._persona_store.keys())
+ ]
+ return {"personas": personas}
+
+ async def _persona_create(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ raw_persona = payload.get("persona")
+ if not isinstance(raw_persona, dict):
+ raise AstrBotError.invalid_input("persona.create requires persona object")
+ persona_id = str(raw_persona.get("persona_id", "")).strip()
+ if not persona_id:
+ raise AstrBotError.invalid_input("persona.create requires persona_id")
+ if persona_id in self._persona_store:
+ raise AstrBotError.invalid_input(f"persona already exists: {persona_id}")
+ now = self._now_iso()
+ record = {
+ "persona_id": persona_id,
+ "system_prompt": str(raw_persona.get("system_prompt", "")),
+ "begin_dialogs": self._normalize_persona_dialogs_payload(
+ raw_persona.get("begin_dialogs")
+ ),
+ "tools": (
+ [str(item) for item in raw_persona.get("tools", [])]
+ if isinstance(raw_persona.get("tools"), list)
+ else None
+ ),
+ "skills": (
+ [str(item) for item in raw_persona.get("skills", [])]
+ if isinstance(raw_persona.get("skills"), list)
+ else None
+ ),
+ "custom_error_message": (
+ str(raw_persona.get("custom_error_message"))
+ if raw_persona.get("custom_error_message") is not None
+ else None
+ ),
+ "folder_id": (
+ str(raw_persona.get("folder_id"))
+ if raw_persona.get("folder_id") is not None
+ else None
+ ),
+ "sort_order": int(raw_persona.get("sort_order", 0)),
+ "created_at": now,
+ "updated_at": now,
+ }
+ self._persona_store[persona_id] = record
+ return {"persona": dict(record)}
+
+ async def _persona_update(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ persona_id = str(payload.get("persona_id", "")).strip()
+ record = self._persona_store.get(persona_id)
+ if record is None:
+ return {"persona": None}
+ raw_persona = payload.get("persona")
+ if not isinstance(raw_persona, dict):
+ raise AstrBotError.invalid_input("persona.update requires persona object")
+ if (
+ "system_prompt" in raw_persona
+ and raw_persona.get("system_prompt") is not None
+ ):
+ record["system_prompt"] = str(raw_persona.get("system_prompt", ""))
+ if "begin_dialogs" in raw_persona:
+ begin_dialogs = raw_persona.get("begin_dialogs")
+ record["begin_dialogs"] = (
+ self._normalize_persona_dialogs_payload(begin_dialogs)
+ if begin_dialogs is not None
+ else []
+ )
+ if "tools" in raw_persona:
+ tools = raw_persona.get("tools")
+ record["tools"] = (
+ [str(item) for item in tools] if isinstance(tools, list) else None
+ )
+ if "skills" in raw_persona:
+ skills = raw_persona.get("skills")
+ record["skills"] = (
+ [str(item) for item in skills] if isinstance(skills, list) else None
+ )
+ if "custom_error_message" in raw_persona:
+ custom_error_message = raw_persona.get("custom_error_message")
+ record["custom_error_message"] = (
+ str(custom_error_message) if custom_error_message is not None else None
+ )
+ record["updated_at"] = self._now_iso()
+ return {"persona": dict(record)}
+
+ async def _persona_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ persona_id = str(payload.get("persona_id", "")).strip()
+ if persona_id not in self._persona_store:
+ raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
+ del self._persona_store[persona_id]
+ return {}
+
+ def _register_persona_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("persona.get", "获取人格"),
+ call_handler=self._persona_get,
+ )
+ self.register(
+ self._builtin_descriptor("persona.list", "列出人格"),
+ call_handler=self._persona_list,
+ )
+ self.register(
+ self._builtin_descriptor("persona.create", "创建人格"),
+ call_handler=self._persona_create,
+ )
+ self.register(
+ self._builtin_descriptor("persona.update", "更新人格"),
+ call_handler=self._persona_update,
+ )
+ self.register(
+ self._builtin_descriptor("persona.delete", "删除人格"),
+ call_handler=self._persona_delete,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py
new file mode 100644
index 0000000000..dbc565a013
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py
@@ -0,0 +1,236 @@
+from __future__ import annotations
+
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class PlatformCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _platform_send(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session, target = self._resolve_target(payload)
+ self._require_platform_support_for_session("platform.send", session)
+ text = str(payload.get("text", ""))
+ message_id = f"msg_{len(self.sent_messages) + 1}"
+ sent: dict[str, Any] = {
+ "message_id": message_id,
+ "session": session,
+ "text": text,
+ }
+ if target is not None:
+ sent["target"] = target
+ self.sent_messages.append(sent)
+ return {"message_id": message_id}
+
+ async def _platform_send_image(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session, target = self._resolve_target(payload)
+ self._require_platform_support_for_session("platform.send_image", session)
+ image_url = str(payload.get("image_url", ""))
+ message_id = f"img_{len(self.sent_messages) + 1}"
+ sent: dict[str, Any] = {
+ "message_id": message_id,
+ "session": session,
+ "image_url": image_url,
+ }
+ if target is not None:
+ sent["target"] = target
+ self.sent_messages.append(sent)
+ return {"message_id": message_id}
+
+ async def _platform_send_chain(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session, target = self._resolve_target(payload)
+ self._require_platform_support_for_session("platform.send_chain", session)
+ chain = payload.get("chain")
+ if not isinstance(chain, list) or not all(
+ isinstance(item, dict) for item in chain
+ ):
+ raise AstrBotError.invalid_input(
+ "platform.send_chain 的 chain 必须是 object 数组"
+ )
+ message_id = f"chain_{len(self.sent_messages) + 1}"
+ sent: dict[str, Any] = {
+ "message_id": message_id,
+ "session": session,
+ "chain": [dict(item) for item in chain],
+ }
+ if target is not None:
+ sent["target"] = target
+ self.sent_messages.append(sent)
+ return {"message_id": message_id}
+
+ async def _platform_send_by_session(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ chain = payload.get("chain")
+ if not isinstance(chain, list) or not all(
+ isinstance(item, dict) for item in chain
+ ):
+ raise AstrBotError.invalid_input(
+ "platform.send_by_session 的 chain 必须是 object 数组"
+ )
+ session = str(payload.get("session", ""))
+ self._require_platform_support_for_session("platform.send_by_session", session)
+ message_id = f"proactive_{len(self.sent_messages) + 1}"
+ self.sent_messages.append(
+ {
+ "message_id": message_id,
+ "session": session,
+ "chain": [dict(item) for item in chain],
+ }
+ )
+ return {"message_id": message_id}
+
+ async def _platform_get_group(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session, _target = self._resolve_target(payload)
+ return {"group": self._mock_group_payload(session)}
+
+ async def _platform_get_members(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session, _target = self._resolve_target(payload)
+ group = self._mock_group_payload(session)
+ if group is None:
+ return {"members": []}
+ return {"members": list(group.get("members", []))}
+
+ async def _platform_list_instances(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("platform.list_instances")
+ return {
+ "platforms": [
+ {
+ "id": str(item.get("id", "")),
+ "name": str(item.get("name", "")),
+ "type": str(item.get("type", "")),
+ "status": str(item.get("status", "unknown")),
+ }
+ for item in self.get_platform_instances()
+ if isinstance(item, dict)
+ and self._plugin_supports_platform(plugin_id, str(item.get("type", "")))
+ ]
+ }
+
+ def _register_platform_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("platform.send", "发送消息"),
+ call_handler=self._platform_send,
+ )
+ self.register(
+ self._builtin_descriptor("platform.send_image", "发送图片"),
+ call_handler=self._platform_send_image,
+ )
+ self.register(
+ self._builtin_descriptor("platform.send_chain", "发送消息链"),
+ call_handler=self._platform_send_chain,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "platform.send_by_session", "按会话主动发送消息链"
+ ),
+ call_handler=self._platform_send_by_session,
+ )
+ self.register(
+ self._builtin_descriptor("platform.get_group", "获取当前群信息"),
+ call_handler=self._platform_get_group,
+ )
+ self.register(
+ self._builtin_descriptor("platform.get_members", "获取群成员"),
+ call_handler=self._platform_get_members,
+ )
+ self.register(
+ self._builtin_descriptor("platform.list_instances", "列出平台实例元信息"),
+ call_handler=self._platform_list_instances,
+ )
+
+ async def _platform_manager_get_by_id(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("platform.manager.get_by_id")
+ platform_id = str(payload.get("platform_id", "")).strip()
+ platform = next(
+ (
+ dict(item)
+ for item in self._platform_instances
+ if str(item.get("id", "")) == platform_id
+ ),
+ None,
+ )
+ return {"platform": platform}
+
+ async def _platform_manager_clear_errors(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("platform.manager.clear_errors")
+ platform_id = str(payload.get("platform_id", "")).strip()
+ for item in self._platform_instances:
+ if str(item.get("id", "")) != platform_id:
+ continue
+ item["errors"] = []
+ item["last_error"] = None
+ if str(item.get("status", "")) == "error":
+ item["status"] = "running"
+ break
+ return {}
+
+ async def _platform_manager_get_stats(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("platform.manager.get_stats")
+ platform_id = str(payload.get("platform_id", "")).strip()
+ for item in self._platform_instances:
+ if str(item.get("id", "")) != platform_id:
+ continue
+ stats = item.get("stats")
+ if isinstance(stats, dict):
+ return {"stats": dict(stats)}
+ errors = item.get("errors")
+ last_error = item.get("last_error")
+ meta = item.get("meta")
+ return {
+ "stats": {
+ "id": platform_id,
+ "type": str(item.get("type", "")),
+ "display_name": str(item.get("name", platform_id)),
+ "status": str(item.get("status", "pending")),
+ "started_at": item.get("started_at"),
+ "error_count": len(errors) if isinstance(errors, list) else 0,
+ "last_error": dict(last_error)
+ if isinstance(last_error, dict)
+ else None,
+ "unified_webhook": bool(item.get("unified_webhook", False)),
+ "meta": dict(meta) if isinstance(meta, dict) else {},
+ }
+ }
+ return {"stats": None}
+
+ def _register_platform_manager_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor(
+ "platform.manager.get_by_id",
+ "按 ID 获取平台管理快照",
+ ),
+ call_handler=self._platform_manager_get_by_id,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "platform.manager.clear_errors",
+ "清除平台错误",
+ ),
+ call_handler=self._platform_manager_clear_errors,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "platform.manager.get_stats",
+ "获取平台统计信息",
+ ),
+ call_handler=self._platform_manager_get_stats,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py
new file mode 100644
index 0000000000..7d3f7bad4c
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py
@@ -0,0 +1,1060 @@
+from __future__ import annotations
+
+import asyncio
+import base64
+from collections.abc import AsyncIterator
+from typing import Any
+
+from ....errors import AstrBotError
+from ..._streaming import StreamExecution
+from ..bridge_base import (
+ _MOCK_EMBEDDING_DIM,
+ CapabilityRouterBridgeBase,
+ _mock_embedding_vector,
+)
+
+
+class ProviderCapabilityMixin(CapabilityRouterBridgeBase):
+ def _provider_payload(
+ self, kind: str, provider_id: str | None
+ ) -> dict[str, Any] | None:
+ if not provider_id:
+ return None
+ for item in self._provider_catalog.get(kind, []):
+ if str(item.get("id", "")) == provider_id:
+ return dict(item)
+ return None
+
+ def _provider_payload_by_id(self, provider_id: str) -> dict[str, Any] | None:
+ normalized = str(provider_id).strip()
+ if not normalized:
+ return None
+ for items in self._provider_catalog.values():
+ for item in items:
+ if str(item.get("id", "")) == normalized:
+ return dict(item)
+ return None
+
+ @staticmethod
+ def _provider_kind_from_type(provider_type: str) -> str:
+ mapping = {
+ "chat_completion": "chat",
+ "text_to_speech": "tts",
+ "speech_to_text": "stt",
+ "embedding": "embedding",
+ "rerank": "rerank",
+ }
+ normalized = str(provider_type).strip().lower()
+ if normalized not in mapping:
+ raise AstrBotError.invalid_input(f"unknown provider_type: {provider_type}")
+ return mapping[normalized]
+
+ def _provider_config_by_id(self, provider_id: str) -> dict[str, Any] | None:
+ record = self._provider_configs.get(str(provider_id).strip())
+ return dict(record) if isinstance(record, dict) else None
+
+ @staticmethod
+ def _managed_provider_record(
+ payload: dict[str, Any],
+ *,
+ loaded: bool,
+ ) -> dict[str, Any]:
+ return {
+ "id": str(payload.get("id", "")),
+ "model": (
+ str(payload.get("model")) if payload.get("model") is not None else None
+ ),
+ "type": str(payload.get("type", "")),
+ "provider_type": str(payload.get("provider_type", "chat_completion")),
+ "loaded": bool(loaded),
+ "enabled": bool(payload.get("enable", True)),
+ "provider_source_id": (
+ str(payload.get("provider_source_id"))
+ if payload.get("provider_source_id") is not None
+ else None
+ ),
+ }
+
+ def _managed_provider_record_by_id(self, provider_id: str) -> dict[str, Any] | None:
+ provider = self._provider_payload_by_id(provider_id)
+ if provider is not None:
+ config = self._provider_config_by_id(provider_id) or provider
+ merged = dict(provider)
+ merged.update(
+ {
+ "enable": config.get("enable", True),
+ "provider_source_id": config.get("provider_source_id"),
+ }
+ )
+ return self._managed_provider_record(merged, loaded=True)
+ config = self._provider_config_by_id(provider_id)
+ if config is None:
+ return None
+ return self._managed_provider_record(config, loaded=False)
+
+ def _emit_provider_change(
+ self,
+ provider_id: str,
+ provider_type: str,
+ umo: str | None,
+ ) -> None:
+ event = {
+ "provider_id": str(provider_id),
+ "provider_type": str(provider_type),
+ "umo": str(umo) if umo is not None else None,
+ }
+ for queue in list(self._provider_change_subscriptions.values()):
+ queue.put_nowait(dict(event))
+
+ def _require_reserved_plugin(self, capability_name: str) -> str:
+ plugin_id = self._require_caller_plugin_id(capability_name)
+ plugin = self._plugins.get(plugin_id)
+ if plugin is not None and bool(plugin.metadata.get("reserved", False)):
+ return plugin_id
+ if plugin_id in {"system", "__system__"}:
+ return plugin_id
+ raise AstrBotError.invalid_input(
+ f"{capability_name} is restricted to reserved/system plugins"
+ )
+
+ def _provider_entry(
+ self,
+ payload: dict[str, Any],
+ capability_name: str,
+ expected_kind: str | None = None,
+ ) -> dict[str, Any]:
+ provider_id = str(payload.get("provider_id", "")).strip()
+ if not provider_id:
+ raise AstrBotError.invalid_input(
+ f"{capability_name} requires provider_id",
+ )
+ provider = self._provider_payload_by_id(provider_id)
+ if provider is None:
+ raise AstrBotError.invalid_input(
+ f"{capability_name} unknown provider_id: {provider_id}",
+ )
+ if (
+ expected_kind is not None
+ and str(provider.get("provider_type")) != expected_kind
+ ):
+ raise AstrBotError.invalid_input(
+ f"{capability_name} requires a {expected_kind} provider",
+ )
+ return provider
+
+ async def _provider_get_using(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider_id = self._active_provider_ids.get("chat")
+ return {"provider": self._provider_payload("chat", provider_id)}
+
+ async def _provider_get_by_id(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return {
+ "provider": self._provider_payload_by_id(
+ str(payload.get("provider_id", ""))
+ )
+ }
+
+ async def _provider_get_current_chat_provider_id(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return {"provider_id": self._active_provider_ids.get("chat")}
+
+ def _provider_list_payload(self, kind: str) -> dict[str, Any]:
+ return {
+ "providers": [dict(item) for item in self._provider_catalog.get(kind, [])]
+ }
+
+ async def _provider_list_all(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return self._provider_list_payload("chat")
+
+ async def _provider_list_all_tts(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return self._provider_list_payload("tts")
+
+ async def _provider_list_all_stt(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return self._provider_list_payload("stt")
+
+ async def _provider_list_all_embedding(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return self._provider_list_payload("embedding")
+
+ async def _provider_list_all_rerank(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ return self._provider_list_payload("rerank")
+
+ async def _provider_get_using_tts(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider_id = self._active_provider_ids.get("tts")
+ return {"provider": self._provider_payload("tts", provider_id)}
+
+ async def _provider_get_using_stt(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider_id = self._active_provider_ids.get("stt")
+ return {"provider": self._provider_payload("stt", provider_id)}
+
+ async def _provider_stt_get_text(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._provider_entry(
+ payload,
+ "provider.stt.get_text",
+ "speech_to_text",
+ )
+ return {"text": f"Mock transcript: {str(payload.get('audio_url', ''))}"}
+
+ async def _provider_tts_get_audio(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider = self._provider_entry(
+ payload,
+ "provider.tts.get_audio",
+ "text_to_speech",
+ )
+ return {
+ "audio_path": (
+ f"mock://tts/{provider.get('id', '')}/{str(payload.get('text', ''))}"
+ )
+ }
+
+ async def _provider_tts_support_stream(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider = self._provider_entry(
+ payload,
+ "provider.tts.support_stream",
+ "text_to_speech",
+ )
+ return {"supported": bool(provider.get("support_stream", True))}
+
+ async def _provider_tts_get_audio_stream(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ token,
+ ) -> StreamExecution:
+ self._provider_entry(
+ payload,
+ "provider.tts.get_audio_stream",
+ "text_to_speech",
+ )
+ text = payload.get("text")
+ text_chunks = payload.get("text_chunks")
+ if isinstance(text, str):
+ chunks = [text]
+ elif isinstance(text_chunks, list) and text_chunks:
+ chunks = [str(item) for item in text_chunks]
+ else:
+ raise AstrBotError.invalid_input(
+ "provider.tts.get_audio_stream requires text or text_chunks"
+ )
+
+ async def iterator() -> AsyncIterator[dict[str, Any]]:
+ for chunk in chunks:
+ token.raise_if_cancelled()
+ await asyncio.sleep(0)
+ yield {
+ "audio_base64": base64.b64encode(
+ f"mock-audio:{chunk}".encode()
+ ).decode("ascii"),
+ "text": chunk,
+ }
+
+ return StreamExecution(
+ iterator=iterator(),
+ finalize=lambda items: (
+ items[-1] if items else {"audio_base64": "", "text": None}
+ ),
+ )
+
+ async def _provider_embedding_get_embedding(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider = self._provider_entry(
+ payload,
+ "provider.embedding.get_embedding",
+ "embedding",
+ )
+ return {
+ "embedding": _mock_embedding_vector(
+ str(payload.get("text", "")),
+ provider_id=str(provider.get("id", "")),
+ )
+ }
+
+ async def _provider_embedding_get_embeddings(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ provider = self._provider_entry(
+ payload,
+ "provider.embedding.get_embeddings",
+ "embedding",
+ )
+ texts = payload.get("texts")
+ if not isinstance(texts, list):
+ raise AstrBotError.invalid_input(
+ "provider.embedding.get_embeddings requires texts",
+ )
+ return {
+ "embeddings": [
+ _mock_embedding_vector(
+ str(text),
+ provider_id=str(provider.get("id", "")),
+ )
+ for text in texts
+ ],
+ }
+
+ async def _provider_embedding_get_dim(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._provider_entry(
+ payload,
+ "provider.embedding.get_dim",
+ "embedding",
+ )
+ return {"dim": _MOCK_EMBEDDING_DIM}
+
+ async def _provider_rerank_rerank(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._provider_entry(
+ payload,
+ "provider.rerank.rerank",
+ "rerank",
+ )
+ documents = payload.get("documents")
+ if not isinstance(documents, list):
+ raise AstrBotError.invalid_input(
+ "provider.rerank.rerank requires documents",
+ )
+ scored = [
+ {
+ "index": index,
+ "score": 1.0,
+ "document": str(raw_document),
+ }
+ for index, raw_document in enumerate(documents)
+ ]
+ top_n = payload.get("top_n")
+ if top_n is not None:
+ scored = scored[: max(int(top_n), 0)]
+ return {"results": scored}
+
+ async def _provider_manager_set(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.set")
+ provider_id = str(payload.get("provider_id", "")).strip()
+ provider_type = str(payload.get("provider_type", "")).strip()
+ kind = self._provider_kind_from_type(provider_type)
+ if not provider_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.set requires provider_id"
+ )
+ if self._provider_payload(kind, provider_id) is None:
+ raise AstrBotError.invalid_input(
+ f"provider.manager.set unknown provider_id: {provider_id}"
+ )
+ self._active_provider_ids[kind] = provider_id
+ self._emit_provider_change(
+ provider_id,
+ provider_type,
+ str(payload.get("umo")) if payload.get("umo") is not None else None,
+ )
+ return {}
+
+ async def _provider_manager_get_by_id(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.get_by_id")
+ return {
+ "provider": self._managed_provider_record_by_id(
+ str(payload.get("provider_id", ""))
+ )
+ }
+
+ async def _provider_manager_get_merged_provider_config(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.get_merged_provider_config")
+ provider_id = str(payload.get("provider_id", "")).strip()
+ if not provider_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.get_merged_provider_config requires provider_id"
+ )
+ provider = self._provider_payload_by_id(provider_id)
+ config = self._provider_config_by_id(provider_id)
+ if provider is None and config is None:
+ raise AstrBotError.invalid_input(
+ "provider.manager.get_merged_provider_config "
+ f"unknown provider_id: {provider_id}"
+ )
+ if provider is None:
+ return {"config": dict(config) if isinstance(config, dict) else config}
+ if config is None:
+ return {"config": dict(provider)}
+ merged_config = dict(provider)
+ merged_config.update(config)
+ return {"config": merged_config}
+
+ @staticmethod
+ def _normalize_provider_config_object(
+ payload: Any,
+ capability_name: str,
+ field_name: str,
+ ) -> dict[str, Any]:
+ if not isinstance(payload, dict):
+ raise AstrBotError.invalid_input(
+ f"{capability_name} requires {field_name} object"
+ )
+ return dict(payload)
+
+ async def _provider_manager_load(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.load")
+ provider_config = self._normalize_provider_config_object(
+ payload.get("provider_config"),
+ "provider.manager.load",
+ "provider_config",
+ )
+ provider_id = str(provider_config.get("id", "")).strip()
+ provider_type = str(provider_config.get("provider_type", "")).strip()
+ kind = self._provider_kind_from_type(provider_type)
+ if not provider_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.load requires provider id"
+ )
+ if bool(provider_config.get("enable", True)):
+ record = {
+ "id": provider_id,
+ "model": (
+ str(provider_config.get("model"))
+ if provider_config.get("model") is not None
+ else None
+ ),
+ "type": str(provider_config.get("type", "")),
+ "provider_type": provider_type,
+ }
+ self._provider_catalog[kind] = [
+ item
+ for item in self._provider_catalog.get(kind, [])
+ if str(item.get("id", "")) != provider_id
+ ]
+ self._provider_catalog[kind].append(record)
+ self._emit_provider_change(provider_id, provider_type, None)
+ return {
+ "provider": self._managed_provider_record(
+ provider_config,
+ loaded=bool(provider_config.get("enable", True)),
+ )
+ }
+
+ async def _provider_manager_terminate(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.terminate")
+ provider_id = str(payload.get("provider_id", "")).strip()
+ if not provider_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.terminate requires provider_id"
+ )
+ managed = self._managed_provider_record_by_id(provider_id)
+ if managed is None:
+ raise AstrBotError.invalid_input(
+ f"provider.manager.terminate unknown provider_id: {provider_id}"
+ )
+ kind = self._provider_kind_from_type(str(managed.get("provider_type", "")))
+ self._provider_catalog[kind] = [
+ item
+ for item in self._provider_catalog.get(kind, [])
+ if str(item.get("id", "")) != provider_id
+ ]
+ if self._active_provider_ids.get(kind) == provider_id:
+ catalog = self._provider_catalog.get(kind, [])
+ self._active_provider_ids[kind] = (
+ str(catalog[0].get("id")) if catalog else None
+ )
+ self._emit_provider_change(
+ provider_id, str(managed.get("provider_type", "")), None
+ )
+ return {}
+
+ async def _provider_manager_create(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.create")
+ provider_config = self._normalize_provider_config_object(
+ payload.get("provider_config"),
+ "provider.manager.create",
+ "provider_config",
+ )
+ provider_id = str(provider_config.get("id", "")).strip()
+ provider_type = str(provider_config.get("provider_type", "")).strip()
+ kind = self._provider_kind_from_type(provider_type)
+ if not provider_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.create requires provider id"
+ )
+ self._provider_configs[provider_id] = dict(provider_config)
+ if bool(provider_config.get("enable", True)):
+ self._provider_catalog[kind] = [
+ item
+ for item in self._provider_catalog.get(kind, [])
+ if str(item.get("id", "")) != provider_id
+ ]
+ self._provider_catalog[kind].append(
+ {
+ "id": provider_id,
+ "model": (
+ str(provider_config.get("model"))
+ if provider_config.get("model") is not None
+ else None
+ ),
+ "type": str(provider_config.get("type", "")),
+ "provider_type": provider_type,
+ }
+ )
+ self._emit_provider_change(provider_id, provider_type, None)
+ return {"provider": self._managed_provider_record_by_id(provider_id)}
+
+ async def _provider_manager_update(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.update")
+ origin_provider_id = str(payload.get("origin_provider_id", "")).strip()
+ new_config = self._normalize_provider_config_object(
+ payload.get("new_config"),
+ "provider.manager.update",
+ "new_config",
+ )
+ if not origin_provider_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.update requires origin_provider_id"
+ )
+ current = self._provider_config_by_id(origin_provider_id)
+ if current is None:
+ current = self._managed_provider_record_by_id(origin_provider_id)
+ if current is None:
+ raise AstrBotError.invalid_input(
+ f"provider.manager.update unknown provider_id: {origin_provider_id}"
+ )
+ target_provider_id = str(new_config.get("id") or origin_provider_id).strip()
+ provider_type = str(
+ new_config.get("provider_type") or current.get("provider_type", "")
+ ).strip()
+ kind = self._provider_kind_from_type(provider_type)
+ self._provider_configs.pop(origin_provider_id, None)
+ merged = dict(current)
+ merged.update(new_config)
+ merged["id"] = target_provider_id
+ merged["provider_type"] = provider_type
+ self._provider_configs[target_provider_id] = merged
+ for catalog_kind, items in list(self._provider_catalog.items()):
+ self._provider_catalog[catalog_kind] = [
+ item for item in items if str(item.get("id", "")) != origin_provider_id
+ ]
+ if bool(merged.get("enable", True)):
+ self._provider_catalog[kind].append(
+ {
+ "id": target_provider_id,
+ "model": (
+ str(merged.get("model"))
+ if merged.get("model") is not None
+ else None
+ ),
+ "type": str(merged.get("type", "")),
+ "provider_type": provider_type,
+ }
+ )
+ for active_kind, active_id in list(self._active_provider_ids.items()):
+ if active_id == origin_provider_id:
+ self._active_provider_ids[active_kind] = (
+ target_provider_id if active_kind == kind else None
+ )
+ self._emit_provider_change(target_provider_id, provider_type, None)
+ return {"provider": self._managed_provider_record_by_id(target_provider_id)}
+
+ async def _provider_manager_delete(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.delete")
+ provider_id = (
+ str(payload.get("provider_id")).strip()
+ if payload.get("provider_id") is not None
+ else None
+ )
+ provider_source_id = (
+ str(payload.get("provider_source_id")).strip()
+ if payload.get("provider_source_id") is not None
+ else None
+ )
+ if not provider_id and not provider_source_id:
+ raise AstrBotError.invalid_input(
+ "provider.manager.delete requires provider_id or provider_source_id"
+ )
+ deleted: list[dict[str, Any]] = []
+ if provider_id:
+ record = self._managed_provider_record_by_id(provider_id)
+ if record is not None:
+ deleted.append(record)
+ self._provider_configs.pop(provider_id, None)
+ else:
+ for record_id, record in list(self._provider_configs.items()):
+ if (
+ str(record.get("provider_source_id", "")).strip()
+ != provider_source_id
+ ):
+ continue
+ deleted_record = self._managed_provider_record_by_id(record_id)
+ if deleted_record is not None:
+ deleted.append(deleted_record)
+ self._provider_configs.pop(record_id, None)
+ deleted_ids = {str(item.get("id", "")) for item in deleted}
+ for kind, items in list(self._provider_catalog.items()):
+ self._provider_catalog[kind] = [
+ item for item in items if str(item.get("id", "")) not in deleted_ids
+ ]
+ if self._active_provider_ids.get(kind) in deleted_ids:
+ catalog = self._provider_catalog.get(kind, [])
+ self._active_provider_ids[kind] = (
+ str(catalog[0].get("id")) if catalog else None
+ )
+ for record in deleted:
+ self._emit_provider_change(
+ str(record.get("id", "")),
+ str(record.get("provider_type", "")),
+ None,
+ )
+ return {}
+
+ async def _provider_manager_get_insts(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("provider.manager.get_insts")
+ return {
+ "providers": [
+ self._managed_provider_record(item, loaded=True)
+ for item in self._provider_catalog.get("chat", [])
+ ]
+ }
+
+ async def _provider_manager_watch_changes(
+ self, request_id: str, _payload: dict[str, Any], _token
+ ) -> StreamExecution:
+ self._require_reserved_plugin("provider.manager.watch_changes")
+ queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
+ self._provider_change_subscriptions[request_id] = queue
+
+ async def iterator() -> AsyncIterator[dict[str, Any]]:
+ try:
+ while True:
+ yield await queue.get()
+ finally:
+ self._provider_change_subscriptions.pop(request_id, None)
+
+ return StreamExecution(
+ iterator=iterator(),
+ finalize=lambda _chunks: {},
+ collect_chunks=False,
+ )
+
+ async def _platform_manager_get_by_id(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("platform.manager.get_by_id")
+ platform_id = str(payload.get("platform_id", "")).strip()
+ platform = next(
+ (
+ dict(item)
+ for item in self._platform_instances
+ if str(item.get("id", "")) == platform_id
+ ),
+ None,
+ )
+ return {"platform": platform}
+
+ async def _platform_manager_clear_errors(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("platform.manager.clear_errors")
+ platform_id = str(payload.get("platform_id", "")).strip()
+ for item in self._platform_instances:
+ if str(item.get("id", "")) != platform_id:
+ continue
+ item["errors"] = []
+ item["last_error"] = None
+ if str(item.get("status", "")) == "error":
+ item["status"] = "running"
+ break
+ return {}
+
+ async def _platform_manager_get_stats(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self._require_reserved_plugin("platform.manager.get_stats")
+ platform_id = str(payload.get("platform_id", "")).strip()
+ for item in self._platform_instances:
+ if str(item.get("id", "")) != platform_id:
+ continue
+ stats = item.get("stats")
+ if isinstance(stats, dict):
+ return {"stats": dict(stats)}
+ errors = item.get("errors")
+ last_error = item.get("last_error")
+ meta = item.get("meta")
+ return {
+ "stats": {
+ "id": platform_id,
+ "type": str(item.get("type", "")),
+ "display_name": str(item.get("name", platform_id)),
+ "status": str(item.get("status", "pending")),
+ "started_at": item.get("started_at"),
+ "error_count": len(errors) if isinstance(errors, list) else 0,
+ "last_error": dict(last_error)
+ if isinstance(last_error, dict)
+ else None,
+ "unified_webhook": bool(item.get("unified_webhook", False)),
+ "meta": dict(meta) if isinstance(meta, dict) else {},
+ }
+ }
+ return {"stats": None}
+
+ async def _llm_tool_manager_get(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("llm_tool.manager.get")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"registered": [], "active": []}
+ registered = [dict(item) for item in plugin.llm_tools.values()]
+ active = [
+ dict(item)
+ for name, item in plugin.llm_tools.items()
+ if name in plugin.active_llm_tools
+ ]
+ return {"registered": registered, "active": active}
+
+ async def _llm_tool_manager_activate(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("llm_tool.manager.activate")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"activated": False}
+ name = str(payload.get("name", ""))
+ spec = plugin.llm_tools.get(name)
+ if spec is None:
+ return {"activated": False}
+ spec["active"] = True
+ plugin.active_llm_tools.add(name)
+ return {"activated": True}
+
+ async def _llm_tool_manager_deactivate(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("llm_tool.manager.deactivate")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"deactivated": False}
+ name = str(payload.get("name", ""))
+ spec = plugin.llm_tools.get(name)
+ if spec is None:
+ return {"deactivated": False}
+ spec["active"] = False
+ plugin.active_llm_tools.discard(name)
+ return {"deactivated": True}
+
+ async def _llm_tool_manager_add(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("llm_tool.manager.add")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"names": []}
+ tools_payload = payload.get("tools")
+ if not isinstance(tools_payload, list):
+ raise AstrBotError.invalid_input("llm_tool.manager.add 的 tools 必须是数组")
+ names: list[str] = []
+ for item in tools_payload:
+ if not isinstance(item, dict):
+ continue
+ name = str(item.get("name", "")).strip()
+ if not name:
+ continue
+ plugin.llm_tools[name] = dict(item)
+ if bool(item.get("active", True)):
+ plugin.active_llm_tools.add(name)
+ else:
+ plugin.active_llm_tools.discard(name)
+ names.append(name)
+ return {"names": names}
+
+ async def _llm_tool_manager_remove(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("llm_tool.manager.remove")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"removed": False}
+ name = str(payload.get("name", "")).strip()
+ removed = plugin.llm_tools.pop(name, None) is not None
+ plugin.active_llm_tools.discard(name)
+ return {"removed": removed}
+
+ async def _agent_registry_list(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("agent.registry.list")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"agents": []}
+ return {"agents": [dict(item) for item in plugin.agents.values()]}
+
+ async def _agent_registry_get(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("agent.registry.get")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return {"agent": None}
+ agent = plugin.agents.get(str(payload.get("name", "")))
+ return {"agent": dict(agent) if isinstance(agent, dict) else None}
+
+ async def _agent_tool_loop_run(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("agent.tool_loop.run")
+ plugin = self._plugins.get(plugin_id)
+ requested_tools = payload.get("tool_names")
+ active_tools: list[str] = []
+ if plugin is not None:
+ if isinstance(requested_tools, list) and requested_tools:
+ active_tools = [
+ name
+ for name in (str(item) for item in requested_tools)
+ if name in plugin.active_llm_tools
+ ]
+ else:
+ active_tools = sorted(plugin.active_llm_tools)
+ prompt = str(payload.get("prompt", "") or "")
+ suffix = ""
+ if active_tools:
+ suffix = f" tools={','.join(active_tools)}"
+ return {
+ "text": f"Mock tool loop: {prompt}{suffix}".strip(),
+ "usage": {
+ "input_tokens": len(prompt),
+ "output_tokens": len(prompt) + len(suffix),
+ },
+ "finish_reason": "stop",
+ "tool_calls": [],
+ "role": "assistant",
+ "reasoning_content": None,
+ "reasoning_signature": None,
+ }
+
+ def _register_provider_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("provider.get_using", "获取当前聊天 Provider"),
+ call_handler=self._provider_get_using,
+ )
+ self.register(
+ self._builtin_descriptor("provider.get_by_id", "按 ID 获取 Provider"),
+ call_handler=self._provider_get_by_id,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.get_current_chat_provider_id",
+ "获取当前聊天 Provider ID",
+ ),
+ call_handler=self._provider_get_current_chat_provider_id,
+ )
+ self.register(
+ self._builtin_descriptor("provider.list_all", "列出聊天 Providers"),
+ call_handler=self._provider_list_all,
+ )
+ self.register(
+ self._builtin_descriptor("provider.list_all_tts", "列出 TTS Providers"),
+ call_handler=self._provider_list_all_tts,
+ )
+ self.register(
+ self._builtin_descriptor("provider.list_all_stt", "列出 STT Providers"),
+ call_handler=self._provider_list_all_stt,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.list_all_embedding",
+ "列出 Embedding Providers",
+ ),
+ call_handler=self._provider_list_all_embedding,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.list_all_rerank",
+ "列出 Rerank Providers",
+ ),
+ call_handler=self._provider_list_all_rerank,
+ )
+ self.register(
+ self._builtin_descriptor("provider.get_using_tts", "获取当前 TTS Provider"),
+ call_handler=self._provider_get_using_tts,
+ )
+ self.register(
+ self._builtin_descriptor("provider.get_using_stt", "获取当前 STT Provider"),
+ call_handler=self._provider_get_using_stt,
+ )
+ self.register(
+ self._builtin_descriptor("provider.stt.get_text", "STT 转写"),
+ call_handler=self._provider_stt_get_text,
+ )
+ self.register(
+ self._builtin_descriptor("provider.tts.get_audio", "TTS 合成音频"),
+ call_handler=self._provider_tts_get_audio,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.tts.support_stream",
+ "检查 TTS 流式支持",
+ ),
+ call_handler=self._provider_tts_support_stream,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.tts.get_audio_stream",
+ "流式 TTS 音频输出",
+ supports_stream=True,
+ cancelable=True,
+ ),
+ stream_handler=self._provider_tts_get_audio_stream,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.embedding.get_embedding",
+ "获取单条向量",
+ ),
+ call_handler=self._provider_embedding_get_embedding,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.embedding.get_embeddings",
+ "批量获取向量",
+ ),
+ call_handler=self._provider_embedding_get_embeddings,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.embedding.get_dim",
+ "获取向量维度",
+ ),
+ call_handler=self._provider_embedding_get_dim,
+ )
+ self.register(
+ self._builtin_descriptor("provider.rerank.rerank", "文档重排序"),
+ call_handler=self._provider_rerank_rerank,
+ )
+
+ def _register_provider_manager_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("provider.manager.set", "设置当前 Provider"),
+ call_handler=self._provider_manager_set,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.manager.get_by_id",
+ "按 ID 获取 Provider 管理记录",
+ ),
+ call_handler=self._provider_manager_get_by_id,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.manager.get_merged_provider_config",
+ "获取 Provider 合并配置",
+ ),
+ call_handler=self._provider_manager_get_merged_provider_config,
+ )
+ self.register(
+ self._builtin_descriptor("provider.manager.load", "运行时加载 Provider"),
+ call_handler=self._provider_manager_load,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.manager.terminate",
+ "终止已加载的 Provider",
+ ),
+ call_handler=self._provider_manager_terminate,
+ )
+ self.register(
+ self._builtin_descriptor("provider.manager.create", "创建 Provider"),
+ call_handler=self._provider_manager_create,
+ )
+ self.register(
+ self._builtin_descriptor("provider.manager.update", "更新 Provider"),
+ call_handler=self._provider_manager_update,
+ )
+ self.register(
+ self._builtin_descriptor("provider.manager.delete", "删除 Provider"),
+ call_handler=self._provider_manager_delete,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.manager.get_insts",
+ "列出已加载聊天 Provider",
+ ),
+ call_handler=self._provider_manager_get_insts,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "provider.manager.watch_changes",
+ "订阅 Provider 变更",
+ supports_stream=True,
+ cancelable=True,
+ ),
+ stream_handler=self._provider_manager_watch_changes,
+ )
+
+ def _register_agent_tool_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("llm_tool.manager.get", "获取 LLM 工具状态"),
+ call_handler=self._llm_tool_manager_get,
+ )
+ self.register(
+ self._builtin_descriptor("llm_tool.manager.activate", "激活 LLM 工具"),
+ call_handler=self._llm_tool_manager_activate,
+ )
+ self.register(
+ self._builtin_descriptor("llm_tool.manager.deactivate", "停用 LLM 工具"),
+ call_handler=self._llm_tool_manager_deactivate,
+ )
+ self.register(
+ self._builtin_descriptor("llm_tool.manager.add", "动态添加 LLM 工具"),
+ call_handler=self._llm_tool_manager_add,
+ )
+ self.register(
+ self._builtin_descriptor("llm_tool.manager.remove", "动态移除 LLM 工具"),
+ call_handler=self._llm_tool_manager_remove,
+ )
+ self.register(
+ self._builtin_descriptor("agent.tool_loop.run", "运行 mock tool loop"),
+ call_handler=self._agent_tool_loop_run,
+ )
+ self.register(
+ self._builtin_descriptor("agent.registry.list", "列出 Agent 元数据"),
+ call_handler=self._agent_registry_list,
+ )
+ self.register(
+ self._builtin_descriptor("agent.registry.get", "获取 Agent 元数据"),
+ call_handler=self._agent_registry_get,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py
new file mode 100644
index 0000000000..e56f979e9e
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class SessionCapabilityMixin(CapabilityRouterBridgeBase):
+ async def _session_plugin_is_enabled(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", ""))
+ plugin_name = str(payload.get("plugin_name", ""))
+ config = self._session_plugin_config(session)
+ enabled_plugins = {
+ str(item) for item in config.get("enabled_plugins", []) if str(item).strip()
+ }
+ disabled_plugins = {
+ str(item)
+ for item in config.get("disabled_plugins", [])
+ if str(item).strip()
+ }
+ if plugin_name in enabled_plugins:
+ return {"enabled": True}
+ return {"enabled": plugin_name not in disabled_plugins}
+
+ async def _session_plugin_filter_handlers(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", ""))
+ handlers = payload.get("handlers")
+ if not isinstance(handlers, list):
+ raise AstrBotError.invalid_input(
+ "session.plugin.filter_handlers 的 handlers 必须是 object 数组"
+ )
+ disabled_plugins = {
+ str(item)
+ for item in self._session_plugin_config(session).get("disabled_plugins", [])
+ if str(item).strip()
+ }
+ reserved_plugins = {
+ str(plugin.metadata.get("name", ""))
+ for plugin in self._plugins.values()
+ if bool(plugin.metadata.get("reserved", False))
+ }
+ filtered = []
+ for item in handlers:
+ if not isinstance(item, dict):
+ continue
+ plugin_name = str(item.get("plugin_name", ""))
+ if (
+ plugin_name
+ and plugin_name in disabled_plugins
+ and plugin_name not in reserved_plugins
+ ):
+ continue
+ filtered.append(dict(item))
+ return {"handlers": filtered}
+
+ async def _session_service_is_llm_enabled(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", ""))
+ config = self._session_service_config(session)
+ return {"enabled": bool(config.get("llm_enabled", True))}
+
+ async def _session_service_set_llm_status(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", ""))
+ config = self._session_service_config(session)
+ config["llm_enabled"] = bool(payload.get("enabled", False))
+ self._session_service_configs[session] = config
+ return {}
+
+ async def _session_service_is_tts_enabled(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", ""))
+ config = self._session_service_config(session)
+ return {"enabled": bool(config.get("tts_enabled", True))}
+
+ async def _session_service_set_tts_status(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ session = str(payload.get("session", ""))
+ config = self._session_service_config(session)
+ config["tts_enabled"] = bool(payload.get("enabled", False))
+ self._session_service_configs[session] = config
+ return {}
+
+ def _register_session_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("session.plugin.is_enabled", "获取会话级插件开关"),
+ call_handler=self._session_plugin_is_enabled,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "session.plugin.filter_handlers",
+ "按会话过滤 handler 元数据",
+ ),
+ call_handler=self._session_plugin_filter_handlers,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "session.service.is_llm_enabled",
+ "获取会话级 LLM 开关",
+ ),
+ call_handler=self._session_service_is_llm_enabled,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "session.service.set_llm_status",
+ "写入会话级 LLM 开关",
+ ),
+ call_handler=self._session_service_set_llm_status,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "session.service.is_tts_enabled",
+ "获取会话级 TTS 开关",
+ ),
+ call_handler=self._session_service_is_tts_enabled,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "session.service.set_tts_status",
+ "写入会话级 TTS 开关",
+ ),
+ call_handler=self._session_service_set_tts_status,
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py
new file mode 100644
index 0000000000..942f696989
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import CapabilityRouterBridgeBase
+
+
+class SkillCapabilityMixin(CapabilityRouterBridgeBase):
+ def _register_skill_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("skill.register", "注册插件 skill"),
+ call_handler=self._skill_register,
+ )
+ self.register(
+ self._builtin_descriptor("skill.unregister", "注销插件 skill"),
+ call_handler=self._skill_unregister,
+ )
+ self.register(
+ self._builtin_descriptor("skill.list", "列出插件 skill"),
+ call_handler=self._skill_list,
+ )
+
+ async def _skill_register(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ _token,
+ ) -> dict[str, str]:
+ plugin_id = self._require_caller_plugin_id("skill.register")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}")
+
+ skill_name = str(payload.get("name", "")).strip()
+ if not skill_name:
+ raise AstrBotError.invalid_input("skill.register requires name")
+ skill_path = str(payload.get("path", "")).strip()
+ if not skill_path:
+ raise AstrBotError.invalid_input("skill.register requires path")
+
+ path_obj = Path(skill_path)
+ skill_dir = path_obj.parent if path_obj.name == "SKILL.md" else path_obj
+
+ entry = {
+ "name": skill_name,
+ "description": str(payload.get("description", "") or ""),
+ "path": skill_path,
+ "skill_dir": str(skill_dir),
+ }
+ plugin.skills[skill_name] = entry
+ return dict(entry)
+
+ async def _skill_unregister(
+ self,
+ _request_id: str,
+ payload: dict[str, Any],
+ _token,
+ ) -> dict[str, bool]:
+ plugin_id = self._require_caller_plugin_id("skill.unregister")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}")
+ removed = (
+ plugin.skills.pop(str(payload.get("name", "")).strip(), None) is not None
+ )
+ return {"removed": removed}
+
+ async def _skill_list(
+ self,
+ _request_id: str,
+ _payload: dict[str, Any],
+ _token,
+ ) -> dict[str, list[dict[str, str]]]:
+ plugin_id = self._require_caller_plugin_id("skill.list")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}")
+ return {
+ "skills": [
+ dict(plugin.skills[name]) for name in sorted(plugin.skills.keys())
+ ]
+ }
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py
new file mode 100644
index 0000000000..f23e63ce4a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py
@@ -0,0 +1,370 @@
+from __future__ import annotations
+
+import json
+from typing import Any
+
+from ....errors import AstrBotError
+from ..bridge_base import (
+ CapabilityRouterBridgeBase,
+ _clone_chain_payload,
+ _clone_target_payload,
+)
+
+
+class SystemCapabilityMixin(CapabilityRouterBridgeBase):
+ @staticmethod
+ def _overlay_request_id(request_id: str, payload: dict[str, Any]) -> str:
+ scope_request_id = payload.get("_request_scope_id")
+ if isinstance(scope_request_id, str) and scope_request_id.strip():
+ return scope_request_id
+ return request_id
+
+ def _register_system_capabilities(self) -> None:
+ self.register(
+ self._builtin_descriptor("system.get_data_dir", "获取插件数据目录"),
+ call_handler=self._system_get_data_dir,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor("system.text_to_image", "文本转图片"),
+ call_handler=self._system_text_to_image,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor("system.html_render", "渲染 HTML 模板"),
+ call_handler=self._system_html_render,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.session_waiter.register",
+ "注册会话等待器",
+ ),
+ call_handler=self._system_session_waiter_register,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.session_waiter.unregister",
+ "注销会话等待器",
+ ),
+ call_handler=self._system_session_waiter_unregister,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor("system.event.react", "发送事件表情回应"),
+ call_handler=self._system_event_react,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor("system.event.send_typing", "发送输入中状态"),
+ call_handler=self._system_event_send_typing,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.event.send_streaming",
+ "发送事件流式消息",
+ ),
+ call_handler=self._system_event_send_streaming,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.event.send_streaming_chunk",
+ "推送事件流式消息分片",
+ ),
+ call_handler=self._system_event_send_streaming_chunk,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.event.send_streaming_close",
+ "关闭事件流式消息会话",
+ ),
+ call_handler=self._system_event_send_streaming_close,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.event.handler_whitelist.get",
+ "读取当前请求 handler 白名单",
+ ),
+ call_handler=self._system_event_handler_whitelist_get,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "system.event.handler_whitelist.set",
+ "写入当前请求 handler 白名单",
+ ),
+ call_handler=self._system_event_handler_whitelist_set,
+ exposed=False,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "registry.get_handlers_by_event_type",
+ "按事件类型列出 handler 元数据",
+ ),
+ call_handler=self._registry_get_handlers_by_event_type,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "registry.get_handler_by_full_name",
+ "按 full name 查询 handler 元数据",
+ ),
+ call_handler=self._registry_get_handler_by_full_name,
+ )
+ self.register(
+ self._builtin_descriptor(
+ "registry.command.register",
+ "注册动态命令路由",
+ ),
+ call_handler=self._registry_command_register,
+ )
+
+ def _ensure_request_overlay(self, request_id: str) -> dict[str, Any]:
+ overlay = self._request_overlays.get(request_id)
+ if overlay is None:
+ overlay = {
+ "handler_whitelist": None,
+ }
+ self._request_overlays[request_id] = overlay
+ return overlay
+
+ async def _system_get_data_dir(
+ self, _request_id: str, _payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("system.get_data_dir")
+ data_dir = self._plugin_data_dir(
+ plugin_id,
+ capability_name="system.get_data_dir",
+ )
+ data_dir.mkdir(parents=True, exist_ok=True)
+ return {"path": str(data_dir)}
+
+ async def _system_text_to_image(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ text = str(payload.get("text", ""))
+ if bool(payload.get("return_url", True)):
+ return {"result": f"mock://text_to_image/{text}"}
+ return {"result": f"{text}"}
+
+ async def _system_html_render(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ tmpl = str(payload.get("tmpl", ""))
+ data = payload.get("data")
+ if not isinstance(data, dict):
+ raise AstrBotError.invalid_input("system.html_render requires object data")
+ if bool(payload.get("return_url", True)):
+ return {"result": f"mock://html_render/{tmpl}"}
+ return {"result": json.dumps({"tmpl": tmpl, "data": data}, ensure_ascii=False)}
+
+ async def _system_event_handler_whitelist_get(
+ self, request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ overlay = self._ensure_request_overlay(
+ self._overlay_request_id(request_id, payload)
+ )
+ whitelist = overlay.get("handler_whitelist")
+ if whitelist is None:
+ return {"plugin_names": None}
+ return {"plugin_names": sorted(str(item) for item in whitelist)}
+
+ async def _system_event_handler_whitelist_set(
+ self, request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ overlay_request_id = self._overlay_request_id(request_id, payload)
+ overlay = self._ensure_request_overlay(overlay_request_id)
+ plugin_names_payload = payload.get("plugin_names")
+ if plugin_names_payload is None:
+ overlay["handler_whitelist"] = None
+ elif isinstance(plugin_names_payload, list):
+ overlay["handler_whitelist"] = {
+ str(item) for item in plugin_names_payload if str(item).strip()
+ }
+ else:
+ raise AstrBotError.invalid_input(
+ "system.event.handler_whitelist.set 的 plugin_names 必须是数组或 null"
+ )
+ return await self._system_event_handler_whitelist_get(
+ request_id,
+ {"_request_scope_id": overlay_request_id},
+ _token,
+ )
+
+ async def _registry_get_handlers_by_event_type(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ event_type = str(payload.get("event_type", "")).strip()
+ handlers: list[dict[str, Any]] = []
+ for plugin in self._plugins.values():
+ handlers.extend(
+ [
+ dict(handler)
+ for handler in plugin.handlers
+ if event_type in handler.get("event_types", [])
+ ]
+ )
+ if event_type == "message":
+ for plugin_name, routes in self._dynamic_command_routes.items():
+ for route in routes:
+ if not isinstance(route, dict):
+ continue
+ handlers.append(
+ {
+ "plugin_name": str(route.get("plugin_name", plugin_name)),
+ "handler_full_name": str(
+ route.get("handler_full_name", "")
+ ),
+ "trigger_type": (
+ "message"
+ if bool(route.get("use_regex", False))
+ else "command"
+ ),
+ "description": (
+ None
+ if route.get("desc") is None
+ else str(route.get("desc", "")).strip() or None
+ ),
+ "event_types": ["message"],
+ "enabled": True,
+ "group_path": [],
+ "priority": int(route.get("priority", 0) or 0),
+ "kind": "handler",
+ "require_admin": False,
+ "required_role": None,
+ }
+ )
+ return {"handlers": handlers}
+
+ async def _registry_get_handler_by_full_name(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ full_name = str(payload.get("full_name", "")).strip()
+ for plugin in self._plugins.values():
+ for handler in plugin.handlers:
+ if handler.get("handler_full_name") == full_name:
+ return {"handler": dict(handler)}
+ return {"handler": None}
+
+ async def _registry_command_register(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ source_event_type = str(payload.get("source_event_type", "")).strip()
+ if source_event_type not in {"astrbot_loaded", "platform_loaded"}:
+ raise AstrBotError.invalid_input(
+ "register_commands is only available in astrbot_loaded/platform_loaded events"
+ )
+ if bool(payload.get("ignore_prefix", False)):
+ raise AstrBotError.invalid_input(
+ "register_commands(ignore_prefix=True) is unsupported in SDK runtime"
+ )
+ priority_value = payload.get("priority", 0)
+ if isinstance(priority_value, bool) or not isinstance(priority_value, int):
+ raise AstrBotError.invalid_input(
+ "registry.command.register 的 priority 必须是 integer"
+ )
+ plugin_id = self._require_caller_plugin_id("registry.command.register")
+ self.register_dynamic_command_route(
+ plugin_id=plugin_id,
+ command_name=str(payload.get("command_name", "")),
+ handler_full_name=str(payload.get("handler_full_name", "")),
+ desc=str(payload.get("desc", "")),
+ priority=priority_value,
+ use_regex=bool(payload.get("use_regex", False)),
+ )
+ return {}
+
+ async def _system_session_waiter_register(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("system.session_waiter.register")
+ session_key = str(payload.get("session_key", "")).strip()
+ if not session_key:
+ raise AstrBotError.invalid_input(
+ "system.session_waiter.register requires session_key"
+ )
+ self._session_waiters.setdefault(plugin_id, set()).add(session_key)
+ return {}
+
+ async def _system_session_waiter_unregister(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ plugin_id = self._require_caller_plugin_id("system.session_waiter.unregister")
+ session_key = str(payload.get("session_key", "")).strip()
+ plugin_waiters = self._session_waiters.get(plugin_id)
+ if plugin_waiters is None:
+ return {}
+ plugin_waiters.discard(session_key)
+ if not plugin_waiters:
+ self._session_waiters.pop(plugin_id, None)
+ return {}
+
+ async def _system_event_react(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self.event_actions.append(
+ {
+ "action": "react",
+ "emoji": str(payload.get("emoji", "")),
+ "target": _clone_target_payload(payload.get("target")),
+ }
+ )
+ return {"supported": True}
+
+ async def _system_event_send_typing(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ self.event_actions.append(
+ {
+ "action": "send_typing",
+ "target": _clone_target_payload(payload.get("target")),
+ }
+ )
+ return {"supported": True}
+
+ async def _system_event_send_streaming(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ stream_id = f"mock-stream-{len(self._event_streams) + 1}"
+ stream_state: dict[str, Any] = {
+ "target": _clone_target_payload(payload.get("target")),
+ "chunks": [],
+ "use_fallback": bool(payload.get("use_fallback", False)),
+ }
+ self._event_streams[stream_id] = stream_state
+ return {"supported": True, "stream_id": stream_id}
+
+ async def _system_event_send_streaming_chunk(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ stream = self._event_streams.get(str(payload.get("stream_id", "")))
+ if stream is None:
+ raise AstrBotError.invalid_input("Unknown sdk event streaming session")
+ chain = payload.get("chain")
+ if not isinstance(chain, list):
+ raise AstrBotError.invalid_input(
+ "system.event.send_streaming_chunk requires a chain array"
+ )
+ stream["chunks"].append({"chain": _clone_chain_payload(chain)})
+ return {}
+
+ async def _system_event_send_streaming_close(
+ self, _request_id: str, payload: dict[str, Any], _token
+ ) -> dict[str, Any]:
+ stream_id = str(payload.get("stream_id", ""))
+ stream = self._event_streams.pop(stream_id, None)
+ if stream is None:
+ raise AstrBotError.invalid_input("Unknown sdk event streaming session")
+ self.event_actions.append(
+ {
+ "action": "send_streaming",
+ "target": stream["target"],
+ "chunks": list(stream["chunks"]),
+ "use_fallback": bool(stream["use_fallback"]),
+ }
+ )
+ return {"supported": True}
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py b/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py
new file mode 100644
index 0000000000..cb8ba44c2a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+import re
+import shlex
+from collections.abc import Sequence
+from typing import Any
+
+from ..protocol.descriptors import ParamSpec
+
+
+def normalize_command_invocation(text: str) -> str:
+ normalized = re.sub(r"\s+", " ", str(text).strip())
+ if not normalized:
+ return ""
+ normalized = re.sub(r"^/\s*", "", normalized)
+ return normalized.strip()
+
+
+def command_root_name(text: str) -> str:
+ normalized = normalize_command_invocation(text)
+ if not normalized:
+ return ""
+ return normalized.split(" ", 1)[0]
+
+
+def match_command_name(text: str, command_name: str) -> str | None:
+ normalized_command = normalize_command_invocation(command_name)
+ if not normalized_command:
+ return None
+ command_tokens = [re.escape(token) for token in normalized_command.split()]
+ command_pattern = r"\s+".join(command_tokens)
+ pattern = rf"^\s*/?\s*{command_pattern}(?:\s+(?P.*))?\s*$"
+ match = re.match(pattern, text)
+ if match is None:
+ return None
+ remainder = match.group("remainder")
+ if remainder is None:
+ return ""
+ return remainder.strip()
+
+
+def build_command_args(
+ param_specs: Sequence[ParamSpec], remainder: str
+) -> dict[str, Any]:
+ if not param_specs or not remainder:
+ return {}
+ if len(param_specs) == 1:
+ return {param_specs[0].name: remainder}
+ parts = split_command_remainder(remainder)
+ values: dict[str, Any] = {}
+ for index, spec in enumerate(param_specs):
+ if index >= len(parts):
+ break
+ if spec.type == "greedy_str":
+ values[spec.name] = " ".join(parts[index:])
+ break
+ values[spec.name] = parts[index]
+ return values
+
+
+def build_regex_args(
+ param_specs: Sequence[ParamSpec], match: re.Match[str]
+) -> dict[str, Any]:
+ named = {
+ key: value for key, value in match.groupdict().items() if value is not None
+ }
+ names = [spec.name for spec in param_specs if spec.name not in named]
+ positional = [value for value in match.groups() if value is not None]
+ for index, value in enumerate(positional):
+ if index >= len(names):
+ break
+ named[names[index]] = value
+ return named
+
+
+def split_command_remainder(remainder: str) -> list[str]:
+ if not remainder:
+ return []
+ try:
+ return shlex.split(remainder)
+ except ValueError:
+ return remainder.split()
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py b/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py
new file mode 100644
index 0000000000..40d162d355
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py
@@ -0,0 +1,156 @@
+"""Support helpers for runtime loader reflection and signature validation.
+
+本模块提供运行时加载器所需的反射和签名验证工具函数,主要用于:
+1. 解析 handler/capability 函数签名,提取参数类型信息
+2. 识别需要注入的框架对象(如 Context、MessageEvent、ScheduleContext)
+3. 构建参数规格 (ParamSpec) 供协议层使用
+4. 验证 schedule handler 的签名合法性
+
+关键函数:
+- build_param_specs: 从 handler 签名构建参数规格列表
+- is_injected_parameter: 判断参数是否应由框架注入而非从命令行解析
+- validate_schedule_signature: 确保 schedule handler 只接受允许的注入参数
+"""
+
+from __future__ import annotations
+
+import inspect
+import typing
+from typing import Any, Literal, TypeAlias, cast
+
+from .._internal.injected_params import is_framework_injected_parameter
+from .._internal.typing_utils import unwrap_optional
+from ..decorators import get_capability_meta, get_handler_meta
+from ..protocol.descriptors import ParamSpec
+from ..types import GreedyStr
+
+ParamTypeName: TypeAlias = Literal[
+ "str", "int", "float", "bool", "optional", "greedy_str"
+]
+OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None
+
+
+def is_injected_parameter(annotation: Any, parameter_name: str) -> bool:
+ return is_framework_injected_parameter(parameter_name, annotation)
+
+
+def param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]:
+ normalized, is_optional = unwrap_optional(annotation)
+ if normalized is GreedyStr:
+ return "greedy_str", None, False
+ if normalized in {int, float, bool, str}:
+ normalized_name = cast(
+ Literal["str", "int", "float", "bool"], normalized.__name__
+ )
+ if is_optional:
+ return "optional", normalized_name, False
+ return normalized_name, None, True
+ if is_optional:
+ return "optional", "str", False
+ return "str", None, True
+
+
+def build_param_specs(handler: Any) -> list[ParamSpec]:
+ try:
+ signature = inspect.signature(handler)
+ except (TypeError, ValueError):
+ return []
+ try:
+ type_hints = typing.get_type_hints(handler)
+ except Exception:
+ type_hints = {}
+
+ specs: list[ParamSpec] = []
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+ annotation = type_hints.get(parameter.name)
+ if is_injected_parameter(annotation, parameter.name):
+ continue
+ param_type, inner_type, required = param_type_name(annotation)
+ if parameter.default is not inspect.Parameter.empty:
+ required = False
+ specs.append(
+ ParamSpec(
+ name=parameter.name,
+ type=param_type,
+ required=required,
+ inner_type=inner_type,
+ )
+ )
+
+ greedy_indexes = [
+ index for index, spec in enumerate(specs) if spec.type == "greedy_str"
+ ]
+ if greedy_indexes and greedy_indexes[-1] != len(specs) - 1:
+ greedy_spec = specs[greedy_indexes[-1]]
+ raise ValueError(f"参数 '{greedy_spec.name}' (GreedyStr) 必须是最后一个参数。")
+ return specs
+
+
+def validate_schedule_signature(handler: Any) -> None:
+ try:
+ signature = inspect.signature(handler)
+ except (TypeError, ValueError):
+ return
+ allowed_names = {"ctx", "context", "sched", "schedule"}
+ invalid = [
+ parameter.name
+ for parameter in signature.parameters.values()
+ if parameter.kind
+ in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ and parameter.name not in allowed_names
+ ]
+ if invalid:
+ raise ValueError(
+ "Schedule handler 只允许注入 ctx/context 和 sched/schedule 参数。"
+ )
+
+
+def resolve_handler_candidate(instance: Any, name: str) -> tuple[Any, Any] | None:
+ try:
+ raw = inspect.getattr_static(instance, name)
+ except AttributeError:
+ return None
+ candidates = [raw]
+ wrapped = getattr(raw, "__func__", None)
+ if wrapped is not None:
+ candidates.append(wrapped)
+ for candidate in candidates:
+ meta = get_handler_meta(candidate)
+ if meta is not None and meta.trigger is not None:
+ return getattr(instance, name), meta
+ return None
+
+
+def resolve_capability_candidate(instance: Any, name: str) -> tuple[Any, Any] | None:
+ try:
+ raw = inspect.getattr_static(instance, name)
+ except AttributeError:
+ return None
+ candidates = [raw]
+ wrapped = getattr(raw, "__func__", None)
+ if wrapped is not None:
+ candidates.append(wrapped)
+ for candidate in candidates:
+ meta = get_capability_meta(candidate)
+ if meta is not None:
+ return getattr(instance, name), meta
+ return None
+
+
+__all__ = [
+ "build_param_specs",
+ "is_injected_parameter",
+ "param_type_name",
+ "resolve_capability_candidate",
+ "resolve_handler_candidate",
+ "unwrap_optional",
+ "validate_schedule_signature",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py b/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py
new file mode 100644
index 0000000000..29d2671caa
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py
@@ -0,0 +1,28 @@
+"""Shared stream execution primitives for runtime internals.
+
+本模块定义流式执行的通用数据结构 StreamExecution,用于:
+1. 封装异步生成器迭代器,支持逐块返回数据
+2. 提供收集完成后的聚合回调 (finalize)
+3. 控制是否需要在内存中累积所有分块
+
+使用场景:
+- LLM 流式对话返回逐字输出
+- DB watch 监听键值变更流
+- 任何需要分块返回而非一次性返回的能力调用
+"""
+
+from __future__ import annotations
+
+from collections.abc import AsyncIterator, Callable
+from dataclasses import dataclass
+from typing import Any
+
+
+@dataclass(slots=True)
+class StreamExecution:
+ iterator: AsyncIterator[dict[str, Any]]
+ finalize: Callable[[list[dict[str, Any]]], dict[str, Any]]
+ collect_chunks: bool = True
+
+
+__all__ = ["StreamExecution"]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py b/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py
new file mode 100644
index 0000000000..b293b6b7d7
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py
@@ -0,0 +1,201 @@
+"""启动引导入口。
+
+对外提供三个顶层启动函数:
+
+- ``run_supervisor``: 启动 Supervisor 进程
+- ``run_plugin_worker``: 启动单插件或组 Worker 进程
+- ``run_websocket_server``: 以 WebSocket 方式启动 Worker
+
+运行时核心类分布在同目录的子模块:
+
+- ``runtime.supervisor``: ``SupervisorRuntime`` / ``WorkerSession``
+- ``runtime.worker``: ``PluginWorkerRuntime`` / ``GroupWorkerRuntime``
+"""
+
+from __future__ import annotations
+
+import asyncio
+import sys
+from pathlib import Path
+from typing import IO
+
+from astrbot_sdk.protocol.codec import (
+ JsonProtocolCodec,
+ MsgpackProtocolCodec,
+ ProtocolCodec,
+)
+
+from .loader import PluginEnvironmentManager
+from .supervisor import (
+ SupervisorRuntime,
+ WorkerSession,
+ _install_signal_handlers,
+ _prepare_stdio_transport,
+ _sdk_source_dir,
+ _wait_for_shutdown,
+)
+from .transport import (
+ StdioTransport,
+ WebSocketServerTransport,
+ build_websocket_server_ssl_context,
+)
+from .worker import GroupWorkerRuntime, PluginWorkerRuntime, _load_plugin_specs
+
+__all__ = [
+ "GroupWorkerRuntime",
+ "PluginWorkerRuntime",
+ "SupervisorRuntime",
+ "WorkerSession",
+ "_install_signal_handlers",
+ "_prepare_stdio_transport",
+ "_sdk_source_dir",
+ "_wait_for_shutdown",
+ "run_supervisor",
+ "run_plugin_worker",
+ "run_websocket_server",
+]
+
+
+def _resolve_wire_codec(wire_codec: str | ProtocolCodec | None = None) -> ProtocolCodec:
+ if isinstance(wire_codec, ProtocolCodec):
+ return wire_codec
+ if wire_codec is None or wire_codec == "msgpack":
+ return MsgpackProtocolCodec()
+ if wire_codec == "json":
+ return JsonProtocolCodec()
+ raise ValueError(f"unsupported wire codec: {wire_codec}")
+
+
+async def run_supervisor(
+ *,
+ plugins_dir: Path = Path("plugins"),
+ stdin: IO[str] | None = None,
+ stdout: IO[str] | None = None,
+ env_manager: PluginEnvironmentManager | None = None,
+ workers_manifest: Path | None = None,
+ wire_codec: str | ProtocolCodec | None = None,
+) -> None:
+ transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport(
+ stdin,
+ stdout,
+ )
+ transport = StdioTransport(stdin=transport_stdin, stdout=transport_stdout)
+ resolved_wire_codec = _resolve_wire_codec(wire_codec)
+ runtime = SupervisorRuntime(
+ transport=transport,
+ plugins_dir=plugins_dir,
+ env_manager=env_manager,
+ workers_manifest=workers_manifest,
+ wire_codec=resolved_wire_codec,
+ )
+
+ try:
+ await runtime.start()
+ stop_event = asyncio.Event()
+ _install_signal_handlers(stop_event)
+ await _wait_for_shutdown(runtime.peer, stop_event)
+ finally:
+ await runtime.stop()
+ if original_stdout is not None:
+ sys.stdout = original_stdout
+
+
+async def run_plugin_worker(
+ *,
+ plugin_dir: Path | None = None,
+ group_metadata: Path | None = None,
+ stdin: IO[str] | None = None,
+ stdout: IO[str] | None = None,
+ wire_codec: str | ProtocolCodec | None = None,
+) -> None:
+ if plugin_dir is None and group_metadata is None:
+ raise ValueError("plugin_dir or group_metadata is required")
+ if plugin_dir is not None and group_metadata is not None:
+ raise ValueError("plugin_dir and group_metadata are mutually exclusive")
+
+ transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport(
+ stdin,
+ stdout,
+ )
+ transport = StdioTransport(stdin=transport_stdin, stdout=transport_stdout)
+ resolved_wire_codec = _resolve_wire_codec(wire_codec)
+ if group_metadata is not None:
+ runtime = GroupWorkerRuntime(
+ group_metadata_path=group_metadata,
+ transport=transport,
+ wire_codec=resolved_wire_codec,
+ )
+ else:
+ # 前置互斥校验已保证单插件模式下 plugin_dir 一定存在;这里显式收窄,
+ # 避免把入口层的 Optional 继续传播到单插件运行时。
+ assert plugin_dir is not None
+ runtime = PluginWorkerRuntime(
+ plugin_dir=plugin_dir,
+ transport=transport,
+ wire_codec=resolved_wire_codec,
+ )
+ try:
+ await runtime.start()
+ stop_event = asyncio.Event()
+ _install_signal_handlers(stop_event)
+ await _wait_for_shutdown(runtime.peer, stop_event)
+ finally:
+ await runtime.stop()
+ if original_stdout is not None:
+ sys.stdout = original_stdout
+
+
+async def run_websocket_server(
+ *,
+ worker_id: str | None = None,
+ host: str = "127.0.0.1",
+ port: int = 8765,
+ path: str = "/",
+ plugin_dirs: list[Path] | None = None,
+ tls_ca_file: Path | None = None,
+ tls_cert_file: Path | None = None,
+ tls_key_file: Path | None = None,
+ wire_codec: str | ProtocolCodec | None = None,
+) -> None:
+ resolved_plugin_dirs = [path.resolve() for path in (plugin_dirs or [Path.cwd()])]
+ resolved_wire_codec = _resolve_wire_codec(wire_codec)
+ if tls_ca_file is None or tls_cert_file is None or tls_key_file is None:
+ raise ValueError(
+ "tls_ca_file, tls_cert_file, and tls_key_file are required for websocket workers"
+ )
+ transport = WebSocketServerTransport(
+ host=host,
+ port=port,
+ path=path,
+ ssl_context=build_websocket_server_ssl_context(
+ ca_file=tls_ca_file,
+ cert_file=tls_cert_file,
+ key_file=tls_key_file,
+ ),
+ )
+ resolved_worker_id = worker_id
+ if resolved_worker_id is None and len(resolved_plugin_dirs) == 1:
+ resolved_worker_id = _load_plugin_specs([resolved_plugin_dirs[0]])[0].name
+ if len(resolved_plugin_dirs) == 1:
+ runtime = PluginWorkerRuntime(
+ plugin_dir=resolved_plugin_dirs[0],
+ worker_id=resolved_worker_id,
+ transport=transport,
+ wire_codec=resolved_wire_codec,
+ )
+ else:
+ if resolved_worker_id is None:
+ raise ValueError("worker_id is required when serving multiple plugins")
+ runtime = GroupWorkerRuntime(
+ plugin_dirs=resolved_plugin_dirs,
+ worker_id=resolved_worker_id,
+ transport=transport,
+ wire_codec=resolved_wire_codec,
+ )
+ try:
+ await runtime.start()
+ stop_event = asyncio.Event()
+ _install_signal_handlers(stop_event)
+ await _wait_for_shutdown(runtime.peer, stop_event)
+ finally:
+ await runtime.stop()
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py b/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py
new file mode 100644
index 0000000000..1e149413a1
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py
@@ -0,0 +1,515 @@
+"""Capability invocation dispatcher.
+
+本模块实现能力调用的分发器,负责:
+1. 接收能力调用请求,定位对应的已注册能力
+2. 构建调用上下文 (Context),注入必要的依赖
+3. 支持同步和流式两种调用模式
+4. 管理活跃调用任务的生命周期和取消
+
+参数注入策略:
+按类型注入 Context / CancelToken / dict,或按参数名注入
+ctx / context / payload / input / data / cancel_token / token。
+若无法匹配则抛出详细的错误信息,帮助开发者定位问题。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import json
+import typing
+from collections.abc import AsyncIterator, Sequence
+from typing import Any, cast, get_type_hints
+
+from .._internal.invocation_context import caller_plugin_scope
+from .._internal.plugin_logger import PluginLogger
+from .._internal.sdk_logger import logger
+from .._internal.star_runtime import bind_star_runtime
+from .._internal.typing_utils import unwrap_optional
+from ..context import CancelToken, Context
+from ..errors import AstrBotError
+from ..events import MessageEvent
+from ..star import Star
+from ._streaming import StreamExecution
+from .loader import LoadedCapability, LoadedLLMTool
+
+
+class CapabilityDispatcher:
+ def __init__(
+ self,
+ *,
+ plugin_id: str,
+ peer,
+ capabilities: Sequence[LoadedCapability],
+ llm_tools: Sequence[LoadedLLMTool] | None = None,
+ ) -> None:
+ self._plugin_id = plugin_id
+ self._peer = peer
+ self._capabilities = {item.descriptor.name: item for item in capabilities}
+ self._llm_tools: dict[tuple[str, str], LoadedLLMTool] = {}
+ try:
+ setattr(peer, "_sdk_capability_dispatcher", self)
+ except AttributeError:
+ logger.warning(
+ f"Failed to attach _sdk_capability_dispatcher to peer {peer}, "
+ "dynamic LLM tool registration may not work"
+ )
+ for item in llm_tools or []:
+ self._register_llm_tool(item, item.plugin_id or plugin_id)
+ self._active: dict[str, tuple[asyncio.Task[Any], CancelToken]] = {}
+
+ def _register_llm_tool(
+ self,
+ loaded: LoadedLLMTool,
+ owner_plugin: str,
+ ) -> None:
+ self._llm_tools[(owner_plugin, loaded.spec.name)] = loaded
+ if loaded.spec.handler_ref and loaded.spec.handler_ref != loaded.spec.name:
+ self._llm_tools[(owner_plugin, loaded.spec.handler_ref)] = loaded
+
+ def add_dynamic_llm_tool(
+ self,
+ *,
+ plugin_id: str,
+ spec,
+ callable_obj,
+ owner: Any | None = None,
+ ) -> None:
+ self.remove_llm_tool(plugin_id, spec.name)
+ loaded = LoadedLLMTool(
+ spec=spec.model_copy(deep=True),
+ callable=callable_obj,
+ owner=owner,
+ plugin_id=plugin_id,
+ )
+ self._register_llm_tool(loaded, plugin_id)
+
+ def remove_llm_tool(self, plugin_id: str, name: str) -> bool:
+ removed = False
+ for key, value in list(self._llm_tools.items()):
+ if key[0] != plugin_id:
+ continue
+ spec_name = str(getattr(value.spec, "name", "")).strip()
+ handler_ref = str(getattr(value.spec, "handler_ref", "") or "").strip()
+ if name not in {spec_name, handler_ref}:
+ continue
+ self._llm_tools.pop(key, None)
+ removed = True
+ return removed
+
+ async def invoke(
+ self,
+ message,
+ cancel_token: CancelToken,
+ ) -> dict[str, Any] | StreamExecution:
+ if message.capability == "internal.llm_tool.execute":
+ return await self._invoke_registered_llm_tool(message, cancel_token)
+
+ loaded = self._capabilities.get(message.capability)
+ if loaded is None:
+ raise LookupError(f"capability not found: {message.capability}")
+
+ plugin_id = self._resolve_plugin_id(loaded)
+ ctx = Context(
+ peer=self._peer,
+ plugin_id=plugin_id,
+ request_id=message.id,
+ cancel_token=cancel_token,
+ )
+ bound_logger = cast(PluginLogger, ctx.logger).bind(
+ plugin_id=plugin_id,
+ request_id=message.id,
+ capability=message.capability,
+ session_id=self._logger_session_id(dict(message.input)),
+ event_type=self._logger_event_type(dict(message.input)),
+ )
+ ctx.logger = bound_logger
+
+ with caller_plugin_scope(plugin_id):
+ task = asyncio.create_task(
+ self._run_capability(
+ loaded,
+ payload=dict(message.input),
+ ctx=ctx,
+ cancel_token=cancel_token,
+ stream=bool(message.stream),
+ )
+ )
+ self._active[message.id] = (task, cancel_token)
+ try:
+ return await task
+ finally:
+ self._active.pop(message.id, None)
+
+ async def _invoke_registered_llm_tool(
+ self,
+ message,
+ cancel_token: CancelToken,
+ ) -> dict[str, Any]:
+ payload = dict(message.input)
+ plugin_id = str(payload.get("plugin_id") or self._plugin_id)
+ tool_name = str(payload.get("tool_name", ""))
+ handler_ref = str(payload.get("handler_ref") or tool_name)
+ loaded = self._llm_tools.get((plugin_id, handler_ref))
+ if loaded is None:
+ loaded = self._llm_tools.get((plugin_id, tool_name))
+ if loaded is None:
+ raise LookupError(f"llm tool not found: {plugin_id}:{tool_name}")
+
+ event_payload = payload.get("event")
+ ctx = Context(
+ peer=self._peer,
+ plugin_id=plugin_id,
+ request_id=message.id,
+ cancel_token=cancel_token,
+ source_event_payload=event_payload
+ if isinstance(event_payload, dict)
+ else None,
+ )
+ bound_logger = cast(PluginLogger, ctx.logger).bind(
+ plugin_id=plugin_id,
+ request_id=message.id,
+ capability="internal.llm_tool.execute",
+ session_id=self._logger_session_id(payload),
+ event_type=self._logger_event_type(payload),
+ )
+ ctx.logger = bound_logger
+ event = MessageEvent.from_payload(
+ event_payload if isinstance(event_payload, dict) else {},
+ context=ctx,
+ )
+ self._bind_event_reply_handler(ctx, event)
+ tool_args = payload.get("tool_args")
+ normalized_args = dict(tool_args) if isinstance(tool_args, dict) else {}
+
+ with caller_plugin_scope(plugin_id):
+ task = asyncio.create_task(
+ self._run_registered_llm_tool(loaded, event, ctx, normalized_args)
+ )
+ self._active[message.id] = (task, cancel_token)
+ try:
+ return await task
+ finally:
+ self._active.pop(message.id, None)
+
+ def _bind_event_reply_handler(self, ctx: Context, event: MessageEvent) -> None:
+ async def reply(text: str) -> None:
+ try:
+ await ctx.platform.send(event.session_ref or event.session_id, text)
+ except TypeError:
+ send = getattr(self._peer, "send", None)
+ if not callable(send):
+ raise
+ result = send(event.session_id, text)
+ if inspect.isawaitable(result):
+ await result
+
+ event.bind_reply_handler(reply)
+
+ async def _run_registered_llm_tool(
+ self,
+ loaded: LoadedLLMTool,
+ event: MessageEvent,
+ ctx: Context,
+ tool_args: dict[str, Any],
+ ) -> dict[str, Any]:
+ owner = loaded.owner if isinstance(loaded.owner, Star) else None
+ with bind_star_runtime(owner, ctx):
+ result = loaded.callable(
+ *self._build_tool_args(
+ loaded.callable,
+ event,
+ ctx,
+ tool_args,
+ )
+ )
+ if inspect.isasyncgen(result):
+ raise AstrBotError.protocol_error(
+ "SDK LLM tool must return awaitable result, async generator is unsupported"
+ )
+ if inspect.isawaitable(result):
+ result = await result
+ if result is None:
+ # content=None means the tool completed successfully but produced no
+ # textual payload. The core bridge preserves this as a real None.
+ return {"content": None, "success": True}
+ if isinstance(result, dict):
+ return {
+ "content": json.dumps(result, ensure_ascii=False, default=str),
+ "success": True,
+ }
+ return {"content": str(result), "success": True}
+
+ def _build_tool_args(
+ self,
+ handler,
+ event: MessageEvent,
+ ctx: Context,
+ tool_args: dict[str, Any],
+ ) -> list[Any]:
+ signature = inspect.signature(handler)
+ args: list[Any] = []
+ type_hints: dict[str, Any] = {}
+ try:
+ type_hints = get_type_hints(handler)
+ except Exception:
+ type_hints = {}
+
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+
+ injected = None
+ param_type = type_hints.get(parameter.name)
+ if param_type is not None:
+ injected = self._inject_tool_by_type(param_type, event, ctx)
+ if injected is None:
+ if parameter.name == "event":
+ injected = event
+ elif parameter.name in {"ctx", "context"}:
+ injected = ctx
+ elif parameter.name in tool_args:
+ injected = tool_args[parameter.name]
+ if injected is None:
+ if parameter.default is not parameter.empty:
+ continue
+ raise TypeError(
+ f"SDK LLM tool '{getattr(handler, '__name__', repr(handler))}' missing required argument '{parameter.name}'"
+ )
+ args.append(injected)
+ return args
+
+ def _inject_tool_by_type(
+ self,
+ param_type: Any,
+ event: MessageEvent,
+ ctx: Context,
+ ) -> Any:
+ param_type, _is_optional = unwrap_optional(param_type)
+
+ if param_type is Context or (
+ isinstance(param_type, type) and issubclass(param_type, Context)
+ ):
+ return ctx
+ if param_type is MessageEvent or (
+ isinstance(param_type, type) and issubclass(param_type, MessageEvent)
+ ):
+ return event
+ return None
+
+ def _resolve_plugin_id(self, loaded: LoadedCapability) -> str:
+ if loaded.plugin_id:
+ return loaded.plugin_id
+ return self._plugin_id
+
+ @staticmethod
+ def _logger_session_id(payload: dict[str, Any]) -> str:
+ if isinstance(payload.get("event"), dict):
+ return str(payload["event"].get("session_id", ""))
+ return str(payload.get("session", ""))
+
+ @staticmethod
+ def _logger_event_type(payload: dict[str, Any]) -> str:
+ if isinstance(payload.get("event"), dict):
+ event_payload = payload["event"]
+ return str(
+ event_payload.get("event_type")
+ or event_payload.get("type")
+ or event_payload.get("message_type")
+ or "message"
+ )
+ if payload.get("session") is not None:
+ return "capability"
+ return "capability"
+
+ async def cancel(self, request_id: str) -> None:
+ active = self._active.get(request_id)
+ if active is None:
+ return
+ task, cancel_token = active
+ cancel_token.cancel()
+ task.cancel()
+
+ async def _run_capability(
+ self,
+ loaded: LoadedCapability,
+ *,
+ payload: dict[str, Any],
+ ctx: Context,
+ cancel_token: CancelToken,
+ stream: bool,
+ ) -> dict[str, Any] | StreamExecution:
+ result = loaded.callable(
+ *self._build_args(
+ loaded.callable,
+ payload,
+ ctx,
+ cancel_token,
+ plugin_id=self._resolve_plugin_id(loaded),
+ capability_name=loaded.descriptor.name,
+ )
+ )
+ if stream:
+ if inspect.isasyncgen(result):
+ return StreamExecution(
+ iterator=self._iterate_generator(result),
+ finalize=lambda chunks: {"items": chunks},
+ )
+ if inspect.isawaitable(result):
+ result = await result
+ if inspect.isasyncgen(result):
+ return StreamExecution(
+ iterator=self._iterate_generator(result),
+ finalize=lambda chunks: {"items": chunks},
+ )
+ if isinstance(result, StreamExecution):
+ return result
+ raise AstrBotError.protocol_error(
+ "stream=true 的插件 capability 必须返回 async generator 或 StreamExecution"
+ )
+
+ if inspect.isasyncgen(result):
+ raise AstrBotError.protocol_error(
+ "stream=false 的插件 capability 不能返回 async generator"
+ )
+ if inspect.isawaitable(result):
+ result = await result
+ return self._normalize_output(result)
+
+ def _build_args(
+ self,
+ handler,
+ payload: dict[str, Any],
+ ctx: Context,
+ cancel_token: CancelToken,
+ *,
+ plugin_id: str | None = None,
+ capability_name: str | None = None,
+ ) -> list[Any]:
+ signature = inspect.signature(handler)
+ args: list[Any] = []
+
+ type_hints: dict[str, Any] = {}
+ try:
+ type_hints = get_type_hints(handler)
+ except Exception:
+ pass
+
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+
+ injected = None
+ param_type = type_hints.get(parameter.name)
+ if param_type is not None:
+ injected = self._inject_by_type(param_type, payload, ctx, cancel_token)
+
+ if injected is None:
+ if parameter.name in {"ctx", "context"}:
+ injected = ctx
+ elif parameter.name in {"payload", "input", "data"}:
+ injected = payload
+ elif parameter.name in {"cancel_token", "token"}:
+ injected = cancel_token
+
+ if injected is None:
+ if parameter.default is not parameter.empty:
+ continue
+ raise TypeError(
+ self._format_capability_injection_error(
+ handler=handler,
+ parameter_name=parameter.name,
+ plugin_id=plugin_id,
+ capability_name=capability_name,
+ payload=payload,
+ )
+ )
+ args.append(injected)
+
+ return args
+
+ def _inject_by_type(
+ self,
+ param_type: Any,
+ payload: dict[str, Any],
+ ctx: Context,
+ cancel_token: CancelToken,
+ ) -> Any:
+ param_type, _is_optional = unwrap_optional(param_type)
+ origin = typing.get_origin(param_type)
+
+ if param_type is Context or (
+ isinstance(param_type, type) and issubclass(param_type, Context)
+ ):
+ return ctx
+ if param_type is CancelToken or (
+ isinstance(param_type, type) and issubclass(param_type, CancelToken)
+ ):
+ return cancel_token
+ if param_type is dict or origin is dict:
+ return payload
+ return None
+
+ def _format_capability_injection_error(
+ self,
+ *,
+ handler,
+ parameter_name: str,
+ plugin_id: str | None,
+ capability_name: str | None,
+ payload: dict[str, Any],
+ ) -> str:
+ plugin_text = plugin_id or self._plugin_id
+ target = capability_name or getattr(handler, "__name__", "")
+ payload_keys = sorted(str(key) for key in payload.keys())
+ payload_keys_text = ", ".join(payload_keys) if payload_keys else ""
+ return (
+ f"插件 '{plugin_text}' 的 capability '{target}' 参数注入失败:"
+ f"必填参数 '{parameter_name}' 无法注入。"
+ f"签名: {getattr(handler, '__name__', '')}"
+ f"{self._callable_signature(handler)}。"
+ "当前支持按类型注入 Context / CancelToken / dict,"
+ "按参数名注入 ctx / context / payload / input / data / cancel_token / token,"
+ f"以及 payload 中现有键:{payload_keys_text}。"
+ )
+
+ async def _iterate_generator(
+ self,
+ generator: AsyncIterator[Any],
+ ) -> AsyncIterator[dict[str, Any]]:
+ async for item in generator:
+ yield self._normalize_chunk(item)
+
+ def _normalize_chunk(self, item: Any) -> dict[str, Any]:
+ output = self._normalize_output(item)
+ if output:
+ return output
+ return {"ok": True}
+
+ def _normalize_output(self, result: Any) -> dict[str, Any]:
+ if result is None:
+ return {}
+ if isinstance(result, dict):
+ return result
+ model_dump = getattr(result, "model_dump", None)
+ if callable(model_dump):
+ dumped = model_dump()
+ if isinstance(dumped, dict):
+ return dumped
+ raise AstrBotError.invalid_input("插件 capability 必须返回 dict 或可序列化对象")
+
+ @staticmethod
+ def _callable_signature(handler) -> str:
+ try:
+ return str(inspect.signature(handler))
+ except (TypeError, ValueError):
+ return "(?)"
+
+
+__all__ = ["CapabilityDispatcher"]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py b/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py
new file mode 100644
index 0000000000..cc45dcd898
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py
@@ -0,0 +1,970 @@
+"""能力路由模块。
+
+定义 CapabilityRouter 类,负责能力的注册、发现和执行路由。
+能力是核心侧提供给插件侧调用的功能,如 LLM 聊天、存储、消息发送等。
+
+核心概念:
+ CapabilityDescriptor: 能力描述符,声明能力名称、输入输出 Schema 等
+ CallHandler: 同步调用处理器,签名 (request_id, payload, cancel_token) -> dict
+ StreamHandler: 流式调用处理器,签名 (request_id, payload, cancel_token) -> AsyncIterator
+ FinalizeHandler: 流式结果聚合器,签名 (chunks) -> dict
+
+内置能力:
+ LLM:
+ llm.chat: 同步 LLM 聊天
+ llm.chat_raw: 同步 LLM 聊天(完整响应)
+ llm.stream_chat: 流式 LLM 聊天
+ Memory:
+ memory.search: 搜索记忆
+ memory.save: 保存记忆
+ memory.save_with_ttl: 保存带过期时间的记忆
+ memory.get: 读取单条记忆
+ memory.list_keys: 列出命名空间中的记忆键
+ memory.exists: 检查记忆键是否存在
+ memory.get_many: 批量获取多条记忆
+ memory.delete: 删除记忆
+ memory.clear_namespace: 清理命名空间中的记忆
+ memory.delete_many: 批量删除多条记忆
+ memory.count: 统计命名空间中的记忆数量
+ memory.stats: 获取记忆统计信息
+ DB:
+ db.get: 读取 KV 存储
+ db.set: 写入 KV 存储
+ db.delete: 删除 KV 存储
+ db.list: 列出 KV 键
+ db.get_many: 批量读取多个 KV 键
+ db.set_many: 批量写入多个 KV 键
+ db.watch: 订阅 KV 变更事件
+ Platform:
+ platform.send: 发送消息
+ platform.send_image: 发送图片
+ platform.send_chain: 发送消息链
+ platform.send_by_session: 主动按会话发送消息链
+ platform.get_group: 获取当前群信息
+ platform.get_members: 获取群成员
+ Permission:
+ permission.check: 查询用户权限角色
+ permission.get_admins: 列出管理员 ID
+ permission.manager.add_admin: 添加管理员 ID
+ permission.manager.remove_admin: 移除管理员 ID
+ HTTP:
+ http.register_api: 注册 HTTP 路由到插件 capability
+ http.unregister_api: 注销 HTTP 路由
+ http.list_apis: 查询已注册的 HTTP 路由
+ Metadata:
+ metadata.get_plugin: 获取单个插件元数据
+ metadata.list_plugins: 列出所有插件元数据
+ metadata.get_plugin_config: 获取当前调用插件自己的配置
+ Provider:
+ provider.get_using: 获取当前聊天 Provider
+ provider.get_current_chat_provider_id: 获取当前聊天 Provider ID
+ provider.list_all: 列出聊天 Providers
+ provider.list_all_tts: 列出 TTS Providers
+ provider.list_all_stt: 列出 STT Providers
+ provider.list_all_embedding: 列出 Embedding Providers
+ provider.list_all_rerank: 列出 Rerank Providers
+ provider.get_using_tts: 获取当前 TTS Provider
+ provider.get_using_stt: 获取当前 STT Provider
+ provider.get_by_id: 按 ID 获取 Provider
+ provider.stt.get_text: STT 转写
+ provider.tts.get_audio: TTS 合成音频
+ provider.tts.support_stream: 检查 TTS 原生流式支持
+ provider.tts.get_audio_stream: 流式 TTS 音频输出
+ provider.embedding.get_embedding: 获取单条向量
+ provider.embedding.get_embeddings: 批量获取向量
+ provider.embedding.get_dim: 获取向量维度
+ provider.rerank.rerank: 文档重排序
+ provider.manager.set: 设置当前 Provider
+ provider.manager.get_by_id: 按 ID 获取 Provider 管理记录
+ provider.manager.get_merged_provider_config: 获取 Provider 合并配置
+ provider.manager.load: 运行时加载 Provider
+ provider.manager.terminate: 终止已加载的 Provider
+ provider.manager.create: 创建 Provider
+ provider.manager.update: 更新 Provider
+ provider.manager.delete: 删除 Provider
+ provider.manager.get_insts: 列出已加载聊天 Provider
+ provider.manager.watch_changes: 订阅 Provider 变更(流式)
+ Platform Manager:
+ platform.manager.get_by_id: 按 ID 获取平台管理快照
+ platform.manager.clear_errors: 清除平台错误
+ platform.manager.get_stats: 获取平台统计信息
+ LLM Tool:
+ llm_tool.manager.get: 获取 LLM 工具状态
+ llm_tool.manager.activate: 激活 LLM 工具
+ llm_tool.manager.deactivate: 停用 LLM 工具
+ llm_tool.manager.add: 动态添加 LLM 工具
+ llm_tool.manager.remove: 动态移除 LLM 工具
+ Agent:
+ agent.tool_loop.run: 运行 tool loop
+ agent.registry.list: 列出 Agent 元数据
+ agent.registry.get: 获取 Agent 元数据
+ Registry:
+ registry.get_handlers_by_event_type: 按事件类型列出 handler 元数据
+ registry.get_handler_by_full_name: 按 full name 查询 handler 元数据
+ Session:
+ session.plugin.is_enabled: 获取会话级插件开关
+ session.plugin.filter_handlers: 按会话过滤 handler 元数据
+ session.service.is_llm_enabled: 获取会话级 LLM 开关
+ session.service.set_llm_status: 写入会话级 LLM 开关
+ session.service.is_tts_enabled: 获取会话级 TTS 开关
+ session.service.set_tts_status: 写入会话级 TTS 开关
+ Managers:
+ persona.get / persona.list / persona.create / persona.update / persona.delete
+ conversation.new / conversation.switch / conversation.delete
+ conversation.get / conversation.list / conversation.update
+ kb.list / kb.get / kb.create / kb.update / kb.delete / kb.retrieve
+ kb.document.upload / kb.document.list / kb.document.get
+ kb.document.delete / kb.document.refresh
+ System (内部使用):
+ system.get_data_dir: 获取插件数据目录
+ system.text_to_image: 文本转图片
+ system.html_render: 渲染 HTML 模板
+ system.session_waiter.register: 注册会话等待器
+ system.session_waiter.unregister: 注销会话等待器
+ system.event.react: 发送事件表情回应
+ system.event.send_typing: 发送输入中状态
+ system.event.send_streaming: 发送事件流式消息
+ system.event.send_streaming_chunk: 推送事件流式消息分片
+ system.dynamic_command.register: 注册动态命令路由
+ system.dynamic_command.list: 列出动态命令路由
+ system.dynamic_command.remove: 移除动态命令路由
+
+能力命名规范:
+ - 格式: {namespace}.{action} 或 {namespace}.{sub_namespace}.{action}
+ - 内置能力命名空间: llm, memory, db, platform, permission, http, metadata, provider, llm_tool, agent, registry
+ - 保留命名空间前缀: handler., system., internal.
+
+使用示例:
+ router = CapabilityRouter()
+
+ # 注册同步能力
+ router.register(
+ CapabilityDescriptor(
+ name="my_plugin.calculate",
+ description="执行计算",
+ input_schema={"type": "object", "properties": {"x": {"type": "number"}}},
+ output_schema={"type": "object", "properties": {"result": {"type": "number"}}},
+ ),
+ call_handler=my_calculate,
+ )
+
+ # 注册流式能力
+ async def stream_data(request_id, payload, token):
+ for i in range(10):
+ yield {"index": i}
+
+ router.register(
+ CapabilityDescriptor(
+ name="my_plugin.stream",
+ description="流式数据",
+ supports_stream=True,
+ cancelable=True,
+ ),
+ stream_handler=stream_data,
+ finalize=lambda chunks: {"count": len(chunks)},
+ )
+
+ # 执行能力
+ result = await router.execute("my_plugin.calculate", {"x": 42}, stream=False, ...)
+ stream_result = await router.execute("my_plugin.stream", {}, stream=True, ...)
+"""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import re
+from collections.abc import AsyncIterator, Awaitable, Callable
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+from .._internal.invocation_context import current_caller_plugin_id
+from ..errors import AstrBotError
+from ..protocol.descriptors import (
+ RESERVED_CAPABILITY_PREFIXES,
+ CapabilityDescriptor,
+)
+from ._capability_router_builtins import BuiltinCapabilityRouterMixin
+from ._streaming import StreamExecution
+
+CallHandler = Callable[[str, dict[str, Any], object], Awaitable[dict[str, Any]]]
+FinalizeHandler = Callable[[list[dict[str, Any]]], dict[str, Any]]
+CAPABILITY_NAME_PATTERN = re.compile(r"^[a-z][a-z0-9_]*(?:\.[a-z][a-z0-9_]*)+$")
+
+
+StreamHandler = Callable[
+ [str, dict[str, Any], object],
+ AsyncIterator[dict[str, Any]]
+ | StreamExecution
+ | Awaitable[AsyncIterator[dict[str, Any]] | StreamExecution],
+]
+
+
+@dataclass(slots=True)
+class _CapabilityRegistration:
+ descriptor: CapabilityDescriptor
+ call_handler: CallHandler | None = None
+ stream_handler: StreamHandler | None = None
+ finalize: FinalizeHandler | None = None
+ exposed: bool = True
+
+
+@dataclass(slots=True)
+class _RegisteredPlugin:
+ metadata: dict[str, Any]
+ config: dict[str, Any]
+ handlers: list[dict[str, Any]]
+ llm_tools: dict[str, dict[str, Any]] = field(default_factory=dict)
+ active_llm_tools: set[str] = field(default_factory=set)
+ agents: dict[str, dict[str, Any]] = field(default_factory=dict)
+ skills: dict[str, dict[str, str]] = field(default_factory=dict)
+
+
+class CapabilityRouter(BuiltinCapabilityRouterMixin):
+ def __init__(self) -> None:
+ self._registrations: dict[str, _CapabilityRegistration] = {}
+ self.db_store: dict[str, Any] = {}
+ self.memory_store: dict[str, dict[str, Any]] = {}
+ self._memory_backends: dict[str, Any] = {}
+ self._memory_index: dict[str, dict[str, Any]] = {}
+ self._memory_dirty_keys: set[str] = set()
+ self._memory_expires_at: dict[str, datetime | None] = {}
+ self.sent_messages: list[dict[str, Any]] = []
+ self.event_actions: list[dict[str, Any]] = []
+ self._event_streams: dict[str, dict[str, Any]] = {}
+ self.http_api_store: list[dict[str, Any]] = []
+ self._plugins: dict[str, _RegisteredPlugin] = {}
+ self._request_overlays: dict[str, dict[str, Any]] = {}
+ self._provider_catalog: dict[str, list[dict[str, Any]]] = {
+ "chat": [
+ {
+ "id": "mock-chat-provider",
+ "model": "mock-chat-model",
+ "type": "mock",
+ "provider_type": "chat_completion",
+ }
+ ],
+ "tts": [
+ {
+ "id": "mock-tts-provider",
+ "model": "mock-tts-model",
+ "type": "mock",
+ "provider_type": "text_to_speech",
+ }
+ ],
+ "stt": [
+ {
+ "id": "mock-stt-provider",
+ "model": "mock-stt-model",
+ "type": "mock",
+ "provider_type": "speech_to_text",
+ }
+ ],
+ "embedding": [
+ {
+ "id": "mock-embedding-provider",
+ "model": "mock-embedding-model",
+ "type": "mock",
+ "provider_type": "embedding",
+ }
+ ],
+ "rerank": [
+ {
+ "id": "mock-rerank-provider",
+ "model": "mock-rerank-model",
+ "type": "mock",
+ "provider_type": "rerank",
+ }
+ ],
+ }
+ self._provider_configs: dict[str, dict[str, Any]] = {
+ str(item["id"]): {**item, "enable": True}
+ for providers in self._provider_catalog.values()
+ for item in providers
+ }
+ self._active_provider_ids: dict[str, str | None] = {
+ kind: providers[0]["id"] if providers else None
+ for kind, providers in self._provider_catalog.items()
+ }
+ self._provider_change_subscriptions: dict[
+ str, asyncio.Queue[dict[str, Any]]
+ ] = {}
+ self._system_data_root = Path.cwd() / ".astrbot_sdk_testing" / "plugin_data"
+ self._session_waiters: dict[str, set[str]] = {}
+ self._db_watch_subscriptions: dict[
+ str, tuple[str | None, asyncio.Queue[dict[str, Any]]]
+ ] = {}
+ self._session_plugin_configs: dict[str, dict[str, Any]] = {}
+ self._session_service_configs: dict[str, dict[str, Any]] = {}
+ self._dynamic_command_routes: dict[str, list[dict[str, Any]]] = {}
+ self._persona_store: dict[str, dict[str, Any]] = {}
+ self._conversation_store: dict[str, dict[str, Any]] = {}
+ self._session_current_conversation_ids: dict[str, str] = {}
+ self._message_history_store: dict[str, list[dict[str, Any]]] = {}
+ self._message_history_next_id = 1
+ self._kb_store: dict[str, dict[str, Any]] = {}
+ self._kb_document_store: dict[str, dict[str, dict[str, Any]]] = {}
+ self._kb_document_content_store: dict[str, str] = {}
+ self._platform_instances: list[dict[str, Any]] = [
+ {
+ "id": "mock-platform",
+ "name": "Mock Platform",
+ "type": "mock",
+ "status": "running",
+ }
+ ]
+ self._permission_admin_ids: list[str] = ["astrbot"]
+ self._register_builtin_capabilities()
+
+ def upsert_plugin(
+ self,
+ *,
+ metadata: dict[str, Any],
+ config: dict[str, Any] | None = None,
+ ) -> None:
+ name = str(metadata.get("name", "")).strip()
+ if not name:
+ raise ValueError("plugin metadata must include a non-empty name")
+ normalized_metadata = dict(metadata)
+ normalized_metadata.setdefault("display_name", name)
+ normalized_metadata.setdefault("description", "")
+ normalized_metadata.setdefault("repo", "")
+ normalized_metadata.setdefault("author", "")
+ normalized_metadata.setdefault("version", "0.0.0")
+ normalized_metadata.setdefault("enabled", True)
+ normalized_metadata.setdefault("reserved", False)
+ normalized_metadata.setdefault("support_platforms", [])
+ normalized_metadata.setdefault("astrbot_version", None)
+ existing = self._plugins.get(name)
+ if existing is not None:
+ existing.metadata = normalized_metadata
+ existing.config = dict(config or {})
+ return
+ self._plugins[name] = _RegisteredPlugin(
+ metadata=normalized_metadata,
+ config=dict(config or {}),
+ handlers=[],
+ )
+
+ def set_plugin_handlers(
+ self,
+ name: str,
+ handlers: list[dict[str, Any]],
+ ) -> None:
+ plugin = self._plugins.get(name)
+ if plugin is None:
+ return
+ plugin.handlers = [dict(item) for item in handlers]
+ valid_handlers = {
+ str(item.get("handler_full_name", "")).strip()
+ for item in plugin.handlers
+ if isinstance(item, dict)
+ }
+ if not valid_handlers:
+ self._dynamic_command_routes.pop(name, None)
+ return
+ routes = self._dynamic_command_routes.get(name)
+ if routes is None:
+ return
+ self._dynamic_command_routes[name] = [
+ dict(item)
+ for item in routes
+ if str(item.get("handler_full_name", "")).strip() in valid_handlers
+ ]
+ if not self._dynamic_command_routes[name]:
+ self._dynamic_command_routes.pop(name, None)
+
+ def set_plugin_enabled(self, name: str, enabled: bool) -> None:
+ plugin = self._plugins.get(name)
+ if plugin is None:
+ return
+ plugin.metadata["enabled"] = enabled
+
+ def register_dynamic_command_route(
+ self,
+ *,
+ plugin_id: str,
+ command_name: str,
+ handler_full_name: str,
+ desc: str = "",
+ priority: int = 0,
+ use_regex: bool = False,
+ ) -> None:
+ command_text = str(command_name).strip()
+ if not command_text:
+ raise AstrBotError.invalid_input("command_name must not be empty")
+ handler_text = str(handler_full_name).strip()
+ if not handler_text:
+ raise AstrBotError.invalid_input("handler_full_name must not be empty")
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}")
+ if not self._plugin_has_handler(plugin_id, handler_text):
+ raise AstrBotError.invalid_input(
+ "handler_full_name must belong to the caller plugin and exist"
+ )
+ route = {
+ "plugin_name": plugin_id,
+ "command_name": command_text,
+ "handler_full_name": handler_text,
+ "desc": str(desc),
+ "priority": int(priority),
+ "use_regex": bool(use_regex),
+ }
+ routes = [
+ item
+ for item in self._dynamic_command_routes.get(plugin_id, [])
+ if str(item.get("command_name", "")).strip() != command_text
+ or bool(item.get("use_regex", False)) != bool(use_regex)
+ ]
+ routes.append(route)
+ self._dynamic_command_routes[plugin_id] = routes
+
+ def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]:
+ return [dict(item) for item in self._dynamic_command_routes.get(plugin_id, [])]
+
+ def remove_dynamic_command_routes_for_plugin(self, plugin_id: str) -> None:
+ self._dynamic_command_routes.pop(plugin_id, None)
+
+ def set_platform_instances(self, instances: list[dict[str, Any]]) -> None:
+ normalized: list[dict[str, Any]] = []
+ for item in instances:
+ if not isinstance(item, dict):
+ continue
+ platform_id = str(item.get("id", "")).strip()
+ platform_type = str(item.get("type", "")).strip()
+ if not platform_id or not platform_type:
+ continue
+ errors = item.get("errors")
+ last_error = item.get("last_error")
+ stats = item.get("stats")
+ meta = item.get("meta")
+ normalized.append(
+ {
+ "id": platform_id,
+ "name": str(item.get("name", platform_id)),
+ "type": platform_type,
+ "status": str(item.get("status", "unknown")),
+ "errors": [
+ dict(error) for error in errors if isinstance(error, dict)
+ ]
+ if isinstance(errors, list)
+ else [],
+ "last_error": (
+ dict(last_error) if isinstance(last_error, dict) else None
+ ),
+ "unified_webhook": bool(item.get("unified_webhook", False)),
+ "stats": dict(stats) if isinstance(stats, dict) else None,
+ "meta": dict(meta) if isinstance(meta, dict) else {},
+ "started_at": item.get("started_at"),
+ }
+ )
+ self._platform_instances = normalized
+
+ def get_platform_instances(self) -> list[dict[str, Any]]:
+ return [dict(item) for item in self._platform_instances]
+
+ def set_admin_ids(self, admin_ids: list[str]) -> None:
+ self._permission_admin_ids = [
+ user_id for user_id in (str(item).strip() for item in admin_ids) if user_id
+ ]
+
+ def _plugin_has_handler(self, plugin_id: str, handler_full_name: str) -> bool:
+ plugin = self._plugins.get(plugin_id)
+ if plugin is None:
+ return False
+ handler_name = str(handler_full_name).strip()
+ if not handler_name:
+ return False
+ for handler in plugin.handlers:
+ if not isinstance(handler, dict):
+ continue
+ if str(handler.get("handler_full_name", "")).strip() == handler_name:
+ return True
+ return False
+
+ def set_plugin_llm_tools(
+ self,
+ name: str,
+ tools: list[dict[str, Any]],
+ ) -> None:
+ plugin = self._plugins.get(name)
+ if plugin is None:
+ return
+ plugin.llm_tools = {
+ str(item.get("name", "")): dict(item)
+ for item in tools
+ if isinstance(item, dict) and str(item.get("name", "")).strip()
+ }
+ plugin.active_llm_tools = {
+ tool_name
+ for tool_name, item in plugin.llm_tools.items()
+ if bool(item.get("active", True))
+ }
+
+ def set_plugin_agents(
+ self,
+ name: str,
+ agents: list[dict[str, Any]],
+ ) -> None:
+ plugin = self._plugins.get(name)
+ if plugin is None:
+ return
+ plugin.agents = {
+ str(item.get("name", "")): dict(item)
+ for item in agents
+ if isinstance(item, dict) and str(item.get("name", "")).strip()
+ }
+
+ def set_provider_catalog(
+ self,
+ kind: str,
+ providers: list[dict[str, Any]],
+ *,
+ active_id: str | None = None,
+ ) -> None:
+ self._provider_catalog[kind] = [
+ dict(item)
+ for item in providers
+ if isinstance(item, dict) and str(item.get("id", "")).strip()
+ ]
+ for item in self._provider_catalog[kind]:
+ provider_id = str(item.get("id", "")).strip()
+ if not provider_id:
+ continue
+ self._provider_configs[provider_id] = {**item, "enable": True}
+ if active_id is not None:
+ self._active_provider_ids[kind] = active_id
+ else:
+ catalog = self._provider_catalog[kind]
+ self._active_provider_ids[kind] = catalog[0]["id"] if catalog else None
+
+ def emit_provider_change(
+ self,
+ provider_id: str,
+ provider_type: str,
+ umo: str | None = None,
+ ) -> None:
+ event = {
+ "provider_id": str(provider_id),
+ "provider_type": str(provider_type),
+ "umo": str(umo) if umo is not None else None,
+ }
+ for queue in list(self._provider_change_subscriptions.values()):
+ queue.put_nowait(dict(event))
+
+ def record_platform_error(
+ self,
+ platform_id: str,
+ message: str,
+ *,
+ traceback: str | None = None,
+ ) -> None:
+ for item in self._platform_instances:
+ if str(item.get("id", "")) != str(platform_id):
+ continue
+ error = {
+ "message": str(message),
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "traceback": str(traceback) if traceback is not None else None,
+ }
+ errors = item.setdefault("errors", [])
+ if isinstance(errors, list):
+ errors.append(error)
+ item["last_error"] = error
+ item["status"] = "error"
+ return
+
+ def set_platform_stats(self, platform_id: str, stats: dict[str, Any]) -> None:
+ for item in self._platform_instances:
+ if str(item.get("id", "")) != str(platform_id):
+ continue
+ item["stats"] = dict(stats)
+ return
+
+ def set_session_plugin_config(
+ self,
+ session_id: str,
+ *,
+ enabled_plugins: list[str] | None = None,
+ disabled_plugins: list[str] | None = None,
+ ) -> None:
+ config: dict[str, Any] = {}
+ if enabled_plugins is not None:
+ config["enabled_plugins"] = [str(item) for item in enabled_plugins]
+ if disabled_plugins is not None:
+ config["disabled_plugins"] = [str(item) for item in disabled_plugins]
+ self._session_plugin_configs[str(session_id)] = config
+
+ def set_session_service_config(
+ self,
+ session_id: str,
+ *,
+ llm_enabled: bool | None = None,
+ tts_enabled: bool | None = None,
+ ) -> None:
+ config: dict[str, Any] = {}
+ if llm_enabled is not None:
+ config["llm_enabled"] = bool(llm_enabled)
+ if tts_enabled is not None:
+ config["tts_enabled"] = bool(tts_enabled)
+ self._session_service_configs[str(session_id)] = config
+
+ def remove_http_apis_for_plugin(self, plugin_id: str) -> None:
+ self.http_api_store = [
+ entry
+ for entry in self.http_api_store
+ if entry.get("plugin_id") != plugin_id
+ ]
+
+ @staticmethod
+ def _require_caller_plugin_id(capability_name: str) -> str:
+ caller_plugin_id = current_caller_plugin_id()
+ if caller_plugin_id:
+ return caller_plugin_id
+ raise AstrBotError.invalid_input(
+ f"{capability_name} 只能在插件运行时上下文中调用"
+ )
+
+ def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None:
+ event = {"op": op, "key": key, "value": value}
+ for prefix, queue in list(self._db_watch_subscriptions.values()):
+ if prefix is not None and not key.startswith(prefix):
+ continue
+ queue.put_nowait(event)
+
+ def descriptors(self) -> list[CapabilityDescriptor]:
+ return [
+ entry.descriptor for entry in self._registrations.values() if entry.exposed
+ ]
+
+ def all_descriptors(self) -> list[CapabilityDescriptor]:
+ return [entry.descriptor for entry in self._registrations.values()]
+
+ def contains(self, name: str) -> bool:
+ return name in self._registrations
+
+ def unregister(self, name: str) -> None:
+ self._registrations.pop(name, None)
+
+ def register(
+ self,
+ descriptor: CapabilityDescriptor,
+ *,
+ call_handler: CallHandler | None = None,
+ stream_handler: StreamHandler | None = None,
+ finalize: FinalizeHandler | None = None,
+ exposed: bool = True,
+ ) -> None:
+ is_internal_reserved = not exposed and descriptor.name.startswith(
+ RESERVED_CAPABILITY_PREFIXES
+ )
+ if (
+ not CAPABILITY_NAME_PATTERN.fullmatch(descriptor.name)
+ and not is_internal_reserved
+ ):
+ raise ValueError(
+ f"capability 名称必须匹配 {{namespace}}.{{method}}:{descriptor.name}"
+ )
+ if exposed and descriptor.name.startswith(RESERVED_CAPABILITY_PREFIXES):
+ raise ValueError(
+ f"保留 capability 命名空间仅供框架内部使用:{descriptor.name}"
+ )
+ self._registrations[descriptor.name] = _CapabilityRegistration(
+ descriptor=descriptor,
+ call_handler=call_handler,
+ stream_handler=stream_handler,
+ finalize=finalize,
+ exposed=exposed,
+ )
+
+ async def execute(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ stream: bool,
+ cancel_token,
+ request_id: str,
+ ) -> dict[str, Any] | StreamExecution:
+ registration = self._registrations.get(capability)
+ if registration is None:
+ raise AstrBotError.capability_not_found(capability)
+
+ self._validate_schema_with_context(
+ capability=capability,
+ phase="输入",
+ schema=registration.descriptor.input_schema,
+ payload=payload,
+ )
+ if stream:
+ if registration.stream_handler is None:
+ raise AstrBotError.invalid_input(f"{capability} 不支持 stream=true")
+ raw_execution = registration.stream_handler(
+ request_id, payload, cancel_token
+ )
+ if inspect.isawaitable(raw_execution):
+ raw_execution = await raw_execution
+ if isinstance(raw_execution, StreamExecution):
+ return self._wrap_stream_execution(
+ registration.descriptor,
+ raw_execution,
+ )
+ finalize = registration.finalize or (lambda chunks: {"items": chunks})
+ return self._wrap_stream_execution(
+ registration.descriptor,
+ StreamExecution(
+ iterator=raw_execution,
+ finalize=finalize,
+ ),
+ )
+
+ if registration.call_handler is None:
+ raise AstrBotError.invalid_input(
+ f"{capability} 只能以 stream=true 调用,registration.call_handler 为 None"
+ )
+ output = await registration.call_handler(request_id, payload, cancel_token)
+ self._validate_schema_with_context(
+ capability=capability,
+ phase="输出",
+ schema=registration.descriptor.output_schema,
+ payload=output,
+ )
+ return output
+
+ def _wrap_stream_execution(
+ self,
+ descriptor: CapabilityDescriptor,
+ execution: StreamExecution,
+ ) -> StreamExecution:
+ def validated_finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]:
+ output = execution.finalize(chunks)
+ self._validate_schema_with_context(
+ capability=descriptor.name,
+ phase="输出",
+ schema=descriptor.output_schema,
+ payload=output,
+ )
+ return output
+
+ return StreamExecution(
+ iterator=execution.iterator,
+ finalize=validated_finalize,
+ collect_chunks=execution.collect_chunks,
+ )
+
+ # ------------------------------------------------------------------
+ # Schema validation
+ # ------------------------------------------------------------------
+
+ def _validate_schema(
+ self,
+ schema: dict[str, Any] | None,
+ payload: Any,
+ ) -> None:
+ if not isinstance(schema, dict) or not schema:
+ return
+ self._validate_value(schema, payload, path="")
+
+ def _validate_schema_with_context(
+ self,
+ *,
+ capability: str,
+ phase: str,
+ schema: dict[str, Any] | None,
+ payload: Any,
+ ) -> None:
+ try:
+ self._validate_schema(schema, payload)
+ except AstrBotError as exc:
+ if exc.code != "invalid_input":
+ raise
+ raise AstrBotError.invalid_input(
+ f"capability '{capability}' 的{phase}校验失败:{exc.message}",
+ hint=(
+ f"请检查 capability '{capability}' 的{phase.lower()}是否符合声明的 schema"
+ ),
+ ) from exc
+
+ def _validate_value(
+ self,
+ schema: dict[str, Any],
+ value: Any,
+ *,
+ path: str,
+ ) -> None:
+ any_of = schema.get("anyOf")
+ if isinstance(any_of, list):
+ for candidate in any_of:
+ if not isinstance(candidate, dict):
+ continue
+ try:
+ self._validate_value(candidate, value, path=path)
+ return
+ except AstrBotError:
+ continue
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 不符合允许的 schema 约束,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+
+ enum = schema.get("enum")
+ if isinstance(enum, list) and value not in enum:
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 {enum},实际收到 {value!r}"
+ )
+
+ schema_type = schema.get("type")
+ if schema_type == "object":
+ if not isinstance(value, dict):
+ if not path:
+ raise AstrBotError.invalid_input(
+ f"输入必须是 object,实际收到 {self._value_type_name(value)}"
+ )
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 object,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ properties = schema.get("properties", {})
+ required_fields = schema.get("required", [])
+ for field_name in required_fields:
+ field_path = self._join_path(path, str(field_name))
+ if field_name not in value:
+ raise AstrBotError.invalid_input(f"缺少必填字段:{field_path}")
+ field_schema = self._property_schema(properties, field_name)
+ if value[field_name] is None and not self._schema_allows_null(
+ field_schema
+ ):
+ raise AstrBotError.invalid_input(f"缺少必填字段:{field_path}")
+ self._validate_value(
+ field_schema,
+ value[field_name],
+ path=field_path,
+ )
+ for field_name, field_value in value.items():
+ field_schema = properties.get(field_name)
+ if isinstance(field_schema, dict):
+ self._validate_value(
+ field_schema,
+ field_value,
+ path=self._join_path(path, str(field_name)),
+ )
+ return
+
+ if schema_type == "array":
+ if not isinstance(value, list):
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 array,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ item_schema = schema.get("items")
+ if isinstance(item_schema, dict):
+ for index, item in enumerate(value):
+ self._validate_value(
+ item_schema,
+ item,
+ path=self._index_path(path, index),
+ )
+ return
+
+ if schema_type == "string":
+ if not isinstance(value, str):
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 string,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ return
+
+ if schema_type == "integer":
+ if not isinstance(value, int) or isinstance(value, bool):
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 integer,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ return
+
+ if schema_type == "number":
+ if not isinstance(value, (int, float)) or isinstance(value, bool):
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 number,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ return
+
+ if schema_type == "boolean":
+ if not isinstance(value, bool):
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 boolean,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ return
+
+ if schema_type == "null":
+ if value is not None:
+ raise AstrBotError.invalid_input(
+ f"{self._field_label(path)} 必须是 null,"
+ f"实际收到 {self._value_type_name(value)}"
+ )
+ return
+
+ @staticmethod
+ def _field_label(path: str) -> str:
+ if not path:
+ return "输入"
+ return f"字段 {path}"
+
+ @staticmethod
+ def _join_path(path: str, field_name: str) -> str:
+ if not path:
+ return field_name
+ return f"{path}.{field_name}"
+
+ @staticmethod
+ def _index_path(path: str, index: int) -> str:
+ return f"{path}[{index}]" if path else f"[{index}]"
+
+ @staticmethod
+ def _property_schema(
+ properties: Any,
+ field_name: str,
+ ) -> dict[str, Any]:
+ if not isinstance(properties, dict):
+ return {}
+ field_schema = properties.get(field_name)
+ if isinstance(field_schema, dict):
+ return field_schema
+ return {}
+
+ @staticmethod
+ def _schema_allows_null(field_schema: Any) -> bool:
+ if not isinstance(field_schema, dict):
+ return False
+ if field_schema.get("type") == "null":
+ return True
+ any_of = field_schema.get("anyOf")
+ if not isinstance(any_of, list):
+ return False
+ return any(
+ isinstance(candidate, dict) and candidate.get("type") == "null"
+ for candidate in any_of
+ )
+
+ @staticmethod
+ def _value_type_name(value: Any) -> str:
+ if value is None:
+ return "null"
+ if isinstance(value, bool):
+ return "boolean"
+ if isinstance(value, int):
+ return "integer"
+ if isinstance(value, float):
+ return "number"
+ if isinstance(value, str):
+ return "string"
+ if isinstance(value, list):
+ return "array"
+ if isinstance(value, dict):
+ return "object"
+ return type(value).__name__
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py b/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py
new file mode 100644
index 0000000000..6503cb842d
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py
@@ -0,0 +1,675 @@
+"""astrbot-sdk runtime 的插件共享环境规划模块。
+
+这个模块负责“多个插件,共享较少数量 Python 环境”的策略。核心约束是:
+
+- 插件仍然独立发现、独立加载
+- Worker 运行时既可以是一插件一进程,也可以由 GroupWorkerRuntime 在同一进程承载多个插件
+- 只有在依赖兼容时才共享 Python 环境
+
+整体流程如下:
+
+1. 先按插件声明的 `runtime.python` 分桶
+2. 再按依赖兼容性构建候选分组
+3. 为每个分组在 `.astrbot/` 下落地 source、lock、metadata 和 venv 路径
+4. 在 worker 启动前准备或同步该分组的共享环境
+
+当前阶段优先保证兼容性,因此仍保留 `--system-site-packages`,也不改变
+现有插件 manifest 语义。
+"""
+
+from __future__ import annotations
+
+import hashlib
+import json
+import os
+import re
+import shutil
+import subprocess
+import tempfile
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .loader import PluginSpec
+
+GROUP_STATE_FILE_NAME = ".group-venv-state.json"
+
+_EXACT_PIN_PATTERN = re.compile(r"^([A-Za-z0-9_.-]+)==([^\s;]+)$")
+_NORMALIZE_PATTERN = re.compile(r"[-_.]+")
+_PYVENV_VERSION_PATTERN = re.compile(
+ r"^(?:version|version_info)\s*=\s*(\d+\.\d+)(?:\.\d+)?\s*$",
+ re.IGNORECASE | re.MULTILINE,
+)
+
+
+def _require_uv_binary(uv_binary: str | None) -> str:
+ if not uv_binary:
+ raise RuntimeError("uv executable not found")
+ return uv_binary
+
+
+def _venv_python_path(venv_path: Path) -> Path:
+ if os.name == "nt":
+ return venv_path / "Scripts" / "python.exe"
+ return venv_path / "bin" / "python"
+
+
+def _normalize_package_name(name: str) -> str:
+ return _NORMALIZE_PATTERN.sub("-", name).lower()
+
+
+def _read_pyvenv_major_minor(pyvenv_cfg: Path) -> str | None:
+ if not pyvenv_cfg.exists():
+ return None
+ try:
+ content = pyvenv_cfg.read_text(encoding="utf-8")
+ except OSError:
+ return None
+ match = _PYVENV_VERSION_PATTERN.search(content)
+ if match is None:
+ return None
+ return match.group(1)
+
+
+def _requirement_lines(plugin: PluginSpec) -> list[str]:
+ if not plugin.requirements_path.exists():
+ return []
+
+ lines: list[str] = []
+ for raw_line in plugin.requirements_path.read_text(encoding="utf-8").splitlines():
+ line = raw_line.strip()
+ if not line or line.startswith("#"):
+ continue
+ lines.append(line)
+ return lines
+
+
+@dataclass(slots=True)
+class EnvironmentGroup:
+ """一个或多个兼容插件最终共享的环境描述。
+
+ 分组是环境复用的最小单位。`plugins` 中的所有插件都会使用同一个
+ `python_path`、lockfile 和 venv 目录,但运行时仍然各自启动独立的
+ worker 进程。
+ """
+
+ id: str
+ python_version: str
+ plugins: list[PluginSpec]
+ source_path: Path
+ lockfile_path: Path
+ metadata_path: Path
+ venv_path: Path
+ python_path: Path
+ environment_fingerprint: str
+
+
+@dataclass(slots=True)
+class EnvironmentPlanResult:
+ """一次完整规划得到的结果。
+
+ `plugins` 只包含成功完成规划的插件。
+ `skipped_plugins` 记录规划失败的插件及原因,这类插件即使单独成组也没
+ 有得到可用的共享环境。
+ """
+
+ groups: list[EnvironmentGroup] = field(default_factory=list)
+ plugins: list[PluginSpec] = field(default_factory=list)
+ plugin_to_group: dict[str, EnvironmentGroup] = field(default_factory=dict)
+ skipped_plugins: dict[str, str] = field(default_factory=dict)
+
+
+class EnvironmentPlanner:
+ """负责共享环境规划和分组工件落地。
+
+ 对 supervisor 启动来说,这个类主要回答两个问题:
+
+ - 哪些插件可以共享一个环境
+ - 这个共享环境应该对应哪份 lockfile 和哪个 venv 路径
+
+ 它本身不负责真正创建或同步 venv,这部分在规划结束后交给
+ `GroupEnvironmentManager` 处理。
+ """
+
+ def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None:
+ self.repo_root = repo_root.resolve()
+ self.uv_binary = uv_binary or shutil.which("uv")
+ self.cache_dir = self.repo_root / ".uv-cache"
+ self.artifacts_dir = self.repo_root / ".astrbot"
+ self.group_dir = self.artifacts_dir / "groups"
+ self.lock_dir = self.artifacts_dir / "locks"
+ self.env_dir = self.artifacts_dir / "envs"
+ self._compatibility_cache: dict[str, bool] = {}
+
+ def plan(self, plugins: list[PluginSpec]) -> EnvironmentPlanResult:
+ """为当前插件集合生成稳定的共享环境规划。
+
+ 之所以在 worker 启动前完成规划,是为了让 supervisor 能够:
+
+ - 只跳过依赖无法满足的那部分插件
+ - 在兼容插件之间复用同一个环境
+ - 清理旧规划遗留的 `.astrbot` 工件
+ """
+ if not plugins:
+ self.cleanup_artifacts([])
+ return EnvironmentPlanResult()
+ _require_uv_binary(self.uv_binary)
+
+ candidate_groups = self._build_candidate_groups(plugins)
+ planned_groups: list[EnvironmentGroup] = []
+ skipped_plugins: dict[str, str] = {}
+ for group_plugins in candidate_groups:
+ materialized, skipped = self._materialize_candidate_group(group_plugins)
+ planned_groups.extend(materialized)
+ skipped_plugins.update(skipped)
+
+ planned_groups.sort(key=lambda group: (group.python_version, group.id))
+ self.cleanup_artifacts(planned_groups)
+
+ plugin_to_group = {
+ plugin.name: group for group in planned_groups for plugin in group.plugins
+ }
+ planned_plugins = [
+ plugin for plugin in plugins if plugin.name in plugin_to_group
+ ]
+ return EnvironmentPlanResult(
+ groups=planned_groups,
+ plugins=planned_plugins,
+ plugin_to_group=plugin_to_group,
+ skipped_plugins=skipped_plugins,
+ )
+
+ def _build_candidate_groups(
+ self, plugins: list[PluginSpec]
+ ) -> list[list[PluginSpec]]:
+ """用贪心方式把插件装入兼容性候选组。
+
+ 分组过程保持确定性,规则是:
+
+ - Python 版本是第一层硬边界
+ - `requirements.txt` 约束更多的插件优先落位
+ - 若仍相同,则按插件名排序
+ """
+ buckets: dict[str, list[PluginSpec]] = {}
+ for plugin in plugins:
+ buckets.setdefault(plugin.python_version, []).append(plugin)
+
+ planned_groups: list[list[PluginSpec]] = []
+ for python_version in sorted(buckets):
+ python_groups: list[list[PluginSpec]] = []
+ for plugin in self._sort_plugins(buckets[python_version]):
+ placed = False
+ for group_plugins in python_groups:
+ if self._is_compatible([*group_plugins, plugin]):
+ group_plugins.append(plugin)
+ placed = True
+ break
+ if not placed:
+ python_groups.append([plugin])
+ planned_groups.extend(python_groups)
+ return planned_groups
+
+ @staticmethod
+ def _sort_plugins(plugins: list[PluginSpec]) -> list[PluginSpec]:
+ return sorted(
+ plugins,
+ key=lambda plugin: (-len(_requirement_lines(plugin)), plugin.name),
+ )
+
+ def _is_compatible(self, plugins: list[PluginSpec]) -> bool:
+ """判断一组插件是否可以共享一个环境。
+
+ 兼容性判断先走一个便宜的快速路径:
+
+ - 如果每条 requirement 都是 `pkg==1.2.3` 这种精确版本锁定
+ - 且归一化后的包名之间没有解析出冲突版本
+ - 那么无需调用求解器,直接认为这一组兼容
+
+ 更复杂的情况则回退到 `uv pip compile`,以它的求解结果作为最终依
+ 赖兼容性的判断依据。
+ """
+ cache_key = self._compatibility_cache_key(plugins)
+ cached = self._compatibility_cache.get(cache_key)
+ if cached is not None:
+ return cached
+
+ requirement_lines = self._collect_requirement_lines(plugins)
+ if not requirement_lines:
+ self._compatibility_cache[cache_key] = True
+ return True
+
+ if self._merge_exact_requirements(requirement_lines) is not None:
+ self._compatibility_cache[cache_key] = True
+ return True
+
+ with tempfile.TemporaryDirectory(
+ prefix="astrbot-env-plan-",
+ dir=self.repo_root,
+ ) as temp_dir:
+ source_path = Path(temp_dir) / "compat.in"
+ output_path = Path(temp_dir) / "compat.txt"
+ self._write_source_file(source_path, plugins)
+ try:
+ self._compile_lockfile(
+ source_path=source_path,
+ output_path=output_path,
+ python_version=plugins[0].python_version,
+ )
+ except RuntimeError:
+ self._compatibility_cache[cache_key] = False
+ return False
+
+ self._compatibility_cache[cache_key] = True
+ return True
+
+ def _materialize_candidate_group(
+ self,
+ plugins: list[PluginSpec],
+ ) -> tuple[list[EnvironmentGroup], dict[str, str]]:
+ """为一个候选组创建工件,失败时自动拆分。
+
+ 如果整组插件无法生成 lockfile,规划器会退回到“一插件一组”继续尝
+ 试,避免单个坏插件阻塞整批插件启动。
+ """
+ try:
+ return [self._materialize_group(plugins)], {}
+ except RuntimeError as exc:
+ if len(plugins) == 1:
+ return [], {plugins[0].name: str(exc)}
+
+ materialized: list[EnvironmentGroup] = []
+ skipped: dict[str, str] = {}
+ for plugin in plugins:
+ groups, child_skipped = self._materialize_candidate_group([plugin])
+ materialized.extend(groups)
+ skipped.update(child_skipped)
+ return materialized, skipped
+
+ def _materialize_group(self, plugins: list[PluginSpec]) -> EnvironmentGroup:
+ """落地定义一个共享环境所需的全部文件。
+
+ 分组身份由 Python 版本和插件集合共同决定。
+ 环境指纹则会进一步包含编译后的 lockfile 内容,这样当依赖解析结果
+ 变化时,已有环境就可以走增量同步而不是盲目重建。
+ """
+ group_id = self._group_identity(plugins)[:16]
+ python_version = plugins[0].python_version
+ source_path = self.group_dir / f"{group_id}.in"
+ lockfile_path = self.lock_dir / f"{group_id}.txt"
+ metadata_path = self.group_dir / f"{group_id}.json"
+ venv_path = self.env_dir / group_id
+ python_path = _venv_python_path(venv_path)
+
+ source_path.parent.mkdir(parents=True, exist_ok=True)
+ lockfile_path.parent.mkdir(parents=True, exist_ok=True)
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
+ venv_path.parent.mkdir(parents=True, exist_ok=True)
+
+ self._write_source_file(source_path, plugins)
+ self._write_lockfile(
+ lockfile_path=lockfile_path,
+ source_path=source_path,
+ plugins=plugins,
+ python_version=python_version,
+ )
+ environment_fingerprint = self._environment_fingerprint(
+ plugins=plugins,
+ python_version=python_version,
+ lockfile_path=lockfile_path,
+ )
+ metadata_path.write_text(
+ json.dumps(
+ {
+ "group_id": group_id,
+ "python_version": python_version,
+ "plugins": [plugin.name for plugin in plugins],
+ "plugin_entries": [
+ {
+ "name": plugin.name,
+ "plugin_dir": str(plugin.plugin_dir),
+ }
+ for plugin in plugins
+ ],
+ "source_path": str(source_path),
+ "lockfile_path": str(lockfile_path),
+ "venv_path": str(venv_path),
+ "environment_fingerprint": environment_fingerprint,
+ },
+ ensure_ascii=True,
+ indent=2,
+ sort_keys=True,
+ ),
+ encoding="utf-8",
+ )
+
+ return EnvironmentGroup(
+ id=group_id,
+ python_version=python_version,
+ plugins=list(plugins),
+ source_path=source_path,
+ lockfile_path=lockfile_path,
+ metadata_path=metadata_path,
+ venv_path=venv_path,
+ python_path=python_path,
+ environment_fingerprint=environment_fingerprint,
+ )
+
+ def _write_source_file(self, source_path: Path, plugins: list[PluginSpec]) -> None:
+ """写入供 lockfile 生成使用的分组 requirements 输入文件。"""
+ lines: list[str] = []
+ for plugin in sorted(plugins, key=lambda item: item.name):
+ requirements = _requirement_lines(plugin)
+ if not requirements:
+ continue
+ lines.append(f"# {plugin.name}")
+ lines.extend(requirements)
+ lines.append("")
+
+ content = "\n".join(lines).rstrip()
+ if content:
+ content += "\n"
+ source_path.write_text(content, encoding="utf-8")
+
+ def _write_lockfile(
+ self,
+ *,
+ lockfile_path: Path,
+ source_path: Path,
+ plugins: list[PluginSpec],
+ python_version: str,
+ ) -> None:
+ """为一个分组生成 lockfile。
+
+ 即使依赖集合为空,也会故意生成空 lockfile,这样整个共享环境流水
+ 线的处理方式可以保持一致。
+ """
+ if not self._collect_requirement_lines(plugins):
+ lockfile_path.write_text("", encoding="utf-8")
+ return
+
+ self._compile_lockfile(
+ source_path=source_path,
+ output_path=lockfile_path,
+ python_version=python_version,
+ )
+
+ def _compile_lockfile(
+ self,
+ *,
+ source_path: Path,
+ output_path: Path,
+ python_version: str,
+ ) -> None:
+ """把依赖求解委托给 `uv pip compile`。"""
+ uv_binary = _require_uv_binary(self.uv_binary)
+ self._run_command(
+ [
+ uv_binary,
+ "pip",
+ "compile",
+ "--python-version",
+ python_version,
+ "--no-managed-python",
+ "--no-python-downloads",
+ "--quiet",
+ str(source_path),
+ "-o",
+ str(output_path),
+ ],
+ cwd=self.repo_root,
+ command_name=f"compile lockfile for {source_path.name}",
+ )
+
+ def _run_command(self, command: list[str], *, cwd: Path, command_name: str) -> None:
+ process = subprocess.run(
+ command,
+ cwd=str(cwd),
+ env={**os.environ, "UV_CACHE_DIR": str(self.cache_dir)},
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+ if process.returncode != 0:
+ raise RuntimeError(
+ f"{command_name} failed with exit code {process.returncode}: "
+ f"{process.stderr.strip() or process.stdout.strip()}"
+ )
+
+ def cleanup_artifacts(self, groups: list[EnvironmentGroup]) -> None:
+ """清理不再被当前规划引用的 `.astrbot` 工件。
+
+ 清理范围只覆盖规划器自己维护的共享环境工件,不会碰旧式插件目录下
+ 的本地 `.venv`。
+ """
+ active_group_ids = {group.id for group in groups}
+ self._cleanup_group_artifacts(active_group_ids)
+ self._cleanup_lockfiles(active_group_ids)
+ self._cleanup_envs(active_group_ids)
+
+ def _cleanup_group_artifacts(self, active_group_ids: set[str]) -> None:
+ if not self.group_dir.exists():
+ return
+ for entry in self.group_dir.iterdir():
+ if entry.suffix not in {".in", ".json"}:
+ continue
+ if entry.stem in active_group_ids:
+ continue
+ entry.unlink(missing_ok=True)
+
+ def _cleanup_lockfiles(self, active_group_ids: set[str]) -> None:
+ if not self.lock_dir.exists():
+ return
+ for entry in self.lock_dir.iterdir():
+ if entry.suffix != ".txt":
+ continue
+ if entry.stem in active_group_ids:
+ continue
+ entry.unlink(missing_ok=True)
+
+ def _cleanup_envs(self, active_group_ids: set[str]) -> None:
+ if not self.env_dir.exists():
+ return
+ for entry in self.env_dir.iterdir():
+ if entry.name in active_group_ids:
+ continue
+ if entry.is_dir():
+ shutil.rmtree(entry)
+ else:
+ entry.unlink(missing_ok=True)
+
+ def _compatibility_cache_key(self, plugins: list[PluginSpec]) -> str:
+ payload = {
+ "python_version": plugins[0].python_version if plugins else "",
+ "plugins": [
+ {
+ "name": plugin.name,
+ "requirements": _requirement_lines(plugin),
+ }
+ for plugin in sorted(plugins, key=lambda item: item.name)
+ ],
+ }
+ encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8")
+ return hashlib.sha256(encoded).hexdigest()
+
+ @staticmethod
+ def _group_identity(plugins: list[PluginSpec]) -> str:
+ payload = {
+ "python_version": plugins[0].python_version if plugins else "",
+ "plugins": sorted(plugin.name for plugin in plugins),
+ }
+ encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8")
+ return hashlib.sha256(encoded).hexdigest()
+
+ @staticmethod
+ def _environment_fingerprint(
+ *,
+ plugins: list[PluginSpec],
+ python_version: str,
+ lockfile_path: Path,
+ ) -> str:
+ payload = {
+ "python_version": python_version,
+ "plugins": sorted(plugin.name for plugin in plugins),
+ "lockfile": lockfile_path.read_text(encoding="utf-8"),
+ }
+ encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8")
+ return hashlib.sha256(encoded).hexdigest()
+
+ @staticmethod
+ def _collect_requirement_lines(plugins: list[PluginSpec]) -> list[str]:
+ lines: list[str] = []
+ for plugin in plugins:
+ lines.extend(_requirement_lines(plugin))
+ return lines
+
+ @staticmethod
+ def _merge_exact_requirements(requirement_lines: list[str]) -> list[str] | None:
+ merged: dict[str, str] = {}
+ for line in requirement_lines:
+ match = _EXACT_PIN_PATTERN.fullmatch(line)
+ if match is None:
+ return None
+ package_name = _normalize_package_name(match.group(1))
+ existing = merged.get(package_name)
+ if existing is not None and existing != line:
+ return None
+ merged[package_name] = line
+ return [merged[name] for name in sorted(merged)]
+
+
+class GroupEnvironmentManager:
+ """负责创建、校验和同步一个已经规划好的共享环境。"""
+
+ def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None:
+ self.repo_root = repo_root.resolve()
+ self.uv_binary = uv_binary or shutil.which("uv")
+ self.cache_dir = self.repo_root / ".uv-cache"
+
+ def prepare(self, group: EnvironmentGroup) -> Path:
+ """确保分组对应的解释器路径已经可以用于 worker 启动。
+
+ 行为概括如下:
+
+ - 环境缺失、Python 版本不对、lockfile 丢失:重建
+ - 环境结构还在但指纹变化:执行 `uv pip sync`
+ - 否则:直接复用现有解释器路径
+ """
+ _require_uv_binary(self.uv_binary)
+
+ state_path = group.venv_path / GROUP_STATE_FILE_NAME
+ state = self._load_state(state_path)
+ if (
+ not group.python_path.exists()
+ or not self._matches_python_version(group.venv_path, group.python_version)
+ or not group.lockfile_path.exists()
+ ):
+ self._rebuild(group)
+ self._write_state(state_path, group)
+ elif not self._state_matches_group(state, group):
+ self._sync_existing(group)
+ self._write_state(state_path, group)
+ return group.python_path
+
+ def _rebuild(self, group: EnvironmentGroup) -> None:
+ if group.venv_path.exists():
+ shutil.rmtree(group.venv_path)
+ self._create_venv(group)
+ self._sync_lockfile(group)
+
+ def _sync_existing(self, group: EnvironmentGroup) -> None:
+ self._sync_lockfile(group)
+
+ def _sync_lockfile(self, group: EnvironmentGroup) -> None:
+ """让已安装包与该分组的 lockfile 精确对齐。"""
+ uv_binary = _require_uv_binary(self.uv_binary)
+ self._run_command(
+ [
+ uv_binary,
+ "pip",
+ "sync",
+ "--python",
+ str(group.python_path),
+ "--allow-empty-requirements",
+ str(group.lockfile_path),
+ ],
+ cwd=self.repo_root,
+ command_name=f"sync group env {group.id}",
+ )
+
+ def _create_venv(self, group: EnvironmentGroup) -> None:
+ """为一个分组创建共享 venv。
+
+ 当前迁移阶段仍保留 `--system-site-packages`,以兼容那些仍然隐式依
+ 赖宿主环境包的旧插件。
+ """
+ uv_binary = _require_uv_binary(self.uv_binary)
+ self._run_command(
+ [
+ uv_binary,
+ "venv",
+ "--python",
+ group.python_version,
+ "--system-site-packages",
+ "--no-python-downloads",
+ "--no-managed-python",
+ str(group.venv_path),
+ ],
+ cwd=self.repo_root,
+ command_name=f"create group venv {group.id}",
+ )
+
+ def _run_command(self, command: list[str], *, cwd: Path, command_name: str) -> None:
+ process = subprocess.run(
+ command,
+ cwd=str(cwd),
+ env={**os.environ, "UV_CACHE_DIR": str(self.cache_dir)},
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+ if process.returncode != 0:
+ raise RuntimeError(
+ f"{command_name} failed with exit code {process.returncode}: "
+ f"{process.stderr.strip() or process.stdout.strip()}"
+ )
+
+ @staticmethod
+ def _matches_python_version(venv_path: Path, version: str) -> bool:
+ return _read_pyvenv_major_minor(venv_path / "pyvenv.cfg") == version
+
+ @staticmethod
+ def _load_state(state_path: Path) -> dict[str, object]:
+ if not state_path.exists():
+ return {}
+ try:
+ data = json.loads(state_path.read_text(encoding="utf-8"))
+ except Exception:
+ return {}
+ return data if isinstance(data, dict) else {}
+
+ @staticmethod
+ def _write_state(state_path: Path, group: EnvironmentGroup) -> None:
+ state_path.parent.mkdir(parents=True, exist_ok=True)
+ state_path.write_text(
+ json.dumps(
+ {
+ "group_id": group.id,
+ "python_version": group.python_version,
+ "environment_fingerprint": group.environment_fingerprint,
+ "plugins": [plugin.name for plugin in group.plugins],
+ },
+ ensure_ascii=True,
+ indent=2,
+ sort_keys=True,
+ ),
+ encoding="utf-8",
+ )
+
+ @staticmethod
+ def _state_matches_group(state: dict[str, object], group: EnvironmentGroup) -> bool:
+ return (
+ state.get("group_id") == group.id
+ and state.get("python_version") == group.python_version
+ and state.get("environment_fingerprint") == group.environment_fingerprint
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py b/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py
new file mode 100644
index 0000000000..72e6098edf
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py
@@ -0,0 +1,1048 @@
+"""处理器分发模块。
+
+定义 HandlerDispatcher 类,负责将能力调用分发到具体的处理器函数。
+支持参数注入、流式执行、错误处理。
+
+核心职责:
+ - 根据处理器 ID 查找处理器
+ - 构建处理器参数(支持类型注解注入)
+ - 执行处理器并处理结果
+ - 处理异步生成器流式结果
+ - 统一的错误处理
+
+参数注入优先级:
+ 1. 按类型注解注入(支持 Optional[Type])
+ 2. 按参数名注入(兼容无类型注解)
+ 3. 从 args 注入(命令参数等)
+
+支持的注入类型:
+ - MessageEvent: 消息事件
+ - Context: 运行时上下文
+"""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import re
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import Any, cast, get_type_hints
+
+from .._internal.command_model import (
+ parse_command_model_remainder,
+ resolve_command_model_param,
+)
+from .._internal.injected_params import legacy_arg_parameter_names
+from .._internal.invocation_context import caller_plugin_scope
+from .._internal.plugin_logger import PluginLogger
+from .._internal.sdk_logger import logger
+from .._internal.star_runtime import bind_star_runtime
+from .._internal.typing_utils import unwrap_optional
+from ..clients.llm import LLMResponse
+from ..context import CancelToken, Context
+from ..conversation import (
+ DEFAULT_BUSY_MESSAGE,
+ ConversationClosed,
+ ConversationReplaced,
+ ConversationSession,
+ ConversationState,
+)
+from ..events import MessageEvent
+from ..filters import LocalFilterBinding
+from ..llm.entities import ProviderRequest
+from ..message.components import BaseMessageComponent
+from ..message.result import (
+ MessageChain,
+ MessageEventResult,
+ coerce_message_chain,
+)
+from ..protocol.descriptors import (
+ CommandTrigger,
+ MessageTrigger,
+ ParamSpec,
+ ScheduleTrigger,
+)
+from ..schedule import ScheduleContext
+from ..session_waiter import (
+ SessionWaiterManager,
+ _mark_session_waiter_background_task,
+ _mark_session_waiter_handler_task,
+ _unmark_session_waiter_background_task,
+ _unmark_session_waiter_handler_task,
+)
+from ..star import Star
+from ._command_matching import (
+ build_command_args,
+ build_regex_args,
+ match_command_name,
+)
+from .capability_dispatcher import CapabilityDispatcher
+from .limiter import LimiterEngine
+from .loader import LoadedHandler
+
+
+@dataclass(slots=True)
+class _ActiveConversation:
+ session: ConversationSession
+ task: Any
+
+
+@dataclass(slots=True)
+class _ManagedConversationTask:
+ task: asyncio.Task[Any]
+ cleanup: Any
+
+ def __await__(self):
+ return self._wait().__await__()
+
+ async def _wait(self) -> Any:
+ try:
+ return await self.task
+ finally:
+ self.cleanup()
+
+ def cancel(self) -> bool:
+ return self.task.cancel()
+
+ def done(self) -> bool:
+ return self.task.done()
+
+
+@dataclass(slots=True)
+class _InjectedEventPayloads:
+ provider_request: ProviderRequest | None = None
+ llm_response: LLMResponse | None = None
+ event_result: MessageEventResult | None = None
+
+
+class HandlerDispatcher:
+ def __init__(
+ self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler]
+ ) -> None:
+ self._plugin_id = plugin_id
+ self._peer = peer
+ self._handlers = {item.descriptor.id: item for item in handlers}
+ self._active: dict[str, tuple[asyncio.Task[Any], CancelToken]] = {}
+ self._session_waiters = SessionWaiterManager(plugin_id=plugin_id, peer=peer)
+ self._limiter = LimiterEngine()
+ self._conversations: dict[str, _ActiveConversation] = {}
+ try:
+ setattr(peer, "_session_waiter_manager", self._session_waiters)
+ except AttributeError:
+ logger.warning(
+ f"Failed to attach _session_waiter_manager to peer {peer}, "
+ "some features may not work as expected"
+ )
+
+ def has_active_waiter(self, event: MessageEvent) -> bool:
+ return self._session_waiters.has_active_waiter(event)
+
+ async def invoke(self, message, cancel_token: CancelToken) -> dict[str, Any]:
+ handler_id = str(message.input.get("handler_id", ""))
+ event_payload = self._coerce_event_payload(message.input.get("event"))
+ if handler_id == "__sdk_session_waiter__":
+ requested_plugin_id = str(message.input.get("plugin_id") or "").strip()
+ plugin_id = self._resolve_waiter_plugin_id(
+ event_payload=event_payload,
+ requested_plugin_id=requested_plugin_id,
+ )
+ ctx, event = self._create_context_event(
+ plugin_id=plugin_id,
+ request_id=message.id,
+ cancel_token=cancel_token,
+ event_payload=event_payload,
+ )
+ event.bind_reply_handler(self._create_reply_handler(ctx, event))
+ task = self._spawn_plugin_task(
+ plugin_id,
+ self._session_waiters.dispatch(event, plugin_id=plugin_id),
+ )
+ return await self._await_tracked_task(message.id, task, cancel_token)
+
+ loaded = self._handlers.get(handler_id)
+ if loaded is None:
+ raise LookupError(f"handler not found: {handler_id}")
+
+ plugin_id = self._resolve_plugin_id(loaded)
+ ctx, event = self._create_context_event(
+ plugin_id=plugin_id,
+ request_id=message.id,
+ cancel_token=cancel_token,
+ event_payload=event_payload,
+ )
+ bound_logger = cast(PluginLogger, ctx.logger).bind(
+ plugin_id=plugin_id,
+ request_id=message.id,
+ handler_ref=handler_id,
+ session_id=event.session_id,
+ event_type=str(
+ event_payload.get("event_type")
+ or event_payload.get("type")
+ or event.message_type
+ ),
+ )
+ ctx.logger = bound_logger
+ event.bind_reply_handler(self._create_reply_handler(ctx, event))
+ schedule_context = self._build_schedule_context(loaded, event_payload)
+
+ # 提取 args 用于兼容 handler 签名
+ raw_args = message.input.get("args") or {}
+ args = dict(raw_args) if isinstance(raw_args, dict) else {}
+ if not args:
+ args = self._derive_args(loaded, event)
+
+ task = self._spawn_plugin_task(
+ plugin_id,
+ self._run_handler(
+ loaded,
+ event,
+ ctx,
+ args,
+ schedule_context=schedule_context,
+ ),
+ )
+ return await self._await_tracked_task(message.id, task, cancel_token)
+
+ @staticmethod
+ def _coerce_event_payload(payload: Any) -> dict[str, Any]:
+ return payload if isinstance(payload, dict) else {}
+
+ @staticmethod
+ def _session_key_from_payload(event_payload: dict[str, Any]) -> str:
+ return MessageEvent.session_key_from_payload(event_payload)
+
+ def _resolve_waiter_plugin_id(
+ self,
+ *,
+ event_payload: dict[str, Any],
+ requested_plugin_id: str,
+ ) -> str:
+ if requested_plugin_id:
+ return requested_plugin_id
+ # Resolve the owning plugin before constructing the runtime Context so a
+ # worker-group waiter follow-up does not rebuild the event twice.
+ plugin_ids = self._session_waiters.get_waiter_plugin_ids(
+ self._session_key_from_payload(event_payload)
+ )
+ if len(plugin_ids) > 1:
+ raise LookupError(
+ "multiple active session_waiters found for session; "
+ "dispatch requires explicit plugin identity"
+ )
+ return plugin_ids[0] if plugin_ids else self._plugin_id
+
+ def _create_context_event(
+ self,
+ *,
+ plugin_id: str,
+ request_id: str,
+ cancel_token: CancelToken,
+ event_payload: dict[str, Any],
+ ) -> tuple[Context, MessageEvent]:
+ ctx = Context(
+ peer=self._peer,
+ plugin_id=plugin_id,
+ request_id=request_id,
+ cancel_token=cancel_token,
+ source_event_payload=event_payload,
+ )
+ event = MessageEvent.from_payload(event_payload, context=ctx)
+ return ctx, event
+
+ @staticmethod
+ def _spawn_plugin_task(plugin_id: str, coroutine):
+ with caller_plugin_scope(plugin_id):
+ return asyncio.create_task(coroutine)
+
+ async def _await_tracked_task(
+ self,
+ request_id: str,
+ task: asyncio.Task[Any],
+ cancel_token: CancelToken,
+ ) -> dict[str, Any]:
+ _mark_session_waiter_handler_task(task)
+ task.add_done_callback(_unmark_session_waiter_handler_task)
+ self._active[request_id] = (task, cancel_token)
+ try:
+ return await task
+ finally:
+ self._active.pop(request_id, None)
+
+ def _resolve_plugin_id(self, loaded: LoadedHandler) -> str:
+ if loaded.plugin_id:
+ return loaded.plugin_id
+ handler_id = getattr(loaded.descriptor, "id", "")
+ if isinstance(handler_id, str) and ":" in handler_id:
+ return handler_id.split(":", 1)[0]
+ return self._plugin_id
+
+ def _create_reply_handler(self, ctx: Context, event: MessageEvent):
+ async def reply(text: str) -> None:
+ try:
+ await ctx.platform.send(event.session_ref or event.session_id, text)
+ except TypeError:
+ send = getattr(self._peer, "send", None)
+ if not callable(send):
+ raise
+ result = send(event.session_id, text)
+ if inspect.isawaitable(result):
+ await result
+
+ return reply
+
+ async def cancel(self, request_id: str) -> None:
+ active = self._active.get(request_id)
+ if active is None:
+ return
+ task, cancel_token = active
+ cancel_token.cancel()
+ task.cancel()
+
+ async def _run_handler(
+ self,
+ loaded: LoadedHandler,
+ event: MessageEvent,
+ ctx: Context,
+ args: dict[str, Any] | None = None,
+ *,
+ schedule_context: ScheduleContext | None = None,
+ ) -> dict[str, Any]:
+ summary = {"sent_message": False, "stop": False, "call_llm": False}
+ injected_payloads = _InjectedEventPayloads()
+ event_type = self._event_type_name(event)
+ try:
+ limiter = loaded.limiter
+ if limiter is not None:
+ decision = self._limiter.evaluate(
+ plugin_id=self._resolve_plugin_id(loaded),
+ handler_id=loaded.descriptor.id,
+ limiter=limiter,
+ event=event,
+ )
+ if not decision.allowed:
+ if decision.error is not None:
+ raise decision.error
+ if decision.hint:
+ await event.reply(decision.hint)
+ summary["sent_message"] = True
+ return summary
+ if not self._run_local_filters(
+ loaded.local_filters,
+ event=event,
+ ctx=ctx,
+ ):
+ return summary
+ parsed_args, help_text = self._prepare_handler_args(
+ loaded,
+ args or {},
+ )
+ if help_text is not None:
+ await event.reply(help_text)
+ summary["sent_message"] = True
+ return summary
+ if loaded.conversation is not None:
+ return await self._start_conversation(
+ loaded,
+ event,
+ ctx,
+ parsed_args,
+ schedule_context=schedule_context,
+ )
+ owner = loaded.owner if isinstance(loaded.owner, Star) else None
+ with bind_star_runtime(owner, ctx):
+ result = loaded.callable(
+ *self._build_args(
+ loaded.callable,
+ event,
+ ctx,
+ parsed_args,
+ plugin_id=self._resolve_plugin_id(loaded),
+ handler_ref=loaded.descriptor.id,
+ schedule_context=schedule_context,
+ injected_payloads=injected_payloads,
+ )
+ )
+ if inspect.isasyncgen(result):
+ async for item in result:
+ self._merge_handler_summary(
+ summary,
+ await self._handle_result_item(item, event, ctx),
+ )
+ summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
+ self._append_injected_payloads(
+ summary,
+ injected_payloads,
+ event=event,
+ event_type=event_type,
+ )
+ return summary
+ if inspect.isawaitable(result):
+ result = await result
+ if result is not None:
+ self._merge_handler_summary(
+ summary,
+ await self._handle_result_item(result, event, ctx),
+ )
+ summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
+ self._append_injected_payloads(
+ summary,
+ injected_payloads,
+ event=event,
+ event_type=event_type,
+ )
+ return summary
+ except Exception as exc:
+ await self._handle_error(
+ loaded.owner,
+ exc,
+ event,
+ ctx,
+ handler_name=loaded.callable.__name__,
+ plugin_id=self._resolve_plugin_id(loaded),
+ )
+ raise
+
+ def _derive_args(
+ self,
+ loaded: LoadedHandler,
+ event: MessageEvent,
+ ) -> dict[str, Any]:
+ trigger = loaded.descriptor.trigger
+ if isinstance(trigger, CommandTrigger):
+ param_specs = loaded.descriptor.param_specs
+ for command_name in [trigger.command, *trigger.aliases]:
+ remainder = match_command_name(event.text, command_name)
+ if remainder is not None:
+ model_param = resolve_command_model_param(loaded.callable)
+ if model_param is not None:
+ return {
+ "__command_model_remainder__": remainder,
+ "__command_name__": command_name,
+ }
+ if param_specs:
+ return build_command_args(param_specs, remainder)
+ return build_command_args(
+ [
+ ParamSpec(name=name, type="str")
+ for name in legacy_arg_parameter_names(loaded.callable)
+ ],
+ remainder,
+ )
+ return {}
+ if isinstance(trigger, MessageTrigger) and trigger.regex:
+ match = re.search(trigger.regex, event.text)
+ if match is None:
+ return {}
+ if loaded.descriptor.param_specs:
+ return build_regex_args(loaded.descriptor.param_specs, match)
+ return build_regex_args(
+ [
+ ParamSpec(name=name, type="str")
+ for name in legacy_arg_parameter_names(loaded.callable)
+ ],
+ match,
+ )
+ return {}
+
+ def _build_args(
+ self,
+ handler,
+ event: MessageEvent,
+ ctx: Context,
+ args: dict[str, Any] | None = None,
+ *,
+ plugin_id: str | None = None,
+ handler_ref: str | None = None,
+ schedule_context: ScheduleContext | None = None,
+ conversation_session: ConversationSession | None = None,
+ injected_payloads: _InjectedEventPayloads | None = None,
+ ) -> list[Any]:
+ """构建 handler 参数列表。"""
+ from .._internal.sdk_logger import logger
+
+ signature = inspect.signature(handler)
+ injected_args: list[Any] = []
+ args = args or {}
+
+ type_hints: dict[str, Any] = {}
+ try:
+ type_hints = get_type_hints(handler)
+ except Exception:
+ pass
+
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+
+ injected = None
+
+ # 1. 优先按类型注解注入
+ param_type = type_hints.get(parameter.name)
+ if param_type is not None:
+ injected = self._inject_by_type(
+ param_type,
+ event,
+ ctx,
+ schedule_context,
+ conversation_session,
+ injected_payloads=injected_payloads,
+ )
+
+ # 2. Fallback 按名字注入
+ if injected is None:
+ if parameter.name == "event":
+ injected = event
+ elif parameter.name in {"ctx", "context"}:
+ injected = ctx
+ elif parameter.name in {"sched", "schedule"}:
+ injected = schedule_context
+ elif parameter.name in {"conversation", "conv"}:
+ injected = conversation_session
+ elif parameter.name in args:
+ injected = args[parameter.name]
+
+ # 3. 检查是否有默认值
+ if injected is None:
+ if parameter.default is not parameter.empty:
+ continue
+ logger.error(
+ "Handler '{}' 的必填参数 '{}' 无法注入",
+ handler.__name__,
+ parameter.name,
+ )
+ raise TypeError(
+ self._format_handler_injection_error(
+ handler=handler,
+ parameter_name=parameter.name,
+ plugin_id=plugin_id,
+ handler_ref=handler_ref,
+ args=args,
+ )
+ )
+ else:
+ injected_args.append(injected)
+
+ return injected_args
+
+ def _prepare_handler_args(
+ self,
+ loaded: LoadedHandler,
+ args: dict[str, Any],
+ ) -> tuple[dict[str, Any], str | None]:
+ parsed_args = (
+ self._parse_handler_args(loaded.descriptor.param_specs, args)
+ if loaded.descriptor.param_specs
+ else {
+ key: value
+ for key, value in dict(args).items()
+ if not str(key).startswith("__command_")
+ }
+ )
+ if not isinstance(loaded.descriptor.trigger, CommandTrigger):
+ return parsed_args, None
+ model_param = resolve_command_model_param(loaded.callable)
+ if model_param is None:
+ return parsed_args, None
+ if "__command_model_remainder__" not in args:
+ return parsed_args, None
+ trigger = loaded.descriptor.trigger
+ command_name = str(args.get("__command_name__", "")) or (
+ trigger.command
+ if isinstance(trigger, CommandTrigger)
+ else loaded.descriptor.id.rsplit(".", 1)[-1]
+ )
+ result = parse_command_model_remainder(
+ remainder=str(args.get("__command_model_remainder__", "")),
+ model_param=model_param,
+ command_name=command_name,
+ )
+ if result.help_text is not None:
+ return parsed_args, result.help_text
+ if result.model is not None:
+ parsed_args[model_param.name] = result.model
+ return parsed_args, None
+
+ async def _start_conversation(
+ self,
+ loaded: LoadedHandler,
+ event: MessageEvent,
+ ctx: Context,
+ parsed_args: dict[str, Any],
+ *,
+ schedule_context: ScheduleContext | None,
+ ) -> dict[str, Any]:
+ assert loaded.conversation is not None
+ conversation_meta = loaded.conversation
+ summary = {"sent_message": False, "stop": True, "call_llm": False}
+ key = f"{self._resolve_plugin_id(loaded)}:{event.session_id}"
+ active = self._conversations.get(key)
+ if active is not None and not active.task.done():
+ if conversation_meta.mode == "reject":
+ await event.reply(
+ conversation_meta.busy_message or DEFAULT_BUSY_MESSAGE
+ )
+ summary["sent_message"] = True
+ return summary
+ active.session.mark_replaced()
+ await self._session_waiters.fail(
+ active.session.session_key,
+ ConversationReplaced("conversation replaced by a newer session"),
+ )
+ await asyncio.sleep(0)
+ active.task.cancel()
+ try:
+ await asyncio.wait_for(
+ asyncio.shield(active.task),
+ timeout=conversation_meta.grace_period,
+ )
+ except asyncio.TimeoutError:
+ cast(PluginLogger, ctx.logger).warning(
+ "Conversation replacement grace period exceeded for handler {}",
+ loaded.descriptor.id,
+ )
+ except asyncio.CancelledError:
+ pass
+ except Exception:
+ pass
+ finally:
+ if self._conversations.get(key) is active:
+ self._conversations.pop(key, None)
+
+ conversation = ConversationSession(
+ ctx=ctx,
+ event=event,
+ waiter_manager=self._session_waiters,
+ timeout=conversation_meta.timeout,
+ )
+
+ async def _runner() -> None:
+ try:
+ await self._run_conversation_task(
+ loaded,
+ event,
+ ctx,
+ parsed_args,
+ conversation,
+ schedule_context=schedule_context,
+ )
+ finally:
+ if conversation.state == ConversationState.ACTIVE:
+ conversation.close(ConversationState.COMPLETED)
+
+ def _cleanup_conversation() -> None:
+ current = self._conversations.get(key)
+ if current is not None and current.session is conversation:
+ self._conversations.pop(key, None)
+
+ task = asyncio.create_task(_runner())
+ conversation.bind_owner_task(task)
+ managed_task = _ManagedConversationTask(
+ task=task, cleanup=_cleanup_conversation
+ )
+ self._conversations[key] = _ActiveConversation(
+ session=conversation,
+ task=managed_task,
+ )
+ _mark_session_waiter_background_task(task)
+
+ def _on_done(done_task: asyncio.Task[Any]) -> None:
+ _cleanup_conversation()
+ _unmark_session_waiter_background_task(done_task)
+
+ task.add_done_callback(_on_done)
+ return summary
+
+ async def _run_conversation_task(
+ self,
+ loaded: LoadedHandler,
+ event: MessageEvent,
+ ctx: Context,
+ parsed_args: dict[str, Any],
+ conversation: ConversationSession,
+ *,
+ schedule_context: ScheduleContext | None,
+ ) -> None:
+ owner = loaded.owner if isinstance(loaded.owner, Star) else None
+ args_with_conversation = dict(parsed_args)
+ args_with_conversation.setdefault("conversation", conversation)
+ try:
+ with bind_star_runtime(owner, ctx):
+ result = loaded.callable(
+ *self._build_args(
+ loaded.callable,
+ event,
+ ctx,
+ args_with_conversation,
+ plugin_id=self._resolve_plugin_id(loaded),
+ handler_ref=loaded.descriptor.id,
+ schedule_context=schedule_context,
+ conversation_session=conversation,
+ )
+ )
+ if inspect.isasyncgen(result):
+ async for item in result:
+ await self._handle_result_item(item, event, ctx)
+ return
+ if inspect.isawaitable(result):
+ result = await result
+ if result is not None:
+ await self._handle_result_item(result, event, ctx)
+ except asyncio.CancelledError:
+ if conversation.state == ConversationState.ACTIVE:
+ conversation.close(ConversationState.CANCELLED)
+ raise
+ except (ConversationReplaced, ConversationClosed):
+ return
+ except Exception as exc:
+ await self._handle_error(
+ loaded.owner,
+ exc,
+ event,
+ ctx,
+ handler_name=loaded.callable.__name__,
+ plugin_id=self._resolve_plugin_id(loaded),
+ )
+
+ def _inject_by_type(
+ self,
+ param_type: Any,
+ event: MessageEvent,
+ ctx: Context,
+ schedule_context: ScheduleContext | None,
+ conversation_session: ConversationSession | None,
+ *,
+ injected_payloads: _InjectedEventPayloads | None = None,
+ ) -> Any:
+ """根据类型注解注入参数。"""
+ param_type, _is_optional = unwrap_optional(param_type)
+
+ # 注入 MessageEvent 及其子类
+ if param_type is MessageEvent:
+ return event
+ if isinstance(param_type, type) and issubclass(param_type, MessageEvent):
+ if isinstance(event, param_type):
+ return event
+ factory = getattr(param_type, "from_message_event", None)
+ if callable(factory):
+ return factory(event)
+ return event
+
+ # 注入 Context 及其子类
+ if param_type is Context or (
+ isinstance(param_type, type) and issubclass(param_type, Context)
+ ):
+ return ctx
+ if param_type is ScheduleContext or (
+ isinstance(param_type, type) and issubclass(param_type, ScheduleContext)
+ ):
+ return schedule_context
+ if param_type is ConversationSession or (
+ isinstance(param_type, type) and issubclass(param_type, ConversationSession)
+ ):
+ return conversation_session
+ if param_type is ProviderRequest or (
+ isinstance(param_type, type) and issubclass(param_type, ProviderRequest)
+ ):
+ return self._inject_provider_request(event, injected_payloads)
+ if param_type is LLMResponse or (
+ isinstance(param_type, type) and issubclass(param_type, LLMResponse)
+ ):
+ return self._inject_llm_response(event, injected_payloads)
+ if param_type is MessageEventResult or (
+ isinstance(param_type, type) and issubclass(param_type, MessageEventResult)
+ ):
+ return self._inject_event_result(event, injected_payloads)
+
+ return None
+
+ @staticmethod
+ def _event_type_name(event: MessageEvent) -> str:
+ raw = event.raw if isinstance(event.raw, dict) else {}
+ value = raw.get("event_type") or raw.get("type")
+ return str(value or "")
+
+ @staticmethod
+ def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None:
+ raw = event.raw if isinstance(event.raw, dict) else {}
+ payload = raw.get(key)
+ if isinstance(payload, dict):
+ return payload
+ nested_raw = raw.get("raw")
+ if isinstance(nested_raw, dict):
+ nested_payload = nested_raw.get(key)
+ if isinstance(nested_payload, dict):
+ return nested_payload
+ return None
+
+ def _inject_provider_request(
+ self,
+ event: MessageEvent,
+ injected_payloads: _InjectedEventPayloads | None,
+ ) -> ProviderRequest | None:
+ if injected_payloads is None:
+ payload = self._payload_from_event(event, "provider_request")
+ return (
+ ProviderRequest.from_payload(payload) if payload is not None else None
+ )
+ if injected_payloads.provider_request is None:
+ payload = self._payload_from_event(event, "provider_request")
+ if payload is None:
+ return None
+ injected_payloads.provider_request = ProviderRequest.from_payload(payload)
+ return injected_payloads.provider_request
+
+ def _inject_llm_response(
+ self,
+ event: MessageEvent,
+ injected_payloads: _InjectedEventPayloads | None,
+ ) -> LLMResponse | None:
+ if injected_payloads is None:
+ payload = self._payload_from_event(event, "llm_response")
+ return LLMResponse.model_validate(payload) if payload is not None else None
+ if injected_payloads.llm_response is None:
+ payload = self._payload_from_event(event, "llm_response")
+ if payload is None:
+ return None
+ injected_payloads.llm_response = LLMResponse.model_validate(payload)
+ return injected_payloads.llm_response
+
+ def _inject_event_result(
+ self,
+ event: MessageEvent,
+ injected_payloads: _InjectedEventPayloads | None,
+ ) -> MessageEventResult | None:
+ if injected_payloads is None:
+ payload = self._payload_from_event(event, "event_result")
+ return (
+ MessageEventResult.from_payload(payload)
+ if payload is not None
+ else None
+ )
+ if injected_payloads.event_result is None:
+ payload = self._payload_from_event(event, "event_result")
+ if payload is None:
+ return None
+ injected_payloads.event_result = MessageEventResult.from_payload(payload)
+ return injected_payloads.event_result
+
+ @staticmethod
+ def _append_injected_payloads(
+ summary: dict[str, Any],
+ injected_payloads: _InjectedEventPayloads,
+ *,
+ event: MessageEvent,
+ event_type: str,
+ ) -> None:
+ if (
+ event_type == "llm_request"
+ and injected_payloads.provider_request is not None
+ ):
+ summary["provider_request"] = (
+ injected_payloads.provider_request.to_payload()
+ )
+ elif (
+ event_type in {"llm_response", "agent_done"}
+ and injected_payloads.llm_response is not None
+ ):
+ summary["llm_response"] = injected_payloads.llm_response.model_dump(
+ exclude_none=True
+ )
+ elif (
+ event_type in {"decorating_result", "streaming_delta"}
+ and injected_payloads.event_result is not None
+ ):
+ summary["event_result"] = injected_payloads.event_result.to_payload()
+ if event._should_serialize_sdk_local_extras(): # noqa: SLF001
+ summary["sdk_local_extras"] = event._sdk_local_extras_payload() # noqa: SLF001
+
+ def _format_handler_injection_error(
+ self,
+ *,
+ handler,
+ parameter_name: str,
+ plugin_id: str | None,
+ handler_ref: str | None,
+ args: dict[str, Any],
+ ) -> str:
+ plugin_text = plugin_id or self._plugin_id
+ target = handler_ref or getattr(handler, "__name__", "")
+ arg_keys = sorted(str(key) for key in args.keys())
+ arg_keys_text = ", ".join(arg_keys) if arg_keys else ""
+ return (
+ f"插件 '{plugin_text}' 的 handler '{target}' 参数注入失败:"
+ f"必填参数 '{parameter_name}' 无法注入。"
+ f"签名: {getattr(handler, '__name__', '')}"
+ f"{self._callable_signature(handler)}。"
+ "当前支持按类型注入 MessageEvent / Context,"
+ "按参数名注入 event / ctx / context,"
+ f"以及 args 中现有键:{arg_keys_text}。"
+ )
+
+ @staticmethod
+ def _callable_signature(handler) -> str:
+ try:
+ return str(inspect.signature(handler))
+ except (TypeError, ValueError):
+ return "(...)"
+
+ async def _handle_result_item(
+ self,
+ item: Any,
+ event: MessageEvent,
+ ctx: Context | None = None,
+ ) -> dict[str, Any]:
+ sent_message = await self._send_result(item, event, ctx)
+ if isinstance(item, dict):
+ return {
+ "sent_message": sent_message,
+ "stop": bool(item.get("stop", False)),
+ "call_llm": bool(item.get("call_llm", False)),
+ }
+ return {
+ "sent_message": sent_message,
+ "stop": False,
+ "call_llm": False,
+ }
+
+ @staticmethod
+ def _merge_handler_summary(
+ target: dict[str, Any],
+ source: dict[str, Any],
+ ) -> None:
+ target["sent_message"] = bool(target.get("sent_message")) or bool(
+ source.get("sent_message")
+ )
+ target["stop"] = bool(target.get("stop")) or bool(source.get("stop"))
+ target["call_llm"] = bool(target.get("call_llm")) or bool(
+ source.get("call_llm")
+ )
+
+ async def _send_result(
+ self,
+ item: Any,
+ event: MessageEvent,
+ ctx: Context | None = None,
+ ) -> bool:
+ """发送处理器结果。"""
+ if isinstance(item, str):
+ await event.reply(item)
+ return True
+ if isinstance(item, dict) and "text" in item:
+ await event.reply(str(item["text"]))
+ return True
+ if isinstance(item, MessageEventResult):
+ chain = item.chain
+ if chain.components:
+ await event.reply_chain(chain)
+ return True
+ return False
+ chain = coerce_message_chain(item)
+ if chain is not None:
+ if chain.components:
+ await event.reply_chain(chain)
+ return True
+ return False
+ if isinstance(item, list) and all(
+ isinstance(component, BaseMessageComponent) for component in item
+ ):
+ await event.reply_chain(MessageChain(list(item)))
+ return True
+ # 支持带 text 属性的对象
+ text = getattr(item, "text", None)
+ if isinstance(text, str):
+ await event.reply(text)
+ return True
+ return False
+
+ @staticmethod
+ def _parse_handler_args(
+ param_specs: Sequence[ParamSpec],
+ args: dict[str, Any],
+ ) -> dict[str, Any]:
+ parsed: dict[str, Any] = {}
+ for spec in param_specs:
+ if spec.name not in args:
+ if spec.type == "optional":
+ parsed[spec.name] = None
+ continue
+ if spec.required:
+ raise TypeError(f"缺少参数: {spec.name}")
+ continue
+ parsed[spec.name] = HandlerDispatcher._convert_param(spec, args[spec.name])
+ return parsed
+
+ @staticmethod
+ def _convert_param(spec: ParamSpec, value: Any) -> Any:
+ if spec.type in {"str", "greedy_str"}:
+ return str(value)
+ if spec.type == "int":
+ return int(str(value))
+ if spec.type == "float":
+ return float(str(value))
+ if spec.type == "bool":
+ normalized = str(value).strip().lower()
+ if normalized in {"true", "1", "yes", "on"}:
+ return True
+ if normalized in {"false", "0", "no", "off"}:
+ return False
+ raise TypeError(f"无法解析布尔参数 {spec.name}: {value!r}")
+ if spec.type == "optional":
+ if value is None:
+ return None
+ inner = ParamSpec(
+ name=spec.name,
+ type=spec.inner_type or "str",
+ required=False,
+ )
+ return HandlerDispatcher._convert_param(inner, value)
+ return value
+
+ @staticmethod
+ def _run_local_filters(
+ bindings: list[LocalFilterBinding],
+ *,
+ event: MessageEvent,
+ ctx: Context,
+ ) -> bool:
+ for binding in bindings:
+ if not binding.evaluate(event=event, ctx=ctx):
+ return False
+ return True
+
+ @staticmethod
+ def _build_schedule_context(
+ loaded: LoadedHandler,
+ event_payload: dict[str, Any],
+ ) -> ScheduleContext | None:
+ if not isinstance(loaded.descriptor.trigger, ScheduleTrigger):
+ return None
+ try:
+ return ScheduleContext.from_payload(event_payload)
+ except Exception:
+ return None
+
+ async def _handle_error(
+ self,
+ owner: Any,
+ exc: Exception,
+ event: MessageEvent,
+ ctx: Context,
+ *,
+ handler_name: str = "",
+ plugin_id: str | None = None,
+ ) -> None:
+ if hasattr(owner, "on_error") and callable(owner.on_error):
+ bound_owner = owner if isinstance(owner, Star) else None
+ with bind_star_runtime(bound_owner, ctx):
+ result = owner.on_error(exc, event, ctx)
+ if inspect.isawaitable(result):
+ await result
+ return
+ await Star.default_on_error(exc, event, ctx)
+
+
+__all__ = ["CapabilityDispatcher", "HandlerDispatcher"]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py b/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py
new file mode 100644
index 0000000000..b32fe6e2da
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py
@@ -0,0 +1,118 @@
+from __future__ import annotations
+
+import time
+from collections import deque
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any
+
+from ..decorators import LimiterMeta
+from ..errors import AstrBotError
+
+DEFAULT_RATE_LIMIT_MESSAGE = "操作过于频繁,请稍后再试。"
+DEFAULT_COOLDOWN_MESSAGE = "冷却中,请在 {remaining_seconds}s 后重试。"
+
+
+@dataclass(slots=True)
+class LimiterDecision:
+ allowed: bool
+ error: AstrBotError | None = None
+ hint: str | None = None
+
+
+class LimiterEngine:
+ def __init__(self, *, clock: Callable[[], float] | None = None) -> None:
+ self._clock = clock or time.monotonic
+ self._windows: dict[str, deque[float]] = {}
+
+ def evaluate(
+ self,
+ *,
+ plugin_id: str,
+ handler_id: str,
+ limiter: LimiterMeta,
+ event: Any,
+ ) -> LimiterDecision:
+ now = float(self._clock())
+ key = self._make_key(
+ plugin_id=plugin_id,
+ handler_id=handler_id,
+ scope=limiter.scope,
+ event=event,
+ )
+ bucket = self._windows.setdefault(key, deque())
+ threshold = now - limiter.window
+ while bucket and bucket[0] <= threshold:
+ bucket.popleft()
+
+ if len(bucket) < limiter.limit:
+ bucket.append(now)
+ return LimiterDecision(allowed=True)
+
+ remaining = 0.0
+ if bucket:
+ remaining = max(0.0, limiter.window - (now - bucket[0]))
+ hint = self._hint_text(limiter, remaining)
+ details = {
+ "scope": limiter.scope,
+ "handler_id": handler_id,
+ "remaining_seconds": round(remaining, 3),
+ }
+ if limiter.behavior == "silent":
+ return LimiterDecision(allowed=False)
+ if limiter.behavior == "error":
+ if limiter.kind == "cooldown":
+ return LimiterDecision(
+ allowed=False,
+ error=AstrBotError.cooldown_active(hint=hint, details=details),
+ )
+ return LimiterDecision(
+ allowed=False,
+ error=AstrBotError.rate_limited(hint=hint, details=details),
+ )
+ return LimiterDecision(allowed=False, hint=hint)
+
+ @staticmethod
+ def _make_key(
+ *,
+ plugin_id: str,
+ handler_id: str,
+ scope: str,
+ event: Any,
+ ) -> str:
+ prefix = f"{plugin_id}:{handler_id}"
+ if scope == "global":
+ return prefix
+ if scope == "session":
+ return f"{prefix}:{getattr(event, 'session_id', '')}"
+ if scope == "user":
+ return (
+ f"{prefix}:{getattr(event, 'platform_id', '')}"
+ f":{getattr(event, 'user_id', '')}"
+ )
+ if scope == "group":
+ return (
+ f"{prefix}:{getattr(event, 'platform_id', '')}"
+ f":{getattr(event, 'group_id', '')}"
+ )
+ return prefix
+
+ @staticmethod
+ def _hint_text(limiter: LimiterMeta, remaining: float) -> str:
+ if limiter.message:
+ return limiter.message.format(
+ remaining_seconds=max(1, int(remaining + 0.999))
+ )
+ if limiter.kind == "cooldown":
+ return DEFAULT_COOLDOWN_MESSAGE.format(
+ remaining_seconds=max(1, int(remaining + 0.999))
+ )
+ return DEFAULT_RATE_LIMIT_MESSAGE
+
+
+__all__ = [
+ "DEFAULT_COOLDOWN_MESSAGE",
+ "DEFAULT_RATE_LIMIT_MESSAGE",
+ "LimiterDecision",
+ "LimiterEngine",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/loader.py b/astrbot-sdk/src/astrbot_sdk/runtime/loader.py
new file mode 100644
index 0000000000..07294d2797
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/loader.py
@@ -0,0 +1,1536 @@
+"""插件加载模块。
+
+定义插件发现、环境管理和加载的核心逻辑。
+仅支持 astrbot-sdk 新版 Star 组件。
+
+核心概念:
+ PluginSpec: 插件规范,描述插件的基本信息
+ PluginDiscoveryResult: 插件发现结果,包含成功和跳过的插件
+ PluginEnvironmentManager: 插件虚拟环境管理器
+ LoadedHandler: 加载后的处理器,包含描述符和可调用对象
+ LoadedPlugin: 加载后的插件,包含处理器和实例
+
+插件发现流程:
+ 1. 扫描 plugins_dir 下的子目录
+ 2. 检查 plugin.yaml 和 requirements.txt
+ 3. 解析 manifest_data 获取插件信息
+ 4. 验证必要字段(name, components, runtime.python)
+ 5. 返回 PluginDiscoveryResult
+
+环境管理流程:
+ 1. 对插件集合做共享环境规划
+ 2. 按 Python 版本和依赖兼容性构建环境分组
+ 3. 为每个分组生成 lock/source/metadata 工件
+ 4. 必要时重建或同步分组虚拟环境
+ 5. 将单个插件映射到所属分组环境
+
+插件加载流程:
+ 1. 将插件目录添加到 sys.path
+ 2. 遍历 components 列表
+ 3. 动态导入组件类
+ 4. 直接实例化(无参构造函数)
+ 5. 扫描处理器方法
+ 6. 构建 HandlerDescriptor
+
+plugin.yaml 格式:
+ name: my_plugin
+ author: author_name
+ repo: my_plugin
+ desc: Plugin description
+ version: 1.0.0
+ runtime:
+ python: "3.11"
+ components:
+ - class: my_plugin.main:MyComponent
+
+`loader` 是 runtime 与插件代码之间的边界层,负责三件事:
+
+- 从 `plugin.yaml` 解析出可运行的 `PluginSpec`
+- 用 `uv` 为插件准备独立环境
+- 把组件实例和 handler 元数据整理成 `LoadedPlugin`
+"""
+
+from __future__ import annotations
+
+import builtins
+import contextlib
+import copy
+import hashlib
+import importlib
+import importlib.abc
+import inspect
+import json
+import os
+import re
+import shutil
+import sys
+import threading
+import types
+import typing
+from collections.abc import Sequence
+from dataclasses import dataclass, field, replace
+from importlib import import_module
+from importlib.machinery import ModuleSpec, PathFinder
+from pathlib import Path
+from typing import Any, Literal, TypeAlias, TypeVar, cast
+
+import yaml
+
+from .._internal.command_model import resolve_command_model_param
+from .._internal.injected_params import is_framework_injected_parameter
+from .._internal.invocation_context import caller_plugin_scope, current_caller_plugin_id
+from .._internal.plugin_ids import (
+ capability_belongs_to_plugin,
+ plugin_capability_prefix,
+ validate_plugin_id,
+)
+from .._internal.sdk_logger import logger
+from .._internal.typing_utils import unwrap_optional
+from ..decorators import (
+ ConversationMeta,
+ LimiterMeta,
+ get_agent_meta,
+ get_capability_meta,
+ get_handler_meta,
+ get_llm_tool_meta,
+)
+from ..llm.agents import AgentSpec
+from ..llm.entities import LLMToolSpec
+from ..protocol.descriptors import (
+ CapabilityDescriptor,
+ HandlerDescriptor,
+ ParamSpec,
+ ScheduleTrigger,
+)
+from ..types import GreedyStr
+from .environment_groups import (
+ EnvironmentGroup,
+ EnvironmentPlanner,
+ EnvironmentPlanResult,
+ GroupEnvironmentManager,
+)
+
+PLUGIN_MANIFEST_FILE = "plugin.yaml"
+STATE_FILE_NAME = ".astrbot-worker-state.json"
+CONFIG_SCHEMA_FILE = "_conf_schema.json"
+PLUGIN_METADATA_ATTR = "__astrbot_plugin_metadata__"
+ParamTypeName: TypeAlias = Literal[
+ "str", "int", "float", "bool", "optional", "greedy_str"
+]
+OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None
+HandlerKind: TypeAlias = Literal["handler", "hook", "tool", "session"]
+DiscoverySeverity: TypeAlias = Literal["warning", "error"]
+DiscoveryPhase: TypeAlias = Literal["discovery", "load", "lifecycle", "reload"]
+_PLUGIN_IMPORT_LOCK = threading.RLock()
+_VALID_HANDLER_KINDS: tuple[HandlerKind, ...] = ("handler", "hook", "tool", "session")
+_PLUGIN_PACKAGE_PREFIX = "astrbot_ext_"
+_GITHUB_REPO_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
+_GITHUB_REPO_SLUG_RE = re.compile(r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+$")
+_GITHUB_REPO_URL_RE = re.compile(
+ r"^https://github\.com/[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+/?$",
+ re.IGNORECASE,
+)
+_PLUGIN_IMPORT_NAMESPACES: dict[str, _PluginImportNamespace] = {}
+_ORIGINAL_BUILTIN_IMPORT = builtins.__import__
+_PLUGIN_IMPORT_HOOK_INSTALLED = False
+_PLUGIN_IMPORT_META_FINDER: _PluginScopedMetaPathFinder | None = None
+_PLUGIN_IMPORT_ALIAS_STATE = threading.local()
+_TMeta = TypeVar("_TMeta", LimiterMeta, ConversationMeta)
+
+
+def _default_python_version() -> str:
+ return f"{sys.version_info.major}.{sys.version_info.minor}"
+
+
+def _is_valid_github_repo_ref(value: str) -> bool:
+ normalized = value.strip()
+ if not normalized:
+ return False
+ return bool(
+ _GITHUB_REPO_NAME_RE.fullmatch(normalized)
+ or _GITHUB_REPO_SLUG_RE.fullmatch(normalized)
+ or _GITHUB_REPO_URL_RE.fullmatch(normalized)
+ )
+
+
+def _venv_python_path(venv_dir: Path) -> Path:
+ if os.name == "nt":
+ return venv_dir / "Scripts" / "python.exe"
+ return venv_dir / "bin" / "python"
+
+
+@dataclass(slots=True)
+class PluginSpec:
+ name: str
+ plugin_dir: Path
+ manifest_path: Path
+ requirements_path: Path
+ python_version: str
+ manifest_data: dict[str, Any]
+
+
+@dataclass(slots=True)
+class PluginDiscoveryResult:
+ plugins: list[PluginSpec]
+ skipped_plugins: dict[str, str]
+ issues: list[PluginDiscoveryIssue] = field(default_factory=list)
+
+
+@dataclass(slots=True)
+class PluginDiscoveryIssue:
+ severity: DiscoverySeverity
+ phase: DiscoveryPhase
+ plugin_id: str
+ message: str
+ details: str = ""
+ hint: str = ""
+
+ def to_payload(self) -> dict[str, str]:
+ return {
+ "severity": self.severity,
+ "phase": self.phase,
+ "plugin_id": self.plugin_id,
+ "message": self.message,
+ "details": self.details,
+ "hint": self.hint,
+ }
+
+
+@dataclass(slots=True)
+class LoadedHandler:
+ descriptor: HandlerDescriptor
+ callable: Any
+ owner: Any
+ plugin_id: str = ""
+ local_filters: list[Any] = field(default_factory=list)
+ limiter: LimiterMeta | None = None
+ conversation: ConversationMeta | None = None
+
+
+@dataclass(slots=True)
+class LoadedCapability:
+ descriptor: CapabilityDescriptor
+ callable: Any
+ owner: Any
+ plugin_id: str = ""
+
+
+@dataclass(slots=True)
+class LoadedLLMTool:
+ spec: LLMToolSpec
+ callable: Any
+ owner: Any
+ plugin_id: str = ""
+
+
+@dataclass(slots=True)
+class LoadedAgent:
+ spec: AgentSpec
+ runner_class: type[Any]
+ owner: Any | None = None
+ plugin_id: str = ""
+
+
+@dataclass(slots=True)
+class LoadedPlugin:
+ plugin: PluginSpec
+ handlers: list[LoadedHandler]
+ capabilities: list[LoadedCapability] = field(default_factory=list)
+ llm_tools: list[LoadedLLMTool] = field(default_factory=list)
+ agents: list[LoadedAgent] = field(default_factory=list)
+ instances: list[Any] = field(default_factory=list)
+
+
+@dataclass(slots=True)
+class _ResolvedComponent:
+ cls: type[Any]
+ class_path: str
+ index: int
+
+
+@dataclass(slots=True)
+class _PluginImportNamespace:
+ plugin_id: str
+ plugin_dir: Path
+ package_name: str
+
+
+@dataclass(slots=True)
+class _ParamTypeInfo:
+ type_name: ParamTypeName
+ inner_type: OptionalInnerType
+ required: bool
+
+
+class _PluginScopedAliasLoader(importlib.abc.Loader):
+ def __init__(self, *, alias_name: str, target_name: str) -> None:
+ self.alias_name = alias_name
+ self.target_name = target_name
+
+ def create_module(self, spec: ModuleSpec) -> types.ModuleType:
+ del spec
+ module = sys.modules.get(self.target_name)
+ if not isinstance(module, types.ModuleType):
+ module = import_module(self.target_name)
+ _record_plugin_import_alias(self.alias_name)
+ return module
+
+ def exec_module(self, module: types.ModuleType) -> None:
+ del module
+
+
+class _PluginScopedMetaPathFinder(importlib.abc.MetaPathFinder):
+ def find_spec(
+ self,
+ fullname: str,
+ path: Sequence[str] | None = None,
+ target: types.ModuleType | None = None,
+ /,
+ ) -> ModuleSpec | None:
+ del path, target
+ namespace = _plugin_import_namespace_for_current_caller()
+ if namespace is None:
+ return None
+ rewritten_name = _rewrite_plugin_import_name(namespace, fullname)
+ if rewritten_name is None:
+ return None
+ parent_name, _, _ = rewritten_name.rpartition(".")
+ parent_search_path = None
+ if parent_name:
+ parent_module = sys.modules.get(parent_name)
+ if not isinstance(parent_module, types.ModuleType):
+ parent_module = import_module(parent_name)
+ parent_search_path = getattr(parent_module, "__path__", None)
+ target_spec = PathFinder.find_spec(
+ rewritten_name,
+ parent_search_path,
+ )
+ if target_spec is None:
+ return None
+ alias_spec = ModuleSpec(
+ fullname,
+ _PluginScopedAliasLoader(
+ alias_name=fullname,
+ target_name=rewritten_name,
+ ),
+ is_package=target_spec.submodule_search_locations is not None,
+ )
+ alias_spec.origin = target_spec.origin
+ alias_spec.cached = target_spec.cached
+ alias_spec.has_location = target_spec.has_location
+ if target_spec.submodule_search_locations is not None:
+ alias_spec.submodule_search_locations = list(
+ target_spec.submodule_search_locations
+ )
+ return alias_spec
+
+
+def _sanitize_package_component(plugin_id: str) -> str:
+ sanitized = re.sub(r"[^A-Za-z0-9_]+", "_", plugin_id).strip("_")
+ return sanitized or "plugin"
+
+
+def _plugin_package_name(plugin_id: str) -> str:
+ digest = hashlib.sha256(plugin_id.encode("utf-8")).hexdigest()[:8]
+ return f"{_PLUGIN_PACKAGE_PREFIX}{_sanitize_package_component(plugin_id)}_{digest}"
+
+
+def _plugin_module_name(package_name: str, module_name: str) -> str:
+ normalized = module_name.strip()
+ return f"{package_name}.{normalized}" if normalized else package_name
+
+
+def _iter_handler_names(instance: Any) -> list[str]:
+ handler_names = getattr(instance.__class__, "__handlers__", ())
+ if handler_names:
+ return list(handler_names)
+ return list(dir(instance))
+
+
+def _iter_discoverable_names(instance: Any) -> list[str]:
+ handler_names = list(dict.fromkeys(_iter_handler_names(instance)))
+ known_names = set(handler_names)
+ extra_names = sorted(name for name in dir(instance) if name not in known_names)
+ return [*handler_names, *extra_names]
+
+
+def _validate_loaded_capability_namespace(
+ plugin: PluginSpec,
+ *,
+ resolved_component: _ResolvedComponent,
+ attribute_name: str,
+ capability_name: str,
+) -> None:
+ if capability_belongs_to_plugin(capability_name, plugin.name):
+ return
+ expected_prefix = plugin_capability_prefix(plugin.name)
+ raise ValueError(
+ f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} "
+ f"方法 {attribute_name!r} 导出的 capability {capability_name!r} 必须使用当前插件名前缀 "
+ f"{expected_prefix!r},例如 {expected_prefix}"
+ )
+
+
+def _register_loaded_capability_name(
+ seen_capability_sources: dict[str, str],
+ *,
+ capability_name: str,
+ source_ref: str,
+) -> None:
+ existing_source = seen_capability_sources.get(capability_name)
+ if existing_source is not None:
+ raise ValueError(
+ f"capability {capability_name!r} 重复定义:{existing_source} 与 {source_ref}"
+ )
+ seen_capability_sources[capability_name] = source_ref
+
+
+def _is_injected_parameter(annotation: Any, parameter_name: str) -> bool:
+ return is_framework_injected_parameter(parameter_name, annotation)
+
+
+def _param_type_name(annotation: Any) -> _ParamTypeInfo:
+ normalized, is_optional = unwrap_optional(annotation)
+ if normalized is GreedyStr:
+ return _ParamTypeInfo("greedy_str", None, False)
+ if normalized in {int, float, bool, str}:
+ normalized_name = cast(
+ Literal["str", "int", "float", "bool"], normalized.__name__
+ )
+ if is_optional:
+ return _ParamTypeInfo("optional", normalized_name, False)
+ return _ParamTypeInfo(normalized_name, None, True)
+ if is_optional:
+ return _ParamTypeInfo("optional", "str", False)
+ return _ParamTypeInfo("str", None, True)
+
+
+def _build_param_specs(handler: Any) -> list[ParamSpec]:
+ model_param = resolve_command_model_param(handler)
+ if model_param is not None:
+ return []
+ try:
+ signature = inspect.signature(handler)
+ except (TypeError, ValueError):
+ return []
+ try:
+ type_hints = typing.get_type_hints(handler)
+ except Exception as exc:
+ logger.warning(
+ "Failed to resolve type hints for handler {}: {}",
+ getattr(handler, "__qualname__", repr(handler)),
+ exc,
+ )
+ type_hints = {}
+
+ specs: list[ParamSpec] = []
+ for parameter in signature.parameters.values():
+ if parameter.kind not in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ continue
+ annotation = type_hints.get(parameter.name)
+ if _is_injected_parameter(annotation, parameter.name):
+ continue
+ type_info = _param_type_name(annotation)
+ required = type_info.required
+ if parameter.default is not inspect.Parameter.empty:
+ required = False
+ specs.append(
+ ParamSpec(
+ name=parameter.name,
+ type=type_info.type_name,
+ required=required,
+ inner_type=type_info.inner_type,
+ )
+ )
+
+ greedy_indexes = [
+ index for index, spec in enumerate(specs) if spec.type == "greedy_str"
+ ]
+ if greedy_indexes and greedy_indexes[-1] != len(specs) - 1:
+ greedy_spec = specs[greedy_indexes[-1]]
+ raise ValueError(f"参数 '{greedy_spec.name}' (GreedyStr) 必须是最后一个参数。")
+ return specs
+
+
+def _validate_schedule_signature(handler: Any) -> None:
+ try:
+ signature = inspect.signature(handler)
+ except (TypeError, ValueError):
+ return
+ allowed_names = {"ctx", "context", "sched", "schedule"}
+ invalid = [
+ parameter.name
+ for parameter in signature.parameters.values()
+ if parameter.kind
+ in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ and parameter.name not in allowed_names
+ ]
+ if invalid:
+ raise ValueError(
+ "Schedule handler 只允许注入 ctx/context 和 sched/schedule 参数。"
+ )
+
+
+def _plugin_context(plugin: PluginSpec) -> str:
+ return f"插件 '{plugin.name}'({plugin.manifest_path})"
+
+
+def _component_context(plugin: PluginSpec, *, class_path: str, index: int) -> str:
+ return f"{_plugin_context(plugin)} 的 components[{index}].class='{class_path}'"
+
+
+def _resolve_candidate(
+ instance: Any,
+ name: str,
+ meta_getter: typing.Callable[[Any], Any | None],
+ *,
+ predicate: typing.Callable[[Any], bool] | None = None,
+) -> tuple[Any, Any] | None:
+ try:
+ raw = inspect.getattr_static(instance, name)
+ except AttributeError:
+ return None
+
+ candidates = [raw]
+ wrapped = getattr(raw, "__func__", None)
+ if wrapped is not None:
+ candidates.append(wrapped)
+
+ for candidate in candidates:
+ meta = meta_getter(candidate)
+ if meta is None:
+ continue
+ if predicate is not None and not predicate(meta):
+ continue
+ try:
+ return getattr(instance, name), meta
+ except AttributeError:
+ return None
+ return None
+
+
+def _resolve_handler_candidate(instance: Any, name: str) -> tuple[Any, Any] | None:
+ """Resolve handler candidates without triggering unrelated descriptor side effects."""
+ return _resolve_candidate(
+ instance,
+ name,
+ get_handler_meta,
+ predicate=lambda meta: meta.trigger is not None,
+ )
+
+
+def _resolve_capability_candidate(instance: Any, name: str) -> tuple[Any, Any] | None:
+ return _resolve_candidate(instance, name, get_capability_meta)
+
+
+def _resolve_llm_tool_candidate(instance: Any, name: str) -> tuple[Any, Any] | None:
+ return _resolve_candidate(instance, name, get_llm_tool_meta)
+
+
+def _iter_agent_candidates(component_cls: type[Any]) -> list[tuple[type[Any], Any]]:
+ module = import_module(component_cls.__module__)
+ seen: set[str] = set()
+ resolved: list[tuple[type[Any], Any]] = []
+
+ def _collect(candidate: Any) -> None:
+ if not inspect.isclass(candidate):
+ return
+ meta = get_agent_meta(candidate)
+ if meta is None:
+ return
+ key = f"{candidate.__module__}.{candidate.__qualname__}"
+ if key in seen:
+ return
+ seen.add(key)
+ resolved.append((candidate, meta))
+
+ for candidate in vars(module).values():
+ _collect(candidate)
+ for candidate in vars(component_cls).values():
+ _collect(candidate)
+ return resolved
+
+
+def _read_yaml(path: Path) -> dict[str, Any]:
+ data = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
+ return data if isinstance(data, dict) else {}
+
+
+def _read_requirements_text(path: Path) -> str:
+ if not path.exists():
+ return ""
+ return path.read_text(encoding="utf-8")
+
+
+def _plugin_config_dir(plugin_dir: Path) -> Path:
+ if plugin_dir.parent.name == "plugins" and plugin_dir.parent.parent.exists():
+ return plugin_dir.parent.parent / "config"
+ return plugin_dir / "data" / "config"
+
+
+def _plugin_config_path(plugin_dir: Path, plugin_name: str) -> Path:
+ return _plugin_config_dir(plugin_dir) / f"{plugin_name}_config.json"
+
+
+def _read_json_object(
+ path: Path,
+ *,
+ parse_error_message: str,
+ read_error_message: str,
+ non_object_message: str | None = None,
+) -> dict[str, Any]:
+ try:
+ payload = json.loads(path.read_text(encoding="utf-8"))
+ except json.JSONDecodeError as exc:
+ logger.warning(parse_error_message, path, exc)
+ return {}
+ except OSError as exc:
+ logger.warning(read_error_message, path, exc)
+ return {}
+ if isinstance(payload, dict):
+ return payload
+ if non_object_message is not None:
+ logger.warning(non_object_message, path, type(payload).__name__)
+ return {}
+
+
+def _schema_default(field_schema: dict[str, Any]) -> Any:
+ if "default" in field_schema:
+ return copy.deepcopy(field_schema["default"])
+
+ field_type = str(field_schema.get("type") or "string")
+ if field_type == "object":
+ items = field_schema.get("items")
+ if isinstance(items, dict):
+ return {
+ key: _normalize_config_value(child_schema, None)
+ for key, child_schema in items.items()
+ if isinstance(child_schema, dict)
+ }
+ return {}
+ if field_type in {"list", "template_list", "file"}:
+ return []
+ if field_type == "dict":
+ return {}
+ if field_type == "int":
+ return 0
+ if field_type == "float":
+ return 0.0
+ if field_type == "bool":
+ return False
+ return ""
+
+
+def _normalize_config_value(field_schema: dict[str, Any], value: Any) -> Any:
+ field_type = str(field_schema.get("type") or "string")
+ default_value = _schema_default(field_schema)
+
+ if field_type == "object":
+ items = field_schema.get("items")
+ if not isinstance(items, dict):
+ return default_value
+ current = value if isinstance(value, dict) else {}
+ return {
+ key: _normalize_config_value(child_schema, current.get(key))
+ for key, child_schema in items.items()
+ if isinstance(child_schema, dict)
+ }
+ if field_type in {"list", "template_list", "file"}:
+ return copy.deepcopy(value) if isinstance(value, list) else default_value
+ if field_type == "dict":
+ return copy.deepcopy(value) if isinstance(value, dict) else default_value
+ if field_type == "int":
+ return (
+ value
+ if isinstance(value, int) and not isinstance(value, bool)
+ else default_value
+ )
+ if field_type == "float":
+ return (
+ value
+ if isinstance(value, (int, float)) and not isinstance(value, bool)
+ else default_value
+ )
+ if field_type == "bool":
+ return value if isinstance(value, bool) else default_value
+ if field_type in {"string", "text"}:
+ return value if isinstance(value, str) else default_value
+ return copy.deepcopy(value) if value is not None else default_value
+
+
+def load_plugin_config_schema(plugin: PluginSpec) -> dict[str, Any]:
+ """加载插件配置 schema,解析失败时记录日志并返回空对象。"""
+ schema_path = plugin.plugin_dir / CONFIG_SCHEMA_FILE
+ if not schema_path.exists():
+ return {}
+ return _read_json_object(
+ schema_path,
+ parse_error_message="Failed to parse SDK plugin config schema {}: {}",
+ read_error_message="Failed to read SDK plugin config schema {}: {}",
+ non_object_message="SDK plugin config schema {} must be a JSON object, got {}",
+ )
+
+
+def save_plugin_config(
+ plugin: PluginSpec,
+ payload: dict[str, Any],
+ *,
+ schema: dict[str, Any] | None = None,
+) -> dict[str, Any]:
+ """按 schema 归一化并写回插件配置。"""
+ active_schema = (
+ load_plugin_config_schema(plugin) if schema is None else dict(schema)
+ )
+ normalized = {
+ key: _normalize_config_value(field_schema, payload.get(key))
+ for key, field_schema in active_schema.items()
+ if isinstance(field_schema, dict)
+ }
+
+ config_path = _plugin_config_path(plugin.plugin_dir, plugin.name)
+ config_path.parent.mkdir(parents=True, exist_ok=True)
+ config_path.write_text(
+ json.dumps(normalized, ensure_ascii=False, indent=2),
+ encoding="utf-8",
+ )
+ return normalized
+
+
+def load_plugin_config(
+ plugin: PluginSpec,
+ *,
+ schema: dict[str, Any] | None = None,
+) -> dict[str, Any]:
+ """加载插件配置,返回普通字典。"""
+ active_schema = (
+ load_plugin_config_schema(plugin) if schema is None else dict(schema)
+ )
+ if not active_schema:
+ return {}
+
+ config_path = _plugin_config_path(plugin.plugin_dir, plugin.name)
+ existing = (
+ _read_json_object(
+ config_path,
+ parse_error_message="Failed to parse SDK plugin config {}: {}",
+ read_error_message="Failed to read SDK plugin config {}: {}",
+ )
+ if config_path.exists()
+ else {}
+ )
+ normalized = {
+ key: _normalize_config_value(field_schema, existing.get(key))
+ for key, field_schema in active_schema.items()
+ if isinstance(field_schema, dict)
+ }
+
+ if not config_path.exists() or normalized != existing:
+ save_plugin_config(plugin, normalized, schema=active_schema)
+ return normalized
+
+
+def _is_new_star_component(cls: type[Any]) -> bool:
+ """检查组件类是否为 astrbot-sdk 新版 Star。"""
+ return bool(getattr(cls, "__astrbot_is_new_star__", False))
+
+
+def _plugin_component_classes(plugin: PluginSpec) -> list[_ResolvedComponent]:
+ """解析插件组件类列表。"""
+ components = plugin.manifest_data.get("components") or []
+ if not isinstance(components, list):
+ return []
+
+ classes: list[_ResolvedComponent] = []
+ for index, component in enumerate(components):
+ if not isinstance(component, dict):
+ raise ValueError(
+ f"{_plugin_context(plugin)} 的 components[{index}] 必须是 object。"
+ )
+ class_path = component.get("class")
+ if not isinstance(class_path, str) or ":" not in class_path:
+ raise ValueError(
+ f"{_plugin_context(plugin)} 的 components[{index}].class "
+ "必须是 ':'。"
+ )
+ try:
+ cls = _import_plugin_string(class_path, plugin)
+ except Exception as exc:
+ raise ValueError(
+ f"{_component_context(plugin, class_path=class_path, index=index)} "
+ f"加载失败:{exc}"
+ ) from exc
+ if not isinstance(cls, type):
+ raise ValueError(
+ f"{_component_context(plugin, class_path=class_path, index=index)} "
+ "解析结果不是类,请检查导出名称。"
+ )
+ classes.append(
+ _ResolvedComponent(
+ cls=cls,
+ class_path=class_path,
+ index=index,
+ )
+ )
+ if not classes:
+ raise ValueError(
+ f"{_plugin_context(plugin)} 未声明任何可加载组件。"
+ "请检查 plugin.yaml 中的 components 配置。"
+ )
+ return classes
+
+
+def load_plugin_spec(plugin_dir: Path) -> PluginSpec:
+ """从插件目录加载插件规范。"""
+ plugin_dir = plugin_dir.resolve()
+ manifest_path = plugin_dir / PLUGIN_MANIFEST_FILE
+ requirements_path = plugin_dir / "requirements.txt"
+
+ if not manifest_path.exists():
+ raise ValueError(f"插件目录 '{plugin_dir}' 缺少 {PLUGIN_MANIFEST_FILE}。")
+
+ manifest_data = _read_yaml(manifest_path)
+ runtime = manifest_data.get("runtime") or {}
+ python_version = runtime.get("python") or _default_python_version()
+
+ return PluginSpec(
+ name=str(manifest_data.get("name") or plugin_dir.name),
+ plugin_dir=plugin_dir,
+ manifest_path=manifest_path,
+ requirements_path=requirements_path,
+ python_version=str(python_version),
+ manifest_data=manifest_data,
+ )
+
+
+def validate_plugin_spec(plugin: PluginSpec) -> None:
+ """校验单个插件规范,供 CLI 和发现流程复用。"""
+ manifest_data = plugin.manifest_data
+ manifest_label = f"插件 '{plugin.name}'({plugin.manifest_path})"
+
+ raw_name = manifest_data.get("name")
+ if not isinstance(raw_name, str) or not raw_name:
+ raise ValueError(f"{manifest_label} 缺少 name。")
+ try:
+ validate_plugin_id(raw_name)
+ except ValueError as exc:
+ raise ValueError(f"{manifest_label} 的 name 不合法:{exc}") from exc
+
+ raw_runtime = manifest_data.get("runtime") or {}
+ raw_python = raw_runtime.get("python")
+ if not isinstance(raw_python, str) or not raw_python:
+ raise ValueError(f"{manifest_label} 缺少 runtime.python。")
+
+ raw_author = manifest_data.get("author")
+ if not isinstance(raw_author, str) or not raw_author.strip():
+ raise ValueError(f"{manifest_label} 缺少 author。")
+
+ raw_repo = manifest_data.get("repo")
+ if not isinstance(raw_repo, str) or not raw_repo.strip():
+ raise ValueError(f"{manifest_label} 缺少 repo。")
+ if not _is_valid_github_repo_ref(raw_repo):
+ raise ValueError(
+ f"{manifest_label} 的 repo 不合法:"
+ "请填写 GitHub 仓库名(repo)、owner/repo,或 https://github.com/owner/repo。"
+ )
+
+ components = manifest_data.get("components")
+ if not isinstance(components, list):
+ raise ValueError(f"{manifest_label} 的 components 必须是数组。")
+
+ for index, component in enumerate(components):
+ if not isinstance(component, dict):
+ raise ValueError(f"{manifest_label} 的 components[{index}] 必须是 object。")
+ class_path = component.get("class")
+ if not isinstance(class_path, str) or ":" not in class_path:
+ raise ValueError(
+ f"{manifest_label} 的 components[{index}].class "
+ "必须是 ':'。"
+ )
+
+
+# TODO: 不能保证插件和命令冲突消失,真有那么一天我们sdk小团体也是好起来了
+def discover_plugins(plugins_dir: Path) -> PluginDiscoveryResult:
+ """扫描目录发现所有插件。"""
+ plugins_root = plugins_dir.resolve()
+ skipped_plugins: dict[str, str] = {}
+ issues: list[PluginDiscoveryIssue] = []
+ plugins: list[PluginSpec] = []
+ # TODO: 改用 dict 记录 name -> plugin_dir 映射,以便在重复时报错时显示冲突路径
+ seen_name_sources: dict[str, Path] = {} # plugin_name -> plugin_dir
+
+ if not plugins_root.exists():
+ return PluginDiscoveryResult([], {}, [])
+
+ for entry in sorted(plugins_root.iterdir()):
+ if not entry.is_dir() or entry.name.startswith("."):
+ continue
+ manifest_path = entry / PLUGIN_MANIFEST_FILE
+ if not manifest_path.exists():
+ continue
+
+ plugin: PluginSpec | None = None
+ try:
+ plugin = load_plugin_spec(entry)
+ validate_plugin_spec(plugin)
+ except Exception as exc:
+ skip_key = entry.name
+ if plugin is not None:
+ raw_name = plugin.manifest_data.get("name")
+ if isinstance(raw_name, str) and raw_name:
+ skip_key = raw_name
+ details = str(exc)
+ skipped_plugins[skip_key] = f"failed to parse plugin manifest: {details}"
+ issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="discovery",
+ plugin_id=skip_key,
+ message="插件发现失败",
+ details=details,
+ )
+ )
+ continue
+
+ plugin_name = plugin.name
+ if not isinstance(plugin_name, str) or not plugin_name:
+ skipped_plugins[entry.name] = "plugin name is required"
+ issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="discovery",
+ plugin_id=entry.name,
+ message="插件缺少名称",
+ details="plugin name is required",
+ )
+ )
+ continue
+ if plugin_name in seen_name_sources:
+ existing_source = seen_name_sources.get(plugin_name, Path(""))
+ skipped_plugins[plugin_name] = "duplicate plugin name"
+ issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="discovery",
+ plugin_id=plugin_name,
+ message="插件名称重复",
+ details=f"冲突的插件目录:{existing_source} 与 {plugin.plugin_dir}",
+ hint="请修改其中一个插件的名称后重试",
+ )
+ )
+ continue
+ seen_name_sources[plugin_name] = plugin.plugin_dir
+ plugins.append(plugin)
+
+ return PluginDiscoveryResult(
+ plugins=plugins,
+ skipped_plugins=skipped_plugins,
+ issues=issues,
+ )
+
+
+class PluginEnvironmentManager:
+ """运行时访问分组环境管理的门面层。
+
+ 运行时仍然保留历史上的 `prepare_environment(plugin)` 调用入口,但底层
+ 实现已经变成两阶段模型:
+
+ 1. `plan()` 负责解析跨插件分组和共享工件
+ 2. `prepare_environment()` 负责把单个插件映射到它所属的分组环境
+ """
+
+ def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None:
+ self.repo_root = repo_root.resolve()
+ self.uv_binary = uv_binary
+ self.cache_dir = self.repo_root / ".uv-cache"
+ self._planner = EnvironmentPlanner(self.repo_root, uv_binary=uv_binary)
+ self._group_manager = GroupEnvironmentManager(
+ self.repo_root, uv_binary=uv_binary
+ )
+ self.uv_binary = self._planner.uv_binary
+ self._plan_result: EnvironmentPlanResult | None = None
+
+ def plan(self, plugins: list[PluginSpec]) -> EnvironmentPlanResult:
+ """为当前插件集合生成共享环境规划。"""
+ plan_result = self._planner.plan(plugins)
+ self._plan_result = plan_result
+ return plan_result
+
+ def prepare_group_environment(self, group: EnvironmentGroup) -> Path:
+ """返回指定分组的解释器路径。"""
+ if self._plan_result is None:
+ self._plan_result = EnvironmentPlanResult(groups=[group])
+ return self._group_manager.prepare(group)
+
+ def prepare_environment(self, plugin: PluginSpec) -> Path:
+ """返回该插件所属分组环境的解释器路径。
+
+ 如果调用方还没有先对整批插件做规划,这里会自动创建一个至少包含当
+ 前插件的最小规划,以保证旧的"单插件直接调用"模式仍然可用。
+ """
+ if (
+ self._plan_result is None
+ or plugin.name not in self._plan_result.plugin_to_group
+ ):
+ planned_plugins = (
+ list(self._plan_result.plugins) if self._plan_result else []
+ )
+ if plugin.name not in {item.name for item in planned_plugins}:
+ planned_plugins.append(plugin)
+ self.plan(planned_plugins)
+
+ assert self._plan_result is not None
+ group = self._plan_result.plugin_to_group.get(plugin.name)
+ if group is None:
+ reason = self._plan_result.skipped_plugins.get(plugin.name)
+ if reason is not None:
+ raise RuntimeError(reason)
+ raise RuntimeError(f"environment plan missing plugin: {plugin.name}")
+
+ return self.prepare_group_environment(group)
+
+ @staticmethod
+ def _fingerprint(plugin: PluginSpec) -> str:
+ requirements = _read_requirements_text(plugin.requirements_path)
+ payload = {
+ "python_version": plugin.python_version,
+ "requirements": requirements,
+ }
+ return json.dumps(payload, ensure_ascii=True, sort_keys=True)
+
+ @staticmethod
+ def _load_state(state_path: Path) -> dict[str, Any]:
+ if not state_path.exists():
+ return {}
+ return _read_json_object(
+ state_path,
+ parse_error_message="Failed to parse plugin worker state {}: {}",
+ read_error_message="Failed to read plugin worker state {}: {}",
+ )
+
+ @staticmethod
+ def _write_state(state_path: Path, plugin: PluginSpec, fingerprint: str) -> None:
+ state_path.write_text(
+ json.dumps(
+ {
+ "plugin": plugin.name,
+ "python_version": plugin.python_version,
+ "fingerprint": fingerprint,
+ },
+ ensure_ascii=True,
+ indent=2,
+ sort_keys=True,
+ ),
+ encoding="utf-8",
+ )
+
+ @staticmethod
+ def _matches_python_version(venv_dir: Path, version: str) -> bool:
+ pyvenv_cfg = venv_dir / "pyvenv.cfg"
+ if not pyvenv_cfg.exists():
+ return False
+ try:
+ content = pyvenv_cfg.read_text(encoding="utf-8")
+ except OSError:
+ return False
+ match = re.search(r"version\s*=\s*(\d+\.\d+)\.\d+", content, re.IGNORECASE)
+ return match is not None and match.group(1) == version
+
+
+def _copy_meta(meta: _TMeta | None) -> _TMeta | None:
+ if meta is None:
+ return None
+ # Use dataclass-level cloning so metadata schema changes do not silently
+ # drift away from the loader's copy helpers.
+ return replace(meta)
+
+
+def _validate_handler_kind(
+ plugin: PluginSpec,
+ *,
+ resolved_component: _ResolvedComponent,
+ attribute_name: str,
+ kind: str,
+) -> HandlerKind:
+ if kind in _VALID_HANDLER_KINDS:
+ return cast(HandlerKind, kind)
+ raise ValueError(
+ f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} "
+ f"方法 {attribute_name!r} 的 handler kind {kind!r} 不合法;"
+ f"允许的值为 {', '.join(_VALID_HANDLER_KINDS)}。"
+ )
+
+
+def _load_component_instance(
+ plugin: PluginSpec,
+ resolved_component: _ResolvedComponent,
+) -> Any:
+ component_cls = resolved_component.cls
+ if not _is_new_star_component(component_cls):
+ raise ValueError(
+ f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} "
+ f"解析到的类 {component_cls.__module__}.{component_cls.__qualname__} "
+ "不是 astrbot-sdk Star 组件。请继承 astrbot_sdk.Star。"
+ )
+ try:
+ instance = component_cls()
+ except Exception as exc:
+ raise ValueError(
+ f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} "
+ f"实例化失败:{exc}"
+ ) from exc
+ logger.debug(
+ "Instantiated SDK plugin component {} for plugin {}",
+ resolved_component.class_path,
+ plugin.name,
+ )
+ return instance
+
+
+def _collect_component_agents(
+ plugin: PluginSpec,
+ component_cls: type[Any],
+ *,
+ seen_agents: set[str],
+) -> list[LoadedAgent]:
+ agents: list[LoadedAgent] = []
+ for runner_class, meta in _iter_agent_candidates(component_cls):
+ runner_key = f"{runner_class.__module__}.{runner_class.__qualname__}"
+ if runner_key in seen_agents:
+ continue
+ seen_agents.add(runner_key)
+ agents.append(
+ LoadedAgent(
+ spec=meta.spec.model_copy(deep=True),
+ runner_class=runner_class,
+ owner=None,
+ plugin_id=plugin.name,
+ )
+ )
+ return agents
+
+
+def _build_loaded_handler(
+ plugin: PluginSpec,
+ *,
+ resolved_component: _ResolvedComponent,
+ instance: Any,
+ attribute_name: str,
+ bound: Any,
+ meta: Any,
+) -> LoadedHandler:
+ handler_kind = _validate_handler_kind(
+ plugin,
+ resolved_component=resolved_component,
+ attribute_name=attribute_name,
+ kind=meta.kind,
+ )
+ handler_id = (
+ f"{plugin.name}:{instance.__class__.__module__}.{instance.__class__.__name__}."
+ f"{attribute_name}"
+ )
+ if isinstance(meta.trigger, ScheduleTrigger):
+ _validate_schedule_signature(bound)
+ param_specs = _build_param_specs(bound)
+ return LoadedHandler(
+ descriptor=HandlerDescriptor(
+ id=handler_id,
+ trigger=meta.trigger,
+ kind=handler_kind,
+ contract=meta.contract,
+ description=meta.description,
+ priority=meta.priority,
+ permissions=meta.permissions.model_copy(deep=True),
+ filters=[item.model_copy(deep=True) for item in meta.filters],
+ param_specs=[item.model_copy(deep=True) for item in param_specs],
+ command_route=(
+ meta.command_route.model_copy(deep=True)
+ if meta.command_route is not None
+ else None
+ ),
+ ),
+ callable=bound,
+ owner=instance,
+ plugin_id=plugin.name,
+ local_filters=list(meta.local_filters),
+ limiter=_copy_meta(meta.limiter),
+ conversation=_copy_meta(meta.conversation),
+ )
+
+
+def _collect_component_members(
+ plugin: PluginSpec,
+ *,
+ resolved_component: _ResolvedComponent,
+ instance: Any,
+ seen_capability_sources: dict[str, str],
+) -> tuple[list[LoadedHandler], list[LoadedCapability], list[LoadedLLMTool]]:
+ handlers: list[LoadedHandler] = []
+ capabilities: list[LoadedCapability] = []
+ llm_tools: list[LoadedLLMTool] = []
+
+ for name in _iter_discoverable_names(instance):
+ resolved = _resolve_handler_candidate(instance, name)
+ capability = _resolve_capability_candidate(instance, name)
+ llm_tool = _resolve_llm_tool_candidate(instance, name)
+ if resolved is None and capability is None and llm_tool is None:
+ continue
+ if capability is not None:
+ bound_capability, capability_meta = capability
+ capability_name = capability_meta.descriptor.name
+ _validate_loaded_capability_namespace(
+ plugin,
+ resolved_component=resolved_component,
+ attribute_name=name,
+ capability_name=capability_name,
+ )
+ _register_loaded_capability_name(
+ seen_capability_sources,
+ capability_name=capability_name,
+ source_ref=f"{resolved_component.class_path}.{name}",
+ )
+ capabilities.append(
+ LoadedCapability(
+ descriptor=capability_meta.descriptor.model_copy(deep=True),
+ callable=bound_capability,
+ owner=instance,
+ plugin_id=plugin.name,
+ )
+ )
+ if llm_tool is not None:
+ bound_tool, tool_meta = llm_tool
+ llm_tools.append(
+ LoadedLLMTool(
+ spec=tool_meta.spec.model_copy(deep=True),
+ callable=bound_tool,
+ owner=instance,
+ plugin_id=plugin.name,
+ )
+ )
+ if resolved is not None:
+ bound_handler, handler_meta = resolved
+ handlers.append(
+ _build_loaded_handler(
+ plugin,
+ resolved_component=resolved_component,
+ instance=instance,
+ attribute_name=name,
+ bound=bound_handler,
+ meta=handler_meta,
+ )
+ )
+ return handlers, capabilities, llm_tools
+
+
+def load_plugin(plugin: PluginSpec) -> LoadedPlugin:
+ """加载插件,返回处理器和能力列表。
+
+ 仅支持 astrbot-sdk 新版 Star 组件(无参构造函数)。
+ """
+ with _PLUGIN_IMPORT_LOCK:
+ logger.debug("Loading SDK plugin {} from {}", plugin.name, plugin.plugin_dir)
+ _ensure_plugin_import_hook_installed()
+ namespace = _register_plugin_import_namespace(plugin)
+ _purge_plugin_bytecode(plugin.plugin_dir)
+ _purge_plugin_package(namespace.package_name)
+ _purge_plugin_modules(plugin.plugin_dir)
+ _prepare_plugin_import(plugin.plugin_dir)
+ _ensure_plugin_package(namespace)
+ importlib.invalidate_caches()
+
+ instances: list[Any] = []
+ handlers: list[LoadedHandler] = []
+ capabilities: list[LoadedCapability] = []
+ llm_tools: list[LoadedLLMTool] = []
+ agents: list[LoadedAgent] = []
+ seen_agents: set[str] = set()
+ seen_capability_sources: dict[str, str] = {}
+ with caller_plugin_scope(plugin.name):
+ resolved_components = _plugin_component_classes(plugin)
+
+ for resolved_component in resolved_components:
+ instance = _load_component_instance(plugin, resolved_component)
+ instances.append(instance)
+ agents.extend(
+ _collect_component_agents(
+ plugin,
+ resolved_component.cls,
+ seen_agents=seen_agents,
+ )
+ )
+ component_handlers, component_capabilities, component_tools = (
+ _collect_component_members(
+ plugin,
+ resolved_component=resolved_component,
+ instance=instance,
+ seen_capability_sources=seen_capability_sources,
+ )
+ )
+ handlers.extend(component_handlers)
+ capabilities.extend(component_capabilities)
+ llm_tools.extend(component_tools)
+
+ logger.debug(
+ "Loaded SDK plugin {}: {} components, {} handlers, {} capabilities, {} llm tools, {} agents",
+ plugin.name,
+ len(resolved_components),
+ len(handlers),
+ len(capabilities),
+ len(llm_tools),
+ len(agents),
+ )
+ return LoadedPlugin(
+ plugin=plugin,
+ handlers=handlers,
+ capabilities=capabilities,
+ llm_tools=llm_tools,
+ agents=agents,
+ instances=instances,
+ )
+
+
+def _path_within_root(path: Path, root: Path) -> bool:
+ try:
+ path.resolve().relative_to(root.resolve())
+ except ValueError:
+ return False
+ return True
+
+
+def _plugin_defines_module_root(plugin_dir: Path, root_name: str) -> bool:
+ return (plugin_dir / f"{root_name}.py").exists() or (
+ plugin_dir / root_name
+ ).exists()
+
+
+def _register_plugin_import_namespace(plugin: PluginSpec) -> _PluginImportNamespace:
+ existing = _PLUGIN_IMPORT_NAMESPACES.get(plugin.name)
+ package_name = (
+ existing.package_name
+ if existing is not None
+ else _plugin_package_name(plugin.name)
+ )
+ namespace = _PluginImportNamespace(
+ plugin_id=plugin.name,
+ plugin_dir=plugin.plugin_dir.resolve(),
+ package_name=package_name,
+ )
+ _PLUGIN_IMPORT_NAMESPACES[plugin.name] = namespace
+ return namespace
+
+
+def _ensure_plugin_package(namespace: _PluginImportNamespace) -> types.ModuleType:
+ existing = sys.modules.get(namespace.package_name)
+ if isinstance(existing, types.ModuleType):
+ existing.__path__ = [str(namespace.plugin_dir)]
+ existing.__package__ = namespace.package_name
+ return existing
+
+ module = types.ModuleType(namespace.package_name)
+ module.__file__ = str(namespace.plugin_dir)
+ module.__package__ = namespace.package_name
+ module.__path__ = [str(namespace.plugin_dir)]
+ module.__loader__ = None
+ spec = ModuleSpec(
+ namespace.package_name,
+ loader=None,
+ is_package=True,
+ )
+ spec.submodule_search_locations = [str(namespace.plugin_dir)]
+ module.__spec__ = spec
+ sys.modules[namespace.package_name] = module
+ return module
+
+
+def _prepare_plugin_import(plugin_dir: Path) -> None:
+ plugin_path = str(plugin_dir.resolve())
+ sys.path[:] = [entry for entry in sys.path if entry != plugin_path]
+ sys.path.insert(0, plugin_path)
+
+
+def _module_belongs_to_plugin(module: Any, plugin_dir: Path) -> bool:
+ file_path = getattr(module, "__file__", None)
+ if isinstance(file_path, str) and _path_within_root(Path(file_path), plugin_dir):
+ return True
+
+ package_paths = getattr(module, "__path__", None)
+ if package_paths is None:
+ return False
+ return any(
+ isinstance(candidate, str) and _path_within_root(Path(candidate), plugin_dir)
+ for candidate in package_paths
+ )
+
+
+def _purge_plugin_modules(plugin_dir: Path) -> None:
+ plugin_root = plugin_dir.resolve()
+ for module_name, module in list(sys.modules.items()):
+ if module is None:
+ continue
+ if _module_belongs_to_plugin(module, plugin_root):
+ sys.modules.pop(module_name, None)
+
+
+def _purge_plugin_package(package_name: str) -> None:
+ for module_name in list(sys.modules):
+ if module_name == package_name or module_name.startswith(f"{package_name}."):
+ sys.modules.pop(module_name, None)
+
+
+def _purge_plugin_bytecode(plugin_dir: Path) -> None:
+ plugin_root = plugin_dir.resolve()
+ for path in plugin_root.rglob("*"):
+ try:
+ if path.is_dir() and path.name == "__pycache__":
+ shutil.rmtree(path, ignore_errors=True)
+ continue
+ if path.is_file() and path.suffix in {".pyc", ".pyo"}:
+ path.unlink(missing_ok=True)
+ except OSError:
+ continue
+
+
+def _import_plugin_string(path: str, plugin: PluginSpec) -> Any:
+ module_name, attr = path.split(":", 1)
+ namespace = _PLUGIN_IMPORT_NAMESPACES.get(plugin.name)
+ if namespace is None:
+ raise RuntimeError(f"plugin import namespace missing: {plugin.name}")
+ module = import_module(_plugin_module_name(namespace.package_name, module_name))
+ return getattr(module, attr)
+
+
+def _plugin_import_namespace_for_current_caller() -> _PluginImportNamespace | None:
+ plugin_id = current_caller_plugin_id()
+ if not plugin_id:
+ return None
+ return _PLUGIN_IMPORT_NAMESPACES.get(plugin_id)
+
+
+def _rewrite_plugin_import_name(
+ namespace: _PluginImportNamespace,
+ name: str,
+) -> str | None:
+ normalized = name.strip()
+ if not normalized:
+ return None
+ if normalized.startswith(_PLUGIN_PACKAGE_PREFIX):
+ return None
+ root_name = normalized.split(".", 1)[0]
+ if not _plugin_defines_module_root(namespace.plugin_dir, root_name):
+ return None
+ return _plugin_module_name(namespace.package_name, normalized)
+
+
+def _plugin_import_alias_buckets() -> list[set[str]]:
+ buckets = getattr(_PLUGIN_IMPORT_ALIAS_STATE, "buckets", None)
+ if buckets is None:
+ buckets = []
+ _PLUGIN_IMPORT_ALIAS_STATE.buckets = buckets
+ return buckets
+
+
+def _push_plugin_import_alias_bucket() -> set[str]:
+ bucket: set[str] = set()
+ _plugin_import_alias_buckets().append(bucket)
+ return bucket
+
+
+def _pop_plugin_import_alias_bucket(bucket: set[str]) -> set[str]:
+ buckets = _plugin_import_alias_buckets()
+ if buckets and buckets[-1] is bucket:
+ buckets.pop()
+ else:
+ with contextlib.suppress(ValueError):
+ buckets.remove(bucket)
+ return bucket
+
+
+def _record_plugin_import_alias(alias_name: str) -> None:
+ normalized = alias_name.strip()
+ if not normalized or normalized.startswith(_PLUGIN_PACKAGE_PREFIX):
+ return
+ buckets = _plugin_import_alias_buckets()
+ if not buckets:
+ return
+ buckets[-1].add(normalized)
+
+
+def _cleanup_plugin_import_aliases(alias_names: set[str]) -> None:
+ for alias_name in sorted(
+ alias_names, key=lambda item: item.count("."), reverse=True
+ ):
+ sys.modules.pop(alias_name, None)
+
+
+def _plugin_scoped_import(
+ name: str,
+ globals: dict[str, Any] | None = None,
+ locals: dict[str, Any] | None = None,
+ fromlist: tuple[Any, ...] | list[Any] = (),
+ level: int = 0,
+) -> Any:
+ with _PLUGIN_IMPORT_LOCK:
+ alias_bucket = _push_plugin_import_alias_bucket()
+ try:
+ return _ORIGINAL_BUILTIN_IMPORT(name, globals, locals, fromlist, level)
+ finally:
+ _cleanup_plugin_import_aliases(
+ _pop_plugin_import_alias_bucket(alias_bucket)
+ )
+
+
+def _ensure_plugin_import_meta_finder_installed() -> None:
+ global _PLUGIN_IMPORT_META_FINDER
+ if (
+ _PLUGIN_IMPORT_META_FINDER is not None
+ and _PLUGIN_IMPORT_META_FINDER in sys.meta_path
+ ):
+ return
+ finder = _PluginScopedMetaPathFinder()
+ sys.meta_path.insert(0, finder)
+ _PLUGIN_IMPORT_META_FINDER = finder
+
+
+def _ensure_plugin_import_hook_installed() -> None:
+ global _PLUGIN_IMPORT_HOOK_INSTALLED
+ _ensure_plugin_import_meta_finder_installed()
+ # 防御性检查:如果 hook 已在位,只补全标志位,不重复安装
+ if builtins.__import__ is _plugin_scoped_import:
+ _PLUGIN_IMPORT_HOOK_INSTALLED = True
+ return
+ # 标志位声称已安装但实际 builtin 已被外部篡改(如测试框架 monkeypatch),
+ # 需要重置标志位以触发重新安装
+ if (
+ _PLUGIN_IMPORT_HOOK_INSTALLED
+ and builtins.__import__ is not _plugin_scoped_import
+ ):
+ _PLUGIN_IMPORT_HOOK_INSTALLED = False
+ if _PLUGIN_IMPORT_HOOK_INSTALLED:
+ return
+ builtins.__import__ = _plugin_scoped_import
+ _PLUGIN_IMPORT_HOOK_INSTALLED = True
+
+
+def _restore_plugin_import_hook() -> None:
+ """还原 builtin __import__,用于插件卸载或测试 teardown 时清理全局状态。"""
+ global _PLUGIN_IMPORT_HOOK_INSTALLED, _PLUGIN_IMPORT_META_FINDER
+ if builtins.__import__ is _plugin_scoped_import:
+ builtins.__import__ = _ORIGINAL_BUILTIN_IMPORT
+ if _PLUGIN_IMPORT_META_FINDER is not None:
+ with contextlib.suppress(ValueError):
+ sys.meta_path.remove(_PLUGIN_IMPORT_META_FINDER)
+ _PLUGIN_IMPORT_META_FINDER = None
+ _PLUGIN_IMPORT_HOOK_INSTALLED = False
+
+
+def import_string(path: str, plugin_dir: Path | None = None) -> Any:
+ """通过字符串路径导入对象。"""
+ with _PLUGIN_IMPORT_LOCK:
+ module_name, attr = path.split(":", 1)
+ module = import_module(module_name)
+ return getattr(module, attr)
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/peer.py b/astrbot-sdk/src/astrbot_sdk/runtime/peer.py
new file mode 100644
index 0000000000..45594a4a5a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/peer.py
@@ -0,0 +1,921 @@
+"""协议对等端模块。
+
+定义 Peer 类,封装双向传输通道上的消息收发、初始化握手、能力调用、
+流式事件转发与取消处理。这里的 peer 指"通信对端/本端"这一网络协议概念,
+而不是业务上的用户、群聊或会话对象。
+
+核心职责:
+ - 消息序列化/反序列化
+ - 初始化握手协议
+ - 能力调用(同步/流式)
+ - 取消处理
+ - 连接生命周期管理
+消息处理:
+ 入站:
+ ResultMessage -> 唤醒等待的 Future
+ EventMessage -> 投递到流式队列
+ InitializeMessage -> 调用 initialize_handler
+ InvokeMessage -> 创建任务调用 invoke_handler
+ CancelMessage -> 取消对应的任务
+
+ 出站:
+ initialize() -> InitializeMessage
+ invoke() -> InvokeMessage(stream=False)
+ invoke_stream() -> InvokeMessage(stream=True)
+ cancel() -> CancelMessage
+
+使用示例:
+ # 作为客户端发起调用
+ peer = Peer(transport=transport, peer_info=PeerInfo(...))
+ await peer.start()
+ output = await peer.initialize(handlers)
+ result = await peer.invoke("llm.chat", {"prompt": "hello"})
+
+ # 作为服务端处理调用
+ peer.set_invoke_handler(my_handler)
+ await peer.start()
+
+消息处理流程:
+ 入站消息:
+ ResultMessage -> 唤醒等待的 Future
+ EventMessage -> 投递到流式队列
+ InitializeMessage -> 调用 _initialize_handler
+ InvokeMessage -> 创建任务调用 _invoke_handler
+ CancelMessage -> 取消对应的任务
+
+ 出站消息:
+ initialize() -> InitializeMessage
+ invoke() -> InvokeMessage(stream=False)
+ invoke_stream() -> InvokeMessage(stream=True)
+ cancel() -> CancelMessage
+
+取消机制:
+ - CancelToken 用于检查取消状态
+ - 入站任务在收到 CancelMessage 时被取消
+ - 早到取消:在任务执行前检查 cancel_token,避免竞态条件
+
+`Peer` 把 `Transport` 和 s5r 协议消息模型接起来,负责:
+
+- 握手与远端元数据缓存
+- 请求 ID 关联
+- 非流式 / 流式调用分发
+- 取消传播
+- 连接异常时的统一收口
+
+它本身不做业务路由,真正的执行逻辑交给 `CapabilityRouter` 或
+`HandlerDispatcher`。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
+from typing import Any
+
+from .._internal.invocation_context import (
+ caller_plugin_scope,
+ current_caller_plugin_id,
+)
+from .._internal.sdk_logger import logger
+from ..context import CancelToken
+from ..errors import AstrBotError, ErrorCodes
+from ..protocol.codec import JsonProtocolCodec, MsgpackProtocolCodec, ProtocolCodec
+from ..protocol.messages import (
+ CancelMessage,
+ ErrorPayload,
+ EventMessage,
+ InitializeMessage,
+ InitializeOutput,
+ InvokeMessage,
+ PeerInfo,
+ ResultMessage,
+)
+from .capability_router import StreamExecution
+
+InitializeHandler = Callable[[InitializeMessage], Awaitable[InitializeOutput]]
+InvokeHandler = Callable[
+ [InvokeMessage, CancelToken], Awaitable[dict[str, Any] | StreamExecution]
+]
+CancelHandler = Callable[[str], Awaitable[None]]
+
+SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY = "supported_protocol_versions"
+NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY = "negotiated_protocol_version"
+WIRE_CODEC_METADATA_KEY = "wire_codec"
+# 入站消息字节数上限(8 MB)。超过此阈值的协议消息会被直接拒绝,
+# 避免恶意或异常的巨型消息耗尽内存或阻塞解析
+MAX_INBOUND_MESSAGE_BYTES = 8 * 1024 * 1024
+
+
+def _wire_codec_name(codec: ProtocolCodec) -> str:
+ if isinstance(codec, JsonProtocolCodec):
+ return "json"
+ if isinstance(codec, MsgpackProtocolCodec):
+ return "msgpack"
+ return type(codec).__name__
+
+
+def _validate_wire_codec_metadata(
+ metadata: dict[str, Any],
+ *,
+ expected_wire_codec: str,
+) -> None:
+ remote_wire_codec = metadata.get(WIRE_CODEC_METADATA_KEY)
+ if not isinstance(remote_wire_codec, str) or not remote_wire_codec:
+ raise AstrBotError.protocol_error("wire_codec metadata missing")
+ if remote_wire_codec != expected_wire_codec:
+ raise AstrBotError.protocol_error(
+ "wire_codec mismatch: "
+ f"expected {expected_wire_codec}, got {remote_wire_codec}"
+ )
+
+
+def _dedupe_protocol_versions(
+ versions: Sequence[str] | None, *, preferred_version: str
+) -> list[str]:
+ ordered_versions: list[str] = [preferred_version]
+ if versions is not None:
+ ordered_versions.extend(versions)
+ deduped: list[str] = []
+ for version in ordered_versions:
+ if not isinstance(version, str) or not version:
+ continue
+ if version not in deduped:
+ deduped.append(version)
+ return deduped
+
+
+def _parse_protocol_version(version: str) -> tuple[int, int] | None:
+ major, dot, minor = version.partition(".")
+ if not dot or not major.isdigit() or not minor.isdigit():
+ return None
+ return int(major), int(minor)
+
+
+def _select_negotiated_protocol_version(
+ requested_version: str,
+ remote_metadata: dict[str, Any],
+ local_supported_versions: Sequence[str],
+) -> str | None:
+ """从双方支持的版本中选出最佳兼容版本。
+
+ 协商策略:优先精确匹配,否则在同主版本号范围内选双方都支持的最高版本。
+ 排除比请求版本更高的候选,因为远端能提供高于我们请求的版本说明我们本地
+ 尚未实现该版本协议,无法正确处理对应的协议消息。
+ """
+ if requested_version in local_supported_versions:
+ return requested_version
+ requested_key = _parse_protocol_version(requested_version)
+ if requested_key is None:
+ return None
+ remote_supported = remote_metadata.get(SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY)
+ if not isinstance(remote_supported, (list, tuple)):
+ return None
+ local_supported_set = set(local_supported_versions)
+ compatible_versions: list[tuple[tuple[int, int], str]] = []
+ for version in remote_supported:
+ if not isinstance(version, str) or version not in local_supported_set:
+ continue
+ parsed_version = _parse_protocol_version(version)
+ if parsed_version is None:
+ continue
+ if parsed_version[0] != requested_key[0] or parsed_version > requested_key:
+ continue
+ compatible_versions.append((parsed_version, version))
+ if not compatible_versions:
+ return None
+ compatible_versions.sort(reverse=True)
+ return compatible_versions[0][1]
+
+
+class Peer:
+ """表示协议连接中的一个对等端。
+
+ `Peer` 封装一条双向传输通道上的消息收发、初始化握手、能力调用、
+ 流式事件转发与取消处理。这里的 `peer` 指“通信对端/本端”这一网络
+ 协议概念,而不是业务上的用户、群聊或会话对象。
+ """
+
+ def __init__(
+ self,
+ *,
+ transport,
+ peer_info: PeerInfo,
+ protocol_version: str = "1.0",
+ supported_protocol_versions: Sequence[str] | None = None,
+ wire_codec: ProtocolCodec | None = None,
+ ) -> None:
+ """创建一个协议对等端实例。"""
+ self.transport = transport
+ self.peer_info = peer_info
+ self.protocol_version = protocol_version
+ self.wire_codec = wire_codec or MsgpackProtocolCodec()
+ self.wire_codec_name = _wire_codec_name(self.wire_codec)
+ self.supported_protocol_versions = _dedupe_protocol_versions(
+ supported_protocol_versions,
+ preferred_version=protocol_version,
+ )
+ self.negotiated_protocol_version: str | None = None
+ self.remote_peer: PeerInfo | None = None
+ self.remote_handlers = []
+ self.remote_provided_capabilities = []
+ self.remote_capabilities = []
+ self.remote_capability_map: dict[str, Any] = {}
+ self.remote_provided_capability_map: dict[str, Any] = {}
+ self.remote_metadata: dict[str, Any] = {}
+
+ self._initialize_handler: InitializeHandler | None = None
+ self._invoke_handler: InvokeHandler | None = None
+ self._cancel_handler: CancelHandler | None = None
+ self._counter = 0
+ self._closed = asyncio.Event()
+ self._unusable = False
+ self._stopping = False
+ self._pending_results: dict[str, asyncio.Future[ResultMessage]] = {}
+ self._pending_streams: dict[str, asyncio.Queue[Any]] = {}
+ self._inbound_tasks: dict[
+ str, tuple[asyncio.Task[None], CancelToken, asyncio.Event]
+ ] = {}
+ self._remote_initialized = asyncio.Event()
+ self._remote_initialized_successfully = False
+ self._transport_watch_task: asyncio.Task[None] | None = None
+ # 记录当前正在执行 stop() 的 Task,用于防止 stop() 被并发重入
+ self._stop_task: asyncio.Task[None] | None = None
+
+ def set_initialize_handler(self, handler: InitializeHandler) -> None:
+ """注册处理远端 `initialize` 请求的握手处理器。"""
+ self._initialize_handler = handler
+
+ def set_invoke_handler(self, handler: InvokeHandler) -> None:
+ """注册处理远端 `invoke` 请求的能力调用处理器。"""
+ self._invoke_handler = handler
+
+ def set_cancel_handler(self, handler: CancelHandler) -> None:
+ """注册处理远端 `cancel` 请求的取消回调。"""
+ self._cancel_handler = handler
+
+ async def start(self) -> None:
+ """启动传输层并将原始入站消息绑定到当前 `Peer`。"""
+ self._closed.clear()
+ self._unusable = False
+ self._stopping = False
+ self.negotiated_protocol_version = None
+ self._remote_initialized.clear()
+ self._remote_initialized_successfully = False
+ self.transport.set_message_handler(self._handle_raw_message)
+ await self.transport.start()
+ self._transport_watch_task = asyncio.create_task(self._watch_transport_closed())
+
+ async def stop(self) -> None:
+ """关闭 `Peer` 并清理所有挂起中的请求、流和入站任务。
+
+ 重入安全性:transport.stop() 关闭底层连接时会触发原始消息处理器的
+ 异常路径,该路径调用 _fail_connection() -> _schedule_stop() -> stop(),
+ 形成间接递归。_stopping 标志和 _stop_task 引用共同防止重复清理资源。
+ 使用 asyncio.shield 等待是因为:如果当前任务在等待另一个 stop() 完成
+ 期间被取消,shield 保护内部 stop_task 不被连带取消,避免 Peer 停留在
+ 半关闭状态。
+ """
+ if self._closed.is_set():
+ return
+ current_task = asyncio.current_task()
+ if self._stopping:
+ # 防止并发重入:如果 stop() 已在其他协程中执行,则等待它完成而不是重复清理
+ stop_task = self._stop_task
+ if stop_task is not None and stop_task is not current_task:
+ await asyncio.shield(stop_task)
+ return
+ self._stopping = True
+ # 记录当前 task,供后续重入检测和 _schedule_stop() 判定
+ if current_task is not None and self._stop_task is None:
+ self._stop_task = current_task
+ try:
+ # 终止所有挂起的 RPC,避免调用方永久挂起
+ for future in list(self._pending_results.values()):
+ if not future.done():
+ future.set_exception(AstrBotError.internal_error("连接已关闭"))
+ self._pending_results.clear()
+
+ for queue in list(self._pending_streams.values()):
+ await queue.put(AstrBotError.internal_error("连接已关闭"))
+ self._pending_streams.clear()
+
+ # 取消所有入站任务
+ for task, token, _started in list(self._inbound_tasks.values()):
+ token.cancel()
+ task.cancel()
+ self._inbound_tasks.clear()
+
+ await self.transport.stop()
+ self._closed.set()
+ finally:
+ # 只在当前 task 就是 stop_task 时才清除引用,避免误清其他 task 的记录。
+ # 场景:A 任务正在 stop() 中,B 任务也进入了 stop() 并等待 A 完成,
+ # 如果 B 在 finally 中清除了 _stop_task,A 还未执行完就会失去引用。
+ if self._stop_task is current_task:
+ self._stop_task = None
+
+ async def wait_closed(self) -> None:
+ """等待底层传输彻底关闭。"""
+ await self.transport.wait_closed()
+
+ async def _watch_transport_closed(self) -> None:
+ """监视底层传输的意外关闭,并主动失败挂起调用。"""
+ try:
+ await self.transport.wait_closed()
+ if self._closed.is_set() or self._stopping:
+ return
+ await self._fail_connection(
+ AstrBotError(
+ code=ErrorCodes.NETWORK_ERROR,
+ message="连接已关闭",
+ hint="请检查对端进程或传输连接",
+ retryable=True,
+ )
+ )
+ finally:
+ current_task = asyncio.current_task()
+ if self._transport_watch_task is current_task:
+ self._transport_watch_task = None
+
+ async def wait_until_remote_initialized(self, timeout: float | None = 30.0) -> None:
+ """等待远端完成初始化握手。
+
+ Args:
+ timeout: 等待秒数。传入 `None` 表示无限等待。
+ """
+ init_waiter = asyncio.create_task(self._remote_initialized.wait())
+ closed_waiter = asyncio.create_task(self.wait_closed())
+ try:
+ done, pending = await asyncio.wait(
+ {init_waiter, closed_waiter},
+ timeout=timeout,
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ if not done:
+ raise TimeoutError()
+ if init_waiter in done and self._remote_initialized_successfully:
+ return
+ raise AstrBotError.protocol_error("连接在初始化完成前关闭")
+ finally:
+ for task in (init_waiter, closed_waiter):
+ if not task.done():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ async def initialize(
+ self,
+ handlers,
+ *,
+ provided_capabilities=None,
+ metadata: dict[str, Any] | None = None,
+ ) -> InitializeOutput:
+ """向远端发送初始化请求并缓存远端声明的能力信息。
+
+ Args:
+ handlers: 当前端点声明可接收的处理器列表。
+ metadata: 附带给远端的握手元数据。
+
+ Returns:
+ 远端返回的初始化结果。
+ """
+ self._ensure_usable()
+ request_id = self._next_id()
+ handshake_metadata = dict(metadata or {})
+ handshake_metadata[SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY] = list(
+ self.supported_protocol_versions
+ )
+ handshake_metadata[WIRE_CODEC_METADATA_KEY] = self.wire_codec_name
+ future = await self._send_pending_result_request(
+ request_id,
+ InitializeMessage(
+ id=request_id,
+ protocol_version=self.protocol_version,
+ peer=self.peer_info,
+ handlers=list(handlers),
+ provided_capabilities=list(provided_capabilities or []),
+ metadata=handshake_metadata,
+ ),
+ )
+ result = await future
+ if result.kind != "initialize_result":
+ raise AstrBotError.protocol_error("initialize 必须收到 initialize_result")
+ if not result.success:
+ self._unusable = True
+ await self.stop()
+ raise AstrBotError.from_payload(
+ result.error.model_dump() if result.error else {}
+ )
+ output = InitializeOutput.model_validate(result.output)
+ negotiated_protocol_version = (
+ output.protocol_version
+ or output.metadata.get(NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY)
+ or self.protocol_version
+ )
+ if (
+ not isinstance(negotiated_protocol_version, str)
+ or negotiated_protocol_version not in self.supported_protocol_versions
+ ):
+ self._unusable = True
+ await self.stop()
+ raise AstrBotError.protocol_version_mismatch(
+ f"对端返回了当前端点不支持的协商协议版本:{negotiated_protocol_version}"
+ )
+ _validate_wire_codec_metadata(
+ output.metadata,
+ expected_wire_codec=self.wire_codec_name,
+ )
+ self.remote_peer = output.peer
+ self.remote_capabilities = output.capabilities
+ self.remote_capability_map = {item.name: item for item in output.capabilities}
+ self.remote_metadata = output.metadata
+ self.negotiated_protocol_version = negotiated_protocol_version
+ self._remote_initialized_successfully = True
+ self._remote_initialized.set()
+ return output
+
+ async def invoke(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ stream: bool = False,
+ request_id: str | None = None,
+ ) -> dict[str, Any]:
+ """发起一次非流式能力调用并等待最终结果。
+
+ Args:
+ capability: 远端能力名。
+ payload: 调用输入。
+ stream: 必须为 `False`;流式场景应改用 `invoke_stream()`。
+ request_id: 可选的请求 ID;未提供时自动生成。
+ """
+ self._ensure_usable()
+ if stream:
+ raise ValueError("stream=True 请使用 invoke_stream()")
+ request_id = request_id or self._next_id()
+ future = await self._send_pending_result_request(
+ request_id,
+ InvokeMessage(
+ id=request_id,
+ capability=capability,
+ input=payload,
+ stream=False,
+ caller_plugin_id=current_caller_plugin_id(),
+ ),
+ )
+ result = await future
+ if not result.success:
+ raise AstrBotError.from_payload(
+ result.error.model_dump() if result.error else {}
+ )
+ return result.output
+
+ async def invoke_stream(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ request_id: str | None = None,
+ include_completed: bool = False,
+ ) -> AsyncIterator[EventMessage]:
+ """发起一次流式能力调用并返回事件迭代器。
+
+ 调用方会收到 `delta` 事件,`started` 会被内部吞掉,
+ 默认情况下 `completed` 用于结束迭代,`failed` 会转换为异常抛出。
+
+ Args:
+ capability: 远端能力名。
+ payload: 调用输入。
+ request_id: 可选的请求 ID;未提供时自动生成。
+ include_completed: 是否把 `completed` 事件也返回给调用方。
+ """
+ self._ensure_usable()
+ request_id = request_id or self._next_id()
+ queue = await self._send_pending_stream_request(
+ request_id,
+ InvokeMessage(
+ id=request_id,
+ capability=capability,
+ input=payload,
+ stream=True,
+ caller_plugin_id=current_caller_plugin_id(),
+ ),
+ )
+
+ async def iterator() -> AsyncIterator[EventMessage]:
+ terminal_received = False
+ try:
+ while True:
+ item = await queue.get()
+ if isinstance(item, Exception):
+ raise item
+ if not isinstance(item, EventMessage):
+ raise AstrBotError.protocol_error("流式调用收到非法事件")
+ if item.phase == "started":
+ continue
+ if item.phase == "delta":
+ yield item
+ continue
+ if item.phase == "completed":
+ terminal_received = True
+ if include_completed:
+ yield item
+ break
+ if item.phase == "failed":
+ terminal_received = True
+ raise AstrBotError.from_payload(
+ item.error.model_dump() if item.error else {}
+ )
+ finally:
+ self._pending_streams.pop(request_id, None)
+ if not terminal_received:
+ try:
+ await self.cancel(
+ request_id,
+ reason="consumer_closed_stream_early",
+ )
+ except Exception as exc:
+ # 取消失败不应中断整个流处理流程,仅记录日志
+ logger.debug(
+ "Failed to cancel stream after consumer closed early: "
+ "request_id={} error={}",
+ request_id,
+ exc,
+ )
+
+ return iterator()
+
+ async def cancel(self, request_id: str, reason: str = "user_cancelled") -> None:
+ """向远端发送取消请求,尝试中止指定 ID 的在途调用。"""
+ await self._send(CancelMessage(id=request_id, reason=reason))
+
+ def _next_id(self) -> str:
+ """生成当前连接内递增的消息 ID。"""
+ self._counter += 1
+ return f"msg_{self._counter:04d}"
+
+ def _ensure_usable(self) -> None:
+ """确保连接仍处于可用状态,否则立即抛出协议错误。"""
+ if self._unusable:
+ raise AstrBotError.protocol_error("连接已进入不可用状态")
+
+ async def _send_pending_result_request(
+ self,
+ request_id: str,
+ message,
+ ) -> asyncio.Future[ResultMessage]:
+ """注册等待中的结果请求,并在发送失败时回收挂起状态。"""
+ future: asyncio.Future[ResultMessage] = (
+ asyncio.get_running_loop().create_future()
+ )
+ self._pending_results[request_id] = future
+ try:
+ await self._send(message)
+ except Exception:
+ self._pending_results.pop(request_id, None)
+ if not future.done():
+ future.cancel()
+ raise
+ return future
+
+ async def _send_pending_stream_request(
+ self,
+ request_id: str,
+ message,
+ ) -> asyncio.Queue[Any]:
+ """注册等待中的流请求,并在发送失败时回收挂起状态。"""
+ queue: asyncio.Queue[Any] = asyncio.Queue()
+ self._pending_streams[request_id] = queue
+ try:
+ await self._send(message)
+ except Exception:
+ self._pending_streams.pop(request_id, None)
+ raise
+ return queue
+
+ async def _handle_raw_message(self, payload: bytes) -> None:
+ """解析原始消息并分发到对应的消息处理分支。"""
+ try:
+ # 入站消息大小检查:拒绝巨型消息,防止 OOM 和解析阻塞
+ if len(payload) > MAX_INBOUND_MESSAGE_BYTES:
+ raise AstrBotError.protocol_error(
+ f"协议消息过大,已拒绝处理:"
+ f"当前 {len(payload) / 1024 / 1024:.1f} MB,"
+ f"上限 {MAX_INBOUND_MESSAGE_BYTES / 1024 / 1024:.0f} MB"
+ )
+ message = self.wire_codec.decode_message(payload)
+ if isinstance(message, ResultMessage):
+ await self._handle_result(message)
+ return
+ if isinstance(message, EventMessage):
+ await self._handle_event(message)
+ return
+ if isinstance(message, InitializeMessage):
+ await self._handle_initialize(message)
+ return
+ if isinstance(message, InvokeMessage):
+ token = CancelToken()
+ started = asyncio.Event()
+ task = asyncio.create_task(self._handle_invoke(message, token, started))
+ self._inbound_tasks[message.id] = (task, token, started)
+
+ def _on_invoke_done(
+ _task: asyncio.Task[None], request_id: str = message.id
+ ) -> None:
+ self._inbound_tasks.pop(request_id, None)
+ if _task.cancelled():
+ return
+ exc = _task.exception()
+ if exc is None:
+ return
+ # 为什么整个连接都要失败?正常情况下 invoke handler 会把错误编码成
+ # ResultMessage 发回给对端。如果异常仍然逃逸,说明要么回复发送失败
+ # (transport 已断),要么 handler 实现有 bug。无论哪种情况,连接的
+ # 消息交换契约已不可靠,继续使用可能导致对端无限等待或数据丢失。
+ # 采用"单点故障 → 全连接失败"策略而非隔离单个 handler,是因为协议层
+ # 无法保证后续消息的正确性。
+ logger.error(
+ "Peer inbound invoke task crashed unexpectedly: "
+ "request_id={} error={!r}",
+ request_id,
+ exc,
+ )
+ error = (
+ exc
+ if isinstance(exc, AstrBotError)
+ else AstrBotError(
+ code=ErrorCodes.NETWORK_ERROR,
+ message="处理入站调用响应时连接已失效",
+ hint=str(exc),
+ retryable=True,
+ )
+ )
+ asyncio.create_task(self._fail_connection(error))
+
+ task.add_done_callback(_on_invoke_done)
+ return
+ if isinstance(message, CancelMessage):
+ await self._handle_cancel(message)
+ return
+ except Exception as exc:
+ if isinstance(exc, AstrBotError):
+ error = exc
+ else:
+ error = AstrBotError.protocol_error(f"无法解析协议消息: {exc}")
+ await self._fail_connection(error)
+ # 不再向上抛出异常,避免在 transport 读循环中引发未处理的异常导致整个连接崩溃
+ logger.warning(
+ "Peer connection marked unusable after inbound message failure: {}",
+ error,
+ )
+ return
+
+ async def _handle_initialize(self, message: InitializeMessage) -> None:
+ """处理远端发起的初始化握手并返回握手结果。"""
+ self.remote_peer = message.peer
+ self.remote_handlers = message.handlers
+ self.remote_provided_capabilities = message.provided_capabilities
+ self.remote_provided_capability_map = {
+ item.name: item for item in message.provided_capabilities
+ }
+ self.remote_metadata = dict(message.metadata)
+ if self._initialize_handler is None:
+ await self._reject_initialize(
+ message,
+ AstrBotError.protocol_error("对端不接受 initialize"),
+ )
+ return
+
+ negotiated_protocol_version = _select_negotiated_protocol_version(
+ message.protocol_version,
+ self.remote_metadata,
+ self.supported_protocol_versions,
+ )
+ if negotiated_protocol_version is None:
+ supported_versions = ", ".join(self.supported_protocol_versions)
+ await self._reject_initialize(
+ message,
+ AstrBotError.protocol_version_mismatch(
+ "服务端支持协议版本 "
+ f"{supported_versions},客户端请求版本 {message.protocol_version}"
+ ),
+ )
+ return
+ try:
+ _validate_wire_codec_metadata(
+ self.remote_metadata,
+ expected_wire_codec=self.wire_codec_name,
+ )
+ except AstrBotError as exc:
+ await self._reject_initialize(message, exc)
+ return
+
+ self.negotiated_protocol_version = negotiated_protocol_version
+ self.remote_metadata[NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY] = (
+ negotiated_protocol_version
+ )
+ output = await self._initialize_handler(message)
+ response_metadata = dict(output.metadata)
+ try:
+ _validate_wire_codec_metadata(
+ response_metadata,
+ expected_wire_codec=self.wire_codec_name,
+ )
+ except AstrBotError as exc:
+ await self._reject_initialize(message, exc)
+ return
+ response_metadata[NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY] = (
+ negotiated_protocol_version
+ )
+ output = output.model_copy(
+ update={
+ "protocol_version": negotiated_protocol_version,
+ "metadata": response_metadata,
+ }
+ )
+ await self._send(
+ ResultMessage(
+ id=message.id,
+ kind="initialize_result",
+ success=True,
+ output=output.model_dump(),
+ )
+ )
+ self._remote_initialized_successfully = True
+ self._remote_initialized.set()
+
+ async def _handle_invoke(
+ self,
+ message: InvokeMessage,
+ token: CancelToken,
+ started: asyncio.Event,
+ ) -> None:
+ """处理远端发起的能力调用,并按流式或非流式协议返回结果。"""
+ try:
+ started.set()
+ token.raise_if_cancelled()
+ if self._invoke_handler is None:
+ raise AstrBotError.capability_not_found(message.capability)
+ with caller_plugin_scope(message.caller_plugin_id):
+ execution = await self._invoke_handler(message, token)
+ if inspect.isawaitable(execution):
+ execution = await execution
+ if message.stream:
+ if not isinstance(execution, StreamExecution):
+ raise AstrBotError.protocol_error(
+ "stream=true 必须返回 StreamExecution"
+ )
+ await self._send(EventMessage(id=message.id, phase="started"))
+ collect_chunks = execution.collect_chunks
+ chunks: list[dict[str, Any]] = []
+ async for chunk in execution.iterator:
+ if collect_chunks:
+ chunks.append(chunk)
+ await self._send(
+ EventMessage(id=message.id, phase="delta", data=chunk)
+ )
+ await self._send(
+ EventMessage(
+ id=message.id,
+ phase="completed",
+ output=execution.finalize(chunks),
+ )
+ )
+ return
+ if isinstance(execution, StreamExecution):
+ raise AstrBotError.protocol_error("stream=false 不能返回流式执行对象")
+ await self._send(
+ ResultMessage(id=message.id, success=True, output=execution)
+ )
+ except asyncio.CancelledError:
+ await self._send_cancelled_termination(message)
+ except LookupError as exc:
+ error = AstrBotError.invalid_input(str(exc))
+ await self._send_error_result(message, error)
+ except AstrBotError as exc:
+ await self._send_error_result(message, exc)
+ except Exception as exc:
+ await self._send_error_result(
+ message, AstrBotError.internal_error(str(exc))
+ )
+
+ async def _handle_cancel(self, message: CancelMessage) -> None:
+ """处理远端取消请求并终止对应的入站任务。"""
+ inbound = self._inbound_tasks.get(message.id)
+ if inbound is None:
+ return
+ task, token, started = inbound
+ token.cancel()
+ if self._cancel_handler is not None:
+ await self._cancel_handler(message.id)
+ if started.is_set():
+ task.cancel()
+
+ async def _handle_result(self, message: ResultMessage) -> None:
+ """处理非流式结果消息并唤醒等待中的调用方。"""
+ future = self._pending_results.pop(message.id, None)
+ if future is None:
+ queue = self._pending_streams.get(message.id)
+ if queue is not None:
+ await queue.put(
+ AstrBotError.protocol_error("stream=true 调用不应收到 result")
+ )
+ return
+ # 检查 future 是否已完成(可能被调用方取消)
+ if not future.done():
+ future.set_result(message)
+
+ async def _handle_event(self, message: EventMessage) -> None:
+ """处理流式事件消息并投递到对应请求的事件队列。"""
+ queue = self._pending_streams.get(message.id)
+ if queue is None:
+ future = self._pending_results.get(message.id)
+ if future is not None and not future.done():
+ future.set_exception(
+ AstrBotError.protocol_error("stream=false 调用不应收到 event")
+ )
+ return
+ await queue.put(message)
+
+ async def _send_error_result(
+ self, message: InvokeMessage, error: AstrBotError
+ ) -> None:
+ """根据调用模式,将错误编码为 `result` 或失败事件发回远端。"""
+ if message.stream:
+ await self._send(
+ EventMessage(
+ id=message.id,
+ phase="failed",
+ error=ErrorPayload.model_validate(error.to_payload()),
+ )
+ )
+ return
+ await self._send(
+ ResultMessage(
+ id=message.id,
+ success=False,
+ error=ErrorPayload.model_validate(error.to_payload()),
+ )
+ )
+
+ async def _reject_initialize(
+ self, message: InitializeMessage, error: AstrBotError
+ ) -> None:
+ """拒绝一次初始化握手,并把连接标记为不可继续使用。"""
+ await self._send(
+ ResultMessage(
+ id=message.id,
+ kind="initialize_result",
+ success=False,
+ error=ErrorPayload.model_validate(error.to_payload()),
+ )
+ )
+ self._unusable = True
+ self._remote_initialized.set()
+ await self.stop()
+
+ async def _send_cancelled_termination(self, message: InvokeMessage) -> None:
+ """把本端取消执行转换为标准化的取消错误响应。"""
+ error = AstrBotError.cancelled()
+ await self._send_error_result(message, error)
+
+ async def _fail_connection(self, error: AstrBotError) -> None:
+ """把连接标记为不可用,并让所有等待中的调用尽快失败。"""
+ if self._unusable:
+ return
+ self._unusable = True
+ self._remote_initialized.set()
+
+ for future in list(self._pending_results.values()):
+ if not future.done():
+ future.set_exception(error)
+ self._pending_results.clear()
+
+ for queue in list(self._pending_streams.values()):
+ await queue.put(error)
+ self._pending_streams.clear()
+
+ for task, token, _started in list(self._inbound_tasks.values()):
+ token.cancel()
+ task.cancel()
+ self._inbound_tasks.clear()
+
+ self._schedule_stop()
+
+ def _schedule_stop(self) -> None:
+ """安全地调度 stop(),避免与正在执行的 stop() 产生并发冲突。"""
+ if self._closed.is_set():
+ return
+ # 已有 stop task 在跑则不重复创建,防止产生竞态条件
+ if self._stop_task is not None and not self._stop_task.done():
+ return
+ self._stop_task = asyncio.create_task(self.stop(), name="astrbot-sdk-peer-stop")
+
+ async def _send(self, message) -> None:
+ """序列化协议消息并通过底层传输发送出去。"""
+ encoded_message = self.wire_codec.encode_message(message)
+ await self.transport.send(encoded_message)
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py b/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py
new file mode 100644
index 0000000000..a454d176e8
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py
@@ -0,0 +1,1090 @@
+"""Supervisor 端运行时:SupervisorRuntime 管理多个 Worker 进程,WorkerSession 封装与单个 Worker 的通信。
+
+架构层次:
+ AstrBot Core (Python)
+ |
+ v
+ SupervisorRuntime (管理多插件)
+ |
+ +-- WorkerSession (插件 A) -- StdioTransport -- PluginWorkerRuntime (子进程)
+ |
+ +-- WorkerSession (插件 B, 插件 C) -- StdioTransport -- GroupWorkerRuntime (子进程)
+ |
+ +-- WorkerSession (插件 D) -- StdioTransport -- PluginWorkerRuntime (子进程)
+
+核心类:
+ SupervisorRuntime: 监管者运行时
+ - 发现并加载所有插件
+ - 为单个插件或兼容插件组启动 Worker 进程
+ - 聚合所有 handler 并向 Core 注册
+ - 路由 Core 的调用请求到对应 Worker
+ - 处理 Worker 进程崩溃和重连
+ - handler ID 冲突检测和警告
+
+ WorkerSession: Worker 会话
+ - 管理单个插件 Worker 进程
+ - 通过 Peer 与 Worker 通信
+ - 提供 invoke_handler 和 cancel 方法
+ - 处理连接关闭回调
+ - 自动清理已注册的 handlers
+
+信号处理:
+ - SIGTERM: 设置 stop_event,触发优雅关闭
+ - SIGINT: 设置 stop_event,触发优雅关闭
+"""
+
+from __future__ import annotations
+
+import asyncio
+import os
+import signal
+import sys
+from collections.abc import Callable
+from pathlib import Path
+from typing import IO, Any, cast
+
+from .._internal.plugin_ids import (
+ capability_belongs_to_plugin,
+ plugin_capability_prefix,
+)
+from .._internal.sdk_logger import logger
+from ..errors import AstrBotError
+from ..protocol.codec import JsonProtocolCodec, MsgpackProtocolCodec, ProtocolCodec
+from ..protocol.descriptors import CapabilityDescriptor
+from ..protocol.messages import EventMessage, InitializeOutput, PeerInfo
+from .capability_router import CapabilityRouter, StreamExecution
+from .environment_groups import EnvironmentGroup
+from .loader import (
+ PluginDiscoveryIssue,
+ PluginEnvironmentManager,
+ PluginSpec,
+ discover_plugins,
+ load_plugin_config,
+)
+from .peer import Peer
+from .transport import (
+ StdioTransport,
+ WebSocketClientTransport,
+ build_websocket_client_ssl_context,
+)
+from .workers_manifest import RemoteWorkerSpec, load_remote_workers_manifest
+
+__all__ = [
+ "SupervisorRuntime",
+ "WorkerSession",
+ "_install_signal_handlers",
+ "_prepare_stdio_transport",
+ "_sdk_source_dir",
+ "_wait_for_shutdown",
+]
+
+# Worker 进程初始化握手超时:60 秒内必须完成 initialize 协议交换,
+# 否则视为进程卡死或挂载过慢,直接报错让上层感知
+WORKER_INITIALIZE_TIMEOUT_SECONDS = 60.0
+
+
+def _install_signal_handlers(stop_event: asyncio.Event) -> None:
+ loop = asyncio.get_running_loop()
+ for sig in (signal.SIGTERM, signal.SIGINT):
+ try:
+ loop.add_signal_handler(sig, stop_event.set)
+ except NotImplementedError:
+ logger.debug("Signal handlers are not supported for {}", sig)
+
+
+def _prepare_stdio_transport(
+ stdin: IO[str] | None,
+ stdout: IO[str] | None,
+) -> tuple[IO[str], IO[str], IO[str] | None]:
+ if stdin is not None and stdout is not None:
+ return stdin, stdout, None
+ transport_stdin = stdin or sys.stdin
+ transport_stdout = stdout or sys.stdout
+ original_stdout = sys.stdout
+ sys.stdout = sys.stderr
+ return transport_stdin, transport_stdout, original_stdout
+
+
+def _sdk_source_dir(repo_root: Path) -> Path:
+ candidate = repo_root.resolve() / "src"
+ if (candidate / "astrbot_sdk").exists():
+ return candidate
+ return Path(__file__).resolve().parents[2]
+
+
+async def _wait_for_shutdown(peer: Peer, stop_event: asyncio.Event) -> None:
+ stop_waiter = asyncio.create_task(stop_event.wait())
+ transport_waiter = asyncio.create_task(peer.wait_closed())
+ done, pending = await asyncio.wait(
+ {stop_waiter, transport_waiter},
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ for task in pending:
+ task.cancel()
+ for task in done:
+ if not task.cancelled():
+ task.result()
+
+
+def _plugin_name_from_handler_id(handler_id: str) -> str:
+ if ":" in handler_id:
+ return handler_id.split(":", 1)[0]
+ return handler_id
+
+
+def _metadata_string_list(value: Any) -> list[str]:
+ if not isinstance(value, list):
+ return []
+ return [item for item in value if isinstance(item, str)]
+
+
+def _metadata_string_dict(value: Any) -> dict[str, str]:
+ if not isinstance(value, dict):
+ return {}
+ return {
+ key: item
+ for key, item in value.items()
+ if isinstance(key, str) and isinstance(item, str)
+ }
+
+
+def _metadata_dict_list(
+ value: Any,
+ *,
+ require_name: bool = False,
+) -> list[dict[str, Any]]:
+ if not isinstance(value, list):
+ return []
+ records = [dict(item) for item in value if isinstance(item, dict)]
+ if not require_name:
+ return records
+ return [record for record in records if str(record.get("name", "")).strip()]
+
+
+def _group_records_by_plugin(
+ records: list[dict[str, Any]],
+) -> dict[str, list[dict[str, Any]]]:
+ grouped: dict[str, list[dict[str, Any]]] = {}
+ for item in records:
+ plugin_name = str(item.get("plugin_id", "")).strip()
+ if not plugin_name:
+ continue
+ grouped.setdefault(plugin_name, []).append(dict(item))
+ return grouped
+
+
+def _plugin_ids_from_worker_registry(entries: list[dict[str, Any]]) -> set[str]:
+ plugin_ids = {
+ str(item.get("name", "")).strip() for item in entries if isinstance(item, dict)
+ }
+ plugin_ids.discard("")
+ return plugin_ids
+
+
+def _wire_codec_cli_name(codec: ProtocolCodec) -> str:
+ if isinstance(codec, MsgpackProtocolCodec):
+ return "msgpack"
+ if isinstance(codec, JsonProtocolCodec):
+ return "json"
+ raise ValueError(
+ f"unsupported wire codec for local worker subprocess: {type(codec).__name__}"
+ )
+
+
+class WorkerSession:
+ def __init__(
+ self,
+ *,
+ plugin: PluginSpec | None = None,
+ group: EnvironmentGroup | None = None,
+ remote_worker: RemoteWorkerSpec | None = None,
+ repo_root: Path,
+ env_manager: PluginEnvironmentManager,
+ capability_router: CapabilityRouter,
+ on_closed: Callable[[], None] | None = None,
+ wire_codec: ProtocolCodec | None = None,
+ ) -> None:
+ target_count = sum(item is not None for item in (plugin, group, remote_worker))
+ if target_count != 1:
+ raise ValueError(
+ "WorkerSession requires exactly one of plugin, group, or remote_worker"
+ )
+ group_ref = group
+ self.remote_worker = remote_worker
+ self.is_remote = remote_worker is not None
+ if group_ref is not None:
+ primary_plugin = group_ref.plugins[0]
+ elif plugin is not None:
+ primary_plugin = plugin
+ else:
+ primary_plugin = None
+ self.group = group
+ self.plugins = (
+ list(group_ref.plugins)
+ if group_ref is not None
+ else ([primary_plugin] if primary_plugin is not None else [])
+ )
+ self.plugin = primary_plugin
+ self.worker_id = (
+ remote_worker.id
+ if remote_worker is not None
+ else (
+ group_ref.id
+ if group_ref is not None
+ else cast(PluginSpec, primary_plugin).name
+ )
+ )
+ self.repo_root = repo_root.resolve()
+ self.env_manager = env_manager
+ self.capability_router = capability_router
+ self.on_closed = on_closed
+ self.wire_codec = wire_codec or MsgpackProtocolCodec()
+ self.peer: Peer | None = None
+ self.handlers = []
+ self.provided_capabilities: list[CapabilityDescriptor] = []
+ self.loaded_plugins: list[str] = []
+ self.skipped_plugins: dict[str, str] = {}
+ self.issues: list[PluginDiscoveryIssue] = []
+ self.capability_sources: dict[str, str] = {}
+ self.llm_tools: list[dict[str, Any]] = []
+ self.agents: list[dict[str, Any]] = []
+ self.worker_registry: list[dict[str, Any]] = []
+ self._connection_watch_task: asyncio.Task[None] | None = None
+
+ async def start(self) -> None:
+ transport = self._build_transport()
+ self.peer = Peer(
+ transport=transport,
+ peer_info=PeerInfo(name="astrbot-core", role="core", version="s5r"),
+ wire_codec=self.wire_codec,
+ )
+ self.peer.set_initialize_handler(self._handle_initialize)
+ self.peer.set_invoke_handler(self._handle_capability_invoke)
+ try:
+ await self.peer.start()
+ await self._wait_until_initialized()
+ self._sync_remote_state()
+ self._validate_initialized_state()
+
+ except Exception:
+ await self.stop()
+ raise
+
+ def _build_transport(self):
+ if self.remote_worker is not None:
+ ssl_context = build_websocket_client_ssl_context(
+ ca_file=self.remote_worker.tls.ca_file,
+ cert_file=self.remote_worker.tls.cert_file,
+ key_file=self.remote_worker.tls.key_file,
+ )
+ return WebSocketClientTransport(
+ url=self.remote_worker.url,
+ ssl_context=ssl_context,
+ server_hostname=self.remote_worker.tls.server_hostname,
+ )
+
+ python_path, command, cwd = self._worker_command()
+ repo_src_dir = str(_sdk_source_dir(self.repo_root))
+ env = os.environ.copy()
+ existing_pythonpath = env.get("PYTHONPATH")
+ env["PYTHONPATH"] = (
+ f"{repo_src_dir}{os.pathsep}{existing_pythonpath}"
+ if existing_pythonpath
+ else repo_src_dir
+ )
+ env.setdefault("PYTHONIOENCODING", "utf-8")
+ env.setdefault("PYTHONUTF8", "1")
+ return StdioTransport(command=command, cwd=cwd, env=env)
+
+ async def _wait_until_initialized(self) -> None:
+ assert self.peer is not None
+ try:
+ await self.peer.wait_until_remote_initialized(
+ timeout=WORKER_INITIALIZE_TIMEOUT_SECONDS
+ )
+ except TimeoutError as exc:
+ raise RuntimeError(
+ f"worker {self.worker_id} 初始化超时 "
+ f"({WORKER_INITIALIZE_TIMEOUT_SECONDS:.0f}s);"
+ "请检查 worker 日志中的 on_start / 装饰器初始化错误"
+ ) from exc
+ except AstrBotError as exc:
+ raise RuntimeError(f"worker {self.worker_id} 在初始化阶段退出") from exc
+
+ def _sync_remote_state(self) -> None:
+ assert self.peer is not None
+ self.handlers = list(self.peer.remote_handlers)
+ self.provided_capabilities = list(self.peer.remote_provided_capabilities)
+ metadata = dict(self.peer.remote_metadata)
+ self.loaded_plugins = _metadata_string_list(metadata.get("loaded_plugins")) or [
+ plugin.name for plugin in self.plugins
+ ]
+ self.skipped_plugins = _metadata_string_dict(metadata.get("skipped_plugins"))
+ self.capability_sources = _metadata_string_dict(
+ metadata.get("capability_sources")
+ )
+ self.issues = self._parse_remote_issues(metadata.get("issues"))
+ self.llm_tools = _metadata_dict_list(metadata.get("llm_tools"))
+ self.agents = _metadata_dict_list(metadata.get("agents"))
+ self.worker_registry = _metadata_dict_list(
+ metadata.get("worker_registry"),
+ require_name=True,
+ )
+
+ def _parse_remote_issues(self, value: Any) -> list[PluginDiscoveryIssue]:
+ default_issue_owner = (
+ self.plugin.name if self.plugin is not None else self.worker_id
+ )
+ issues: list[PluginDiscoveryIssue] = []
+ for item in _metadata_dict_list(value):
+ issues.append(
+ PluginDiscoveryIssue(
+ severity=str(item.get("severity", "error")), # type: ignore[arg-type]
+ phase=str(item.get("phase", "load")), # type: ignore[arg-type]
+ plugin_id=str(item.get("plugin_id", default_issue_owner)),
+ message=str(item.get("message", "")),
+ details=str(item.get("details", "")),
+ hint=str(item.get("hint", "")),
+ )
+ )
+ return issues
+
+ def _validate_initialized_state(self) -> None:
+ assert self.peer is not None
+ if self.remote_worker is not None and self.peer.remote_peer is not None:
+ if self.peer.remote_peer.name != self.worker_id:
+ raise RuntimeError(
+ "remote worker identity mismatch: "
+ f"expected {self.worker_id!r}, got {self.peer.remote_peer.name!r}"
+ )
+
+ plugin_ids = _plugin_ids_from_worker_registry(self.worker_registry)
+ if not plugin_ids and self.plugins:
+ plugin_ids = {plugin.name for plugin in self.plugins}
+ if self.remote_worker is not None and not plugin_ids:
+ raise RuntimeError(
+ f"remote worker {self.worker_id} did not provide worker_registry"
+ )
+
+ for plugin_name in self.loaded_plugins:
+ if plugin_ids and plugin_name not in plugin_ids:
+ raise RuntimeError(
+ f"worker {self.worker_id} reported undeclared loaded plugin: "
+ f"{plugin_name}"
+ )
+ for plugin_name in self.skipped_plugins:
+ if plugin_ids and plugin_name not in plugin_ids:
+ raise RuntimeError(
+ f"worker {self.worker_id} reported undeclared skipped plugin: "
+ f"{plugin_name}"
+ )
+ for capability_name, plugin_name in self.capability_sources.items():
+ if plugin_ids and plugin_name not in plugin_ids:
+ raise RuntimeError(
+ f"worker {self.worker_id} returned capability source outside "
+ f"worker_registry: {capability_name} -> {plugin_name}"
+ )
+ for handler in self.handlers:
+ owner_plugin = _plugin_name_from_handler_id(handler.id)
+ if plugin_ids and owner_plugin not in plugin_ids:
+ raise RuntimeError(
+ f"worker {self.worker_id} returned handler outside worker_registry: "
+ f"{handler.id}"
+ )
+ for item in self.llm_tools:
+ plugin_name = str(item.get("plugin_id", "")).strip()
+ if plugin_ids and plugin_name and plugin_name not in plugin_ids:
+ raise RuntimeError(
+ f"worker {self.worker_id} returned llm tool outside worker_registry: "
+ f"{plugin_name}"
+ )
+ for item in self.agents:
+ plugin_name = str(item.get("plugin_id", "")).strip()
+ if plugin_ids and plugin_name and plugin_name not in plugin_ids:
+ raise RuntimeError(
+ f"worker {self.worker_id} returned agent outside worker_registry: "
+ f"{plugin_name}"
+ )
+
+ def _worker_command(self) -> tuple[Path, list[str], str]:
+ wire_codec = _wire_codec_cli_name(self.wire_codec)
+ if self.group is not None:
+ prepare_group = getattr(self.env_manager, "prepare_group_environment", None)
+ if callable(prepare_group):
+ python_path = cast(Path, prepare_group(self.group))
+ else:
+ python_path = self.env_manager.prepare_environment(self.plugins[0])
+ return (
+ python_path,
+ [
+ str(python_path),
+ "-m",
+ "astrbot_sdk",
+ "worker",
+ "--wire-codec",
+ wire_codec,
+ "--group-metadata",
+ str(self.group.metadata_path),
+ ],
+ str(self.repo_root),
+ )
+
+ assert self.plugin is not None
+ plugin = self.plugin
+ python_path = self.env_manager.prepare_environment(plugin)
+ return (
+ python_path,
+ [
+ str(python_path),
+ "-m",
+ "astrbot_sdk",
+ "worker",
+ "--wire-codec",
+ wire_codec,
+ "--plugin-dir",
+ str(plugin.plugin_dir),
+ ],
+ str(plugin.plugin_dir),
+ )
+
+ def start_close_watch(self) -> None:
+ if (
+ self.on_closed is None
+ or self.peer is None
+ or self._connection_watch_task is not None
+ ):
+ return
+ self._connection_watch_task = asyncio.create_task(self._watch_connection())
+
+ async def _watch_connection(self) -> None:
+ """监听 Worker 连接关闭,触发清理回调"""
+ try:
+ if self.peer is not None:
+ await self.peer.wait_closed()
+ if self.on_closed is not None:
+ try:
+ self.on_closed()
+ except Exception:
+ logger.exception(
+ "on_closed callback failed for worker {}", self.worker_id
+ )
+ finally:
+ current_task = asyncio.current_task()
+ if self._connection_watch_task is current_task:
+ self._connection_watch_task = None
+
+ async def stop(self) -> None:
+ if self.peer is not None:
+ await self.peer.stop()
+
+ async def invoke_handler(
+ self,
+ handler_id: str,
+ event_payload: dict[str, Any],
+ *,
+ request_id: str,
+ args: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ if self.peer is None:
+ raise RuntimeError("worker session is not running")
+ return await self.peer.invoke(
+ "handler.invoke",
+ {
+ "handler_id": handler_id,
+ "event": event_payload,
+ "args": dict(args or {}),
+ },
+ request_id=request_id,
+ )
+
+ async def invoke_capability(
+ self,
+ capability_name: str,
+ payload: dict[str, Any],
+ *,
+ request_id: str,
+ ) -> dict[str, Any]:
+ if self.peer is None:
+ raise RuntimeError("worker session is not running")
+ return await self.peer.invoke(
+ capability_name,
+ payload,
+ request_id=request_id,
+ )
+
+ async def invoke_capability_stream(
+ self,
+ capability_name: str,
+ payload: dict[str, Any],
+ *,
+ request_id: str,
+ ):
+ if self.peer is None:
+ raise RuntimeError("worker session is not running")
+ event_stream = await self.peer.invoke_stream(
+ capability_name,
+ payload,
+ request_id=request_id,
+ include_completed=True,
+ )
+ async for event in event_stream:
+ yield event
+
+ async def cancel(self, request_id: str) -> None:
+ if self.peer is None:
+ return
+ await self.peer.cancel(request_id)
+
+ async def _handle_initialize(self, _message) -> InitializeOutput:
+ if self.peer is None:
+ raise RuntimeError("worker session is not running")
+ return InitializeOutput(
+ peer=PeerInfo(name="astrbot-supervisor", role="core", version="s5r"),
+ capabilities=self.capability_router.all_descriptors(),
+ metadata={
+ "worker_id": self.worker_id,
+ "plugins": [plugin.name for plugin in self.plugins],
+ "wire_codec": self.peer.wire_codec_name,
+ },
+ )
+
+ async def _handle_capability_invoke(self, message, cancel_token):
+ return await self.capability_router.execute(
+ message.capability,
+ message.input,
+ stream=message.stream,
+ cancel_token=cancel_token,
+ request_id=message.id,
+ )
+
+ def describe(self) -> dict[str, Any]:
+ return {
+ "worker_id": self.worker_id,
+ "plugins": [plugin.name for plugin in self.plugins],
+ "loaded_plugins": list(self.loaded_plugins),
+ "skipped_plugins": dict(self.skipped_plugins),
+ "issues": [issue.to_payload() for issue in self.issues],
+ }
+
+
+class SupervisorRuntime:
+ def __init__(
+ self,
+ *,
+ transport,
+ plugins_dir: Path,
+ env_manager: PluginEnvironmentManager | None = None,
+ workers_manifest: Path | None = None,
+ wire_codec: ProtocolCodec | None = None,
+ ) -> None:
+ self.transport = transport
+ self.plugins_dir = plugins_dir.resolve()
+ self.repo_root = Path(__file__).resolve().parents[3]
+ self.env_manager = env_manager or PluginEnvironmentManager(self.repo_root)
+ self.workers_manifest = workers_manifest.resolve() if workers_manifest else None
+ self.wire_codec = wire_codec or MsgpackProtocolCodec()
+ self.capability_router = CapabilityRouter()
+ self.peer = Peer(
+ transport=self.transport,
+ peer_info=PeerInfo(name="astrbot-supervisor", role="plugin", version="s5r"),
+ wire_codec=self.wire_codec,
+ )
+ self.peer.set_invoke_handler(self._handle_upstream_invoke)
+ self.peer.set_cancel_handler(self._handle_upstream_cancel)
+ self.worker_sessions: dict[str, WorkerSession] = {}
+ self.handler_to_worker: dict[str, WorkerSession] = {}
+ self.capability_to_worker: dict[str, WorkerSession] = {}
+ self.plugin_to_worker_session: dict[str, WorkerSession] = {}
+ self._handler_sources: dict[str, str] = {} # handler_id -> plugin_name
+ self._capability_sources: dict[str, str] = {} # capability_name -> plugin_name
+ self.active_requests: dict[str, WorkerSession] = {}
+ self.loaded_plugins: list[str] = []
+ self.skipped_plugins: dict[str, str] = {}
+ self.issues: list[PluginDiscoveryIssue] = []
+ self._register_internal_capabilities()
+
+ def _publish_plugin_registry_snapshot(
+ self,
+ plugins: list[PluginSpec],
+ *,
+ enabled_plugins: set[str],
+ ) -> None:
+ for plugin in plugins:
+ manifest = plugin.manifest_data
+ self.capability_router.upsert_plugin(
+ metadata={
+ "name": plugin.name,
+ "display_name": str(manifest.get("display_name") or plugin.name),
+ "description": str(
+ manifest.get("desc") or manifest.get("description") or ""
+ ),
+ "repo": str(manifest.get("repo") or ""),
+ "author": str(manifest.get("author") or ""),
+ "version": str(manifest.get("version") or "0.0.0"),
+ "enabled": plugin.name in enabled_plugins,
+ },
+ config=load_plugin_config(plugin),
+ )
+
+ def _publish_discovered_plugin_registry(self, plugins: list[PluginSpec]) -> None:
+ """发布已发现插件的静态元数据。
+
+ 这一阶段发生在 worker 真正启动前。此时 supervisor 已经知道有哪些插件、
+ 它们的 manifest/config 是什么,但尚未确认哪些插件实际完成加载,因此统一
+ 以 `enabled=False` 暴露给 metadata 能力。
+ """
+ self._publish_plugin_registry_snapshot(plugins, enabled_plugins=set())
+
+ def _publish_loaded_plugin_registry(self, plugins: list[PluginSpec]) -> None:
+ """在 worker 启动完成后刷新插件启用状态。"""
+ self._publish_plugin_registry_snapshot(
+ plugins,
+ enabled_plugins=set(self.loaded_plugins),
+ )
+
+ def _publish_worker_registry(self, entries: list[dict[str, Any]]) -> None:
+ for item in entries:
+ plugin_name = str(item.get("name", "")).strip()
+ if not plugin_name:
+ continue
+ config = item.get("config")
+ metadata = dict(item)
+ metadata.pop("config", None)
+ self.capability_router.upsert_plugin(
+ metadata=metadata,
+ config=dict(config) if isinstance(config, dict) else {},
+ )
+
+ def _publish_session_runtime_metadata(self, session: WorkerSession) -> None:
+ self._publish_worker_registry(session.worker_registry)
+ for plugin_name, items in _group_records_by_plugin(session.llm_tools).items():
+ self.capability_router.set_plugin_llm_tools(plugin_name, items)
+ for plugin_name, items in _group_records_by_plugin(session.agents).items():
+ self.capability_router.set_plugin_agents(plugin_name, items)
+
+ @staticmethod
+ def _session_plugin_ids(session: WorkerSession) -> set[str]:
+ plugin_ids = _plugin_ids_from_worker_registry(session.worker_registry)
+ if plugin_ids:
+ return plugin_ids
+ return {plugin.name for plugin in session.plugins}
+
+ def _validate_remote_session_plugins(
+ self,
+ session: WorkerSession,
+ *,
+ local_plugin_ids: set[str],
+ ) -> None:
+ if not session.is_remote:
+ return
+ conflicts = self._session_plugin_ids(session) & (
+ local_plugin_ids | set(self.plugin_to_worker_session)
+ )
+ if not conflicts:
+ return
+ names = ", ".join(sorted(conflicts))
+ raise RuntimeError(
+ f"remote worker {session.worker_id} conflicts with existing plugins: {names}"
+ )
+
+ def _record_session_start_failure(
+ self,
+ session: WorkerSession,
+ exc: Exception,
+ ) -> None:
+ if session.plugins:
+ for plugin in session.plugins:
+ self.skipped_plugins[plugin.name] = str(exc)
+ self.issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="load",
+ plugin_id=plugin.name,
+ message="插件 worker 启动失败",
+ details=str(exc),
+ )
+ )
+ return
+ self.issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="load",
+ plugin_id=session.worker_id,
+ message="远程 worker 连接失败",
+ details=str(exc),
+ )
+ )
+
+ def _register_started_session(self, session: WorkerSession) -> None:
+ self.worker_sessions[session.worker_id] = session
+ self.skipped_plugins.update(session.skipped_plugins)
+ self.issues.extend(session.issues)
+ self._publish_session_runtime_metadata(session)
+ for plugin_name in session.loaded_plugins:
+ self.plugin_to_worker_session[plugin_name] = session
+ if plugin_name not in self.loaded_plugins:
+ self.loaded_plugins.append(plugin_name)
+ for handler in session.handlers:
+ self._register_handler(
+ handler,
+ session,
+ _plugin_name_from_handler_id(handler.id),
+ )
+ for descriptor in session.provided_capabilities:
+ plugin_name = session.capability_sources.get(descriptor.name)
+ if plugin_name is None and len(session.loaded_plugins) == 1:
+ plugin_name = session.loaded_plugins[0]
+ if plugin_name is None:
+ plugin_name = session.worker_id
+ self._register_plugin_capability(descriptor, session, plugin_name)
+ session.start_close_watch()
+
+ def _register_internal_capabilities(self) -> None:
+ self.capability_router.register(
+ CapabilityDescriptor(
+ name="handler.invoke",
+ description="框架内部:转发到插件 handler",
+ input_schema={
+ "type": "object",
+ "properties": {
+ "handler_id": {"type": "string"},
+ "event": {"type": "object"},
+ },
+ "required": ["handler_id", "event"],
+ },
+ output_schema={
+ "type": "object",
+ "properties": {},
+ "required": [],
+ },
+ cancelable=True,
+ ),
+ call_handler=self._route_handler_invoke,
+ exposed=False,
+ )
+
+ def _register_handler(
+ self, handler, session: WorkerSession, plugin_name: str
+ ) -> None:
+ """注册 handler,处理冲突时输出警告。
+
+ Args:
+ handler: Handler 描述符
+ session: Worker 会话
+ plugin_name: 插件名称
+ """
+ handler_id = handler.id
+ existing_plugin = self._handler_sources.get(handler_id)
+
+ if existing_plugin is not None:
+ logger.warning(
+ f"Handler ID 冲突:'{handler_id}' 已被插件 '{existing_plugin}' 注册,"
+ f"现在被插件 '{plugin_name}' 覆盖。"
+ )
+
+ self.handler_to_worker[handler_id] = session
+ self._handler_sources[handler_id] = plugin_name
+
+ def _register_plugin_capability(
+ self,
+ descriptor: CapabilityDescriptor,
+ session: WorkerSession,
+ plugin_name: str,
+ ) -> None:
+ """注册插件 capability。"""
+ capability_name = descriptor.name
+ if not capability_belongs_to_plugin(capability_name, plugin_name):
+ expected_prefix = plugin_capability_prefix(plugin_name)
+ raise ValueError(
+ "插件导出的 capability 必须使用 plugin_id 作为公开命名空间前缀:"
+ f" plugin={plugin_name!r}, capability={capability_name!r}, "
+ f"expected_prefix={expected_prefix!r}"
+ )
+ # Worker 侧 loader 已经做过命名空间校验;这里若还能撞名,说明协议数据
+ # 与本地路由状态不一致,继续静默改名只会掩盖问题。
+ if self.capability_router.contains(capability_name):
+ existing_plugin = self._capability_sources.get(capability_name, "")
+ raise RuntimeError(
+ "duplicate capability registration detected after worker load validation: "
+ f"{capability_name!r} already registered by {existing_plugin!r}, "
+ f"cannot register again for {plugin_name!r}"
+ )
+ self._do_register_capability(descriptor, session, capability_name, plugin_name)
+
+ def _do_register_capability(
+ self,
+ descriptor: CapabilityDescriptor,
+ session: WorkerSession,
+ capability_name: str,
+ plugin_name: str,
+ ) -> None:
+ """实际执行 capability 注册。"""
+ self.capability_router.register(
+ descriptor,
+ call_handler=self._make_plugin_capability_caller(session, capability_name),
+ stream_handler=(
+ self._make_plugin_capability_streamer(session, capability_name)
+ if descriptor.supports_stream
+ else None
+ ),
+ )
+ self.capability_to_worker[capability_name] = session
+ self._capability_sources[capability_name] = plugin_name
+
+ def _make_plugin_capability_caller(
+ self,
+ session: WorkerSession,
+ capability_name: str,
+ ):
+ async def call_handler(
+ request_id: str,
+ payload: dict[str, Any],
+ _cancel_token,
+ ) -> dict[str, Any]:
+ self.active_requests[request_id] = session
+ try:
+ return await session.invoke_capability(
+ capability_name,
+ payload,
+ request_id=request_id,
+ )
+ finally:
+ self.active_requests.pop(request_id, None)
+
+ return call_handler
+
+ def _make_plugin_capability_streamer(
+ self,
+ session: WorkerSession,
+ capability_name: str,
+ ):
+ async def stream_handler(
+ request_id: str,
+ payload: dict[str, Any],
+ _cancel_token,
+ ):
+ completed_output: dict[str, Any] = {}
+
+ async def iterator():
+ self.active_requests[request_id] = session
+ try:
+ async for event in session.invoke_capability_stream(
+ capability_name,
+ payload,
+ request_id=request_id,
+ ):
+ if not isinstance(event, EventMessage):
+ raise AstrBotError.protocol_error(
+ "插件 worker 返回了非法的流式事件"
+ )
+ if event.phase == "delta":
+ yield event.data or {}
+ continue
+ if event.phase == "completed":
+ completed_output.clear()
+ completed_output.update(event.output or {})
+ finally:
+ self.active_requests.pop(request_id, None)
+
+ return StreamExecution(
+ iterator=iterator(),
+ finalize=lambda chunks: completed_output or {"items": chunks},
+ )
+
+ return stream_handler
+
+ async def start(self) -> None:
+ discovery = discover_plugins(self.plugins_dir)
+ self.skipped_plugins = dict(discovery.skipped_plugins)
+ self.issues = list(discovery.issues)
+ local_plugin_ids = {plugin.name for plugin in discovery.plugins}
+ plan_result = self.env_manager.plan(discovery.plugins)
+ remote_workers = (
+ load_remote_workers_manifest(self.workers_manifest)
+ if self.workers_manifest is not None
+ else []
+ )
+ self.skipped_plugins.update(plan_result.skipped_plugins)
+ self.issues.extend(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="load",
+ plugin_id=plugin_name,
+ message="插件环境规划失败",
+ details=str(reason),
+ )
+ for plugin_name, reason in plan_result.skipped_plugins.items()
+ )
+ # 先发布静态插件元数据,允许 supervisor 侧在 worker 启动阶段就读取配置/清单。
+ self._publish_discovered_plugin_registry(discovery.plugins)
+ try:
+ planned_sessions: list[WorkerSession] = []
+ if plan_result.groups:
+ for group in plan_result.groups:
+ planned_sessions.append(
+ WorkerSession(
+ group=group,
+ repo_root=self.repo_root,
+ env_manager=self.env_manager,
+ capability_router=self.capability_router,
+ wire_codec=self.wire_codec,
+ on_closed=lambda worker_id=group.id: (
+ self._handle_worker_closed(worker_id)
+ ),
+ )
+ )
+ else:
+ for plugin in plan_result.plugins:
+ planned_sessions.append(
+ WorkerSession(
+ plugin=plugin,
+ repo_root=self.repo_root,
+ env_manager=self.env_manager,
+ capability_router=self.capability_router,
+ wire_codec=self.wire_codec,
+ on_closed=lambda worker_id=plugin.name: (
+ self._handle_worker_closed(worker_id)
+ ),
+ )
+ )
+ for remote_worker in remote_workers:
+ planned_sessions.append(
+ WorkerSession(
+ remote_worker=remote_worker,
+ repo_root=self.repo_root,
+ env_manager=self.env_manager,
+ capability_router=self.capability_router,
+ wire_codec=self.wire_codec,
+ on_closed=lambda worker_id=remote_worker.id: (
+ self._handle_worker_closed(worker_id)
+ ),
+ )
+ )
+
+ for session in planned_sessions:
+ try:
+ await session.start()
+ self._validate_remote_session_plugins(
+ session,
+ local_plugin_ids=local_plugin_ids,
+ )
+ except Exception as exc:
+ self._record_session_start_failure(session, exc)
+ await session.stop()
+ continue
+ self._register_started_session(session)
+
+ # worker 启动后再用实际加载结果刷新 enabled 状态,形成显式两阶段发布。
+ self._publish_loaded_plugin_registry(discovery.plugins)
+
+ aggregated_handlers = list(self.handler_to_worker.keys())
+ logger.info(
+ "Loaded plugins: {}", ", ".join(sorted(self.loaded_plugins)) or "none"
+ )
+
+ await self.peer.start()
+ await self.peer.initialize(
+ [
+ handler
+ for session in self.worker_sessions.values()
+ for handler in session.handlers
+ ],
+ provided_capabilities=self.capability_router.descriptors(),
+ metadata={
+ "plugins": sorted(self.loaded_plugins),
+ "skipped_plugins": self.skipped_plugins,
+ "issues": [issue.to_payload() for issue in self.issues],
+ "aggregated_handler_ids": aggregated_handlers,
+ "workers": [
+ session.describe() for session in self.worker_sessions.values()
+ ],
+ "worker_count": len(self.worker_sessions),
+ },
+ )
+ except Exception:
+ await self.stop()
+ raise
+
+ def _handle_worker_closed(self, worker_id: str) -> None:
+ """Worker 连接关闭时的清理回调"""
+ session = self.worker_sessions.pop(worker_id, None)
+ if session is None:
+ return
+ # 从 handler_to_worker 中移除该插件注册的 handlers(仅当来源仍为此插件时)
+ for handler in session.handlers:
+ source_plugin = self._handler_sources.get(handler.id)
+ if source_plugin == _plugin_name_from_handler_id(handler.id) or (
+ source_plugin == worker_id
+ ):
+ self.handler_to_worker.pop(handler.id, None)
+ self._handler_sources.pop(handler.id, None)
+ for descriptor in session.provided_capabilities:
+ source_plugin = self._capability_sources.get(descriptor.name)
+ capability_plugin = session.capability_sources.get(descriptor.name)
+ if source_plugin == capability_plugin or (
+ capability_plugin is None
+ and (
+ source_plugin == worker_id
+ or source_plugin in session.loaded_plugins
+ )
+ ):
+ self.capability_to_worker.pop(descriptor.name, None)
+ self._capability_sources.pop(descriptor.name, None)
+ self.capability_router.unregister(descriptor.name)
+ session_loaded_plugins = getattr(session, "loaded_plugins", None)
+ if not isinstance(session_loaded_plugins, list):
+ session_loaded_plugins = [worker_id]
+ for plugin_name in session_loaded_plugins:
+ if plugin_name in self.loaded_plugins:
+ self.loaded_plugins.remove(plugin_name)
+ self.plugin_to_worker_session.pop(plugin_name, None)
+ self.capability_router.set_plugin_enabled(plugin_name, False)
+ self.capability_router.remove_http_apis_for_plugin(plugin_name)
+ stale_requests = [
+ request_id
+ for request_id, active_session in self.active_requests.items()
+ if active_session is session
+ ]
+ for request_id in stale_requests:
+ self.active_requests.pop(request_id, None)
+ logger.warning("worker {} 连接已关闭,已清理相关 handlers", worker_id)
+
+ async def stop(self) -> None:
+ for session in list(self.worker_sessions.values()):
+ await session.stop()
+ await self.peer.stop()
+
+ async def _handle_upstream_invoke(self, message, cancel_token):
+ return await self.capability_router.execute(
+ message.capability,
+ message.input,
+ stream=message.stream,
+ cancel_token=cancel_token,
+ request_id=message.id,
+ )
+
+ async def _route_handler_invoke(
+ self,
+ request_id: str,
+ payload: dict[str, Any],
+ _cancel_token,
+ ) -> dict[str, Any]:
+ handler_id = str(payload.get("handler_id", ""))
+ session = self.handler_to_worker.get(handler_id)
+ if session is None:
+ raise AstrBotError.invalid_input(f"handler not found: {handler_id}")
+ self.active_requests[request_id] = session
+ try:
+ return await session.invoke_handler(
+ handler_id,
+ payload.get("event", {}),
+ request_id=request_id,
+ args=payload.get("args", {}),
+ )
+ finally:
+ self.active_requests.pop(request_id, None)
+
+ async def _handle_upstream_cancel(self, request_id: str) -> None:
+ session = self.active_requests.get(request_id)
+ if session is not None:
+ await session.cancel(request_id)
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/transport.py b/astrbot-sdk/src/astrbot_sdk/runtime/transport.py
new file mode 100644
index 0000000000..1b09beac05
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/transport.py
@@ -0,0 +1,523 @@
+"""传输层抽象模块。
+
+定义 Transport 抽象基类及其实现,负责底层的消息传输。
+传输层只关心"发送 opaque bytes"和"接收 opaque bytes",不处理协议细节。
+传输实现:
+ Transport: 抽象基类,定义 start/stop/send/wait_closed 接口
+ StdioTransport: 标准输入输出传输
+ - 进程模式: 通过 command 参数启动子进程
+ - 文件模式: 通过 stdin/stdout 参数指定文件描述符
+
+传输类型:
+ Transport: 抽象基类,定义 start/stop/send 接口
+ StdioTransport: 标准输入输出传输,支持进程模式和文件模式
+ WebSocketServerTransport: WebSocket 服务端传输
+ - 单连接限制,支持心跳配置
+ - 通过 port 属性获取实际监听端口
+ - 自动重连需要外部实现
+
+使用示例:
+ # 子进程模式
+ transport = StdioTransport(
+ command=["python", "-m", "my_plugin"],
+ cwd="/path/to/plugin",
+ )
+
+ # 标准输入输出模式
+ transport = StdioTransport(stdin=sys.stdin, stdout=sys.stdout)
+
+ # WebSocket 服务端
+ transport = WebSocketServerTransport(host="0.0.0.0", port=8765)
+
+ # WebSocket 客户端
+ transport = WebSocketClientTransport(url="ws://localhost:8765")
+
+ # 统一接口
+ transport.set_message_handler(my_handler)
+ await transport.start()
+ await transport.send(json_bytes)
+ await transport.stop()
+
+`Transport` 只处理“opaque bytes 发出去 / opaque bytes 收进来”这件事,不做协议解析,也不关心
+能力、handler 或迁移适配策略。当前实现包括:
+
+- `StdioTransport`: 子进程或文件对象上的长度前缀字节帧传输
+- `WebSocketServerTransport`: 单连接 WebSocket 服务端
+- `WebSocketClientTransport`: WebSocket 客户端
+
+自动重连、消息重放等策略不在这里实现,统一留给更上层编排。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import ssl
+import sys
+from abc import ABC, abstractmethod
+from collections.abc import Awaitable, Callable, Sequence
+from pathlib import Path
+from typing import IO, Any
+
+from .._internal.sdk_logger import logger
+
+MessageHandler = Callable[[bytes], Awaitable[None]]
+STDIO_SUBPROCESS_STREAM_LIMIT = 8 * 1024 * 1024
+
+
+def build_websocket_server_ssl_context(
+ *,
+ ca_file: str | Path,
+ cert_file: str | Path,
+ key_file: str | Path,
+) -> ssl.SSLContext:
+ """Build a mutual-TLS server SSL context for websocket workers."""
+ context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(cafile=str(ca_file))
+ context.load_cert_chain(certfile=str(cert_file), keyfile=str(key_file))
+ return context
+
+
+def build_websocket_client_ssl_context(
+ *,
+ ca_file: str | Path,
+ cert_file: str | Path,
+ key_file: str | Path,
+) -> ssl.SSLContext:
+ """Build a mutual-TLS client SSL context for websocket supervisor sessions."""
+ context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=str(ca_file))
+ context.load_cert_chain(certfile=str(cert_file), keyfile=str(key_file))
+ return context
+
+
+def _get_aiohttp():
+ import aiohttp
+
+ return aiohttp
+
+
+def _get_web():
+ from aiohttp import web
+
+ return web
+
+
+def _frame_stdio_payload(payload: bytes | bytearray | memoryview) -> bytes:
+ body = bytes(payload)
+ return f"{len(body)}\n".encode("ascii") + body
+
+
+def _parse_stdio_header(raw_header: bytes) -> int:
+ header = raw_header.decode("ascii").strip()
+ if not header:
+ raise ValueError("STDIO frame header is empty")
+ try:
+ size = int(header)
+ except ValueError as exc:
+ raise ValueError(f"Invalid STDIO frame header: {header!r}") from exc
+ # 拒绝负数 size,防止子进程写入畸形 header 导致 readexactly 行为异常
+ if size < 0:
+ raise ValueError(f"STDIO frame size must be non-negative: {size}")
+ return size
+
+
+def _is_windows_access_denied(error: BaseException) -> bool:
+ return (
+ sys.platform == "win32"
+ and isinstance(error, PermissionError)
+ and getattr(error, "winerror", None) == 5
+ )
+
+
+class Transport(ABC):
+ def __init__(self) -> None:
+ self._handler: MessageHandler | None = None
+ self._closed = asyncio.Event()
+
+ def set_message_handler(self, handler: MessageHandler) -> None:
+ """注册收到原始字节帧后的回调。"""
+ self._handler = handler
+
+ @abstractmethod
+ async def start(self) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ async def stop(self) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ async def send(self, payload: bytes) -> None:
+ raise NotImplementedError
+
+ async def wait_closed(self) -> None:
+ """等待传输层进入关闭状态。"""
+ await self._closed.wait()
+
+ async def _dispatch(self, payload: bytes) -> None:
+ """把收到的原始字节载荷转交给上层处理器。"""
+ if self._handler is not None:
+ await self._handler(payload)
+
+ async def _dispatch_safely(self, payload: bytes, *, source: str) -> None:
+ """安全地分发一帧消息:捕获所有非取消异常,避免单帧处理错误拖垮整个读循环。"""
+ try:
+ await self._dispatch(payload)
+ except asyncio.CancelledError:
+ # CancelledError 必须放行,否则无法优雅关闭
+ raise
+ except Exception:
+ # 记录异常后继续读下一帧,而不是让读循环崩溃导致整个 transport 不可用
+ logger.exception("Dropping inbound transport frame from {}", source)
+
+
+class StdioTransport(Transport):
+ def __init__(
+ self,
+ *,
+ stdin: IO[str] | None = None,
+ stdout: IO[str] | None = None,
+ command: Sequence[str] | None = None,
+ cwd: str | None = None,
+ env: dict[str, str] | None = None,
+ ) -> None:
+ super().__init__()
+ self._stdin = stdin
+ self._stdout = stdout
+ self._command = list(command) if command is not None else None
+ self._cwd = cwd
+ self._env = env
+ self._process: asyncio.subprocess.Process | None = None
+ self._reader_task: asyncio.Task[None] | None = None
+
+ async def start(self) -> None:
+ self._closed.clear()
+ if self._command is not None:
+ self._process = await self._start_subprocess_with_retry()
+ self._reader_task = asyncio.create_task(self._read_process_loop())
+ return
+
+ self._stdin = self._stdin or sys.stdin
+ self._stdout = self._stdout or sys.stdout
+ self._reader_task = asyncio.create_task(self._read_file_loop())
+
+ async def _start_subprocess_with_retry(self) -> asyncio.subprocess.Process:
+ assert self._command is not None # 类型收窄:start() 已确保非空
+ delays = [0.15, 0.35, 0.75]
+ last_error: BaseException | None = None
+ for attempt, delay in enumerate([0.0, *delays], start=1):
+ if delay:
+ await asyncio.sleep(delay)
+ try:
+ return await asyncio.create_subprocess_exec(
+ *self._command,
+ cwd=self._cwd,
+ env=self._env,
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=sys.stderr,
+ limit=STDIO_SUBPROCESS_STREAM_LIMIT,
+ )
+ except Exception as exc:
+ last_error = exc
+ if not _is_windows_access_denied(exc) or attempt == len(delays) + 1:
+ raise
+ logger.warning(
+ "Windows denied access while starting freshly prepared worker "
+ "interpreter, retrying attempt {}/{}: {}",
+ attempt,
+ len(delays) + 1,
+ exc,
+ )
+ assert last_error is not None
+ raise last_error
+
+ async def stop(self) -> None:
+ if self._reader_task is not None:
+ self._reader_task.cancel()
+ try:
+ await self._reader_task
+ except asyncio.CancelledError:
+ pass
+ self._reader_task = None
+
+ if self._process is not None:
+ if self._process.returncode is None:
+ self._process.terminate()
+ try:
+ await asyncio.wait_for(self._process.wait(), timeout=5)
+ except asyncio.TimeoutError:
+ self._process.kill()
+ await self._process.wait()
+ self._process = None
+ self._closed.set()
+
+ async def send(self, payload: bytes) -> None:
+ frame = _frame_stdio_payload(payload)
+ if self._process is not None:
+ if self._process.stdin is None:
+ raise RuntimeError("STDIO subprocess stdin 不可用")
+ self._process.stdin.write(frame)
+ await self._process.stdin.drain()
+ return
+
+ if self._stdout is None:
+ raise RuntimeError("STDIO stdout 不可用")
+
+ def _write() -> None:
+ assert self._stdout is not None
+ binary_stdout = getattr(self._stdout, "buffer", None)
+ if binary_stdout is None:
+ raise RuntimeError("STDIO stdout 必须提供可写入 bytes 的 buffer")
+ binary_stdout.write(frame)
+ binary_stdout.flush()
+
+ await asyncio.to_thread(_write)
+
+ async def _read_process_loop(self) -> None:
+ """从子进程 stdout 持续读取 STDIO 帧,单帧异常不中断整体读取。"""
+ assert self._process is not None
+ assert self._process.stdout is not None
+ try:
+ while True:
+ try:
+ raw_header = await self._process.stdout.readline()
+ if not raw_header:
+ break
+ payload_size = _parse_stdio_header(raw_header)
+ raw = await self._process.stdout.readexactly(payload_size)
+ # 使用 _dispatch_safely 而非 _dispatch,确保上层的单帧处理错误不会终结读循环
+ await self._dispatch_safely(
+ raw,
+ source="stdio-process",
+ )
+ except asyncio.CancelledError:
+ raise
+ except asyncio.IncompleteReadError:
+ # 帧被截断说明子进程已经异常退出,读循环应终止
+ logger.warning("STDIO subprocess frame truncated before completion")
+ break
+ except ValueError as exc:
+ # header 解析失败后无法再可靠定位后续帧边界;继续读取只会让协议流长期失同步。
+ logger.warning(
+ "Stopping STDIO subprocess read loop after malformed frame: {}",
+ exc,
+ )
+ break
+ finally:
+ self._closed.set()
+
+ async def _read_file_loop(self) -> None:
+ """从本地 stdin(file 模式)持续读取 STDIO 帧,单帧异常不中断整体读取。"""
+ assert self._stdin is not None
+ try:
+ while True:
+ try:
+ binary_stdin = getattr(self._stdin, "buffer", None)
+ if binary_stdin is None:
+ raise RuntimeError("STDIO stdin 必须提供可读取 bytes 的 buffer")
+ raw_header = await asyncio.to_thread(binary_stdin.readline)
+ if not raw_header:
+ break
+ payload_size = _parse_stdio_header(raw_header)
+ raw = await asyncio.to_thread(binary_stdin.read, payload_size)
+ if len(raw) != payload_size:
+ raise EOFError("STDIO frame truncated before payload completed")
+ await self._dispatch_safely(
+ raw,
+ source="stdio-file",
+ )
+ except asyncio.CancelledError:
+ raise
+ except EOFError as exc:
+ # 流被截断意味着上游已关闭,读循环应终止
+ logger.warning("{}", exc)
+ break
+ except ValueError as exc:
+ # 文件模式同样无法从坏 header 中恢复到下一帧边界;直接终止读取更安全。
+ logger.warning(
+ "Stopping STDIO file read loop after malformed frame: {}", exc
+ )
+ break
+ finally:
+ self._closed.set()
+
+
+class WebSocketServerTransport(Transport):
+ def __init__(
+ self,
+ *,
+ host: str = "127.0.0.1",
+ port: int = 8765,
+ path: str = "/",
+ heartbeat: float = 30.0,
+ ssl_context: ssl.SSLContext | None = None,
+ ) -> None:
+ super().__init__()
+ self._host = host
+ self._port = port
+ self._actual_port: int | None = None
+ self._path = path
+ self._heartbeat = heartbeat
+ self._ssl_context = ssl_context
+ self._app: Any | None = None
+ self._runner: Any | None = None
+ self._site: Any | None = None
+ self._ws: Any | None = None
+ self._write_lock = asyncio.Lock()
+ self._connected = asyncio.Event()
+
+ async def start(self) -> None:
+ web = _get_web()
+ self._closed.clear()
+ self._connected.clear()
+ self._app = web.Application()
+ self._app.router.add_get(self._path, self._handle_socket)
+ self._runner = web.AppRunner(self._app)
+ await self._runner.setup()
+ self._site = web.TCPSite(
+ self._runner,
+ self._host,
+ self._port,
+ ssl_context=self._ssl_context,
+ )
+ await self._site.start()
+ if self._site._server and getattr(self._site._server, "sockets", None):
+ socket = self._site._server.sockets[0]
+ self._actual_port = socket.getsockname()[1]
+
+ async def stop(self) -> None:
+ self._connected.clear()
+ if self._ws is not None and not self._ws.closed:
+ await self._ws.close()
+ if self._site is not None:
+ await self._site.stop()
+ self._site = None
+ if self._runner is not None:
+ await self._runner.cleanup()
+ self._runner = None
+ self._closed.set()
+
+ async def send(self, payload: bytes) -> None:
+ if self._ws is None or self._ws.closed:
+ await asyncio.wait_for(self._connected.wait(), timeout=30.0)
+ if self._ws is None or self._ws.closed:
+ raise RuntimeError("WebSocket 尚未连接")
+ async with self._write_lock:
+ await self._ws.send_bytes(payload)
+
+ async def _handle_socket(self, request) -> Any:
+ web = _get_web()
+ aiohttp = _get_aiohttp()
+ if self._ws is not None and not self._ws.closed:
+ ws = web.WebSocketResponse()
+ await ws.prepare(request)
+ await ws.close(code=1008, message=b"only one websocket connection allowed")
+ return ws
+
+ ws = web.WebSocketResponse(
+ heartbeat=self._heartbeat if self._heartbeat > 0 else None
+ )
+ await ws.prepare(request)
+ self._ws = ws
+ self._connected.set()
+ try:
+ async for msg in ws:
+ if msg.type == aiohttp.WSMsgType.TEXT:
+ await self._dispatch_safely(
+ msg.data.encode("utf-8"), source="websocket-server-text"
+ )
+ elif msg.type == aiohttp.WSMsgType.BINARY:
+ await self._dispatch_safely(
+ bytes(msg.data),
+ source="websocket-server-binary",
+ )
+ elif msg.type == aiohttp.WSMsgType.ERROR:
+ logger.error("websocket server error: {}", ws.exception())
+ break
+ finally:
+ self._connected.clear()
+ self._closed.set()
+ self._ws = None
+ return ws
+
+ @property
+ def port(self) -> int:
+ return self._actual_port or self._port
+
+ @property
+ def url(self) -> str:
+ scheme = "wss" if self._ssl_context is not None else "ws"
+ return f"{scheme}://{self._host}:{self.port}{self._path}"
+
+
+class WebSocketClientTransport(Transport):
+ def __init__(
+ self,
+ *,
+ url: str,
+ heartbeat: float = 30.0,
+ ssl_context: ssl.SSLContext | None = None,
+ server_hostname: str | None = None,
+ ) -> None:
+ super().__init__()
+ self._url = url
+ self._heartbeat = heartbeat
+ self._ssl_context = ssl_context
+ self._server_hostname = server_hostname
+ self._session: Any | None = None
+ self._ws: Any | None = None
+ self._reader_task: asyncio.Task[None] | None = None
+
+ async def start(self) -> None:
+ aiohttp = _get_aiohttp()
+ self._closed.clear()
+ self._session = aiohttp.ClientSession()
+ self._ws = await self._session.ws_connect(
+ self._url,
+ heartbeat=self._heartbeat if self._heartbeat > 0 else None,
+ ssl_context=self._ssl_context,
+ server_hostname=self._server_hostname,
+ )
+ self._reader_task = asyncio.create_task(self._read_loop())
+
+ async def stop(self) -> None:
+ if self._reader_task is not None:
+ self._reader_task.cancel()
+ try:
+ await self._reader_task
+ except asyncio.CancelledError:
+ pass
+ self._reader_task = None
+ if self._ws is not None and not self._ws.closed:
+ await self._ws.close()
+ if self._session is not None:
+ await self._session.close()
+ self._ws = None
+ self._session = None
+ self._closed.set()
+
+ async def send(self, payload: bytes) -> None:
+ if self._ws is None or self._ws.closed:
+ raise RuntimeError("WebSocket client 尚未连接")
+ await self._ws.send_bytes(payload)
+
+ async def _read_loop(self) -> None:
+ assert self._ws is not None
+ aiohttp = _get_aiohttp()
+ try:
+ async for msg in self._ws:
+ if msg.type == aiohttp.WSMsgType.TEXT:
+ await self._dispatch_safely(
+ msg.data.encode("utf-8"), source="websocket-client-text"
+ )
+ elif msg.type == aiohttp.WSMsgType.BINARY:
+ await self._dispatch_safely(
+ bytes(msg.data),
+ source="websocket-client-binary",
+ )
+ elif msg.type == aiohttp.WSMsgType.ERROR:
+ logger.error("websocket client error: {}", self._ws.exception())
+ break
+ finally:
+ self._closed.set()
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/worker.py b/astrbot-sdk/src/astrbot_sdk/runtime/worker.py
new file mode 100644
index 0000000000..9715f248d8
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/worker.py
@@ -0,0 +1,516 @@
+"""Worker 端运行时:PluginWorkerRuntime 运行单个插件,GroupWorkerRuntime 在同一进程中运行多个插件。
+
+核心类:
+ GroupWorkerRuntime: 组 Worker 运行时
+ - 在同一进程中加载并运行多个插件
+ - 聚合所有插件的 handlers 和 capabilities
+ - 统一处理 invoke 和 cancel 请求
+ - 管理每个插件的生命周期回调
+
+ PluginWorkerRuntime: 单插件 Worker 运行时
+ - 加载单个插件
+ - 通过 Peer 与 Supervisor 通信
+ - 分发 handler 调用
+ - 处理生命周期回调 (on_start, on_stop)
+
+启动流程:
+ Worker 启动:
+ 1. load_plugin_spec() 加载插件规范
+ 2. load_plugin() 加载插件组件
+ 3. 创建 Peer 并设置处理器
+ 4. 向 Supervisor 发送 initialize
+ 5. 等待 Supervisor 的 initialize_result
+ 6. 执行 on_start 生命周期回调
+"""
+
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+from .._internal.decorator_lifecycle import run_lifecycle_with_decorators
+from .._internal.invocation_context import caller_plugin_scope
+from .._internal.sdk_logger import logger
+from ..context import Context as RuntimeContext
+from ..errors import AstrBotError
+from ..protocol.codec import MsgpackProtocolCodec, ProtocolCodec
+from ..protocol.messages import PeerInfo
+from .handler_dispatcher import CapabilityDispatcher, HandlerDispatcher
+from .loader import (
+ LoadedPlugin,
+ PluginDiscoveryIssue,
+ PluginSpec,
+ load_plugin,
+ load_plugin_config,
+ load_plugin_spec,
+)
+from .peer import Peer
+
+__all__ = [
+ "GroupPluginRuntimeState",
+ "GroupWorkerRuntime",
+ "PluginWorkerRuntime",
+ "_load_plugin_specs",
+ "_load_group_plugin_specs",
+]
+
+
+@dataclass(slots=True)
+class GroupPluginRuntimeState:
+ plugin: PluginSpec
+ loaded_plugin: LoadedPlugin
+ lifecycle_context: RuntimeContext
+
+
+def _load_group_plugin_specs(group_metadata_path: Path) -> tuple[str, list[PluginSpec]]:
+ try:
+ payload = json.loads(group_metadata_path.read_text(encoding="utf-8"))
+ except Exception as exc:
+ raise RuntimeError(
+ f"failed to read worker group metadata: {group_metadata_path}"
+ ) from exc
+
+ if not isinstance(payload, dict):
+ raise RuntimeError(f"invalid worker group metadata: {group_metadata_path}")
+
+ entries = payload.get("plugin_entries")
+ if not isinstance(entries, list) or not entries:
+ raise RuntimeError(
+ f"worker group metadata missing plugin_entries: {group_metadata_path}"
+ )
+
+ plugins: list[PluginSpec] = []
+ for entry in entries:
+ if not isinstance(entry, dict):
+ raise RuntimeError(
+ f"worker group metadata contains invalid plugin entry: {group_metadata_path}"
+ )
+ plugin_dir = entry.get("plugin_dir")
+ if not isinstance(plugin_dir, str) or not plugin_dir:
+ raise RuntimeError(
+ f"worker group metadata contains invalid plugin_dir: {group_metadata_path}"
+ )
+ plugins.append(load_plugin_spec(Path(plugin_dir)))
+
+ group_id = payload.get("group_id")
+ if not isinstance(group_id, str) or not group_id:
+ group_id = group_metadata_path.stem
+ return group_id, plugins
+
+
+def _load_plugin_specs(plugin_dirs: list[Path]) -> list[PluginSpec]:
+ if not plugin_dirs:
+ raise RuntimeError("worker requires at least one plugin directory")
+ return [load_plugin_spec(plugin_dir) for plugin_dir in plugin_dirs]
+
+
+def _build_worker_registry_entry(
+ plugin: PluginSpec,
+ *,
+ enabled: bool,
+) -> dict[str, Any]:
+ manifest = plugin.manifest_data
+ return {
+ "name": plugin.name,
+ "display_name": str(manifest.get("display_name") or plugin.name),
+ "description": str(manifest.get("desc") or manifest.get("description") or ""),
+ "repo": str(manifest.get("repo") or ""),
+ "author": str(manifest.get("author") or ""),
+ "version": str(manifest.get("version") or "0.0.0"),
+ "enabled": enabled,
+ "config": load_plugin_config(plugin),
+ }
+
+
+def _build_worker_initialize_metadata(
+ *,
+ worker_id: str,
+ plugins: list[PluginSpec],
+ loaded_plugins: list[tuple[PluginSpec, LoadedPlugin]],
+ skipped_plugins: dict[str, str],
+ issues: list[PluginDiscoveryIssue],
+) -> dict[str, Any]:
+ loaded_plugin_names = [plugin.name for plugin, _loaded_plugin in loaded_plugins]
+ enabled_plugins = set(loaded_plugin_names)
+ capability_sources: dict[str, str] = {}
+ llm_tools: list[dict[str, Any]] = []
+ agents: list[dict[str, Any]] = []
+
+ for plugin, loaded_plugin in loaded_plugins:
+ plugin_name = plugin.name
+ capability_sources.update(
+ {
+ capability.descriptor.name: plugin_name
+ for capability in loaded_plugin.capabilities
+ }
+ )
+ llm_tools.extend(
+ {
+ **tool.spec.to_payload(),
+ "plugin_id": plugin_name,
+ }
+ for tool in loaded_plugin.llm_tools
+ )
+ agents.extend(
+ {
+ **agent.spec.to_payload(),
+ "plugin_id": plugin_name,
+ }
+ for agent in loaded_plugin.agents
+ )
+
+ return {
+ "worker_id": worker_id,
+ "plugins": [plugin.name for plugin in plugins],
+ "loaded_plugins": loaded_plugin_names,
+ "skipped_plugins": dict(skipped_plugins),
+ "worker_registry": [
+ _build_worker_registry_entry(
+ plugin,
+ enabled=plugin.name in enabled_plugins,
+ )
+ for plugin in plugins
+ ],
+ "capability_sources": capability_sources,
+ "issues": [issue.to_payload() for issue in issues],
+ "llm_tools": llm_tools,
+ "agents": agents,
+ }
+
+
+async def run_plugin_lifecycle(
+ instances: list[Any],
+ method_name: str,
+ context: RuntimeContext,
+) -> None:
+ """运行插件生命周期方法。"""
+ for instance in instances:
+ method = getattr(instance, method_name, None)
+ with caller_plugin_scope(context.plugin_id):
+ await run_lifecycle_with_decorators(
+ instance=instance,
+ hook=method if callable(method) else None,
+ method_name=method_name,
+ context=context,
+ )
+
+
+class GroupWorkerRuntime:
+ def __init__(
+ self,
+ *,
+ transport,
+ group_metadata_path: Path | None = None,
+ plugin_dirs: list[Path] | None = None,
+ worker_id: str | None = None,
+ wire_codec: ProtocolCodec | None = None,
+ ) -> None:
+ if group_metadata_path is None and not plugin_dirs:
+ raise ValueError("group_metadata_path or plugin_dirs is required")
+ if group_metadata_path is not None and plugin_dirs:
+ raise ValueError(
+ "group_metadata_path and plugin_dirs are mutually exclusive"
+ )
+ self.group_metadata_path = (
+ group_metadata_path.resolve() if group_metadata_path is not None else None
+ )
+ if self.group_metadata_path is not None:
+ default_worker_id, plugins = _load_group_plugin_specs(
+ self.group_metadata_path
+ )
+ else:
+ assert plugin_dirs is not None
+ plugins = _load_plugin_specs([path.resolve() for path in plugin_dirs])
+ default_worker_id = plugins[0].name
+ self.plugins = plugins
+ self.worker_id = str(worker_id or default_worker_id)
+ self.transport = transport
+ self.wire_codec = wire_codec or MsgpackProtocolCodec()
+ self.peer = Peer(
+ transport=self.transport,
+ peer_info=PeerInfo(name=self.worker_id, role="plugin", version="s5r"),
+ wire_codec=self.wire_codec,
+ )
+ self.skipped_plugins: dict[str, str] = {}
+ self.issues: list[PluginDiscoveryIssue] = []
+ self._plugin_states: list[GroupPluginRuntimeState] = []
+ self._active_plugin_states: list[GroupPluginRuntimeState] = []
+ self._load_plugins()
+ self._refresh_dispatchers()
+ self.peer.set_invoke_handler(self._handle_invoke)
+ self.peer.set_cancel_handler(self._handle_cancel)
+
+ def _load_plugins(self) -> None:
+ for plugin in self.plugins:
+ try:
+ loaded_plugin = load_plugin(plugin)
+ except Exception as exc:
+ self.skipped_plugins[plugin.name] = str(exc)
+ self.issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="load",
+ plugin_id=plugin.name,
+ message="插件加载失败",
+ details=str(exc),
+ )
+ )
+ logger.exception(
+ "worker {} 中插件 {} 加载失败,启动时将跳过",
+ self.worker_id,
+ plugin.name,
+ )
+ continue
+
+ lifecycle_context = RuntimeContext(peer=self.peer, plugin_id=plugin.name)
+ self._plugin_states.append(
+ GroupPluginRuntimeState(
+ plugin=plugin,
+ loaded_plugin=loaded_plugin,
+ lifecycle_context=lifecycle_context,
+ )
+ )
+ self._active_plugin_states = list(self._plugin_states)
+
+ def _refresh_dispatchers(self) -> None:
+ handlers = [
+ handler
+ for state in self._active_plugin_states
+ for handler in state.loaded_plugin.handlers
+ ]
+ capabilities = [
+ capability
+ for state in self._active_plugin_states
+ for capability in state.loaded_plugin.capabilities
+ ]
+ self.dispatcher = HandlerDispatcher(
+ plugin_id=self.worker_id,
+ peer=self.peer,
+ handlers=handlers,
+ )
+ self.capability_dispatcher = CapabilityDispatcher(
+ plugin_id=self.worker_id,
+ peer=self.peer,
+ capabilities=capabilities,
+ llm_tools=[
+ tool
+ for state in self._active_plugin_states
+ for tool in state.loaded_plugin.llm_tools
+ ],
+ )
+
+ async def start(self) -> None:
+ await self.peer.start()
+ started_states: list[GroupPluginRuntimeState] = []
+ try:
+ active_states: list[GroupPluginRuntimeState] = []
+ for state in self._plugin_states:
+ try:
+ await self._run_lifecycle(state, "on_start")
+ except Exception as exc:
+ self.skipped_plugins[state.plugin.name] = str(exc)
+ self.issues.append(
+ PluginDiscoveryIssue(
+ severity="error",
+ phase="lifecycle",
+ plugin_id=state.plugin.name,
+ message="插件 on_start 失败",
+ details=str(exc),
+ )
+ )
+ logger.exception(
+ "worker {} 中插件 {} on_start 失败,启动时将跳过",
+ self.worker_id,
+ state.plugin.name,
+ )
+ continue
+ active_states.append(state)
+ started_states.append(state)
+
+ self._active_plugin_states = active_states
+ self._refresh_dispatchers()
+ if not self._active_plugin_states:
+ raise RuntimeError(f"worker {self.worker_id} has no active plugins")
+
+ await self.peer.initialize(
+ [
+ handler.descriptor
+ for state in self._active_plugin_states
+ for handler in state.loaded_plugin.handlers
+ ],
+ provided_capabilities=[
+ capability.descriptor
+ for state in self._active_plugin_states
+ for capability in state.loaded_plugin.capabilities
+ ],
+ metadata=self._initialize_metadata(),
+ )
+ except Exception:
+ for state in reversed(started_states):
+ try:
+ await self._run_lifecycle(state, "on_stop")
+ except Exception:
+ logger.exception(
+ "worker {} 在启动失败清理插件 {} on_stop 时发生异常",
+ self.worker_id,
+ state.plugin.name,
+ )
+ await self.peer.stop()
+ raise
+
+ async def stop(self) -> None:
+ first_error: Exception | None = None
+ try:
+ for state in reversed(self._active_plugin_states):
+ try:
+ await self._run_lifecycle(state, "on_stop")
+ except Exception as exc:
+ if first_error is None:
+ first_error = exc
+ logger.exception(
+ "worker {} 停止插件 {} 时发生异常",
+ self.worker_id,
+ state.plugin.name,
+ )
+ finally:
+ await self.peer.stop()
+ if first_error is not None:
+ raise first_error
+
+ async def _handle_invoke(self, message, cancel_token):
+ if message.capability == "handler.invoke":
+ return await self.dispatcher.invoke(message, cancel_token)
+ try:
+ return await self.capability_dispatcher.invoke(message, cancel_token)
+ except LookupError as exc:
+ raise AstrBotError.capability_not_found(message.capability) from exc
+
+ async def _handle_cancel(self, request_id: str) -> None:
+ await self.dispatcher.cancel(request_id)
+ await self.capability_dispatcher.cancel(request_id)
+
+ def _initialize_metadata(self) -> dict[str, Any]:
+ return _build_worker_initialize_metadata(
+ worker_id=self.worker_id,
+ plugins=self.plugins,
+ loaded_plugins=[
+ (state.plugin, state.loaded_plugin)
+ for state in self._active_plugin_states
+ ],
+ skipped_plugins=self.skipped_plugins,
+ issues=self.issues,
+ )
+
+ async def _run_lifecycle(
+ self,
+ state: GroupPluginRuntimeState,
+ method_name: str,
+ ) -> None:
+ await run_plugin_lifecycle(
+ state.loaded_plugin.instances, method_name, state.lifecycle_context
+ )
+
+
+class PluginWorkerRuntime:
+ def __init__(
+ self,
+ *,
+ plugin_dir: Path,
+ transport,
+ worker_id: str | None = None,
+ wire_codec: ProtocolCodec | None = None,
+ ) -> None:
+ self.plugin = load_plugin_spec(plugin_dir)
+ self.worker_id = str(worker_id or self.plugin.name)
+ self.transport = transport
+ self.wire_codec = wire_codec or MsgpackProtocolCodec()
+ self.loaded_plugin = load_plugin(self.plugin)
+ self.peer = Peer(
+ transport=self.transport,
+ peer_info=PeerInfo(name=self.worker_id, role="plugin", version="s5r"),
+ wire_codec=self.wire_codec,
+ )
+ self.dispatcher = HandlerDispatcher(
+ plugin_id=self.plugin.name,
+ peer=self.peer,
+ handlers=self.loaded_plugin.handlers,
+ )
+ self.capability_dispatcher = CapabilityDispatcher(
+ plugin_id=self.plugin.name,
+ peer=self.peer,
+ capabilities=self.loaded_plugin.capabilities,
+ llm_tools=self.loaded_plugin.llm_tools,
+ )
+ self._lifecycle_context = RuntimeContext(
+ peer=self.peer, plugin_id=self.plugin.name
+ )
+ self.issues: list[PluginDiscoveryIssue] = []
+ self.peer.set_invoke_handler(self._handle_invoke)
+ self.peer.set_cancel_handler(self._handle_cancel)
+
+ async def start(self) -> None:
+ await self.peer.start()
+ lifecycle_started = False
+ try:
+ await self._run_lifecycle("on_start")
+ lifecycle_started = True
+ await self.peer.initialize(
+ [item.descriptor for item in self.loaded_plugin.handlers],
+ provided_capabilities=[
+ item.descriptor for item in self.loaded_plugin.capabilities
+ ],
+ metadata=_build_worker_initialize_metadata(
+ worker_id=self.worker_id,
+ plugins=[self.plugin],
+ loaded_plugins=[(self.plugin, self.loaded_plugin)],
+ skipped_plugins={},
+ issues=self.issues,
+ ),
+ )
+ except Exception:
+ if lifecycle_started:
+ logger.exception(
+ "插件 {} 在向 supervisor 上报 initialize 时失败",
+ self.plugin.name,
+ )
+ else:
+ logger.exception(
+ "插件 {} 在 on_start / 装饰器初始化阶段失败;"
+ "supervisor 可能随后只看到初始化超时,请优先检查这条异常",
+ self.plugin.name,
+ )
+ if lifecycle_started:
+ try:
+ await self._run_lifecycle("on_stop")
+ except Exception:
+ logger.exception(
+ "插件 {} 在启动失败清理 on_stop 时发生异常",
+ self.plugin.name,
+ )
+ await self.peer.stop()
+ raise
+
+ async def stop(self) -> None:
+ try:
+ await self._run_lifecycle("on_stop")
+ finally:
+ await self.peer.stop()
+
+ async def _handle_invoke(self, message, cancel_token):
+ if message.capability == "handler.invoke":
+ return await self.dispatcher.invoke(message, cancel_token)
+ try:
+ return await self.capability_dispatcher.invoke(message, cancel_token)
+ except LookupError as exc:
+ raise AstrBotError.capability_not_found(message.capability) from exc
+
+ async def _handle_cancel(self, request_id: str) -> None:
+ await self.dispatcher.cancel(request_id)
+ await self.capability_dispatcher.cancel(request_id)
+
+ async def _run_lifecycle(self, method_name: str) -> None:
+ await run_plugin_lifecycle(
+ self.loaded_plugin.instances, method_name, self._lifecycle_context
+ )
diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/workers_manifest.py b/astrbot-sdk/src/astrbot_sdk/runtime/workers_manifest.py
new file mode 100644
index 0000000000..724ffa247b
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/runtime/workers_manifest.py
@@ -0,0 +1,120 @@
+"""Supervisor-side manifest for remote websocket workers."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from pathlib import Path
+from urllib.parse import urlparse
+
+import yaml
+
+
+@dataclass(slots=True)
+class RemoteWorkerTLSConfig:
+ ca_file: Path
+ cert_file: Path
+ key_file: Path
+ server_hostname: str | None = None
+
+
+@dataclass(slots=True)
+class RemoteWorkerSpec:
+ id: str
+ url: str
+ tls: RemoteWorkerTLSConfig
+
+
+def load_remote_workers_manifest(manifest_path: Path) -> list[RemoteWorkerSpec]:
+ resolved_path = manifest_path.resolve()
+ payload = yaml.safe_load(resolved_path.read_text(encoding="utf-8")) or {}
+ if not isinstance(payload, dict):
+ raise ValueError("workers manifest must be a mapping")
+
+ entries = payload.get("workers")
+ if not isinstance(entries, list):
+ raise ValueError("workers manifest must define a 'workers' list")
+
+ workers: list[RemoteWorkerSpec] = []
+ seen_ids: set[str] = set()
+ for index, entry in enumerate(entries):
+ if not isinstance(entry, dict):
+ raise ValueError(f"workers[{index}] must be an object")
+ _reject_unsupported_worker_keys(entry, index=index)
+ worker_id = str(entry.get("id", "")).strip()
+ if not worker_id:
+ raise ValueError(f"workers[{index}].id must be a non-empty string")
+ if worker_id in seen_ids:
+ raise ValueError(f"duplicate worker id in workers manifest: {worker_id}")
+ seen_ids.add(worker_id)
+
+ raw_url = str(entry.get("url", "")).strip()
+ parsed = urlparse(raw_url)
+ if parsed.scheme != "wss":
+ raise ValueError(
+ f"workers[{index}].url must use wss:// for mutual TLS: {raw_url!r}"
+ )
+ if not parsed.netloc:
+ raise ValueError(f"workers[{index}].url must include a host: {raw_url!r}")
+
+ tls_payload = entry.get("tls")
+ if not isinstance(tls_payload, dict):
+ raise ValueError(f"workers[{index}].tls must be an object")
+ tls = _load_tls_config(
+ tls_payload,
+ manifest_dir=resolved_path.parent,
+ prefix=f"workers[{index}].tls",
+ )
+ workers.append(RemoteWorkerSpec(id=worker_id, url=raw_url, tls=tls))
+
+ return workers
+
+
+def _reject_unsupported_worker_keys(entry: dict[str, object], *, index: int) -> None:
+ unsupported = {"group_id", "plugins"} & set(entry)
+ if unsupported:
+ names = ", ".join(sorted(unsupported))
+ raise ValueError(
+ f"workers[{index}] must not declare {names}; websocket host config only "
+ "accepts worker connection settings"
+ )
+
+
+def _load_tls_config(
+ payload: dict[str, object],
+ *,
+ manifest_dir: Path,
+ prefix: str,
+) -> RemoteWorkerTLSConfig:
+ ca_file = _resolve_required_path(
+ payload.get("ca_file"), manifest_dir, f"{prefix}.ca_file"
+ )
+ cert_file = _resolve_required_path(
+ payload.get("cert_file"),
+ manifest_dir,
+ f"{prefix}.cert_file",
+ )
+ key_file = _resolve_required_path(
+ payload.get("key_file"), manifest_dir, f"{prefix}.key_file"
+ )
+ server_hostname_raw = payload.get("server_hostname")
+ server_hostname = (
+ str(server_hostname_raw).strip() if server_hostname_raw is not None else None
+ )
+ if server_hostname == "":
+ server_hostname = None
+ return RemoteWorkerTLSConfig(
+ ca_file=ca_file,
+ cert_file=cert_file,
+ key_file=key_file,
+ server_hostname=server_hostname,
+ )
+
+
+def _resolve_required_path(value: object, base_dir: Path, field_name: str) -> Path:
+ text = str(value or "").strip()
+ if not text:
+ raise ValueError(f"{field_name} must be a non-empty path")
+ path = Path(text)
+ if not path.is_absolute():
+ path = (base_dir / path).resolve()
+ return path
diff --git a/astrbot-sdk/src/astrbot_sdk/schedule.py b/astrbot-sdk/src/astrbot_sdk/schedule.py
new file mode 100644
index 0000000000..5daccdd78a
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/schedule.py
@@ -0,0 +1,93 @@
+"""Schedule-specific SDK types.
+
+本模块定义定时任务相关的 SDK 类型,主要为 ScheduleContext 提供数据结构。
+
+ScheduleContext 包含:
+- schedule_id: 调度任务唯一标识
+- job_id: core cron_jobs 表中的任务 ID
+- plugin_id: 所属插件 ID
+- handler_id: 对应 handler 的标识
+- name: 调度任务名称
+- description: 调度任务说明
+- job_type: core cron job 类型(basic / active_agent)
+- trigger_kind: 触发类型(cron / interval / once)
+- cron: cron 表达式(仅 cron 类型)
+- interval_seconds: 间隔秒数(仅 interval 类型)
+- timezone: IANA 时区名称(仅声明了时区时存在)
+- scheduled_at: 计划执行时间(仅 once 类型)
+
+使用方式:
+通过 @on_schedule 装饰器注册的 handler 可通过参数注入获取 ScheduleContext。
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+
+
+@dataclass(slots=True)
+class ScheduleContext:
+ schedule_id: str
+ plugin_id: str
+ handler_id: str
+ trigger_kind: str
+ job_id: str | None = None
+ name: str | None = None
+ description: str | None = None
+ job_type: str | None = None
+ cron: str | None = None
+ interval_seconds: int | None = None
+ timezone: str | None = None
+ scheduled_at: str | None = None
+
+ @classmethod
+ def from_payload(cls, payload: dict[str, Any]) -> ScheduleContext:
+ schedule = payload.get("schedule")
+ if not isinstance(schedule, dict):
+ raise ValueError("schedule payload is required")
+ return cls(
+ schedule_id=str(schedule.get("schedule_id", "")),
+ job_id=(
+ str(schedule["job_id"])
+ if isinstance(schedule.get("job_id"), str)
+ else None
+ ),
+ plugin_id=str(schedule.get("plugin_id", "")),
+ handler_id=str(schedule.get("handler_id", "")),
+ name=(
+ str(schedule["name"]) if isinstance(schedule.get("name"), str) else None
+ ),
+ description=(
+ str(schedule["description"])
+ if isinstance(schedule.get("description"), str)
+ else None
+ ),
+ job_type=(
+ str(schedule["job_type"])
+ if isinstance(schedule.get("job_type"), str)
+ else None
+ ),
+ trigger_kind=str(schedule.get("trigger_kind", "")),
+ cron=(
+ str(schedule["cron"]) if isinstance(schedule.get("cron"), str) else None
+ ),
+ interval_seconds=(
+ int(schedule["interval_seconds"])
+ if isinstance(schedule.get("interval_seconds"), int)
+ else None
+ ),
+ timezone=(
+ str(schedule["timezone"])
+ if isinstance(schedule.get("timezone"), str)
+ else None
+ ),
+ scheduled_at=(
+ str(schedule["scheduled_at"])
+ if isinstance(schedule.get("scheduled_at"), str)
+ else None
+ ),
+ )
+
+
+__all__ = ["ScheduleContext"]
diff --git a/astrbot-sdk/src/astrbot_sdk/session_waiter.py b/astrbot-sdk/src/astrbot_sdk/session_waiter.py
new file mode 100644
index 0000000000..4b7b92972d
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/session_waiter.py
@@ -0,0 +1,665 @@
+"""Session-based conversational flow management.
+
+本模块实现会话等待器 (session_waiter),用于构建多轮对话流程。
+
+核心组件:
+- SessionController: 控制会话生命周期,支持超时管理、会话保持、历史记录
+- SessionWaiterManager: 管理活跃的会话等待器,处理事件分发和注册/注销
+- @session_waiter 装饰器: 将普通 handler 转换为会话式 handler
+
+使用场景:
+当需要在用户首次触发后继续监听后续消息(如分步表单、问答游戏),
+可使用 @session_waiter 装饰器自动管理会话状态和超时。
+
+注意事项:
+在当前桥接设计中,不应在普通 SDK handler 内直接 await session_waiter,
+这会导致首次 dispatch 保持打开直到下一条消息到达。
+推荐写法是 `await ctx.register_task(waiter(...), "...")`,让 waiter 在后台任务中
+承接后续消息;直接 await 仅适用于你明确需要保持当前 dispatch 挂起的场景。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import time
+import weakref
+from collections.abc import Awaitable, Callable, Coroutine
+from contextvars import ContextVar
+from dataclasses import dataclass, field
+from functools import wraps
+from typing import Any, Concatenate, ParamSpec, Protocol, TypeVar, cast, overload
+
+from ._internal.invocation_context import current_caller_plugin_id
+from ._internal.sdk_logger import logger
+from .events import MessageEvent
+
+_OwnerT = TypeVar("_OwnerT")
+_P = ParamSpec("_P")
+_ResultT = TypeVar("_ResultT")
+_WaiterKey = tuple[str, str]
+
+_HANDLER_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
+_REGISTERED_BACKGROUND_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
+_WARNED_DIRECT_WAIT_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
+_ACTIVE_WAITER_KEY: ContextVar[_WaiterKey | None] = ContextVar(
+ "astrbot_sdk_active_waiter_key",
+ default=None,
+)
+
+
+class _TaskReentrantLock:
+ def __init__(self) -> None:
+ self._lock = asyncio.Lock()
+ self._owner: asyncio.Task[Any] | None = None
+ self._depth = 0
+
+ async def acquire(self) -> None:
+ current_task = asyncio.current_task()
+ if current_task is None:
+ raise RuntimeError("session waiter lock requires an active asyncio task")
+ if self._owner is current_task:
+ self._depth += 1
+ return
+ await self._lock.acquire()
+ self._owner = current_task
+ self._depth = 1
+
+ def release(self) -> None:
+ current_task = asyncio.current_task()
+ if current_task is None or self._owner is not current_task:
+ raise RuntimeError("session waiter lock released by a non-owner task")
+ self._depth -= 1
+ if self._depth > 0:
+ return
+ self._owner = None
+ self._lock.release()
+
+ async def __aenter__(self) -> _TaskReentrantLock:
+ await self.acquire()
+ return self
+
+ async def __aexit__(self, *_exc_info: object) -> None:
+ self.release()
+
+
+def _mark_session_waiter_handler_task(task: asyncio.Task[Any]) -> None:
+ _HANDLER_TASKS.add(task)
+
+
+def _unmark_session_waiter_handler_task(task: asyncio.Task[Any]) -> None:
+ _HANDLER_TASKS.discard(task)
+
+
+def _mark_session_waiter_background_task(task: asyncio.Task[Any]) -> None:
+ _REGISTERED_BACKGROUND_TASKS.add(task)
+
+
+def _unmark_session_waiter_background_task(task: asyncio.Task[Any]) -> None:
+ _REGISTERED_BACKGROUND_TASKS.discard(task)
+
+
+class _SessionWaiterDecorator(Protocol):
+ @overload
+ def __call__(
+ self,
+ func: Callable[
+ Concatenate[SessionController, MessageEvent, _P],
+ Awaitable[_ResultT],
+ ],
+ /,
+ ) -> Callable[Concatenate[MessageEvent, _P], Coroutine[Any, Any, _ResultT]]: ...
+
+ @overload
+ def __call__(
+ self,
+ func: Callable[
+ Concatenate[_OwnerT, SessionController, MessageEvent, _P],
+ Awaitable[_ResultT],
+ ],
+ /,
+ ) -> Callable[
+ Concatenate[_OwnerT, MessageEvent, _P],
+ Coroutine[Any, Any, _ResultT],
+ ]: ...
+
+
+@dataclass(slots=True)
+class SessionController:
+ future: asyncio.Future[Any] = field(default_factory=asyncio.Future)
+ current_event: asyncio.Event | None = None
+ ts: float | None = None
+ timeout: float | None = None
+ history_chains: list[list[dict[str, Any]]] = field(default_factory=list)
+
+ def stop(self, error: Exception | None = None) -> None:
+ if self.future.done():
+ return
+ if error is not None:
+ self.future.set_exception(error)
+ else:
+ self.future.set_result(None)
+
+ def keep(self, timeout: float = 0, reset_timeout: bool = False) -> None:
+ new_ts = time.time()
+ if reset_timeout:
+ if timeout <= 0:
+ self.stop()
+ return
+ else:
+ if self.timeout is None or self.ts is None:
+ raise RuntimeError(
+ "session waiter keep(reset_timeout=False) requires an active timeout"
+ )
+ left_timeout = self.timeout - (new_ts - self.ts)
+ timeout = left_timeout + timeout
+ if timeout <= 0:
+ self.stop()
+ return
+
+ if self.current_event and not self.current_event.is_set():
+ self.current_event.set()
+
+ current_event = asyncio.Event()
+ self.current_event = current_event
+ self.ts = new_ts
+ self.timeout = timeout
+ asyncio.create_task(self._holding(current_event, timeout))
+
+ async def _holding(self, event: asyncio.Event, timeout: float) -> None:
+ try:
+ await asyncio.wait_for(event.wait(), timeout)
+ except asyncio.TimeoutError as exc:
+ self.stop(exc)
+ except asyncio.CancelledError:
+ return
+
+ def get_history_chains(self) -> list[list[dict[str, Any]]]:
+ return list(self.history_chains)
+
+
+@dataclass(slots=True)
+class _WaiterEntry:
+ session_key: str
+ plugin_id: str
+ handler: Callable[[SessionController, MessageEvent], Awaitable[Any]]
+ controller: SessionController
+ record_history_chains: bool
+ unregister_enabled: bool = True
+
+
+class SessionWaiterManager:
+ def __init__(self, *, plugin_id: str, peer) -> None:
+ self._plugin_id = plugin_id
+ self._peer = peer
+ self._entries: dict[str, dict[str, _WaiterEntry]] = {}
+ self._locks: dict[_WaiterKey, _TaskReentrantLock] = {}
+
+ @staticmethod
+ def _make_key(*, plugin_id: str, session_key: str) -> _WaiterKey:
+ return (plugin_id, session_key)
+
+ async def register(
+ self,
+ *,
+ event: MessageEvent,
+ handler: Callable[[SessionController, MessageEvent], Awaitable[Any]],
+ timeout: int,
+ record_history_chains: bool,
+ ) -> Any:
+ if event._context is None:
+ raise RuntimeError("session_waiter requires runtime context")
+ self._warn_if_direct_wait_in_handler(event)
+ session_key = event.unified_msg_origin
+ plugin_id = self._resolve_plugin_id(event)
+ entry = _WaiterEntry(
+ session_key=session_key,
+ plugin_id=plugin_id,
+ handler=handler,
+ controller=SessionController(),
+ record_history_chains=record_history_chains,
+ )
+ previous = self._entries.setdefault(session_key, {}).get(plugin_id)
+ restorable_previous: _WaiterEntry | None = None
+ self._entries[session_key][plugin_id] = entry
+ self._lock_for(session_key, plugin_id)
+ if previous is not None:
+ previous.unregister_enabled = False
+ if _ACTIVE_WAITER_KEY.get() == self._make_key(
+ plugin_id=plugin_id,
+ session_key=session_key,
+ ):
+ restorable_previous = previous
+ else:
+ self._finish_entry(
+ previous,
+ RuntimeError("session waiter replaced by a newer waiter"),
+ )
+ logger.warning(
+ "Session waiter replaced: plugin_id={} session_key={}",
+ plugin_id,
+ session_key,
+ )
+ try:
+ await self._invoke_system_waiter(
+ "system.session_waiter.register",
+ session_key=session_key,
+ plugin_id=plugin_id,
+ )
+ entry.controller.keep(timeout, reset_timeout=True)
+ except Exception:
+ entry.unregister_enabled = False
+ await self._remove_entry(entry)
+ if restorable_previous is not None:
+ self._entries.setdefault(session_key, {})[plugin_id] = (
+ restorable_previous
+ )
+ restorable_previous.unregister_enabled = True
+ self._lock_for(session_key, plugin_id)
+ raise
+ try:
+ return await entry.controller.future
+ finally:
+ if entry.unregister_enabled:
+ await self.unregister(session_key, plugin_id=plugin_id)
+
+ def _warn_if_direct_wait_in_handler(self, event: MessageEvent) -> None:
+ current_task = asyncio.current_task()
+ if current_task is None:
+ return
+ if current_task not in _HANDLER_TASKS:
+ return
+ if current_task in _REGISTERED_BACKGROUND_TASKS:
+ return
+ if current_task in _WARNED_DIRECT_WAIT_TASKS:
+ return
+ _WARNED_DIRECT_WAIT_TASKS.add(current_task)
+ logger.warning(
+ "Direct await on session_waiter blocks the current handler dispatch; "
+ 'prefer `await ctx.register_task(waiter(...), "...")`: '
+ "plugin_id={} session_key={}",
+ event._context.plugin_id,
+ event.unified_msg_origin,
+ )
+
+ async def wait_for_event(
+ self,
+ *,
+ event: MessageEvent,
+ timeout: int,
+ record_history_chains: bool = False,
+ ) -> MessageEvent:
+ future: asyncio.Future[MessageEvent] = (
+ asyncio.get_running_loop().create_future()
+ )
+
+ async def _handler(
+ controller: SessionController,
+ waiter_event: MessageEvent,
+ ) -> None:
+ if not future.done():
+ future.set_result(waiter_event)
+ controller.stop()
+
+ await self.register(
+ event=event,
+ handler=_handler,
+ timeout=timeout,
+ record_history_chains=record_history_chains,
+ )
+ return future.result()
+
+ async def unregister(
+ self,
+ session_key: str,
+ *,
+ plugin_id: str | None = None,
+ ) -> None:
+ target_plugin_id = self._resolve_unregister_plugin_id(
+ session_key,
+ plugin_id=plugin_id,
+ )
+ if target_plugin_id is None:
+ return
+ lock_key = (session_key, target_plugin_id)
+ lock = self._lock_for(session_key, target_plugin_id)
+ removed = False
+ async with lock:
+ session_entries = self._entries.get(session_key)
+ if session_entries is None:
+ return
+ removed = session_entries.pop(target_plugin_id, None) is not None
+ if not session_entries:
+ self._entries.pop(session_key, None)
+ if self._locks.get(lock_key) is lock:
+ self._locks.pop(lock_key, None)
+ if not removed:
+ return
+ try:
+ await self._invoke_system_waiter(
+ "system.session_waiter.unregister",
+ session_key=session_key,
+ plugin_id=target_plugin_id,
+ )
+ except Exception:
+ logger.debug(
+ "Failed to unregister session waiter: plugin_id={} session_key={}",
+ target_plugin_id,
+ session_key,
+ )
+
+ async def fail(
+ self,
+ session_key: str,
+ error: Exception,
+ *,
+ plugin_id: str | None = None,
+ ) -> bool:
+ resolved_plugin_id = plugin_id
+ if resolved_plugin_id is None:
+ caller_plugin_id = current_caller_plugin_id()
+ if caller_plugin_id:
+ resolved_plugin_id = caller_plugin_id
+ entry = self._select_entry(
+ session_key,
+ plugin_id=resolved_plugin_id,
+ allow_ambiguous=False,
+ missing_result=None,
+ )
+ if entry is None:
+ return False
+ lock = self._lock_for(session_key, entry.plugin_id)
+ async with lock:
+ current = self._get_entry(session_key, entry.plugin_id)
+ if current is None or current.controller.future.done():
+ return False
+ self._finish_entry(current, error)
+ return True
+
+ def has_active_waiter(self, event: MessageEvent) -> bool:
+ session_key = event.unified_msg_origin
+ event_plugin_id = self._event_plugin_id(event)
+ if event_plugin_id is not None:
+ entry = self._get_entry(session_key, event_plugin_id)
+ return entry is not None and not entry.controller.future.done()
+ return bool(self.get_waiter_plugin_ids(session_key))
+
+ def has_waiter(self, event: MessageEvent) -> bool:
+ return self.has_active_waiter(event)
+
+ def get_waiter_plugin_ids(self, session_key: str) -> list[str]:
+ return sorted(
+ plugin_id
+ for plugin_id, entry in self._entries.get(session_key, {}).items()
+ if not entry.controller.future.done()
+ )
+
+ async def dispatch(
+ self,
+ event: MessageEvent,
+ *,
+ plugin_id: str | None = None,
+ ) -> dict[str, Any]:
+ if event._context is None:
+ raise RuntimeError("session_waiter dispatch requires runtime context")
+ session_key = event.unified_msg_origin
+ entry = self._select_entry(
+ session_key,
+ plugin_id=plugin_id,
+ allow_ambiguous=False,
+ missing_result=None,
+ ambiguous_error=LookupError(
+ f"session waiter dispatch for session '{session_key}' requires explicit plugin identity"
+ ),
+ )
+ if entry is None:
+ return {"sent_message": False, "stop": False, "call_llm": False}
+ lock = self._lock_for(session_key, entry.plugin_id)
+ async with lock:
+ current = self._get_entry(session_key, entry.plugin_id)
+ if current is None or current.controller.future.done():
+ return {"sent_message": False, "stop": False, "call_llm": False}
+ waiter_event = self._build_waiter_event(current, event)
+ if current.record_history_chains:
+ chain = []
+ raw_chain = (
+ waiter_event.raw.get("chain")
+ if isinstance(waiter_event.raw, dict)
+ else None
+ )
+ if isinstance(raw_chain, list):
+ chain = [dict(item) for item in raw_chain if isinstance(item, dict)]
+ current.controller.history_chains.append(chain)
+ active_key_token = _ACTIVE_WAITER_KEY.set(
+ self._make_key(
+ plugin_id=current.plugin_id,
+ session_key=current.session_key,
+ )
+ )
+ try:
+ # Keep follow-up handler execution serialized per waiter while still
+ # allowing nested waiter cleanup in the same task to re-enter safely.
+ await current.handler(current.controller, waiter_event)
+ finally:
+ _ACTIVE_WAITER_KEY.reset(active_key_token)
+ return {
+ "sent_message": False,
+ "stop": waiter_event.is_stopped(),
+ "call_llm": False,
+ }
+
+ def _resolve_plugin_id(self, event: MessageEvent) -> str:
+ caller_plugin_id = current_caller_plugin_id()
+ if caller_plugin_id:
+ return caller_plugin_id
+ context = event._context
+ if context is not None and context.plugin_id.strip():
+ return context.plugin_id
+ return self._plugin_id
+
+ @staticmethod
+ def _event_plugin_id(event: MessageEvent) -> str | None:
+ context = event._context
+ if context is None:
+ return None
+ plugin_id = context.plugin_id.strip()
+ return plugin_id or None
+
+ def _resolve_unregister_plugin_id(
+ self,
+ session_key: str,
+ *,
+ plugin_id: str | None,
+ ) -> str | None:
+ if plugin_id is not None:
+ normalized = str(plugin_id).strip()
+ return normalized or None
+ session_entries = self._entries.get(session_key, {})
+ if len(session_entries) != 1:
+ return None
+ return next(iter(session_entries))
+
+ def _select_entry(
+ self,
+ session_key: str,
+ *,
+ plugin_id: str | None,
+ allow_ambiguous: bool,
+ missing_result: _WaiterEntry | None,
+ ambiguous_error: Exception | None = None,
+ ) -> _WaiterEntry | None:
+ if plugin_id is not None:
+ return self._get_entry(session_key, plugin_id)
+ active_entries = [
+ entry
+ for entry in self._entries.get(session_key, {}).values()
+ if not entry.controller.future.done()
+ ]
+ if not active_entries:
+ return missing_result
+ if len(active_entries) > 1 and not allow_ambiguous:
+ if ambiguous_error is not None:
+ raise ambiguous_error
+ return missing_result
+ return active_entries[0]
+
+ def _get_entry(self, session_key: str, plugin_id: str) -> _WaiterEntry | None:
+ return self._entries.get(session_key, {}).get(plugin_id)
+
+ def _lock_for(self, session_key: str, plugin_id: str) -> _TaskReentrantLock:
+ return self._locks.setdefault((session_key, plugin_id), _TaskReentrantLock())
+
+ async def _remove_entry(self, entry: _WaiterEntry) -> None:
+ lock_key = (entry.session_key, entry.plugin_id)
+ lock = self._lock_for(entry.session_key, entry.plugin_id)
+ async with lock:
+ session_entries = self._entries.get(entry.session_key)
+ if session_entries is None:
+ return
+ current = session_entries.get(entry.plugin_id)
+ if current is not entry:
+ return
+ session_entries.pop(entry.plugin_id, None)
+ if not session_entries:
+ self._entries.pop(entry.session_key, None)
+ if self._locks.get(lock_key) is lock:
+ self._locks.pop(lock_key, None)
+
+ @staticmethod
+ def _finish_entry(entry: _WaiterEntry, error: Exception | None = None) -> None:
+ entry.controller.stop(error)
+ if (
+ entry.controller.current_event is not None
+ and not entry.controller.current_event.is_set()
+ ):
+ entry.controller.current_event.set()
+
+ async def _invoke_system_waiter(
+ self,
+ capability: str,
+ *,
+ session_key: str,
+ plugin_id: str,
+ ) -> None:
+ from ._internal.invocation_context import caller_plugin_scope
+
+ with caller_plugin_scope(plugin_id):
+ await self._peer.invoke(
+ capability,
+ {"session_key": session_key},
+ )
+
+ def _build_waiter_event(
+ self,
+ entry: _WaiterEntry,
+ event: MessageEvent,
+ ) -> MessageEvent:
+ from .context import Context
+
+ source_payload = self._source_payload_from_event(event)
+ cancel_token = (
+ event._context.cancel_token if event._context is not None else None
+ )
+ waiter_context = Context(
+ peer=self._peer,
+ plugin_id=entry.plugin_id,
+ request_id=(
+ event._context.request_id if event._context is not None else None
+ ),
+ cancel_token=cancel_token,
+ source_event_payload=source_payload,
+ )
+ # Rebuild the event so the waiter always sees the registering plugin identity
+ # and the exact source payload that triggered the follow-up dispatch.
+ return MessageEvent.from_payload(
+ source_payload,
+ context=waiter_context,
+ )
+
+ @staticmethod
+ def _source_payload_from_event(event: MessageEvent) -> dict[str, Any]:
+ raw_payload = event.raw if isinstance(event.raw, dict) else None
+ if raw_payload is not None and {
+ "text",
+ "session_id",
+ "platform",
+ }.issubset(raw_payload):
+ return dict(raw_payload)
+ return event.to_payload()
+
+
+def session_waiter(
+ timeout: int = 30,
+ *,
+ record_history_chains: bool = False,
+) -> _SessionWaiterDecorator:
+ def decorator(
+ func: Callable[..., Awaitable[Any]],
+ ) -> Callable[..., Coroutine[Any, Any, Any]]:
+ @wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ owner = None
+ event: MessageEvent | None = None
+ trailing_args: tuple[Any, ...] = ()
+ if args and isinstance(args[0], MessageEvent):
+ event = args[0]
+ trailing_args = args[1:]
+ elif len(args) >= 2 and isinstance(args[1], MessageEvent):
+ owner = args[0]
+ event = args[1]
+ trailing_args = args[2:]
+ if event is None:
+ raise RuntimeError("session_waiter requires a MessageEvent argument")
+ if event._context is None:
+ raise RuntimeError("session_waiter requires runtime context")
+ manager = getattr(event._context.peer, "_session_waiter_manager", None)
+ if manager is None:
+ raise RuntimeError("session_waiter manager is unavailable")
+
+ if owner is None:
+ free_func = cast(Callable[..., Awaitable[Any]], func)
+
+ async def bound_handler(
+ controller: SessionController,
+ waiter_event: MessageEvent,
+ ) -> Any:
+ return await free_func(
+ controller,
+ waiter_event,
+ *trailing_args,
+ **kwargs,
+ )
+ else:
+ method_func = cast(Callable[..., Awaitable[Any]], func)
+
+ async def bound_handler(
+ controller: SessionController,
+ waiter_event: MessageEvent,
+ ) -> Any:
+ return await method_func(
+ owner,
+ controller,
+ waiter_event,
+ *trailing_args,
+ **kwargs,
+ )
+
+ return await manager.register(
+ event=event,
+ handler=bound_handler,
+ timeout=timeout,
+ record_history_chains=record_history_chains,
+ )
+
+ return wrapper
+
+ return cast(_SessionWaiterDecorator, decorator)
+
+
+__all__ = [
+ "_OwnerT",
+ "_P",
+ "_ResultT",
+ "SessionController",
+ "SessionWaiterManager",
+ "session_waiter",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/star.py b/astrbot-sdk/src/astrbot_sdk/star.py
new file mode 100644
index 0000000000..3d4457efc4
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/star.py
@@ -0,0 +1,122 @@
+"""astrbot-sdk 原生插件基类。"""
+
+from __future__ import annotations
+
+import traceback
+from contextvars import ContextVar, Token
+from typing import TYPE_CHECKING, Any, cast
+
+from ._internal.sdk_logger import logger
+from .errors import AstrBotError
+from .plugin_kv import PluginKVStoreMixin
+
+if TYPE_CHECKING:
+ from .context import Context
+
+
+class Star(PluginKVStoreMixin):
+ """astrbot-sdk 原生插件基类。"""
+
+ __handlers__: tuple[str, ...] = ()
+
+ def __init_subclass__(cls, **kwargs: Any) -> None:
+ super().__init_subclass__(**kwargs)
+ from .decorators import get_handler_meta
+
+ handlers: dict[str, None] = {}
+ for base in reversed(cls.__mro__):
+ for name, attr in getattr(base, "__dict__", {}).items():
+ func = getattr(attr, "__func__", attr)
+ meta = get_handler_meta(func)
+ if meta is not None and meta.trigger is not None:
+ handlers[name] = None
+ cls.__handlers__ = tuple(handlers.keys())
+
+ @property
+ def context(self) -> Context | None:
+ return self._context_var().get()
+
+ def _require_runtime_context(self) -> Context:
+ ctx = self.context
+ if ctx is None:
+ raise RuntimeError(
+ "Star runtime context is only available during lifecycle, "
+ "handler, and registered LLM tool execution"
+ )
+ return ctx
+
+ def _context_var(self) -> ContextVar[Context | None]:
+ existing_context_var = getattr(self, "__astrbot_context_var__", None)
+ if isinstance(existing_context_var, ContextVar):
+ return cast("ContextVar[Context | None]", existing_context_var)
+ created_context_var: ContextVar[Context | None] = ContextVar(
+ f"astrbot_sdk_star_context_{id(self)}",
+ default=None,
+ )
+ setattr(self, "__astrbot_context_var__", created_context_var)
+ return created_context_var
+
+ def _bind_runtime_context(self, ctx: Context | None) -> Token[Context | None]:
+ return self._context_var().set(ctx)
+
+ def _reset_runtime_context(self, token: Token[Context | None]) -> None:
+ self._context_var().reset(token)
+
+ async def on_start(self, ctx: Any | None = None) -> None:
+ await self.initialize()
+
+ async def on_stop(self, ctx: Any | None = None) -> None:
+ await self.terminate()
+
+ async def initialize(self) -> None:
+ return None
+
+ async def terminate(self) -> None:
+ return None
+
+ async def text_to_image(
+ self,
+ text: str,
+ *,
+ return_url: bool = True,
+ ) -> str:
+ return await self._require_runtime_context().text_to_image(
+ text,
+ return_url=return_url,
+ )
+
+ async def html_render(
+ self,
+ tmpl: str,
+ data: dict[str, Any],
+ *,
+ return_url: bool = True,
+ options: dict[str, Any] | None = None,
+ ) -> str:
+ return await self._require_runtime_context().html_render(
+ tmpl,
+ data,
+ return_url=return_url,
+ options=options,
+ )
+
+ @staticmethod
+ async def default_on_error(error: Exception, event, ctx) -> None:
+ del ctx
+ if isinstance(error, AstrBotError):
+ lines = [error.hint or error.message]
+ if error.docs_url:
+ lines.append(f"文档:{error.docs_url}")
+ if error.details:
+ lines.append(f"详情:{error.details!r}")
+ await event.reply("\n".join(lines))
+ else:
+ await event.reply("出了点问题,请联系插件作者")
+ logger.error("handler 执行失败\n{}", traceback.format_exc())
+
+ async def on_error(self, error: Exception, event, ctx) -> None:
+ await Star.default_on_error(error, event, ctx)
+
+ @classmethod
+ def __astrbot_is_new_star__(cls) -> bool:
+ return True
diff --git a/astrbot-sdk/src/astrbot_sdk/star_tools.py b/astrbot-sdk/src/astrbot_sdk/star_tools.py
new file mode 100644
index 0000000000..fe7aa451c0
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/star_tools.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+from collections.abc import Awaitable, Callable, Sequence
+from typing import TYPE_CHECKING, Any
+
+from ._internal.star_runtime import current_star_context
+from .context import Context
+from .message.components import BaseMessageComponent
+from .message.result import MessageChain
+from .message.session import MessageSession
+
+if TYPE_CHECKING:
+ from .clients.skills import SkillRegistration
+ from .llm.tools import LLMToolManager
+
+
+class _StarToolsContextDescriptor:
+ def __get__(self, _instance: object, _owner: type[object]) -> Context | None:
+ return current_star_context()
+
+
+class StarTools:
+ """Star 工具类,提供类方法访问运行时上下文能力。
+
+ 所有方法都通过当前上下文动态路由到对应的能力接口。
+ 只在 lifecycle、handler 和已注册的 LLM tool 执行期间可用。
+ """
+
+ _context = _StarToolsContextDescriptor()
+
+ @classmethod
+ def _get_context(cls) -> Context | None:
+ """获取当前 Star 运行时上下文。"""
+ return cls._context
+
+ @classmethod
+ def _require_context(cls) -> Context:
+ """获取当前运行时上下文,如果不存在则抛出 RuntimeError。"""
+ ctx = current_star_context()
+ if ctx is None:
+ raise RuntimeError(
+ "StarTools context is only available during lifecycle, "
+ "handler, and registered LLM tool execution"
+ )
+ return ctx
+
+ @classmethod
+ def get_llm_tool_manager(cls) -> LLMToolManager:
+ return cls._require_context().get_llm_tool_manager()
+
+ @classmethod
+ async def activate_llm_tool(cls, name: str) -> bool:
+ return await cls._require_context().activate_llm_tool(name)
+
+ @classmethod
+ async def deactivate_llm_tool(cls, name: str) -> bool:
+ return await cls._require_context().deactivate_llm_tool(name)
+
+ @classmethod
+ async def send_message(
+ cls,
+ session: str | MessageSession,
+ content: (
+ str
+ | MessageChain
+ | Sequence[BaseMessageComponent]
+ | Sequence[dict[str, Any]]
+ ),
+ ) -> dict[str, Any]:
+ return await cls._require_context().send_message(session, content)
+
+ @classmethod
+ async def send_message_by_id(
+ cls,
+ type: str,
+ id: str,
+ content: (
+ str
+ | MessageChain
+ | Sequence[BaseMessageComponent]
+ | Sequence[dict[str, Any]]
+ ),
+ *,
+ platform: str,
+ ) -> dict[str, Any]:
+ return await cls._require_context().send_message_by_id(
+ type,
+ id,
+ content,
+ platform=platform,
+ )
+
+ @classmethod
+ async def register_llm_tool(
+ cls,
+ name: str,
+ parameters_schema: dict[str, Any],
+ desc: str,
+ func_obj: Callable[..., Awaitable[Any]] | Callable[..., Any],
+ *,
+ active: bool = True,
+ ) -> list[str]:
+ return await cls._require_context().register_llm_tool(
+ name,
+ parameters_schema,
+ desc,
+ func_obj,
+ active=active,
+ )
+
+ @classmethod
+ async def unregister_llm_tool(cls, name: str) -> bool:
+ return await cls._require_context().unregister_llm_tool(name)
+
+ @classmethod
+ async def register_skill(
+ cls,
+ *,
+ name: str,
+ path: str,
+ description: str = "",
+ ) -> SkillRegistration:
+ return await cls._require_context().skills.register(
+ name=name,
+ path=path,
+ description=description,
+ )
+
+ @classmethod
+ async def unregister_skill(cls, name: str) -> bool:
+ return await cls._require_context().skills.unregister(name)
diff --git a/astrbot-sdk/src/astrbot_sdk/templates/project_notes/AGENTS.md b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/AGENTS.md
new file mode 100644
index 0000000000..33bb5548f5
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/AGENTS.md
@@ -0,0 +1,12 @@
+# AGENTS.md
+
+## AstrBot Plugin Notes
+
+- Prefer raising `AstrBotError` from `astrbot_sdk.errors` for expected failures.
+- Reuse stable `ErrorCodes` and factory helpers instead of inventing ad-hoc `{"error": ...}` payloads.
+- Validate the generated plugin with `astrbot-sdk validate --plugin-dir .` before packaging or sharing it.
+- Run `python -m pytest tests/test_plugin.py -v` after changing plugin behavior so the sample harness contract stays honest.
+- `astrbot-sdk build --plugin-dir .` should create the release zip without development-only files such as `AGENTS.md`, `CLAUDE.md`, `.claude/`, `.agents/`, or `.opencode/`.
+- Exported capabilities should use `.`, and HTTP routes should use `/{plugin_id}` or `/{plugin_id}/...` so the plugin stays collision-safe inside `GroupWorkerRuntime`.
+
+- 除非有充分理由,插件的直接依赖应声明已验证的最低兼容版本。若已知存在不兼容的大版本或问题版本,应同时补充上界或排除约束
diff --git a/astrbot-sdk/src/astrbot_sdk/templates/project_notes/CLAUDE.md b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/CLAUDE.md
new file mode 100644
index 0000000000..6df0e003b9
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/templates/project_notes/CLAUDE.md
@@ -0,0 +1,12 @@
+# CLAUDE.md
+
+## AstrBot Plugin Notes
+
+- Prefer raising `AstrBotError` from `astrbot_sdk.errors` for expected failures.
+- Reuse stable `ErrorCodes` and factory helpers instead of inventing ad-hoc `{"error": ...}` payloads.
+- Validate the generated plugin with `astrbot-sdk validate --plugin-dir .` before packaging or sharing it.
+- Run `python -m pytest tests/test_plugin.py -v` after changing plugin behavior so the sample harness contract stays honest.
+- `astrbot-sdk build --plugin-dir .` should create the release zip without development-only files such as `AGENTS.md`, `CLAUDE.md`, `.claude/`, `.agents/`, or `.opencode/`.
+- Exported capabilities should use `.`, and HTTP routes should use `/{plugin_id}` or `/{plugin_id}/...` so the plugin stays collision-safe inside `GroupWorkerRuntime`.
+
+- 除非有充分理由,插件的直接依赖应声明已验证的最低兼容版本。若已知存在不兼容的大版本或问题版本,应同时补充上界或排除约束
diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/SKILL.md b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/SKILL.md
new file mode 100644
index 0000000000..b811cdcf65
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/SKILL.md
@@ -0,0 +1,29 @@
+---
+name: {{skill_name}}
+description: Work on the {{display_name}} plugin scaffold with {{agent_display_name}}.
+---
+
+# {{display_name}} Plugin Guide
+
+Use this skill when working inside the plugin created by `astr init --agents {{agent_name}}`.
+
+## Workspace
+- Plugin root: `{{plugin_root}}`
+- Skill directory: `{{skill_dir_name}}`
+- Plugin package: `{{plugin_name}}`
+- Main class: `{{class_name}}`
+
+## Expectations
+- Read `{{plugin_root}}/plugin.yaml` and `{{plugin_root}}/main.py` before editing behavior.
+- Keep handler names, config keys, and user-facing command text stable unless the user asks to change them.
+- Prefer focused changes that match the generated plugin layout instead of broad rewrites.
+- Run the smallest relevant validation after behavior changes.
+
+## Validation
+- `uv run astr validate --plugin-dir {{plugin_root}}`
+- Add or run focused tests when the request changes behavior.
+- Keep new comments in English.
+
+## Delivery
+- Summarize what changed, why it changed, and which checks were run.
+- Call out any follow-up work or remaining risks if the requested change cannot be completed fully.
diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/agents/openai.yaml b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/agents/openai.yaml
new file mode 100644
index 0000000000..6a95224239
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/agents/openai.yaml
@@ -0,0 +1,6 @@
+model: gpt-5.4-mini
+reasoning_effort: medium
+instructions: |
+ Use the {{skill_name}} skill when editing the {{plugin_name}} plugin.
+ Start from {{plugin_root}}/plugin.yaml and {{plugin_root}}/main.py.
+ Keep changes aligned with the generated plugin scaffold.
diff --git a/astrbot-sdk/src/astrbot_sdk/testing.py b/astrbot-sdk/src/astrbot_sdk/testing.py
new file mode 100644
index 0000000000..c257c8aca5
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/testing.py
@@ -0,0 +1,849 @@
+"""本地开发与插件测试辅助。
+
+`astrbot_sdk.testing` 是面向插件作者的稳定开发入口:
+
+- `PluginHarness` 负责复用现有 loader / dispatcher 执行链
+- `MockCapabilityRouter` 提供进程内 mock core 能力
+- `MockPeer` 让 `Context` 客户端继续走真实的 capability 调用路径
+- `StdoutPlatformSink` / `RecordedSend` 提供可观测的发送记录
+
+这个模块刻意不暴露 runtime 内部编排数据结构,只封装本地开发/测试真正
+需要的最小稳定面。
+"""
+
+from __future__ import annotations
+
+import asyncio
+import re
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+from ._internal.decorator_lifecycle import run_lifecycle_with_decorators
+from ._internal.testing_support import (
+ InMemoryDB,
+ InMemoryMemory,
+ MockCapabilityRouter,
+ MockContext,
+ MockLLMClient,
+ MockMessageEvent,
+ MockPeer,
+ MockPlatformClient,
+ RecordedSend,
+ StdoutPlatformSink,
+)
+from ._message_types import normalize_message_type
+from .context import CancelToken
+from .context import Context as RuntimeContext
+from .errors import AstrBotError
+from .events import MessageEvent
+from .protocol.descriptors import (
+ CommandTrigger,
+ CompositeFilterSpec,
+ EventTrigger,
+ LocalFilterRefSpec,
+ MessageTrigger,
+ MessageTypeFilterSpec,
+ PlatformFilterSpec,
+ ScheduleTrigger,
+)
+from .protocol.messages import InvokeMessage
+from .runtime._command_matching import (
+ build_command_args,
+ build_regex_args,
+ command_root_name,
+ match_command_name,
+)
+from .runtime._streaming import StreamExecution
+from .runtime.handler_dispatcher import CapabilityDispatcher, HandlerDispatcher
+from .runtime.loader import (
+ LoadedHandler,
+ LoadedPlugin,
+ PluginSpec,
+ load_plugin,
+ load_plugin_config,
+ load_plugin_spec,
+ validate_plugin_spec,
+)
+from .star import Star
+
+
+class _PluginLoadError(RuntimeError):
+ """本地 harness 初始化阶段的已知插件加载失败。"""
+
+
+class _PluginExecutionError(RuntimeError):
+ """本地 harness 执行插件代码时的已知插件异常。"""
+
+
+def _plugin_metadata_from_spec(
+ plugin: PluginSpec,
+ *,
+ enabled: bool,
+) -> dict[str, Any]:
+ manifest = plugin.manifest_data
+ support_platforms = manifest.get("support_platforms")
+ return {
+ "name": plugin.name,
+ "display_name": str(manifest.get("display_name") or plugin.name),
+ "description": str(manifest.get("desc") or manifest.get("description") or ""),
+ "repo": str(manifest.get("repo") or ""),
+ "author": str(manifest.get("author") or ""),
+ "version": str(manifest.get("version") or "0.0.0"),
+ "enabled": enabled,
+ "reserved": bool(manifest.get("reserved", False)),
+ "support_platforms": [
+ str(item) for item in support_platforms if isinstance(item, str)
+ ]
+ if isinstance(support_platforms, list)
+ else [],
+ "astrbot_version": (
+ str(manifest.get("astrbot_version"))
+ if manifest.get("astrbot_version") is not None
+ else None
+ ),
+ }
+
+
+def _handler_metadata_from_loaded(
+ plugin_id: str, loaded: LoadedHandler
+) -> dict[str, Any]:
+ event_types: list[str] = []
+ trigger = loaded.descriptor.trigger
+ if isinstance(trigger, EventTrigger):
+ event_types.append(trigger.type)
+ return {
+ "plugin_name": plugin_id,
+ "handler_full_name": loaded.descriptor.id,
+ "trigger_type": trigger.type
+ if isinstance(trigger, EventTrigger)
+ else str(getattr(trigger, "kind", trigger.type)),
+ "event_types": event_types,
+ "enabled": True,
+ "group_path": list(
+ loaded.descriptor.command_route.group_path
+ if loaded.descriptor.command_route is not None
+ else []
+ ),
+ "require_admin": loaded.descriptor.permissions.require_admin,
+ "required_role": loaded.descriptor.permissions.required_role,
+ }
+
+
+@dataclass(slots=True)
+class LocalRuntimeConfig:
+ """本地 harness 的稳定配置对象。"""
+
+ plugin_dir: Path
+ session_id: str = "local-session"
+ user_id: str = "local-user"
+ platform: str = "test"
+ group_id: str | None = None
+ event_type: str = "message"
+
+
+@dataclass(slots=True)
+class MockClock:
+ now: float = 0.0
+
+ def time(self) -> float:
+ return self.now
+
+ def advance(self, seconds: float) -> float:
+ self.now += float(seconds)
+ return self.now
+
+
+@dataclass(slots=True)
+class SDKTestEnvironment:
+ root: Path
+
+ @property
+ def plugins_dir(self) -> Path:
+ path = self.root / "plugins"
+ path.mkdir(parents=True, exist_ok=True)
+ return path
+
+ def plugin_dir(self, name: str) -> Path:
+ path = self.plugins_dir / name
+ path.mkdir(parents=True, exist_ok=True)
+ return path
+
+
+class PluginHarness:
+ """本地插件消息泵。
+
+ 这里复用真实的 loader / dispatcher 执行链,只负责:
+ - 在同一个事件循环里装配单插件运行时
+ - 维持本地 mock core 与发送记录
+ - 把后续消息持续送入同一个 dispatcher
+ """
+
+ def __init__(
+ self,
+ config: LocalRuntimeConfig,
+ *,
+ platform_sink: StdoutPlatformSink | None = None,
+ ) -> None:
+ self.config = config
+ self.platform_sink = platform_sink or StdoutPlatformSink()
+ self.router = MockCapabilityRouter(platform_sink=self.platform_sink)
+ self.peer = MockPeer(self.router)
+ self.plugin: PluginSpec | None = None
+ self.loaded_plugin: LoadedPlugin | None = None
+ self.dispatcher: HandlerDispatcher | None = None
+ self.capability_dispatcher: CapabilityDispatcher | None = None
+ self.lifecycle_context: RuntimeContext | None = None
+ self._request_counter = 0
+ self._started = False
+
+ @classmethod
+ def from_plugin_dir(
+ cls,
+ plugin_dir: str | Path,
+ *,
+ session_id: str = "local-session",
+ user_id: str = "local-user",
+ platform: str = "test",
+ group_id: str | None = None,
+ event_type: str = "message",
+ platform_sink: StdoutPlatformSink | None = None,
+ ) -> PluginHarness:
+ return cls(
+ LocalRuntimeConfig(
+ plugin_dir=Path(plugin_dir),
+ session_id=session_id,
+ user_id=user_id,
+ platform=platform,
+ group_id=group_id,
+ event_type=event_type,
+ ),
+ platform_sink=platform_sink,
+ )
+
+ async def __aenter__(self) -> PluginHarness:
+ await self.start()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb) -> None:
+ await self.stop()
+
+ @property
+ def sent_messages(self) -> list[RecordedSend]:
+ return list(self.platform_sink.records)
+
+ def clear_sent_messages(self) -> None:
+ self.platform_sink.clear()
+
+ async def start(self) -> None:
+ if self._started:
+ return
+ try:
+ self.plugin = load_plugin_spec(self.config.plugin_dir)
+ validate_plugin_spec(self.plugin)
+ self.loaded_plugin = load_plugin(self.plugin)
+ except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖
+ raise _PluginLoadError(str(exc)) from exc
+ self.dispatcher = HandlerDispatcher(
+ plugin_id=self.plugin.name,
+ peer=self.peer,
+ handlers=self.loaded_plugin.handlers,
+ )
+ self.capability_dispatcher = CapabilityDispatcher(
+ plugin_id=self.plugin.name,
+ peer=self.peer,
+ capabilities=self.loaded_plugin.capabilities,
+ llm_tools=self.loaded_plugin.llm_tools,
+ )
+ self.lifecycle_context = RuntimeContext(
+ peer=self.peer,
+ plugin_id=self.plugin.name,
+ )
+ plugin_metadata = _plugin_metadata_from_spec(self.plugin, enabled=True)
+ self.router.upsert_plugin(
+ metadata=plugin_metadata,
+ config=load_plugin_config(self.plugin),
+ )
+ self.router.set_plugin_handlers(
+ self.plugin.name,
+ [
+ _handler_metadata_from_loaded(self.plugin.name, handler)
+ for handler in self.loaded_plugin.handlers
+ ],
+ )
+ self.router.set_plugin_llm_tools(
+ self.plugin.name,
+ [tool.spec.to_payload() for tool in self.loaded_plugin.llm_tools],
+ )
+ self.router.set_plugin_agents(
+ self.plugin.name,
+ [agent.spec.to_payload() for agent in self.loaded_plugin.agents],
+ )
+ try:
+ await self._run_lifecycle("on_start")
+ except AstrBotError:
+ raise
+ except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖
+ raise _PluginExecutionError(str(exc)) from exc
+ self._started = True
+
+ async def stop(self) -> None:
+ if (
+ not self._started
+ or self.loaded_plugin is None
+ or self.lifecycle_context is None
+ ):
+ return
+ try:
+ await self._run_lifecycle("on_stop")
+ finally:
+ if self.plugin is not None:
+ self.router.set_plugin_enabled(self.plugin.name, False)
+ self.router.set_plugin_handlers(self.plugin.name, [])
+ self.router.remove_dynamic_command_routes_for_plugin(self.plugin.name)
+ self.router.remove_http_apis_for_plugin(self.plugin.name)
+ self._started = False
+
+ async def dispatch_text(
+ self,
+ text: str,
+ *,
+ session_id: str | None = None,
+ user_id: str | None = None,
+ platform: str | None = None,
+ group_id: str | None = None,
+ event_type: str | None = None,
+ request_id: str | None = None,
+ ) -> list[RecordedSend]:
+ payload = self.build_event_payload(
+ text=text,
+ session_id=session_id,
+ user_id=user_id,
+ platform=platform,
+ group_id=group_id,
+ event_type=event_type,
+ request_id=request_id,
+ )
+ return await self.dispatch_event(payload, request_id=request_id)
+
+ async def dispatch_event(
+ self,
+ event_payload: dict[str, Any],
+ *,
+ request_id: str | None = None,
+ ) -> list[RecordedSend]:
+ await self.start()
+ assert self.loaded_plugin is not None
+ assert self.dispatcher is not None
+
+ start_index = len(self.platform_sink.records)
+ if self._has_waiter_for_event(event_payload):
+ await self._invoke_session_waiter(
+ event_payload,
+ request_id=request_id,
+ )
+ await self._wait_for_followup_side_effects(
+ start_index=start_index,
+ event_payload=event_payload,
+ )
+ return self.platform_sink.records[start_index:]
+
+ matches = self._match_handlers(event_payload)
+ help_text = self._build_group_root_help(event_payload)
+ if help_text is not None and not any(
+ isinstance(loaded.descriptor.trigger, CommandTrigger)
+ for loaded, _args in matches
+ ):
+ assert self.lifecycle_context is not None
+ await self.lifecycle_context.platform.send(
+ str(event_payload.get("session_id", "")),
+ help_text,
+ )
+ return self.platform_sink.records[start_index:]
+ if not matches:
+ raise AstrBotError.invalid_input("未找到匹配的 handler")
+ for loaded, args in matches:
+ result = await self._invoke_handler(
+ loaded,
+ event_payload,
+ args=args,
+ request_id=request_id,
+ )
+ # Mirror the runtime dispatcher contract: once a handler explicitly
+ # stops the event, later matches in the same dispatch should not run.
+ if bool(result.get("stop", False)):
+ break
+ return self.platform_sink.records[start_index:]
+
+ async def invoke_capability(
+ self,
+ capability: str,
+ payload: dict[str, Any],
+ *,
+ request_id: str | None = None,
+ stream: bool = False,
+ ) -> dict[str, Any] | StreamExecution:
+ await self.start()
+ assert self.capability_dispatcher is not None
+ message = InvokeMessage(
+ id=request_id or self._next_request_id("cap"),
+ capability=capability,
+ input=dict(payload),
+ stream=stream,
+ )
+ try:
+ return await self.capability_dispatcher.invoke(message, CancelToken())
+ except AstrBotError:
+ raise
+ except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖
+ raise _PluginExecutionError(str(exc)) from exc
+
+ def build_event_payload(
+ self,
+ *,
+ text: str,
+ session_id: str | None = None,
+ user_id: str | None = None,
+ platform: str | None = None,
+ group_id: str | None = None,
+ event_type: str | None = None,
+ request_id: str | None = None,
+ ) -> dict[str, Any]:
+ session_value = session_id or self.config.session_id
+ group_value = group_id if group_id is not None else self.config.group_id
+ event_type_value = event_type or self.config.event_type
+ payload = {
+ "type": event_type_value,
+ "event_type": event_type_value,
+ "text": text,
+ "session_id": session_value,
+ "user_id": user_id or self.config.user_id,
+ "platform": platform or self.config.platform,
+ "platform_id": platform or self.config.platform,
+ "group_id": group_value,
+ "self_id": f"{platform or self.config.platform}-bot",
+ "sender_name": str(user_id or self.config.user_id or ""),
+ "is_admin": False,
+ "raw": {
+ "trace_id": request_id or self._next_request_id("trace"),
+ "event_type": event_type_value,
+ },
+ }
+ if group_value:
+ payload["message_type"] = "group"
+ elif payload["user_id"]:
+ payload["message_type"] = "private"
+ else:
+ payload["message_type"] = "other"
+ return payload
+
+ async def _invoke_handler(
+ self,
+ loaded: LoadedHandler,
+ event_payload: dict[str, Any],
+ *,
+ args: dict[str, Any],
+ request_id: str | None = None,
+ ) -> dict[str, Any]:
+ assert self.dispatcher is not None
+ message = InvokeMessage(
+ id=request_id or self._next_request_id("msg"),
+ capability="handler.invoke",
+ input={
+ "handler_id": loaded.descriptor.id,
+ "event": dict(event_payload),
+ "args": dict(args),
+ },
+ )
+ try:
+ result = await self.dispatcher.invoke(message, CancelToken())
+ return dict(result)
+ except AstrBotError:
+ raise
+ except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖
+ raise _PluginExecutionError(str(exc)) from exc
+
+ async def _invoke_session_waiter(
+ self,
+ event_payload: dict[str, Any],
+ *,
+ request_id: str | None = None,
+ ) -> dict[str, Any]:
+ assert self.dispatcher is not None
+ message = InvokeMessage(
+ id=request_id or self._next_request_id("msg"),
+ capability="handler.invoke",
+ input={
+ "handler_id": "__sdk_session_waiter__",
+ "event": dict(event_payload),
+ "args": {},
+ },
+ )
+ try:
+ result = await self.dispatcher.invoke(message, CancelToken())
+ return dict(result)
+ except AstrBotError:
+ raise
+ except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖
+ raise _PluginExecutionError(str(exc)) from exc
+
+ async def _wait_for_followup_side_effects(
+ self,
+ *,
+ start_index: int,
+ event_payload: dict[str, Any],
+ ) -> None:
+ settled_rounds = 0
+ for _ in range(20):
+ if len(self.platform_sink.records) > start_index:
+ return
+ await asyncio.sleep(0)
+ if self._has_waiter_for_event(event_payload):
+ settled_rounds = 0
+ continue
+ settled_rounds += 1
+ if settled_rounds >= 3:
+ return
+
+ async def _run_lifecycle(self, method_name: str) -> None:
+ assert self.loaded_plugin is not None
+ assert self.lifecycle_context is not None
+
+ for instance in self.loaded_plugin.instances:
+ hook = self._resolve_lifecycle_hook(instance, method_name)
+ await run_lifecycle_with_decorators(
+ instance=instance,
+ hook=hook,
+ method_name=method_name,
+ context=self.lifecycle_context,
+ )
+
+ def _match_handlers(
+ self,
+ event_payload: dict[str, Any],
+ ) -> list[tuple[LoadedHandler, dict[str, Any]]]:
+ assert self.loaded_plugin is not None
+ ranked: list[tuple[int, int, LoadedHandler, dict[str, Any]]] = []
+ for index, loaded in enumerate(self.loaded_plugin.handlers):
+ args = self._match_handler(loaded, event_payload)
+ if args is None:
+ continue
+ ranked.append((loaded.descriptor.priority, index, loaded, args))
+ for dynamic in self._match_dynamic_handlers(event_payload):
+ ranked.append(dynamic)
+ ranked.sort(key=lambda item: (-item[0], item[1]))
+ return [(loaded, args) for _priority, _index, loaded, args in ranked]
+
+ def _match_dynamic_handlers(
+ self,
+ event_payload: dict[str, Any],
+ ) -> list[tuple[int, int, LoadedHandler, dict[str, Any]]]:
+ assert self.loaded_plugin is not None
+ assert self.plugin is not None
+ ranked: list[tuple[int, int, LoadedHandler, dict[str, Any]]] = []
+ routes = self.router.list_dynamic_command_routes(self.plugin.name)
+ handler_map = {
+ loaded.descriptor.id: loaded for loaded in self.loaded_plugin.handlers
+ }
+ base_order = len(self.loaded_plugin.handlers)
+ for index, route in enumerate(routes):
+ if not isinstance(route, dict):
+ continue
+ handler_full_name = str(route.get("handler_full_name", "")).strip()
+ loaded = handler_map.get(handler_full_name)
+ if loaded is None:
+ continue
+ args = self._match_dynamic_route(loaded, route, event_payload)
+ if args is None:
+ continue
+ priority = route.get("priority", loaded.descriptor.priority)
+ if not isinstance(priority, int) or isinstance(priority, bool):
+ priority = loaded.descriptor.priority
+ ranked.append((priority, base_order + index, loaded, args))
+ return ranked
+
+ def _match_dynamic_route(
+ self,
+ loaded: LoadedHandler,
+ route: dict[str, Any],
+ event_payload: dict[str, Any],
+ ) -> dict[str, Any] | None:
+ if not self._passes_filters(loaded, event_payload):
+ return None
+ command_name = str(route.get("command_name", "")).strip()
+ if not command_name:
+ return None
+ text = str(event_payload.get("text", ""))
+ if bool(route.get("use_regex", False)):
+ match = re.search(command_name, text)
+ if match is None:
+ return None
+ return build_regex_args(loaded.descriptor.param_specs, match)
+ remainder = match_command_name(text, command_name)
+ if remainder is None:
+ return None
+ return build_command_args(loaded.descriptor.param_specs, remainder)
+
+ def _match_handler(
+ self,
+ loaded: LoadedHandler,
+ event_payload: dict[str, Any],
+ ) -> dict[str, Any] | None:
+ if not self._passes_permissions(loaded, event_payload):
+ return None
+ trigger = loaded.descriptor.trigger
+ if isinstance(trigger, CommandTrigger):
+ return self._match_command_trigger(loaded, trigger, event_payload)
+ if isinstance(trigger, MessageTrigger):
+ return self._match_message_trigger(loaded, trigger, event_payload)
+ if isinstance(trigger, EventTrigger):
+ current_type = str(
+ event_payload.get("event_type")
+ or event_payload.get("type")
+ or "message"
+ )
+ if current_type != trigger.event_type:
+ return None
+ return {}
+ if isinstance(trigger, ScheduleTrigger):
+ if (
+ str(event_payload.get("event_type") or event_payload.get("type"))
+ == "schedule"
+ ):
+ schedule_payload = event_payload.get("schedule")
+ if isinstance(schedule_payload, dict):
+ target_handler_id = str(
+ schedule_payload.get("handler_id", "")
+ ).strip()
+ if target_handler_id and target_handler_id != loaded.descriptor.id:
+ return None
+ return {}
+ return None
+ return None
+
+ def _match_command_trigger(
+ self,
+ loaded: LoadedHandler,
+ trigger: CommandTrigger,
+ event_payload: dict[str, Any],
+ ) -> dict[str, Any] | None:
+ if not self._passes_filters(loaded, event_payload):
+ return None
+ text = str(event_payload.get("text", "")).strip()
+ for command_name in [trigger.command, *trigger.aliases]:
+ if not command_name:
+ continue
+ match = match_command_name(text, command_name)
+ if match is None:
+ continue
+ return build_command_args(loaded.descriptor.param_specs, match)
+ return None
+
+ def _build_group_root_help(self, event_payload: dict[str, Any]) -> str | None:
+ assert self.loaded_plugin is not None
+ root_name = command_root_name(str(event_payload.get("text", "")))
+ if not root_name:
+ return None
+ entries: list[tuple[str, str | None]] = []
+ seen_commands: set[str] = set()
+ for loaded in self.loaded_plugin.handlers:
+ descriptor = loaded.descriptor
+ trigger = descriptor.trigger
+ if not isinstance(trigger, CommandTrigger):
+ continue
+ if not self._passes_filters(loaded, event_payload):
+ continue
+ route = descriptor.command_route
+ root_candidates: list[str] = []
+ if route is not None and route.group_path:
+ group_root = str(route.group_path[0]).strip()
+ if group_root:
+ root_candidates.append(group_root)
+ for name in [trigger.command, *trigger.aliases]:
+ normalized = str(name).strip()
+ if " " not in normalized:
+ continue
+ command_root = normalized.split()[0].strip()
+ if command_root:
+ root_candidates.append(command_root)
+ if root_name not in dict.fromkeys(root_candidates):
+ continue
+ display_command = (
+ str(route.display_command).strip()
+ if route is not None and str(route.display_command).strip()
+ else str(trigger.command).strip()
+ )
+ if not display_command or display_command in seen_commands:
+ continue
+ seen_commands.add(display_command)
+ description = (
+ str(descriptor.description or "").strip()
+ or str(trigger.description or "").strip()
+ or None
+ )
+ entries.append((display_command, description))
+ if not entries:
+ return None
+ lines = [f"{root_name}命令:"]
+ for command_name, description in entries:
+ line = f"- /{command_name}"
+ if description:
+ line += f": {description}"
+ lines.append(line)
+ return "\n".join(lines)
+
+ def _match_message_trigger(
+ self,
+ loaded: LoadedHandler,
+ trigger: MessageTrigger,
+ event_payload: dict[str, Any],
+ ) -> dict[str, Any] | None:
+ if not self._passes_filters(loaded, event_payload):
+ return None
+ text = str(event_payload.get("text", ""))
+ if trigger.regex:
+ match = re.search(trigger.regex, text)
+ if match is None:
+ return None
+ return build_regex_args(loaded.descriptor.param_specs, match)
+ if trigger.keywords and not any(
+ keyword in text for keyword in trigger.keywords
+ ):
+ return None
+ return {}
+
+ @staticmethod
+ def _passes_permissions(
+ loaded: LoadedHandler,
+ event_payload: dict[str, Any],
+ ) -> bool:
+ permissions = loaded.descriptor.permissions
+ required_role = permissions.required_role
+ if required_role is None and permissions.require_admin:
+ required_role = "admin"
+ if required_role == "admin":
+ return bool(event_payload.get("is_admin", False))
+ return True
+
+ def _passes_filters(
+ self,
+ loaded: LoadedHandler,
+ event_payload: dict[str, Any],
+ ) -> bool:
+ for filter_spec in loaded.descriptor.filters:
+ if isinstance(filter_spec, PlatformFilterSpec):
+ if str(event_payload.get("platform", "")) not in filter_spec.platforms:
+ return False
+ elif isinstance(filter_spec, MessageTypeFilterSpec):
+ if (
+ self._message_type_name(event_payload)
+ not in filter_spec.message_types
+ ):
+ return False
+ elif isinstance(filter_spec, CompositeFilterSpec):
+ if not self._passes_composite_filter(filter_spec, event_payload):
+ return False
+ elif isinstance(filter_spec, LocalFilterRefSpec):
+ continue
+ return True
+
+ def _passes_composite_filter(
+ self,
+ filter_spec: CompositeFilterSpec,
+ event_payload: dict[str, Any],
+ ) -> bool:
+ results: list[bool] = []
+ for child in filter_spec.children:
+ if isinstance(child, PlatformFilterSpec):
+ results.append(
+ str(event_payload.get("platform", "")) in child.platforms
+ )
+ elif isinstance(child, MessageTypeFilterSpec):
+ results.append(
+ self._message_type_name(event_payload) in child.message_types
+ )
+ elif isinstance(child, LocalFilterRefSpec):
+ results.append(True)
+ elif isinstance(child, CompositeFilterSpec):
+ results.append(self._passes_composite_filter(child, event_payload))
+ if filter_spec.kind == "and":
+ return all(results)
+ return any(results)
+
+ def _has_waiter_for_event(self, event_payload: dict[str, Any]) -> bool:
+ assert self.dispatcher is not None
+ probe_event = MessageEvent.from_payload(
+ event_payload,
+ context=self.lifecycle_context,
+ )
+ public_probe = getattr(self.dispatcher, "has_active_waiter", None)
+ if callable(public_probe):
+ return bool(public_probe(probe_event))
+ session_waiters = getattr(self.dispatcher, "_session_waiters", None)
+ if session_waiters is None:
+ return False
+ if hasattr(session_waiters, "has_waiter"):
+ return session_waiters.has_waiter(probe_event)
+ if isinstance(session_waiters, dict):
+ return any(
+ manager.has_waiter(probe_event)
+ for manager in session_waiters.values()
+ if hasattr(manager, "has_waiter")
+ )
+ return False
+
+ @staticmethod
+ def _message_type_name(event_payload: dict[str, Any]) -> str:
+ return normalize_message_type(
+ event_payload.get("message_type", ""),
+ group_id=str(event_payload.get("group_id", "")).strip() or None,
+ user_id=str(event_payload.get("user_id", "")).strip() or None,
+ empty_default="other",
+ )
+
+ @staticmethod
+ def _resolve_lifecycle_hook(instance: Any, method_name: str):
+ hook = getattr(instance, method_name, None)
+ marker = getattr(instance.__class__, "__astrbot_is_new_star__", None)
+ is_new_star = True
+ if callable(marker):
+ is_new_star = bool(marker())
+
+ if hook is not None and callable(hook):
+ bound_func = getattr(hook, "__func__", hook)
+ star_default = getattr(Star, method_name, None)
+ if star_default is None or bound_func is not star_default:
+ return hook
+
+ if not is_new_star:
+ alias = {"on_start": "initialize", "on_stop": "terminate"}.get(method_name)
+ if alias is not None:
+ legacy_hook = getattr(instance, alias, None)
+ if legacy_hook is not None and callable(legacy_hook):
+ return legacy_hook
+
+ if hook is not None and callable(hook):
+ return hook
+ return None
+
+ def _next_request_id(self, prefix: str) -> str:
+ self._request_counter += 1
+ return f"{prefix}_{self._request_counter:04d}"
+
+
+__all__ = [
+ "InMemoryDB",
+ "InMemoryMemory",
+ "LocalRuntimeConfig",
+ "MockClock",
+ "MockCapabilityRouter",
+ "MockContext",
+ "MockLLMClient",
+ "MockMessageEvent",
+ "MockPeer",
+ "MockPlatformClient",
+ "SDKTestEnvironment",
+ "PluginHarness",
+ "RecordedSend",
+ "StdoutPlatformSink",
+]
diff --git a/astrbot-sdk/src/astrbot_sdk/types.py b/astrbot-sdk/src/astrbot_sdk/types.py
new file mode 100644
index 0000000000..c2bc911ec7
--- /dev/null
+++ b/astrbot-sdk/src/astrbot_sdk/types.py
@@ -0,0 +1,22 @@
+"""SDK parameter helper types.
+
+本模块提供 SDK 参数类型助手,用于增强命令参数解析能力。
+
+GreedyStr:
+用于标记"贪婪字符串"参数,在命令解析时将剩余所有文本作为一个整体参数。
+例如:/echo hello world this is a test
+如果最后一个参数类型为 GreedyStr,将获取 "hello world this is a test" 而非仅 "hello"
+
+使用方式:
+在 handler 签名中将最后一个参数标注为 GreedyStr 类型,
+_loader_support 会识别此类型并调整参数解析逻辑。
+"""
+
+from __future__ import annotations
+
+
+class GreedyStr(str):
+ """Consume the remaining command text as one argument."""
+
+
+__all__ = ["GreedyStr"]
diff --git a/pyproject.toml b/pyproject.toml
index da69c33116..027f091433 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ dependencies = [
"httpx[socks]>=0.28.1",
"lark-oapi>=1.4.15",
"mcp>=1.8.0",
+ "msgpack>=1.1.1",
"openai>=1.78.0",
"ormsgpack>=1.9.1",
"pillow>=11.2.1",
diff --git a/requirements.txt b/requirements.txt
index a6eb84bb01..083ed69a3d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,6 +18,7 @@ google-genai>=1.56.0
httpx[socks]>=0.28.1
lark-oapi>=1.4.15
mcp>=1.8.0
+msgpack>=1.1.1
openai>=1.78.0
ormsgpack>=1.9.1
pillow>=11.2.1
@@ -53,4 +54,4 @@ shipyard-python-sdk>=0.2.4
shipyard-neo-sdk>=0.2.0
packaging>=24.2
qrcode>=8.2
-python-ripgrep==0.0.8
\ No newline at end of file
+python-ripgrep==0.0.8
diff --git a/scripts/sync-sdk.ps1 b/scripts/sync-sdk.ps1
new file mode 100644
index 0000000000..7099da197f
--- /dev/null
+++ b/scripts/sync-sdk.ps1
@@ -0,0 +1,182 @@
+[CmdletBinding()]
+param(
+ [string]$RemoteName = "sdk-remote",
+ [string]$RemoteBranch = "vendor-branch",
+ [string]$Prefix = "astrbot-sdk",
+ [switch]$NoWait
+)
+
+Set-StrictMode -Version Latest
+$ErrorActionPreference = "Stop"
+
+function Invoke-Git {
+ param(
+ [Parameter(Mandatory = $true, ValueFromRemainingArguments = $true)]
+ [string[]]$Arguments
+ )
+
+ & git @Arguments
+ if ($LASTEXITCODE -ne 0) {
+ throw "git $($Arguments -join ' ') failed with exit code $LASTEXITCODE."
+ }
+}
+
+function Test-GitObjectPath {
+ param(
+ [Parameter(Mandatory = $true)]
+ [string]$Revision,
+ [Parameter(Mandatory = $true)]
+ [string]$Path
+ )
+
+ & git cat-file -e "$Revision`:$Path" 2>$null
+ return $LASTEXITCODE -eq 0
+}
+
+function Assert-RemoteExists {
+ param(
+ [Parameter(Mandatory = $true)]
+ [string]$Name
+ )
+
+ $remoteNames = (& git remote)
+ if ($LASTEXITCODE -ne 0) {
+ throw "Failed to read git remotes."
+ }
+
+ if ($remoteNames -notcontains $Name) {
+ throw "Git remote '$Name' is missing. Add it first, for example: git remote add $Name https://github.com/united-pooh/astrbot-sdk.git"
+ }
+}
+
+function Assert-CleanWorktree {
+ $statusOutput = (& git status --porcelain=v1 | Out-String).Trim()
+ if ($LASTEXITCODE -ne 0) {
+ throw "Failed to inspect git worktree status."
+ }
+
+ if ($statusOutput) {
+ throw "Worktree is not clean. Commit or stash changes before syncing the vendored SDK.`n$statusOutput"
+ }
+}
+
+function Assert-LocalPath {
+ param(
+ [Parameter(Mandatory = $true)]
+ [string]$Path,
+ [Parameter(Mandatory = $true)]
+ [string]$Reason
+ )
+
+ if (-not (Test-Path -LiteralPath $Path)) {
+ throw "Expected local path '$Path' is missing. $Reason"
+ }
+}
+
+function Assert-RemotePath {
+ param(
+ [Parameter(Mandatory = $true)]
+ [string]$Revision,
+ [Parameter(Mandatory = $true)]
+ [string]$Path,
+ [Parameter(Mandatory = $true)]
+ [string]$Reason
+ )
+
+ if (-not (Test-GitObjectPath -Revision $Revision -Path $Path)) {
+ throw "Remote snapshot '$Revision' is missing '$Path'. $Reason"
+ }
+}
+
+function Test-ShouldWaitBeforeExit {
+ if ($NoWait.IsPresent) {
+ return $false
+ }
+
+ if ($env:ASTRBOT_SYNC_SDK_NO_WAIT -eq "1") {
+ return $false
+ }
+
+ try {
+ return (
+ [Environment]::UserInteractive -and
+ -not [Console]::IsInputRedirected -and
+ -not [Console]::IsOutputRedirected
+ )
+ } catch {
+ return $false
+ }
+}
+
+function Wait-BeforeExit {
+ if (-not (Test-ShouldWaitBeforeExit)) {
+ return
+ }
+
+ Write-Host ""
+ Write-Host "Press any key to close this window..."
+ $null = [System.Console]::ReadKey($true)
+}
+
+try {
+ $repoRoot = (& git rev-parse --show-toplevel).Trim()
+ if ($LASTEXITCODE -ne 0 -or [string]::IsNullOrWhiteSpace($repoRoot)) {
+ throw "This script must run inside a git repository."
+ }
+
+ Set-Location -LiteralPath $repoRoot
+
+ $localRequiredPaths = @(
+ (Join-Path $Prefix "pyproject.toml"),
+ (Join-Path $Prefix "README.md"),
+ (Join-Path $Prefix "src/astrbot_sdk/__init__.py")
+ )
+
+ foreach ($requiredPath in $localRequiredPaths) {
+ Assert-LocalPath -Path $requiredPath -Reason "The current AstrBot workspace expects '$Prefix' to keep the SDK's editable package layout."
+ }
+
+ Assert-RemoteExists -Name $RemoteName
+ Assert-CleanWorktree
+
+ Write-Host "Fetching $RemoteName/$RemoteBranch..."
+ Invoke-Git fetch $RemoteName $RemoteBranch
+
+ $remoteRef = "refs/remotes/$RemoteName/$RemoteBranch"
+ $remoteCommit = (& git rev-parse $remoteRef).Trim()
+ if ($LASTEXITCODE -ne 0 -or [string]::IsNullOrWhiteSpace($remoteCommit)) {
+ throw "Unable to resolve remote ref '$remoteRef' after fetch."
+ }
+
+ # Fail fast if the source branch does not match the package layout the main repo
+ # currently installs via `astrbot-sdk = { path = \"./astrbot-sdk\", editable = true }`.
+ # Pulling an incompatible snapshot would silently break dependency resolution.
+ $remoteRequiredPaths = @(
+ "pyproject.toml",
+ "README.md",
+ "src/astrbot_sdk/__init__.py"
+ )
+
+ foreach ($requiredPath in $remoteRequiredPaths) {
+ Assert-RemotePath -Revision $remoteRef -Path $requiredPath -Reason "The vendor branch must expose the full SDK package layout required by the main repo before subtree sync is allowed."
+ }
+
+ Write-Host "Pulling $RemoteName/$RemoteBranch into $Prefix with git subtree --squash..."
+ Invoke-Git subtree pull "--prefix=$Prefix" $RemoteName $RemoteBranch --squash
+
+ foreach ($requiredPath in $localRequiredPaths) {
+ Assert-LocalPath -Path $requiredPath -Reason "The subtree pull finished, but the local SDK layout is incomplete."
+ }
+
+ Write-Host ""
+ Write-Host "SDK sync completed successfully."
+ Write-Host "Review the result with:"
+ Write-Host " git status --short"
+ Write-Host " Get-ChildItem $Prefix"
+ Write-Host " Test-Path $Prefix\\pyproject.toml"
+ Write-Host " Test-Path $Prefix\\src\\astrbot_sdk\\__init__.py"
+} finally {
+ # Keep interactive terminal windows open so manual sync runs do not disappear
+ # before the user can inspect success or failure output.
+ Wait-BeforeExit
+}
diff --git a/scripts/sync-sdk.sh b/scripts/sync-sdk.sh
new file mode 100644
index 0000000000..98af772919
--- /dev/null
+++ b/scripts/sync-sdk.sh
@@ -0,0 +1,151 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+fail() {
+ echo "$1" >&2
+ exit 1
+}
+
+no_wait="${ASTRBOT_SYNC_SDK_NO_WAIT:-0}"
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --no-wait)
+ no_wait="1"
+ shift
+ ;;
+ --)
+ shift
+ break
+ ;;
+ -*)
+ fail "Unknown option: $1"
+ ;;
+ *)
+ break
+ ;;
+ esac
+done
+
+remote_name="${1:-sdk-remote}"
+remote_branch="${2:-vendor-branch}"
+prefix="${3:-astrbot-sdk}"
+
+run_git() {
+ git "$@" || fail "git $* failed."
+}
+
+test_git_object_path() {
+ local revision="$1"
+ local path="$2"
+
+ git cat-file -e "${revision}:${path}" >/dev/null 2>&1
+}
+
+assert_remote_exists() {
+ local name="$1"
+
+ if ! git remote | grep -Fxq "$name"; then
+ fail "Git remote '$name' is missing. Add it first, for example: git remote add $name https://github.com/united-pooh/astrbot-sdk.git"
+ fi
+}
+
+assert_clean_worktree() {
+ local status_output
+ status_output="$(git status --porcelain=v1)"
+
+ if [[ -n "$status_output" ]]; then
+ fail "Worktree is not clean. Commit or stash changes before syncing the vendored SDK.
+$status_output"
+ fi
+}
+
+assert_local_path() {
+ local path="$1"
+ local reason="$2"
+
+ [[ -e "$path" ]] || fail "Expected local path '$path' is missing. $reason"
+}
+
+assert_remote_path() {
+ local revision="$1"
+ local path="$2"
+ local reason="$3"
+
+ test_git_object_path "$revision" "$path" || fail "Remote snapshot '$revision' is missing '$path'. $reason"
+}
+
+should_wait_before_exit() {
+ [[ "$no_wait" != "1" ]] || return 1
+ [[ -t 0 && -t 1 ]] || return 1
+}
+
+wait_before_exit() {
+ local exit_code="$1"
+
+ if ! should_wait_before_exit; then
+ return
+ fi
+
+ echo
+ if [[ "$exit_code" -eq 0 ]]; then
+ printf 'Press any key to close this window...'
+ else
+ printf 'Script exited with code %s. Press any key to close this window...' "$exit_code"
+ fi
+ IFS= read -r -n 1 -s _
+ echo
+}
+
+trap 'wait_before_exit "$?"' EXIT
+
+repo_root="$(git rev-parse --show-toplevel 2>/dev/null)" || fail "This script must run inside a git repository."
+cd "$repo_root"
+
+local_required_paths=(
+ "${prefix}/pyproject.toml"
+ "${prefix}/README.md"
+ "${prefix}/src/astrbot_sdk/__init__.py"
+)
+
+for required_path in "${local_required_paths[@]}"; do
+ assert_local_path "$required_path" "The current AstrBot workspace expects '$prefix' to keep the SDK's editable package layout."
+done
+
+assert_remote_exists "$remote_name"
+assert_clean_worktree
+
+echo "Fetching ${remote_name}/${remote_branch}..."
+run_git fetch "$remote_name" "$remote_branch"
+
+remote_ref="refs/remotes/${remote_name}/${remote_branch}"
+remote_commit="$(git rev-parse "$remote_ref" 2>/dev/null)" || fail "Unable to resolve remote ref '$remote_ref' after fetch."
+[[ -n "$remote_commit" ]] || fail "Unable to resolve remote ref '$remote_ref' after fetch."
+
+# Fail fast if the source branch does not match the package layout the main repo
+# currently installs via `astrbot-sdk = { path = "./astrbot-sdk", editable = true }`.
+# Pulling an incompatible snapshot would silently break dependency resolution.
+remote_required_paths=(
+ "pyproject.toml"
+ "README.md"
+ "src/astrbot_sdk/__init__.py"
+)
+
+for required_path in "${remote_required_paths[@]}"; do
+ assert_remote_path "$remote_ref" "$required_path" "The vendor branch must expose the full SDK package layout required by the main repo before subtree sync is allowed."
+done
+
+echo "Pulling ${remote_name}/${remote_branch} into ${prefix} with git subtree --squash..."
+run_git subtree pull "--prefix=${prefix}" "$remote_name" "$remote_branch" --squash
+
+for required_path in "${local_required_paths[@]}"; do
+ assert_local_path "$required_path" "The subtree pull finished, but the local SDK layout is incomplete."
+done
+
+echo
+echo "SDK sync completed successfully."
+echo "Review the result with:"
+echo " git status --short"
+echo " ls ${prefix}"
+echo " test -e ${prefix}/pyproject.toml"
+echo " test -e ${prefix}/src/astrbot_sdk/__init__.py"
diff --git a/tests/test_msgpack_dependency.py b/tests/test_msgpack_dependency.py
new file mode 100644
index 0000000000..3d6c01bb92
--- /dev/null
+++ b/tests/test_msgpack_dependency.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+
+import tomllib
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+REQUIREMENTS_PATH = PROJECT_ROOT / "requirements.txt"
+PYPROJECT_PATH = PROJECT_ROOT / "pyproject.toml"
+MSGPACK_DEPENDENCY = "msgpack>=1.1.1"
+
+
+def _read_requirements() -> list[str]:
+ entries = []
+ for line in REQUIREMENTS_PATH.read_text(encoding="utf-8").splitlines():
+ candidate = line.split("#", 1)[0].strip()
+ if candidate:
+ entries.append(candidate)
+ return entries
+
+
+def _read_pyproject_dependencies() -> list[str]:
+ with PYPROJECT_PATH.open("rb") as file:
+ pyproject = tomllib.load(file)
+ return pyproject["project"]["dependencies"]
+
+
+def test_requirements_include_msgpack_dependency() -> None:
+ assert MSGPACK_DEPENDENCY in _read_requirements(), (
+ "Expected msgpack dependency in requirements.txt for vendored SDK protocol "
+ "codec support"
+ )
+
+
+def test_pyproject_declares_msgpack_dependency() -> None:
+ assert MSGPACK_DEPENDENCY in _read_pyproject_dependencies(), (
+ "Expected msgpack dependency in pyproject.toml for vendored SDK protocol "
+ "codec support"
+ )
diff --git a/tests/test_vendored_sdk_review_fixes.py b/tests/test_vendored_sdk_review_fixes.py
new file mode 100644
index 0000000000..7682c378bb
--- /dev/null
+++ b/tests/test_vendored_sdk_review_fixes.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+
+import pytest
+from pydantic import BaseModel
+
+PROJECT_ROOT = Path(__file__).resolve().parent.parent
+SDK_SRC = PROJECT_ROOT / "astrbot-sdk" / "src"
+
+if str(SDK_SRC) not in sys.path:
+ sys.path.insert(0, str(SDK_SRC))
+
+from astrbot_sdk._internal.command_model import format_command_model_help
+from astrbot_sdk.cli import _render_init_agent_templates
+from astrbot_sdk.clients.managers import MessageHistoryPage, MessageHistoryRecord
+from astrbot_sdk.errors import AstrBotError
+from astrbot_sdk.filters import CustomFilter
+from astrbot_sdk.message.components import Plain
+from astrbot_sdk.message.session import MessageSession
+from astrbot_sdk.runtime._capability_router_builtins.capabilities.http import (
+ HttpCapabilityMixin,
+)
+
+
+class _BooleanOptionsModel(BaseModel):
+ foo_bar: bool = False
+
+
+class _HttpCapabilityHost(HttpCapabilityMixin):
+ def __init__(self) -> None:
+ self.http_api_store: list[dict[str, object]] = []
+
+ @staticmethod
+ def _require_caller_plugin_id(_capability_name: str) -> str:
+ return "demo"
+
+
+def test_command_model_help_uses_canonical_boolean_option_names() -> None:
+ help_text = format_command_model_help("demo", _BooleanOptionsModel)
+
+ assert "--foo-bar / --no-foo-bar" in help_text
+ assert "--foo_bar / --no-foo_bar" not in help_text
+
+
+def test_render_init_agent_templates_creates_codex_skill_scaffold(tmp_path: Path) -> None:
+ _render_init_agent_templates(
+ target_dir=tmp_path,
+ plugin_name="demo_plugin",
+ display_name="Demo Plugin",
+ agents=("codex",),
+ )
+
+ skill_dir = tmp_path / ".agents" / "skills" / "astrbot-plugin-dev"
+ assert (skill_dir / "SKILL.md").exists()
+ assert (skill_dir / "agents" / "openai.yaml").exists()
+
+
+@pytest.mark.asyncio
+async def test_http_register_api_rejects_empty_methods_after_normalization() -> None:
+ host = _HttpCapabilityHost()
+
+ with pytest.raises(AstrBotError, match="至少需要一个非空 HTTP 方法"):
+ await host._http_register_api(
+ "req-1",
+ {
+ "methods": ["", " "],
+ "route": "/demo",
+ "handler_capability": "demo.handle",
+ },
+ None,
+ )
+
+
+def test_local_filter_binding_caches_signature_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
+ binding = CustomFilter(lambda event: bool(event)).compile()[1][0]
+
+ def _unexpected_signature(_callable: object) -> object:
+ raise AssertionError("evaluate() should not inspect signatures repeatedly")
+
+ monkeypatch.setattr("astrbot_sdk.filters.inspect.signature", _unexpected_signature)
+
+ assert binding.evaluate(event="payload") is True
+ assert binding.evaluate(event="payload") is True
+
+
+def test_message_history_record_preserves_component_instances() -> None:
+ session = MessageSession(platform_id="qq", message_type="group", session_id="42")
+ component = Plain("hello")
+
+ record = MessageHistoryRecord.model_validate(
+ {
+ "id": 1,
+ "session": session,
+ "parts": [component],
+ }
+ )
+
+ assert record.parts == [component]
+ assert record.parts[0] is component
+
+
+def test_message_history_page_preserves_record_instances() -> None:
+ record = MessageHistoryRecord.model_validate(
+ {
+ "id": 1,
+ "session": MessageSession(
+ platform_id="qq",
+ message_type="group",
+ session_id="42",
+ ),
+ "parts": [Plain("hello")],
+ }
+ )
+
+ page = MessageHistoryPage.model_validate({"records": [record]})
+
+ assert page.records == [record]
+ assert page.records[0] is record