-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathembedder.py
More file actions
157 lines (119 loc) · 5.17 KB
/
embedder.py
File metadata and controls
157 lines (119 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Embedding providers — generate vector embeddings for nodes.
Supports any OpenAI-compatible endpoint:
- OpenAI API
- vLLM (--served-model-name)
- llama.cpp (server --embedding)
- Ollama (/api/embeddings)
- TEI (Text Embeddings Inference)
"""
from __future__ import annotations
import logging
from typing import Protocol
logger = logging.getLogger(__name__)
class EmbeddingProvider(Protocol):
"""Generate embedding vectors from text."""
async def embed(self, text: str) -> list[float]: ...
async def embed_batch(self, texts: list[str]) -> list[list[float]]: ...
class _EmbedFromBatchMixin:
"""Mixin that provides a default embed() implementation by delegating to embed_batch()."""
async def embed(self, text: str) -> list[float]:
# 빈/공백 텍스트 방어
if not text or not text.strip():
text = " "
try:
results = await self.embed_batch([text]) # type: ignore[attr-defined]
vec = results[0]
# NaN 방어
import math
if any(math.isnan(v) for v in vec):
logger.warning("NaN in embedding, returning zero vector")
return [0.0] * len(vec)
return vec
except Exception:
logger.warning("Embedding failed, returning empty", exc_info=True)
return []
class MockEmbeddingProvider:
"""Mock embedding provider for testing. Returns deterministic vectors."""
__slots__ = ("_dim",)
def __init__(self, dim: int = 4) -> None:
self._dim = dim
async def embed(self, text: str) -> list[float]:
h = hash(text) & 0xFFFFFFFF
return [((h >> (i * 8)) & 0xFF) / 255.0 for i in range(self._dim)]
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [await self.embed(t) for t in texts]
class OpenAIEmbeddingProvider(_EmbedFromBatchMixin):
"""OpenAI-compatible embedding provider.
Works with any server implementing the /v1/embeddings endpoint:
- OpenAI: api_base="https://api.openai.com/v1", model="text-embedding-3-small"
- vLLM: api_base="http://localhost:8000/v1", model="BAAI/bge-m3"
- llama.cpp: api_base="http://localhost:8080/v1", model="default"
- Ollama: api_base="http://localhost:11434/v1", model="nomic-embed-text"
- TEI: api_base="http://localhost:8080/v1", model="default"
Uses aiohttp (zero extra deps — already pulled in by miniopy-async).
"""
__slots__ = ("_api_base", "_api_key", "_model", "_timeout")
def __init__(
self,
api_base: str = "http://localhost:8080/v1",
*,
api_key: str = "",
model: str = "default",
timeout: int = 60,
) -> None:
self._api_base = api_base.rstrip("/")
self._api_key = api_key
self._model = model
self._timeout = timeout
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
import aiohttp
url = f"{self._api_base}/embeddings"
headers: dict[str, str] = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
payload = {"model": self._model, "input": texts}
timeout = aiohttp.ClientTimeout(total=self._timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, headers=headers, json=payload) as resp:
if resp.status != 200:
body = await resp.text()
msg = f"Embedding API error {resp.status}: {body[:200]}"
raise RuntimeError(msg)
data = await resp.json()
embeddings: list[list[float]] = []
for item in sorted(data["data"], key=lambda x: x["index"]): # type: ignore[no-any-return]
embeddings.append(item["embedding"]) # type: ignore[index]
return embeddings
class OllamaEmbeddingProvider(_EmbedFromBatchMixin):
"""Ollama native embedding endpoint (/api/embed).
For Ollama servers that don't expose /v1/embeddings.
Usage:
provider = OllamaEmbeddingProvider(
base_url="http://localhost:11434",
model="nomic-embed-text",
)
"""
__slots__ = ("_base_url", "_model", "_timeout")
def __init__(
self,
base_url: str = "http://localhost:11434",
*,
model: str = "nomic-embed-text",
timeout: int = 60,
) -> None:
self._base_url = base_url.rstrip("/")
self._model = model
self._timeout = timeout
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
import aiohttp
url = f"{self._base_url}/api/embed"
payload = {"model": self._model, "input": texts}
timeout = aiohttp.ClientTimeout(total=self._timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=payload) as resp:
if resp.status != 200:
body = await resp.text()
msg = f"Ollama embed error {resp.status}: {body[:200]}"
raise RuntimeError(msg)
data = await resp.json()
return data["embeddings"] # type: ignore[no-any-return]