-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmemory.py
More file actions
381 lines (314 loc) · 13.6 KB
/
memory.py
File metadata and controls
381 lines (314 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""In-memory storage backend for testing."""
from __future__ import annotations
import math
from collections.abc import Sequence
from difflib import SequenceMatcher
from synaptic.models import (
ConsolidationLevel,
Edge,
Node,
NodeKind,
)
class MemoryBackend:
"""Dict-based in-memory backend. No external dependencies."""
__slots__ = ("_edges", "_nodes")
def __init__(self) -> None:
self._nodes: dict[str, Node] = {}
self._edges: dict[str, Edge] = {}
async def connect(self) -> None:
pass
async def close(self) -> None:
self._nodes.clear()
self._edges.clear()
# --- Node CRUD ---
async def save_node(self, node: Node) -> None:
self._nodes[node.id] = node
async def get_node(self, node_id: str) -> Node | None:
return self._nodes.get(node_id)
async def update_node(self, node: Node) -> None:
if node.id in self._nodes:
self._nodes[node.id] = node
async def delete_node(self, node_id: str) -> None:
self._nodes.pop(node_id, None)
# Cascade delete edges
to_delete = [
eid
for eid, e in self._edges.items()
if e.source_id == node_id or e.target_id == node_id
]
for eid in to_delete:
del self._edges[eid]
async def list_nodes(
self,
*,
kind: str | NodeKind | None = None,
level: ConsolidationLevel | None = None,
limit: int = 100,
) -> list[Node]:
result: list[Node] = []
for node in self._nodes.values():
if kind is not None and node.kind != kind:
continue
if level is not None and node.level != level:
continue
result.append(node)
if len(result) >= limit:
break
return result
async def get_nodes_batch(self, node_ids: list[str]) -> list[Node]:
return [self._nodes[nid] for nid in node_ids if nid in self._nodes]
async def count_nodes(
self,
*,
kind: str | NodeKind | None = None,
category: str | None = None,
year: int | None = None,
) -> int:
count = 0
for node in self._nodes.values():
if kind is not None and node.kind != kind:
continue
props = node.properties or {}
if category and category.lower() not in (props.get("category") or "").lower():
continue
if year is not None and str(props.get("year")) != str(year):
continue
count += 1
return count
# --- Edge CRUD ---
async def save_edge(self, edge: Edge) -> None:
self._edges[edge.id] = edge
async def get_edges(self, node_id: str, *, direction: str = "both") -> list[Edge]:
result: list[Edge] = []
for edge in self._edges.values():
if direction in ("both", "outgoing") and edge.source_id == node_id:
result.append(edge)
elif direction in ("both", "incoming") and edge.target_id == node_id:
result.append(edge)
return result
async def update_edge(self, edge: Edge) -> None:
if edge.id in self._edges:
self._edges[edge.id] = edge
async def delete_edge(self, edge_id: str) -> None:
self._edges.pop(edge_id, None)
# --- Search ---
async def search_fts(self, query: str, *, limit: int = 20) -> list[Node]:
# Apply the same query-time noise-stripping used by SqliteBackend —
# Korean question-form verbs ("설명해주세요 / 무엇인가요 / 어떻게")
# are dropped for natural-language Korean queries (≥50 % Hangul),
# which measurably improves MRR on the Allganize RAG-ko / RAG-Eval
# benchmarks. English / code queries are unchanged.
from synaptic.backends.sqlite import _normalize_korean
normalized = _normalize_korean(query, query_mode=True)
query_lower = normalized.lower() if normalized else query.lower()
terms = [t for t in query_lower.split() if len(t) >= 1]
# Re-add original tokens that carry structural signal (digits /
# ASCII letters) — same guard as SqliteBackend.search_fts.
seen = set(terms)
for t in query.strip().split():
tl = t.lower()
if tl in seen:
continue
if any(c.isdigit() or ("a" <= c.lower() <= "z") for c in tl):
terms.append(tl)
seen.add(tl)
if not terms:
return []
# --- BM25 parameters ---
k1 = 1.5
b = 0.75
title_boost = 3.0
# Pre-compute corpus statistics for BM25
n_docs = len(self._nodes) # total documents
if n_docs == 0:
return []
# Document frequencies: how many docs contain each term (substring match)
doc_freq: dict[str, int] = {}
doc_texts: dict[str, str] = {} # node_id → full searchable text
doc_lengths: dict[str, int] = {} # node_id → word count
for node in self._nodes.values():
text = f"{node.title.lower()} {node.content.lower()}"
if node.tags:
text += " " + " ".join(node.tags).lower()
if node.properties:
kw = node.properties.get("_search_keywords", "")
if kw:
text += " " + kw.lower()
doc_texts[node.id] = text
doc_lengths[node.id] = len(text.split())
avgdl = sum(doc_lengths.values()) / n_docs if n_docs > 0 else 1.0
for t in terms:
count = 0
for text in doc_texts.values():
if t in text:
count += 1
doc_freq[t] = count
# Bigrams for phrase matching
bigrams: list[str] = []
if len(terms) >= 2:
for i in range(len(terms) - 1):
bigrams.append(f"{terms[i]} {terms[i + 1]}")
# --- Score each document (BM25 + substring hybrid) ---
scored: list[tuple[Node, float]] = []
for node in self._nodes.values():
title_lower = node.title.lower()
content_lower = node.content.lower()
full_text = doc_texts[node.id]
dl = doc_lengths[node.id]
bm25_score = 0.0
substr_score = 0.0
matched_terms = 0 # query term coverage 계산용
# Full query in title — 강한 신호 (모든 corpus 크기에서 유효)
if query_lower in title_lower:
substr_score += len(terms) * 3.0
for t in terms:
tf_content = content_lower.count(t)
tf_title = title_lower.count(t)
if tf_content == 0 and tf_title == 0:
continue
# --- BM25 component ---
df = doc_freq.get(t, 0)
idf = math.log((n_docs - df + 0.5) / (df + 0.5) + 1.0)
if tf_content > 0:
numerator = tf_content * (k1 + 1)
denominator = tf_content + k1 * (1 - b + b * dl / avgdl)
bm25_score += idf * numerator / denominator
if tf_title > 0:
bm25_score += idf * title_boost
# --- Substring component (corpus-size independent) ---
if tf_title > 0:
substr_score += 2.0
if tf_content > 0:
substr_score += 1.0
matched_terms += 1
# Bigram bonus
for bg in bigrams:
if bg in full_text:
bm25_score += 1.5
substr_score += 1.5
# Tag match
if node.tags:
tag_text = " ".join(node.tags).lower()
for t in terms:
if t in tag_text:
substr_score += 1.0
# Search keywords
if node.properties:
search_kw = node.properties.get("_search_keywords", "").lower()
if search_kw:
for t in terms:
if t in search_kw:
substr_score += 1.5
# Query term coverage bonus — 쿼리 단어 대부분 매칭 시 보너스
# coverage 80%+ → 보너스, 대규모 corpus에서 precision 향상
if len(terms) >= 2 and matched_terms > 0:
coverage = matched_terms / len(terms)
if coverage >= 0.8:
substr_score += len(terms) * 1.5 # 높은 coverage 보상
elif coverage >= 0.5:
substr_score += len(terms) * 0.5
# Hybrid: BM25 weight increases with corpus size
# N≤500: mostly substring, N=5000+: mostly BM25
bm25_weight = min(0.8, max(0.1, (n_docs - 500) / 5000))
score = bm25_score * bm25_weight + substr_score * (1 - bm25_weight)
if score > 0:
scored.append((node, score))
scored.sort(key=lambda x: x[1], reverse=True)
return [n for n, _ in scored[:limit]]
async def search_fuzzy(
self, query: str, *, limit: int = 20, threshold: float = 0.4
) -> list[Node]:
query_lower = query.lower()
# Deduplicate and cap query terms to avoid O(n*m) explosion on long queries
query_terms = list(dict.fromkeys(query_lower.split()))[:10]
scored: list[tuple[Node, float]] = []
for node in self._nodes.values():
title_lower = node.title.lower()
# Compare against title (short text → fair ratio)
title_ratio = SequenceMatcher(None, query_lower[:200], title_lower).ratio()
best = title_ratio
# Per-term fuzzy: match each query term against title words + content sample
if query_terms:
title_words = title_lower.split()
# Content: first 100 words for broader coverage
content_words = node.content.lower().split()[:100]
# Tag words too
tag_words = [t.lower() for t in (node.tags or [])]
text_words = title_words + content_words + tag_words
term_scores: list[float] = []
for qt in query_terms:
term_best = 0.0
for tw in text_words:
r = SequenceMatcher(None, qt, tw).ratio()
if r > term_best:
term_best = r
term_scores.append(term_best)
avg_term = sum(term_scores) / len(term_scores)
# Title term match bonus: boost when term is exactly in title
title_boost = sum(0.1 for qt in query_terms if qt in title_lower)
best = max(best, avg_term) + title_boost
if best >= threshold:
scored.append((node, best))
scored.sort(key=lambda x: x[1], reverse=True)
return [n for n, _ in scored[:limit]]
async def search_vector(self, embedding: list[float], *, limit: int = 20) -> list[Node]:
if not embedding:
return []
scored: list[tuple[Node, float]] = []
for node in self._nodes.values():
if not node.embedding:
continue
sim = _cosine_similarity(embedding, node.embedding)
scored.append((node, sim))
scored.sort(key=lambda x: x[1], reverse=True)
return [n for n, _ in scored[:limit]]
# --- Graph traversal ---
async def get_neighbors(self, node_id: str, *, depth: int = 1) -> list[tuple[Node, Edge]]:
result: list[tuple[Node, Edge]] = []
visited: set[str] = {node_id}
frontier: set[str] = {node_id}
for _ in range(depth):
next_frontier: set[str] = set()
for nid in frontier:
for edge in self._edges.values():
neighbor_id: str | None = None
if edge.source_id == nid and edge.target_id not in visited:
neighbor_id = edge.target_id
elif edge.target_id == nid and edge.source_id not in visited:
neighbor_id = edge.source_id
if neighbor_id is not None:
neighbor = self._nodes.get(neighbor_id)
if neighbor is not None:
result.append((neighbor, edge))
visited.add(neighbor_id)
next_frontier.add(neighbor_id)
frontier = next_frontier
return result
# --- Batch ---
async def save_nodes_batch(self, nodes: Sequence[Node]) -> None:
for node in nodes:
self._nodes[node.id] = node
async def save_edges_batch(self, edges: Sequence[Edge]) -> None:
for edge in edges:
self._edges[edge.id] = edge
# --- Maintenance ---
async def prune_edges(self, *, weight_below: float = 0.1) -> int:
to_delete = [eid for eid, e in self._edges.items() if e.weight < weight_below]
for eid in to_delete:
del self._edges[eid]
return len(to_delete)
async def decay_vitality(self, *, factor: float = 0.95) -> int:
count = 0
for node in self._nodes.values():
node.vitality *= factor
count += 1
return count
def _cosine_similarity(a: list[float], b: list[float]) -> float:
if len(a) != len(b) or not a:
return 0.0
dot = sum(x * y for x, y in zip(a, b, strict=True))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)