Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,7 @@ class ChatProviderTemplate(TypedDict):
"description": "嵌入模型",
"type": "string",
"hint": "嵌入模型名称。",
"_special": "get_embedding_models",
},
"embedding_api_key": {
"description": "API Key",
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def get_dim(self) -> int:
"""获取向量的维度"""
...

async def get_models(self) -> list[str]:
"""Optional model discovery for embedding providers."""
raise NotImplementedError(
"This embedding provider does not support model discovery",
)

async def test(self) -> None:
await self.get_embedding("astrbot")

Expand Down
49 changes: 48 additions & 1 deletion astrbot/core/provider/sources/gemini_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast

from google import genai
from google.genai import types
Expand Down Expand Up @@ -78,10 +78,57 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")

async def get_models(self) -> list[str]:
try:
all_model_ids: list[str] = []
embedding_model_ids: list[str] = []

async for model in await self.client.models.list():
model_id = self._extract_model_id(model)
if not model_id:
continue
all_model_ids.append(model_id)
if self._supports_embedding(model, model_id):
embedding_model_ids.append(model_id)

all_model_ids = sorted(dict.fromkeys(all_model_ids))
embedding_model_ids = sorted(dict.fromkeys(embedding_model_ids))
Comment thread
Sisyphbaous-DT-Project marked this conversation as resolved.

return embedding_model_ids or all_model_ids
except Exception as e:
raise Exception(f"获取 Gemini 嵌入模型列表失败: {e!s}") from e

def get_dim(self) -> int:
"""获取向量的维度"""
return int(self.provider_config.get("embedding_dimensions", 768))

async def terminate(self):
if self.client:
await self.client.aclose()

@staticmethod
def _extract_model_id(model: Any) -> str:
model_name = getattr(model, "name", "") or getattr(model, "model", "")
if not model_name:
return ""
return str(model_name).removeprefix("models/")

@classmethod
def _supports_embedding(cls, model: Any, model_id: str) -> bool:
supported_actions = getattr(model, "supported_actions", None) or getattr(
model, "supported_generation_methods", []
)
if isinstance(supported_actions, list):
normalized_actions = {
str(action).lower().replace("_", "").replace("-", "")
for action in supported_actions
}
if "embedcontent" in normalized_actions:
return True

return cls._looks_like_embedding_model(model_id)

@staticmethod
def _looks_like_embedding_model(model_id: str) -> bool:
normalized_model_id = model_id.lower()
return "embedding" in normalized_model_id or "embed" in normalized_model_id
32 changes: 32 additions & 0 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
provider_type=ProviderType.EMBEDDING,
)
class OpenAIEmbeddingProvider(EmbeddingProvider):
_EMBEDDING_MODEL_HINTS = (
"embedding",
"bge",
"gte",
"e5",
"m3e",
"multilingual-e5",
)

def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
Expand Down Expand Up @@ -62,6 +71,29 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
)
return [item.embedding for item in embeddings.data]

async def get_models(self) -> list[str]:
try:
models = await self.client.models.list()
all_model_ids = sorted(
{model.id for model in models.data if getattr(model, "id", None)}
)

embedding_model_ids = [
model_id
for model_id in all_model_ids
if self._looks_like_embedding_model(model_id)
]

# Fall back to all model ids when no embedding-like names are detected.
return embedding_model_ids or all_model_ids
except Exception as e:
raise Exception(f"获取嵌入模型列表失败: {e!s}") from e

@classmethod
def _looks_like_embedding_model(cls, model_id: str) -> bool:
normalized = model_id.lower()
return any(hint in normalized for hint in cls._EMBEDDING_MODEL_HINTS)

