diff --git a/Dockerfile.cn b/Dockerfile.cn new file mode 100644 index 0000000000..f869ffe1f3 --- /dev/null +++ b/Dockerfile.cn @@ -0,0 +1,37 @@ +FROM python:3.12-slim +WORKDIR /AstrBot + +# 国内镜像源加速 +RUN sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources + +COPY . /AstrBot/ + +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + build-essential \ + python3-dev \ + libffi-dev \ + libssl-dev \ + ca-certificates \ + bash \ + ffmpeg \ + libavcodec-extra \ + curl \ + gnupg \ + git \ + ripgrep \ + && curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + +RUN python -m pip install uv -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ + && echo "3.12" > .python-version \ + && uv lock \ + && uv export --format requirements.txt --output-file requirements.txt --frozen \ + && uv pip install -r requirements.txt --no-cache-dir --system --index-url https://mirrors.aliyun.com/pypi/simple/ \ + && uv pip install socksio uv pilk --no-cache-dir --system --index-url https://mirrors.aliyun.com/pypi/simple/ + +EXPOSE 6185 + +CMD ["python", "main.py"] diff --git a/README_zh.md b/README_zh.md index 7ff07e35ac..ef37038547 100644 --- a/README_zh.md +++ b/README_zh.md @@ -102,6 +102,17 @@ uv tool upgrade astrbot --python 3.12 请参考官方文档 [使用 Docker 部署 AstrBot](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。 +#### 国内用户 Docker 加速构建 + +项目提供了国内镜像源加速的 `Dockerfile.cn` 和 `docker-compose.yml`,使用阿里云镜像源加速 apt 和 pip 依赖下载: + +```bash +# 克隆项目后,使用国内加速配置构建并启动 +docker compose -f docker-compose.yml up -d --build +``` + +构建完成后,通过 `http://<服务器IP>:6185` 访问 WebUI 进行初始化配置。数据持久化目录为 `./data`。 + ### 在 雨云 上部署 对于希望一键部署 AstrBot 且不想自行管理服务器的用户,我们推荐使用雨云的一键云部署服务 ☁️: diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index b75999ea65..9a765d4721 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -109,6 +109,31 @@ "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.", ) +try: + import httpx as _httpx + + def _create_no_verify_httpx_client( + headers: dict[str, str] | None = None, + timeout: _httpx.Timeout | None = None, + auth: _httpx.Auth | None = None, + ) -> _httpx.AsyncClient: + kwargs: dict[str, Any] = { + "follow_redirects": True, + "verify": False, + } + if timeout is None: + kwargs["timeout"] = _httpx.Timeout(30, read=300) + else: + kwargs["timeout"] = timeout + if headers is not None: + kwargs["headers"] = headers + if auth is not None: + kwargs["auth"] = auth + return _httpx.AsyncClient(**kwargs) + +except (ModuleNotFoundError, ImportError): + _create_no_verify_httpx_client = None + def _prepare_config(config: dict) -> dict: """Prepare configuration, handle nested format""" @@ -237,13 +262,57 @@ def validate_mcp_stdio_config(config: dict) -> None: raise ValueError("MCP stdio env keys and values must be strings.") +def _get_certifi_ca_bundle() -> str | None: + """Try to locate the certifi CA bundle for SSL_CERT_FILE.""" + try: + import certifi + + return certifi.where() + except ImportError: + pass + # Fallback: look for certifi in common locations + for candidate in ( + os.path.join( + os.path.dirname(sys.executable), + "Lib", + "site-packages", + "certifi", + "cacert.pem", + ), + os.path.join( + os.path.dirname(sys.executable), + "..", + "Lib", + "site-packages", + "certifi", + "cacert.pem", + ), + ): + if os.path.isfile(candidate): + return candidate + return None + + def _prepare_stdio_env(config: dict) -> dict: - """Preserve Windows executable resolution for stdio subprocesses.""" - if sys.platform != "win32": - return config + """Prepare environment variables for stdio subprocesses. + + On Windows: + - Merges system environment variables (case-insensitive handling). + - For uv/uvx commands, sets SSL_CERT_FILE from certifi to avoid + ``invalid peer certificate: UnknownIssuer`` errors caused by + uv's bundled TLS not trusting the system certificate store. + """ prepared = config.copy() env = dict(prepared.get("env") or {}) env = _merge_environment_variables(env) + + if sys.platform == "win32": + command_name = _normalize_stdio_command_name(config.get("command", "")) + if command_name in ("uv", "uvx") and "SSL_CERT_FILE" not in env: + ca_bundle = _get_certifi_ca_bundle() + if ca_bundle: + env["SSL_CERT_FILE"] = ca_bundle + prepared["env"] = env return prepared @@ -326,6 +395,18 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" +_NONSTANDARD_TYPE_MAP: dict[str, str] = { + "int": "integer", + "float": "number", + "double": "number", + "decimal": "number", + "bool": "boolean", + "str": "string", + "dict": "object", + "list": "array", +} + + def _normalize_mcp_input_schema(schema: dict[str, Any]) -> dict[str, Any]: """Normalize common non-standard MCP JSON Schema variants. @@ -334,6 +415,9 @@ def _normalize_mcp_input_schema(schema: dict[str, Any]) -> dict[str, Any]: parent object to declare `required` as an array of property names instead. We lift those booleans to the parent object so the schema remains usable without disabling validation entirely. + + Also normalizes non-standard type names (e.g. ``"int"`` → ``"integer"``, + ``"str"`` → ``"string"``) that some MCP servers emit. """ def _normalize(node: Any) -> Any: @@ -345,6 +429,16 @@ def _normalize(node: Any) -> Any: normalized = {key: _normalize(value) for key, value in node.items()} + # Normalize non-standard type names + type_val = normalized.get("type") + if isinstance(type_val, str) and type_val in _NONSTANDARD_TYPE_MAP: + normalized["type"] = _NONSTANDARD_TYPE_MAP[type_val] + elif isinstance(type_val, list): + normalized["type"] = [ + _NONSTANDARD_TYPE_MAP.get(t, t) if isinstance(t, str) else t + for t in type_val + ] + properties = normalized.get("properties") if isinstance(properties, dict): original_properties = ( @@ -439,14 +533,22 @@ def logging_callback( else: raise Exception("MCP connection config missing transport or type field") + _http_client_kwargs: dict[str, Any] = { + "url": cfg["url"], + "headers": cfg.get("headers", {}), + } + if _create_no_verify_httpx_client is not None: + _http_client_kwargs["httpx_client_factory"] = ( + _create_no_verify_httpx_client + ) + if transport_type != "streamable_http": # SSE transport method - self._streams_context = sse_client( - url=cfg["url"], - headers=cfg.get("headers", {}), - timeout=cfg.get("timeout", 5), - sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), + _http_client_kwargs["timeout"] = cfg.get("timeout", 5) + _http_client_kwargs["sse_read_timeout"] = cfg.get( + "sse_read_timeout", 60 * 5 ) + self._streams_context = sse_client(**_http_client_kwargs) streams = await self.exit_stack.enter_async_context( self._streams_context, ) @@ -461,17 +563,16 @@ def logging_callback( ), ) else: - timeout = timedelta(seconds=cfg.get("timeout", 30)) - sse_read_timeout = timedelta( + _http_client_kwargs["timeout"] = timedelta( + seconds=cfg.get("timeout", 30) + ) + _http_client_kwargs["sse_read_timeout"] = timedelta( seconds=cfg.get("sse_read_timeout", 60 * 5), ) - self._streams_context = streamablehttp_client( - url=cfg["url"], - headers=cfg.get("headers", {}), - timeout=timeout, - sse_read_timeout=sse_read_timeout, - terminate_on_close=cfg.get("terminate_on_close", True), + _http_client_kwargs["terminate_on_close"] = cfg.get( + "terminate_on_close", True ) + self._streams_context = streamablehttp_client(**_http_client_kwargs) read_s, write_s, _ = await self.exit_stack.enter_async_context( self._streams_context, ) diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index 93f8d3570d..39739dea60 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -298,7 +298,7 @@ async def parse_file(item: dict): case "video": return Comp.Video(file=item["url"]) case _: - return Comp.File(name=item["filename"], file=item["url"]) + return Comp.File(name=item["filename"], url=item["url"]) output = chunk["data"]["outputs"][self.workflow_output_key] chains = [] diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 792970cc3b..4a26b1871c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -1178,126 +1178,125 @@ async def build_main_agent( req.prompt = event.message_str[len(config.provider_wake_prefix) :] - # media files attachments - for comp in event.message_obj.message: - if isinstance(comp, Image): - path = await comp.convert_to_file_path() - image_path = await _compress_image_for_provider( - path, - config.provider_settings, - ) - if _is_generated_compressed_image_path(path, image_path): - event.track_temporary_local_file(image_path) - req.image_urls.append(image_path) - req.extra_user_content_parts.append( - TextPart(text=f"[Image Attachment: path {image_path}]") + conversation = await _get_session_conv(event, plugin_context) + req.conversation = conversation + req.contexts = json.loads(conversation.history) + event.set_extra("provider_request", req) + + # media files attachments (always process, regardless of req source) + for comp in event.message_obj.message: + if isinstance(comp, Image): + path = await comp.convert_to_file_path() + image_path = await _compress_image_for_provider( + path, + config.provider_settings, + ) + if _is_generated_compressed_image_path(path, image_path): + event.track_temporary_local_file(image_path) + req.image_urls.append(image_path) + req.extra_user_content_parts.append( + TextPart(text=f"[Image Attachment: path {image_path}]") + ) + elif isinstance(comp, Record): + audio_path = await comp.convert_to_file_path() + req.audio_urls.append(audio_path) + _append_audio_attachment(req, audio_path) + elif isinstance(comp, File): + file_path = await comp.get_file() + file_name = comp.name or os.path.basename(file_path) + req.extra_user_content_parts.append( + TextPart( + text=f"[File Attachment: name {file_name}, path {file_path}]" ) - elif isinstance(comp, Record): - audio_path = await comp.convert_to_file_path() - req.audio_urls.append(audio_path) - _append_audio_attachment(req, audio_path) - elif isinstance(comp, File): - file_path = await comp.get_file() - file_name = comp.name or os.path.basename(file_path) - req.extra_user_content_parts.append( - TextPart( - text=f"[File Attachment: name {file_name}, path {file_path}]" + ) + elif isinstance(comp, Video): + await _append_video_attachment(req, comp) + # quoted message attachments + reply_comps = [ + comp for comp in event.message_obj.message if isinstance(comp, Reply) + ] + quoted_message_settings = _get_quoted_message_parser_settings( + config.provider_settings + ) + fallback_quoted_image_count = 0 + for comp in reply_comps: + has_embedded_image = False + if comp.chain: + for reply_comp in comp.chain: + if isinstance(reply_comp, Image): + has_embedded_image = True + path = await reply_comp.convert_to_file_path() + image_path = await _compress_image_for_provider( + path, + config.provider_settings, ) - ) - elif isinstance(comp, Video): - await _append_video_attachment(req, comp) - # quoted message attachments - reply_comps = [ - comp for comp in event.message_obj.message if isinstance(comp, Reply) - ] - quoted_message_settings = _get_quoted_message_parser_settings( - config.provider_settings - ) - fallback_quoted_image_count = 0 - for comp in reply_comps: - has_embedded_image = False - if comp.chain: - for reply_comp in comp.chain: - if isinstance(reply_comp, Image): - has_embedded_image = True - path = await reply_comp.convert_to_file_path() - image_path = await _compress_image_for_provider( - path, - config.provider_settings, - ) - if _is_generated_compressed_image_path(path, image_path): - event.track_temporary_local_file(image_path) - req.image_urls.append(image_path) - _append_quoted_image_attachment(req, image_path) - elif isinstance(reply_comp, Record): - audio_path = await reply_comp.convert_to_file_path() - req.audio_urls.append(audio_path) - _append_quoted_audio_attachment(req, audio_path) - elif isinstance(reply_comp, File): - file_path = await reply_comp.get_file() - file_name = reply_comp.name or os.path.basename(file_path) - req.extra_user_content_parts.append( - TextPart( - text=( - f"[File Attachment in quoted message: " - f"name {file_name}, path {file_path}]" - ) + if _is_generated_compressed_image_path(path, image_path): + event.track_temporary_local_file(image_path) + req.image_urls.append(image_path) + _append_quoted_image_attachment(req, image_path) + elif isinstance(reply_comp, Record): + audio_path = await reply_comp.convert_to_file_path() + req.audio_urls.append(audio_path) + _append_quoted_audio_attachment(req, audio_path) + elif isinstance(reply_comp, File): + file_path = await reply_comp.get_file() + file_name = reply_comp.name or os.path.basename(file_path) + req.extra_user_content_parts.append( + TextPart( + text=( + f"[File Attachment in quoted message: " + f"name {file_name}, path {file_path}]" ) ) - elif isinstance(reply_comp, Video): - await _append_video_attachment(req, reply_comp, quoted=True) - - # Fallback quoted image extraction for reply-id-only payloads, or when - # embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]). - if not has_embedded_image: - try: - fallback_images = normalize_and_dedupe_strings( - await extract_quoted_message_images( - event, - comp, - settings=quoted_message_settings, - ) ) - remaining_limit = max( - config.max_quoted_fallback_images - - fallback_quoted_image_count, - 0, + elif isinstance(reply_comp, Video): + await _append_video_attachment(req, reply_comp, quoted=True) + + # Fallback quoted image extraction for reply-id-only payloads, or when + # embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]). + if not has_embedded_image: + try: + fallback_images = normalize_and_dedupe_strings( + await extract_quoted_message_images( + event, + comp, + settings=quoted_message_settings, ) - if remaining_limit <= 0 and fallback_images: - logger.warning( - "Skip quoted fallback images due to limit=%d for umo=%s", - config.max_quoted_fallback_images, - event.unified_msg_origin, - ) - continue - if len(fallback_images) > remaining_limit: - logger.warning( - "Truncate quoted fallback images for umo=%s, reply_id=%s from %d to %d", - event.unified_msg_origin, - getattr(comp, "id", None), - len(fallback_images), - remaining_limit, - ) - fallback_images = fallback_images[:remaining_limit] - for image_ref in fallback_images: - if image_ref in req.image_urls: - continue - req.image_urls.append(image_ref) - fallback_quoted_image_count += 1 - _append_quoted_image_attachment(req, image_ref) - except Exception as exc: # noqa: BLE001 + ) + remaining_limit = max( + config.max_quoted_fallback_images - fallback_quoted_image_count, + 0, + ) + if remaining_limit <= 0 and fallback_images: + logger.warning( + "Skip quoted fallback images due to limit=%d for umo=%s", + config.max_quoted_fallback_images, + event.unified_msg_origin, + ) + continue + if len(fallback_images) > remaining_limit: logger.warning( - "Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s", + "Truncate quoted fallback images for umo=%s, reply_id=%s from %d to %d", event.unified_msg_origin, getattr(comp, "id", None), - exc, - exc_info=True, + len(fallback_images), + remaining_limit, ) - - conversation = await _get_session_conv(event, plugin_context) - req.conversation = conversation - req.contexts = json.loads(conversation.history) - event.set_extra("provider_request", req) + fallback_images = fallback_images[:remaining_limit] + for image_ref in fallback_images: + if image_ref in req.image_urls: + continue + req.image_urls.append(image_ref) + fallback_quoted_image_count += 1 + _append_quoted_image_attachment(req, image_ref) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s", + event.unified_msg_origin, + getattr(comp, "id", None), + exc, + exc_info=True, + ) if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 227e0b0242..ebac76e8c6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1781,6 +1781,8 @@ "embedding_api_base": "", "embedding_model": "", "embedding_dimensions": 1024, + "embedding_send_dimensions": True, + "embedding_input_type": "", "timeout": 20, "proxy": "", }, @@ -2130,6 +2132,16 @@ "type": "string", "hint": "嵌入模型名称。", }, + "embedding_send_dimensions": { + "description": "发送嵌入维度参数", + "type": "bool", + "hint": "是否在请求中发送 dimensions 参数。部分兼容 OpenAI 的服务(如 NVIDIA)不支持该参数,需要关闭,但 embedding_dimensions 仍会作为本地向量索引维度使用。", + }, + "embedding_input_type": { + "description": "嵌入输入类型", + "type": "string", + "hint": "部分嵌入服务需要 input_type 参数。例如 NVIDIA 的检索嵌入模型可填写 query。留空则不发送。", + }, "embedding_api_key": { "description": "API Key", "type": "string", diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index aa11bb601f..c604960bfe 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any from zoneinfo import ZoneInfo +from apscheduler.executors.asyncio import AsyncIOExecutor from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger @@ -34,6 +35,11 @@ class CronJobManager: def __init__(self, db: BaseDatabase) -> None: self.db = db self.scheduler = AsyncIOScheduler() + # Bypass add_executor isinstance check — directly set the executor + # to avoid TypeError in certain packaged environments where + # _create_default_executor() fails the type check. + self._default_executor = AsyncIOExecutor() + self.scheduler._executors["default"] = self._default_executor self._basic_handlers: dict[str, Callable[..., Any]] = {} self._lock = asyncio.Lock() self._started = False @@ -151,6 +157,10 @@ def _remove_scheduled(self, job_id: str) -> None: def _schedule_job(self, job: CronJob) -> None: if not self._started: + # Ensure default executor exists before starting + if "default" not in self.scheduler._executors: + self._default_executor = AsyncIOExecutor() + self.scheduler._executors["default"] = self._default_executor self.scheduler.start() self._started = True try: diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index dc6977cf8a..1934f7a7e4 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -16,6 +16,7 @@ def __init__(self, dimension: int, path: str | None = None) -> None: self.index = None if path and os.path.exists(path): self.index = faiss.read_index(path) + self.dimension = self.index.d else: base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 0474683754..3022010eb1 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -35,6 +35,12 @@ def __init__( async def initialize(self) -> None: await self.document_storage.initialize() + # 如果维度未配置(为 0),通过实际请求自动探测 + if self.embedding_storage.dimension == 0: + vec = await self.embedding_provider.get_embedding("probe") + dim = len(vec) + logger.info(f"自动探测到嵌入模型维度: {dim}") + self.embedding_storage = EmbeddingStorage(dim, self.index_store_path) async def insert( self, diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 604f1ded0e..1fb5ac3278 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -178,6 +178,18 @@ async def process( return if result.result_content_type == ResultContentType.STREAMING_FINISH: event.set_extra("_streaming_finished", True) + # Send file/video/image attachments from the final result that were + # not included in the streaming text (e.g. Dify workflow file outputs). + media_comps = [ + comp + for comp in result.chain + if isinstance(comp, (Comp.File, Comp.Image, Comp.Video)) + ] + if media_comps: + try: + await event.send(result.derive(media_comps)) + except Exception as e: + logger.error(f"发送流式结果附件失败: {e}", exc_info=True) return logger.info( diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 27880e5481..518c3cd1b5 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -42,12 +42,33 @@ class ManagedBotWebSocket(BotWebSocket): def __init__(self, session, connection: Any, client: botClient): super().__init__(session, connection) self._client = client + # 防止 on_error + on_closed 双重入队导致连接指数增长 + self._reenqueued = False async def on_closed(self, close_status_code, close_msg): if self._client.is_shutting_down: logger.debug("[QQOfficial] Ignore websocket reconnect during shutdown.") return - await super().on_closed(close_status_code, close_msg) + if self._reenqueued: + logger.debug("[QQOfficial] Session already re-enqueued, skip on_closed.") + return + try: + self._reenqueued = True + await super().on_closed(close_status_code, close_msg) + except Exception: + self._reenqueued = False + raise + + async def on_error(self, exception: BaseException) -> None: + if self._reenqueued: + logger.debug("[QQOfficial] Session already re-enqueued, skip on_error.") + return + try: + self._reenqueued = True + await super().on_error(exception) + except Exception: + self._reenqueued = False + raise async def close(self) -> None: self._can_reconnect = False @@ -57,10 +78,14 @@ async def close(self) -> None: # QQ 机器人官方框架 class botClient(Client): + # 消息去重:message_id -> 收到时间戳 + _DEDUP_TTL = 120 # 去重窗口,秒 + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._shutting_down = False self._active_websockets: set[ManagedBotWebSocket] = set() + self._seen_message_ids: dict[str, float] = {} def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform @@ -116,6 +141,22 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None self._commit(abm) def _commit(self, abm: AstrBotMessage) -> None: + msg_id = abm.message_id + if msg_id: + now = time.monotonic() + # 清理过期条目 + expired = [ + k + for k, ts in self._seen_message_ids.items() + if now - ts > self._DEDUP_TTL + ] + for k in expired: + del self._seen_message_ids[k] + if msg_id in self._seen_message_ids: + logger.debug(f"[QQOfficial] Duplicate message {msg_id}, skipping.") + return + self._seen_message_ids[msg_id] = now + self.platform.remember_session_message_id(abm.session_id, abm.message_id) self.platform.commit_event( QQOfficialMessageEvent( @@ -128,7 +169,16 @@ def _commit(self, abm: AstrBotMessage) -> None: ) async def bot_connect(self, session) -> None: - logger.info("[QQOfficial] Websocket session starting.") + active_count = len(self._active_websockets) + if active_count > 0: + logger.warning( + "[QQOfficial] bot_connect called with %d existing active websocket(s). " + "This may indicate a reconnection storm.", + active_count, + ) + logger.info( + "[QQOfficial] Websocket session starting (active: %d).", active_count + 1 + ) websocket = ManagedBotWebSocket(session, self._connection, self) self._active_websockets.add(websocket) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index bc1e1a6bcd..f36e3e0e6e 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -112,11 +112,9 @@ async def _send( # save file to local file_path = await comp.get_file() original_name = comp.name or os.path.basename(file_path) - ext = os.path.splitext(original_name)[1] or "" - filename = f"{uuid.uuid4()!s}{ext}" - dest_path = os.path.join(attachments_dir, filename) + dest_path = os.path.join(attachments_dir, original_name) shutil.copy2(file_path, dest_path) - data = f"[FILE]{filename}" + data = f"[FILE]{original_name}" await web_chat_back_queue.put( { "type": "file", diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 31436ebf2e..025b5f019e 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -210,6 +210,10 @@ def __init__( self.client.__setattr__("API_BASE_URL", self.api_base_url) + # 消息去重 + self._seen_msg_ids: dict[str, float] = {} + self._DEDUP_TTL = 120 # 去重窗口,秒 + async def callback(msg: BaseMessage) -> None: if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": @@ -511,6 +515,20 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: await self.handle_msg(abm) async def handle_msg(self, message: AstrBotMessage) -> None: + # 消息去重检查 + msg_id = message.message_id + if msg_id: + now = time.monotonic() + expired = [ + k for k, ts in self._seen_msg_ids.items() if now - ts > self._DEDUP_TTL + ] + for k in expired: + del self._seen_msg_ids[k] + if msg_id in self._seen_msg_ids: + logger.debug(f"[WeCom] Duplicate message {msg_id}, skipping.") + return + self._seen_msg_ids[msg_id] = now + message_event = WecomPlatformEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 79fe6f8ed2..0b9c4dee48 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -152,6 +152,10 @@ def __init__( # 事件循环和关闭信号 self.shutdown_event = asyncio.Event() + # 消息去重:msgid -> monotonic 时间戳 + self._seen_msg_ids: dict[str, float] = {} + self._DEDUP_TTL = 120 # 去重窗口,秒 + # 队列管理器 self.queue_mgr = WecomAIQueueMgr() @@ -528,7 +532,9 @@ async def convert_message(self, payload: dict) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = self.bot_name abm.message_str = content or "[未知消息]" - abm.message_id = str(uuid.uuid4()) + # 使用企业微信平台提供的 msgid 而非随机 UUID,以支持去重 + platform_msgid = message_data.get("msgid") + abm.message_id = str(platform_msgid) if platform_msgid else str(uuid.uuid4()) abm.timestamp = int(time.time()) abm.raw_message = payload @@ -647,6 +653,20 @@ def meta(self) -> PlatformMetadata: async def handle_msg(self, message: AstrBotMessage) -> None: """处理消息,创建消息事件并提交到事件队列""" + # 消息去重检查 + msg_id = message.message_id + if msg_id: + now = time.monotonic() + expired = [ + k for k, ts in self._seen_msg_ids.items() if now - ts > self._DEDUP_TTL + ] + for k in expired: + del self._seen_msg_ids[k] + if msg_id in self._seen_msg_ids: + logger.debug(f"[WecomAI] Duplicate message {msg_id}, skipping.") + return + self._seen_msg_ids[msg_id] = now + try: message_event = WecomAIBotMessageEvent( message_str=message.message_str, diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index e332474d2a..26ba67190d 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -169,6 +169,9 @@ def __init__( 1, ) self._recent_messages: dict[str, WeixinOCRecentSessionCache] = {} + # 消息去重 + self._seen_msg_ids: dict[str, float] = {} + self._DEDUP_TTL = 120 # 去重窗口,秒 self._typing_keepalive_interval_s = max( 1, int(platform_config.get("weixin_oc_typing_keepalive_interval", 5)), @@ -1531,6 +1534,22 @@ async def _handle_inbound_message(self, msg: dict[str, Any]) -> None: message_str=text, ) + # 消息去重 + now = time.monotonic() + expired = [ + k for k, t in self._seen_msg_ids.items() if now - t > self._DEDUP_TTL + ] + for k in expired: + del self._seen_msg_ids[k] + if message_id in self._seen_msg_ids: + logger.debug( + "weixin_oc(%s): duplicate message %s, skipping.", + self.meta().id, + message_id, + ) + return + self._seen_msg_ids[message_id] = now + self.commit_event( WeixinOCMessageEvent( message_str=text, diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index ab6dd037f4..bb9cb3ab4c 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -916,6 +916,26 @@ def save_mcp_config(self, config: dict) -> bool: logger.error(f"保存 MCP 配置失败: {e}") return False + async def _detect_mcp_transport(self, url: str) -> str: + """通过探测 URL 的响应 Content-Type 自动判断 MCP 传输类型。 + + - SSE 端点返回 ``text/event-stream`` + - Streamable HTTP 端点返回 ``application/json`` 或其他非 SSE 类型 + """ + try: + async with aiohttp.ClientSession() as session: + async with session.get( + url, + headers={"Accept": "application/json, text/event-stream"}, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + content_type = resp.headers.get("Content-Type", "") + if "text/event-stream" in content_type: + return "sse" + except Exception: + pass + return "streamable_http" + async def sync_modelscope_mcp_servers(self, access_token: str) -> None: """从 ModelScope 平台同步 MCP 服务器配置""" base_url = "https://www.modelscope.cn/openapi/v1" @@ -946,10 +966,14 @@ async def sync_modelscope_mcp_servers(self, access_token: str) -> None: server_url = url_info.get("url") if not server_url: continue + # 自动检测传输类型 + transport = await self._detect_mcp_transport( + server_url, + ) # 添加到配置中(同名会覆盖) local_mcp_config["mcpServers"][server_name] = { "url": server_url, - "transport": "sse", + "transport": transport, "active": True, "provider": "modelscope", } diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ae531996ae..9388cc5bda 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,7 +1,10 @@ -import httpx from openai import AsyncOpenAI +# 使用 openai 库内部引用的 httpx 模块,避免打包后 isinstance 校验失败 +from openai._base_client import httpx as _openai_httpx + from astrbot import logger +from astrbot.core.utils.network_utils import create_proxy_client from ..entities import ProviderType from ..provider import EmbeddingProvider @@ -18,12 +21,12 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config self.provider_settings = provider_settings - proxy = provider_config.get("proxy", "") provider_id = provider_config.get("id", "unknown_id") - http_client = None - if proxy: - logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") - http_client = httpx.AsyncClient(proxy=proxy) + http_client = create_proxy_client( + "OpenAI Embedding", + provider_config.get("proxy", ""), + httpx_module=_openai_httpx, + ) api_base = ( provider_config.get("embedding_api_base", "https://api.openai.com/v1") .strip() @@ -65,23 +68,36 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} - if "embedding_dimensions" in self.provider_config: + extra_body = {} + dim_val = self.provider_config.get("embedding_dimensions") + send_dimensions = self.provider_config.get("embedding_send_dimensions", True) + if dim_val not in (None, "", 0) and send_dimensions: try: - kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"]) + dim_int = int(dim_val) + if dim_int > 0: + kwargs["dimensions"] = dim_int except (ValueError, TypeError): logger.warning( f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." ) + + input_type = self.provider_config.get("embedding_input_type") + if input_type: + extra_body["input_type"] = input_type + + if extra_body: + kwargs["extra_body"] = extra_body return kwargs def get_dim(self) -> int: """获取向量的维度""" - if "embedding_dimensions" in self.provider_config: + dim_val = self.provider_config.get("embedding_dimensions") + if dim_val not in (None, ""): try: - return int(self.provider_config["embedding_dimensions"]) + return int(dim_val) except (ValueError, TypeError): logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + f"embedding_dimensions in embedding configs is not a valid integer: '{dim_val}', ignored." ) return 0 diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 9ec24d254d..6dbeac1526 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -894,9 +894,13 @@ async def get_embedding_dim(self): if inspect.iscoroutinefunction(init_fn): await init_fn() - # 通过实际请求验证当前 embedding_dimensions 是否可用 - vec = await inst.get_embedding("echo") - dim = len(vec) + # 通过实际请求检测模型原生维度 + vec = await inst.client.embeddings.create( + input="echo", + model=inst.model, + **inst._embedding_kwargs(), + ) + dim = len(vec.data[0].embedding) logger.info( f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}", diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 1b6f7a435d..ebb43ae835 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -392,9 +392,12 @@ async def create_kb(self): ) try: vec = await prv.get_embedding("astrbot") - if len(vec) != prv.get_dim(): + actual_dim = len(vec) + configured_dim = prv.get_dim() + # configured_dim == 0 表示未配置维度,使用实际维度 + if configured_dim != 0 and actual_dim != configured_dim: raise ValueError( - f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}", + f"嵌入向量维度不匹配,实际是 {actual_dim},然而配置是 {configured_dim}", ) except Exception as e: return Response().error(f"测试嵌入模型失败: {e!s}").__dict__ diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 689c460d83..8bfd0b94d5 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -1309,6 +1309,14 @@ "description": "Embedding model", "hint": "Embedding model name." }, + "embedding_send_dimensions": { + "description": "Send embedding dimensions", + "hint": "Whether to send the dimensions parameter in embedding requests. Some OpenAI-compatible services, such as NVIDIA, do not support it; disable this while keeping embedding_dimensions as the local vector index dimension." + }, + "embedding_input_type": { + "description": "Embedding input type", + "hint": "Some embedding services require an input_type parameter. For NVIDIA retrieval embedding models, use query. Leave empty to omit it." + }, "embedding_api_key": { "description": "API Key" }, diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index ec124eeeec..3fe66a976f 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -1306,6 +1306,14 @@ "description": "Модель эмбеддингов", "hint": "Имя модели эмбеддингов." }, + "embedding_send_dimensions": { + "description": "Отправлять параметр dimensions", + "hint": "Отправлять ли параметр dimensions в запросах embeddings. Некоторые OpenAI-совместимые сервисы, например NVIDIA, его не поддерживают; отключите этот параметр, но оставьте embedding_dimensions как локальную размерность векторного индекса." + }, + "embedding_input_type": { + "description": "Тип входа Embedding", + "hint": "Некоторым сервисам embeddings нужен параметр input_type. Для retrieval embedding моделей NVIDIA используйте query. Оставьте пустым, чтобы не отправлять." + }, "embedding_api_key": { "description": "API Base URL" }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 73f6903bbe..91c50bc3a3 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1311,6 +1311,14 @@ "description": "嵌入模型", "hint": "嵌入模型名称。" }, + "embedding_send_dimensions": { + "description": "发送嵌入维度参数", + "hint": "是否在请求中发送 dimensions 参数。部分兼容 OpenAI 的服务(如 NVIDIA)不支持该参数,需要关闭,但嵌入维度仍会作为本地向量索引维度使用。" + }, + "embedding_input_type": { + "description": "嵌入输入类型", + "hint": "部分嵌入服务需要 input_type 参数。例如 NVIDIA 的检索嵌入模型可填写 query。留空则不发送。" + }, "embedding_api_key": { "description": "API Key" }, diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000..b375baf023 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,17 @@ +services: + astrbot: + build: + context: . + dockerfile: Dockerfile.cn + container_name: astrbot + restart: always + security_opt: + - no-new-privileges:true + ports: + - "6185:6185" # AstrBot WebUI + - "6199:6199" # Optional. OneBot v11 Napcat Websocket Port + environment: + - TZ=Asia/Shanghai + volumes: + - ./data:/AstrBot/data + - /etc/localtime:/etc/localtime:ro