From 8568f3eab374a7dd13c05d4bf37a624a4b9a3a12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=E2=82=82=E2=82=82H=E2=82=82=E2=82=85NO=E2=82=86?= <96930391+Sisyphbaous-DT-Project@users.noreply.github.com> Date: Thu, 23 Apr 2026 10:51:20 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20OpenAI=20?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E5=B5=8C=E5=85=A5=E6=A8=A1=E5=9E=8B=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E5=8F=91=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 1 + astrbot/core/provider/provider.py | 6 + .../sources/openai_embedding_source.py | 32 ++++++ astrbot/dashboard/routes/config.py | 79 ++++++++++++++ .../src/components/shared/AstrBotConfig.vue | 43 +++++++- .../components/shared/ConfigItemRenderer.vue | 39 ++++++- tests/test_dashboard.py | 103 ++++++++++++++++++ tests/test_openai_embedding_source.py | 64 +++++++++++ 8 files changed, 365 insertions(+), 2 deletions(-) create mode 100644 tests/test_openai_embedding_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 09976f7c41..eac62984d9 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2141,6 +2141,7 @@ class ChatProviderTemplate(TypedDict): "description": "嵌入模型", "type": "string", "hint": "嵌入模型名称。", + "_special": "get_embedding_models", }, "embedding_api_key": { "description": "API Key", diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index efe9e2e47e..13b4b19608 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -334,6 +334,12 @@ def get_dim(self) -> int: """获取向量的维度""" ... + async def get_models(self) -> list[str]: + """Optional model discovery for embedding providers.""" + raise NotImplementedError( + "This embedding provider does not support model discovery", + ) + async def test(self) -> None: await self.get_embedding("astrbot") diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ae531996ae..3b4d314fd3 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -14,6 +14,15 @@ provider_type=ProviderType.EMBEDDING, ) class OpenAIEmbeddingProvider(EmbeddingProvider): + _EMBEDDING_MODEL_HINTS = ( + "embedding", + "bge", + "gte", + "e5", + "m3e", + "multilingual-e5", + ) + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config @@ -62,6 +71,29 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: ) return [item.embedding for item in embeddings.data] + async def get_models(self) -> list[str]: + try: + models = await self.client.models.list() + all_model_ids = sorted( + {model.id for model in models.data if getattr(model, "id", None)} + ) + + embedding_model_ids = [ + model_id + for model_id in all_model_ids + if self._looks_like_embedding_model(model_id) + ] + + # Fall back to all model ids when no embedding-like names are detected. + return embedding_model_ids or all_model_ids + except Exception as e: + raise Exception(f"获取嵌入模型列表失败: {e!s}") from e + + @classmethod + def _looks_like_embedding_model(cls, model_id: str) -> bool: + normalized = model_id.lower() + return any(hint in normalized for hint in cls._EMBEDDING_MODEL_HINTS) + def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..b9bd8c92c5 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -372,6 +372,10 @@ def __init__( "/config/provider/list": ("GET", self.get_provider_config_list), "/config/provider/model_list": ("GET", self.get_provider_model_list), "/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim), + "/config/provider/get_embedding_models": ( + "POST", + self.get_embedding_models, + ), "/config/provider_sources/models": ( "GET", self.get_provider_source_models, @@ -844,6 +848,7 @@ async def get_embedding_dim(self): if not provider_config: return Response().error("缺少参数 provider_config").__dict__ + inst = None try: # 动态导入 EmbeddingProvider from astrbot.core.provider.provider import EmbeddingProvider @@ -907,6 +912,80 @@ async def get_embedding_dim(self): except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ + finally: + terminate_fn = getattr(inst, "terminate", None) if inst else None + if inspect.iscoroutinefunction(terminate_fn): + try: + await terminate_fn() + except Exception: + logger.warning("释放嵌入 provider 资源失败") + + async def get_embedding_models(self): + """根据临时 provider_config 获取可用嵌入模型列表""" + post_data = await request.json + provider_config = post_data.get("provider_config", None) + if not provider_config: + return Response().error("缺少参数 provider_config").__dict__ + + inst = None + try: + from astrbot.core.provider.provider import EmbeddingProvider + from astrbot.core.provider.register import provider_cls_map + + provider_type = provider_config.get("type", None) + if not provider_type: + return Response().error("provider_config 缺少 type 字段").__dict__ + + if provider_type not in provider_cls_map: + try: + self.core_lifecycle.provider_manager.dynamic_import_provider( + provider_type, + ) + except ImportError: + logger.error(traceback.format_exc()) + return Response().error("提供商适配器加载失败").__dict__ + + if provider_type not in provider_cls_map: + return ( + Response() + .error(f"未找到适用于 {provider_type} 的提供商适配器") + .__dict__ + ) + + provider_metadata = provider_cls_map[provider_type] + cls_type = provider_metadata.cls_type + if not cls_type: + return Response().error(f"无法找到 {provider_type} 的类").__dict__ + + inst = cls_type(provider_config, {}) + if not isinstance(inst, EmbeddingProvider): + return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ + + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() + + try: + models = await inst.get_models() + except NotImplementedError: + return ( + Response() + .error("当前提供商暂不支持自动获取模型列表,请手动填写模型 ID") + .__dict__ + ) + + models = sorted(dict.fromkeys(models or [])) + return Response().ok({"models": models}).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取嵌入模型列表失败: {e!s}").__dict__ + finally: + terminate_fn = getattr(inst, "terminate", None) if inst else None + if inspect.iscoroutinefunction(terminate_fn): + try: + await terminate_fn() + except Exception: + logger.warning("释放嵌入 provider 资源失败") async def get_provider_source_models(self): """获取指定 provider_source 支持的模型列表 diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 33273a36c9..7768e3a268 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -1,6 +1,6 @@