def _embedding_kwargs(self) -> dict:
"""构建嵌入请求的可选参数"""
kwargs = {}
Expand Down
79 changes: 79 additions & 0 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def __init__(
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
"/config/provider/get_embedding_models": (
"POST",
self.get_embedding_models,
),
"/config/provider_sources/models": (
"GET",
self.get_provider_source_models,
Expand Down Expand Up @@ -844,6 +848,7 @@ async def get_embedding_dim(self):
if not provider_config:
return Response().error("缺少参数 provider_config").__dict__

inst = None
try:
# 动态导入 EmbeddingProvider
from astrbot.core.provider.provider import EmbeddingProvider
Expand Down Expand Up @@ -907,6 +912,80 @@ async def get_embedding_dim(self):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取嵌入维度失败: {e!s}").__dict__
finally:
terminate_fn = getattr(inst, "terminate", None) if inst else None
if inspect.iscoroutinefunction(terminate_fn):
try:
await terminate_fn()
except Exception:
logger.warning("释放嵌入 provider 资源失败")

async def get_embedding_models(self):
Comment thread
Sisyphbaous-DT-Project marked this conversation as resolved.
Comment thread
Sisyphbaous-DT-Project marked this conversation as resolved.
"""根据临时 provider_config 获取可用嵌入模型列表"""
post_data = await request.json
provider_config = post_data.get("provider_config", None)
if not provider_config:
return Response().error("缺少参数 provider_config").__dict__

inst = None
try:
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.register import provider_cls_map

provider_type = provider_config.get("type", None)
if not provider_type:
return Response().error("provider_config 缺少 type 字段").__dict__

if provider_type not in provider_cls_map:
try:
self.core_lifecycle.provider_manager.dynamic_import_provider(
provider_type,
)
except ImportError:
logger.error(traceback.format_exc())
return Response().error("提供商适配器加载失败").__dict__

if provider_type not in provider_cls_map:
return (
Response()
.error(f"未找到适用于 {provider_type} 的提供商适配器")
.__dict__
)

provider_metadata = provider_cls_map[provider_type]
cls_type = provider_metadata.cls_type
if not cls_type:
return Response().error(f"无法找到 {provider_type} 的类").__dict__

inst = cls_type(provider_config, {})
if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__

init_fn = getattr(inst, "initialize", None)
if inspect.iscoroutinefunction(init_fn):
await init_fn()

try:
models = await inst.get_models()
except NotImplementedError:
return (
Response()
.error("当前提供商暂不支持自动获取模型列表,请手动填写模型 ID")
.__dict__
)

models = sorted(dict.fromkeys(models or []))
Comment thread
Sisyphbaous-DT-Project marked this conversation as resolved.
return Response().ok({"models": models}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取嵌入模型列表失败: {e!s}").__dict__
finally:
terminate_fn = getattr(inst, "terminate", None) if inst else None
if inspect.iscoroutinefunction(terminate_fn):
try:
await terminate_fn()
except Exception:
logger.warning("释放嵌入 provider 资源失败")

Comment thread
Sisyphbaous-DT-Project marked this conversation as resolved.
async def get_provider_source_models(self):
"""获取指定 provider_source 支持的模型列表
Expand Down
43 changes: 42 additions & 1 deletion dashboard/src/components/shared/AstrBotConfig.vue
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<script setup>
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
import { ref, computed } from 'vue'
import { ref, computed, watch } from 'vue'
import ConfigItemRenderer from './ConfigItemRenderer.vue'
import TemplateListEditor from './TemplateListEditor.vue'
import { useI18n, useModuleI18n } from '@/i18n/composables'
Expand Down Expand Up @@ -88,6 +88,8 @@ const currentEditingLanguage = ref('json')
const currentEditingTheme = ref('vs-light')
let currentEditingKeyIterable = null
const loadingEmbeddingDim = ref(false)
const loadingEmbeddingModels = ref(false)
const availableEmbeddingModels = ref([])

function openEditorDialog(key, value, theme, language) {
currentEditingKey.value = key
Expand Down Expand Up @@ -125,6 +127,39 @@ async function getEmbeddingDimensions(providerConfig) {
}
}

async function getEmbeddingModels(providerConfig) {
if (loadingEmbeddingModels.value) return

loadingEmbeddingModels.value = true
try {
const response = await axios.post('/api/config/provider/get_embedding_models', {
provider_config: providerConfig
})

if (response.data.status !== 'error' && Array.isArray(response.data.data?.models)) {
availableEmbeddingModels.value = response.data.data.models
useToast().success(`Fetched: ${response.data.data.models.length}`)
} else {
useToast().error(response.data.message)
}
} catch (error) {
console.error('Error getting embedding models:', error)
} finally {
loadingEmbeddingModels.value = false
}
}

watch(
() => [
props.iterable?.type,
props.iterable?.embedding_api_key,
props.iterable?.embedding_api_base
],
() => {
availableEmbeddingModels.value = []
}
)

function getValueBySelector(obj, selector) {
const keys = selector.split('.')
let current = obj
Expand Down Expand Up @@ -266,8 +301,11 @@ function hasVisibleItemsAfter(items, currentIndex) {
:plugin-name="pluginName"
:config-key="getItemPath(key)"
:loading="loadingEmbeddingDim"
:loading-models="loadingEmbeddingModels"
:available-models="availableEmbeddingModels"
:show-fullscreen-btn="!!metadata[metadataKey].items[key]?.editor_mode"
@get-embedding-dim="getEmbeddingDimensions(iterable)"
@get-embedding-models="getEmbeddingModels(iterable)"
@open-fullscreen="openEditorDialog(key, iterable, metadata[metadataKey].items[key]?.editor_theme, metadata[metadataKey].items[key]?.editor_language)"
/>
</v-col>
Expand Down Expand Up @@ -311,6 +349,9 @@ function hasVisibleItemsAfter(items, currentIndex) {
:item-meta="metadata[metadataKey]"
:plugin-name="pluginName"
:config-key="getItemPath(metadataKey)"
:loading-models="loadingEmbeddingModels"
:available-models="availableEmbeddingModels"
@get-embedding-models="getEmbeddingModels(iterable)"
/>
</v-col>
</v-row>
Expand Down
39 changes: 38 additions & 1 deletion dashboard/src/components/shared/ConfigItemRenderer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,30 @@
</v-btn>
</div>
</template>
<template v-else-if="itemMeta?._special === 'get_embedding_models'">
<div class="d-flex align-center gap-2">
<v-combobox
:model-value="modelValue"
@update:model-value="emitUpdate"
:items="availableModels"
density="compact"
variant="outlined"
class="config-field"
hide-details
clearable
></v-combobox>
<v-btn
color="primary"
variant="tonal"
size="small"
@click="$emit('get-embedding-models')"
:loading="loadingModels"
class="ml-2"
>
{{ t('core.common.autoDetect') }}
</v-btn>
</div>
</template>

<div
v-else-if="itemMeta?.type === 'list' && itemMeta?.options && itemMeta?.render_type === 'checkbox'"
Expand Down Expand Up @@ -264,13 +288,26 @@ const props = defineProps({
type: Boolean,
default: false
},
loadingModels: {
type: Boolean,
default: false
},
availableModels: {
type: Array,
default: () => []
},
showFullscreenBtn: {
type: Boolean,
default: false
}
})

const emit = defineEmits(['update:modelValue', 'get-embedding-dim', 'open-fullscreen'])
const emit = defineEmits([
'update:modelValue',
'get-embedding-dim',
'get-embedding-models',
'open-fullscreen'
])
const { t } = useI18n()
const { getRaw } = useModuleI18n('features/config-metadata')

Expand Down
Loading
Loading