diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 09976f7c41..eac62984d9 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2141,6 +2141,7 @@ class ChatProviderTemplate(TypedDict): "description": "嵌入模型", "type": "string", "hint": "嵌入模型名称。", + "_special": "get_embedding_models", }, "embedding_api_key": { "description": "API Key", diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index efe9e2e47e..13b4b19608 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -334,6 +334,12 @@ def get_dim(self) -> int: """获取向量的维度""" ... + async def get_models(self) -> list[str]: + """Optional model discovery for embedding providers.""" + raise NotImplementedError( + "This embedding provider does not support model discovery", + ) + async def test(self) -> None: await self.get_embedding("astrbot") diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index 61ba9cadbe..47596bb327 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from google import genai from google.genai import types @@ -78,6 +78,26 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: except APIError as e: raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") + async def get_models(self) -> list[str]: + try: + all_model_ids: list[str] = [] + embedding_model_ids: list[str] = [] + + async for model in await self.client.models.list(): + model_id = self._extract_model_id(model) + if not model_id: + continue + all_model_ids.append(model_id) + if self._supports_embedding(model, model_id): + embedding_model_ids.append(model_id) + + all_model_ids = sorted(dict.fromkeys(all_model_ids)) + embedding_model_ids = sorted(dict.fromkeys(embedding_model_ids)) + + return embedding_model_ids or all_model_ids + except Exception as e: + raise Exception(f"获取 Gemini 嵌入模型列表失败: {e!s}") from e + def get_dim(self) -> int: """获取向量的维度""" return int(self.provider_config.get("embedding_dimensions", 768)) @@ -85,3 +105,30 @@ def get_dim(self) -> int: async def terminate(self): if self.client: await self.client.aclose() + + @staticmethod + def _extract_model_id(model: Any) -> str: + model_name = getattr(model, "name", "") or getattr(model, "model", "") + if not model_name: + return "" + return str(model_name).removeprefix("models/") + + @classmethod + def _supports_embedding(cls, model: Any, model_id: str) -> bool: + supported_actions = getattr(model, "supported_actions", None) or getattr( + model, "supported_generation_methods", [] + ) + if isinstance(supported_actions, list): + normalized_actions = { + str(action).lower().replace("_", "").replace("-", "") + for action in supported_actions + } + if "embedcontent" in normalized_actions: + return True + + return cls._looks_like_embedding_model(model_id) + + @staticmethod + def _looks_like_embedding_model(model_id: str) -> bool: + normalized_model_id = model_id.lower() + return "embedding" in normalized_model_id or "embed" in normalized_model_id diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ae531996ae..3b4d314fd3 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -14,6 +14,15 @@ provider_type=ProviderType.EMBEDDING, ) class OpenAIEmbeddingProvider(EmbeddingProvider): + _EMBEDDING_MODEL_HINTS = ( + "embedding", + "bge", + "gte", + "e5", + "m3e", + "multilingual-e5", + ) + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config @@ -62,6 +71,29 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: ) return [item.embedding for item in embeddings.data] + async def get_models(self) -> list[str]: + try: + models = await self.client.models.list() + all_model_ids = sorted( + {model.id for model in models.data if getattr(model, "id", None)} + ) + + embedding_model_ids = [ + model_id + for model_id in all_model_ids + if self._looks_like_embedding_model(model_id) + ] + + # Fall back to all model ids when no embedding-like names are detected. + return embedding_model_ids or all_model_ids + except Exception as e: + raise Exception(f"获取嵌入模型列表失败: {e!s}") from e + + @classmethod + def _looks_like_embedding_model(cls, model_id: str) -> bool: + normalized = model_id.lower() + return any(hint in normalized for hint in cls._EMBEDDING_MODEL_HINTS) + def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..b9bd8c92c5 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -372,6 +372,10 @@ def __init__( "/config/provider/list": ("GET", self.get_provider_config_list), "/config/provider/model_list": ("GET", self.get_provider_model_list), "/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim), + "/config/provider/get_embedding_models": ( + "POST", + self.get_embedding_models, + ), "/config/provider_sources/models": ( "GET", self.get_provider_source_models, @@ -844,6 +848,7 @@ async def get_embedding_dim(self): if not provider_config: return Response().error("缺少参数 provider_config").__dict__ + inst = None try: # 动态导入 EmbeddingProvider from astrbot.core.provider.provider import EmbeddingProvider @@ -907,6 +912,80 @@ async def get_embedding_dim(self): except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ + finally: + terminate_fn = getattr(inst, "terminate", None) if inst else None + if inspect.iscoroutinefunction(terminate_fn): + try: + await terminate_fn() + except Exception: + logger.warning("释放嵌入 provider 资源失败") + + async def get_embedding_models(self): + """根据临时 provider_config 获取可用嵌入模型列表""" + post_data = await request.json + provider_config = post_data.get("provider_config", None) + if not provider_config: + return Response().error("缺少参数 provider_config").__dict__ + + inst = None + try: + from astrbot.core.provider.provider import EmbeddingProvider + from astrbot.core.provider.register import provider_cls_map + + provider_type = provider_config.get("type", None) + if not provider_type: + return Response().error("provider_config 缺少 type 字段").__dict__ + + if provider_type not in provider_cls_map: + try: + self.core_lifecycle.provider_manager.dynamic_import_provider( + provider_type, + ) + except ImportError: + logger.error(traceback.format_exc()) + return Response().error("提供商适配器加载失败").__dict__ + + if provider_type not in provider_cls_map: + return ( + Response() + .error(f"未找到适用于 {provider_type} 的提供商适配器") + .__dict__ + ) + + provider_metadata = provider_cls_map[provider_type] + cls_type = provider_metadata.cls_type + if not cls_type: + return Response().error(f"无法找到 {provider_type} 的类").__dict__ + + inst = cls_type(provider_config, {}) + if not isinstance(inst, EmbeddingProvider): + return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ + + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() + + try: + models = await inst.get_models() + except NotImplementedError: + return ( + Response() + .error("当前提供商暂不支持自动获取模型列表,请手动填写模型 ID") + .__dict__ + ) + + models = sorted(dict.fromkeys(models or [])) + return Response().ok({"models": models}).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取嵌入模型列表失败: {e!s}").__dict__ + finally: + terminate_fn = getattr(inst, "terminate", None) if inst else None + if inspect.iscoroutinefunction(terminate_fn): + try: + await terminate_fn() + except Exception: + logger.warning("释放嵌入 provider 资源失败") async def get_provider_source_models(self): """获取指定 provider_source 支持的模型列表 diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 33273a36c9..7768e3a268 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -1,6 +1,6 @@